# Knowledge Distillation of an LLM-integration in GamePals

In [22]:
from doom.preprocessing.doom_game_state_perturbator import DoomGameStatePerturbator
from doom.utils.doom_game_state import DoomGameState, MonsterType, WeaponName, AimedAtType
from sklearn.cluster import DBSCAN
from dataclasses import dataclass
from collections import Counter
from typing import Iterable

import os
import numpy as np

## 1) Loading the Dataset

In [8]:
@dataclass
class GameStateEntry:
    id: str
    state: DoomGameState

In [9]:
game_states = [
    DoomGameState.model_validate_json(line[15:])
    for filepath in os.scandir("data/gamelogs")
    for line in open(filepath)
    if line.startswith("[GS] GAMESTATE ")
]
dataset = [
    GameStateEntry(
        id=f"state-{idx}",
        state=state
    )
    for idx, state in enumerate(game_states[:10000])
]

print(f"Extracted {len(dataset)} game states from gameplay")

Extracted 10000 game states from gameplay


## 2) Preprocessing the Dataset

### 2.1) Eliminate uninteresting game states

In [11]:
def is_interesting(gs: DoomGameState) -> bool:
    if len(gs.MONSTERS) > 0: return True
    if gs.AIMED_AT.interactable: return True
    return False


dataset = [
    gse
    for gse in dataset
    if is_interesting(gse.state)
]

print(f"Reduced to {len(dataset)} interesting game states")

Reduced to 4672 interesting game states


### 2.2) Apply clustering to only keep unique situations

In [17]:
def bucket_distance(d: float) -> float:
    if d < 256: return 0.0
    if d < 768: return 0.5
    return 1.0

def one_hot(value: str, vocab: list[str]) -> list[float]:
    vec = [0.0] * len(vocab)
    if value in vocab:
        vec[vocab.index(value)] = 1.0
    return vec

def ammo_status(ammo: int) -> float:
    if ammo == 0: return 0.0
    if ammo < 10: return 0.33
    if ammo < 40:return 0.66
    return 1.0

def to_feature_vector(gs: DoomGameState) -> np.ndarray:
    features = list()
    features.append(float(len(gs.MONSTERS)))                        # Feature #1: Number of Monsters

    if gs.MONSTERS:
        closest = min(m.distance for m in gs.MONSTERS)
        features.append(bucket_distance(closest))                   # Feature #2: Distance to the closest Monster
        types = [m.monsterType for m in gs.MONSTERS]
        common_type = Counter(types).most_common(1)[0][0]
        features.extend(one_hot(common_type, list(MonsterType)))    # Feature #3: OHE Most common Enemy Type
    else:
        features.append(1.0)
        features.extend([0.0] * len(MonsterType))

    slot = gs.INVENTORY.inventorySlots[gs.INVENTORY.currentSlot]
    features.append(ammo_status(slot.ammoCount))                    # Feature #4: Ammunition count

    features.extend(one_hot(slot.weaponName.lower(), list(WeaponName))) # Feature #5: OHE Current Weapon

    features.append(float(gs.AIMED_AT.interactable))                # Feature #6: is aiming at interactable

    aimed_type = gs.AIMED_AT.entityType.lower() if gs.AIMED_AT.entityType else "none"
    features.extend(one_hot(aimed_type, list(AimedAtType)))         # Feature #6: OHE aimed at entity type
    return np.array(features, dtype=np.float32)

In [25]:
features = np.array([to_feature_vector(gse.state) for gse in dataset])

clustering = DBSCAN(
    eps=1e-2,
    min_samples=1,
    metric="euclidean",
)
labels = clustering.fit_predict(features)

new_dataset = list()
for cluster_id in set(labels):
    if cluster_id == -1: continue

    # Get elements of the given cluster
    cluster_indices = np.where(labels == cluster_id)[0]
    cluster_features = features[cluster_indices]

    # Find the closest item to the cluster center
    centroid = cluster_features.mean(axis=0)
    distances = np.linalg.norm(cluster_features - centroid, axis=1)
    center_idx = cluster_indices[np.argmin(distances)]

    new_dataset.append(dataset[center_idx])

dataset = new_dataset

print(f"Reduced to {len(dataset)} cluster-center game states")

Reduced to 41 cluster-center game states


### 2.3) Generate perturbations of the game states

In [26]:
def yield_perturbations(gs: DoomGameState) -> Iterable[DoomGameState]:
    return DoomGameStatePerturbator.perturbate(gs)

In [27]:
dataset = [
    GameStateEntry(
        id=f'{gse.id}-p{idx}',
        state=new_gs
    )
    for gse in dataset
    for idx, new_gs in enumerate(yield_perturbations(gse.state))
]

print(f"Applied perturbations and went up to {len(dataset)} game states")

Applied perturbations and went up to 146 game states


## 3) Generating User Commands using the Teacher Model