# Preprocess Dataset

In [None]:
from sklearn.model_selection import train_test_split
from dataset_helper import get_dataset_dir, get_dataset, export_dataset

# Amount of dataset lines that will be compiled and converted to dataset.jsonl. 
# If -1, use all lines.
# max_dataset=100
max_dataset=-1
dataset_output_path = f"{get_dataset_dir()}/classification_dataset.jsonl"

output_onnx_name = "ml_enhanced_actions.onnx"
output_labels_name = "ml_enhanced_actions_labels.json"

# Load data
# To use existing dataset, use dataset_dir param
df, dir = get_dataset()

if max_dataset>-1:
    df = df.sample(max_dataset)

export_dataset(
    df,
    dataset_output_path,
    format="jsonl_state_action",
    completion_mode="short",
    include_pos_rot=False
)

print(f"Saved {len(df)} samples to {dataset_output_path}")

# Prepare Dataset

In [None]:
import sys, os
sys.path.append(os.path.abspath(".."))

import numpy as np
import pandas as pd
import json
import re

print(f"Loading data from: {dataset_output_path}")

def parse_state(state_str):
    """Parse state string to extract features"""
    # Example: "AngleToEnemy=0.00, AngleToEnemyScore=1.00, DistanceToEnemyScore=0.26, NearBorderArenaScore=0.75, FacingToArena=-1.00."
    pattern = r"(\w+)=([-+]?\d*\.?\d+)"
    matches = re.findall(pattern, state_str)
    return {k: float(v) for k, v in matches}

def parse_action_enhanced(action_str):
    """Parse action string to extract skill, dash, movement, and duration"""
    # Example: "SK, DS, FWD0.11" or "TL0.85" or "SK"
    # Actions: SK=Skill, DS=Dash, FWD=Forward/Accelerate, TL=TurnLeft, TR=TurnRight
    
    actions = [a.strip() for a in action_str.split(',')]
    
    # Check for instant actions
    has_skill = 1 if any('SK' in a for a in actions) else 0
    has_dash = 1 if any('DS' in a for a in actions) else 0
    
    # Find movement action with duration
    movement = "None"  # Default: no movement
    duration = 0.0
    
    for action in actions:
        if action.startswith('FWD'):
            movement = 'FWD'
            duration = float(action[3:]) if len(action) > 3 else 0.0
            break
        elif action.startswith('TL'):
            movement = 'TL'
            duration = float(action[2:]) if len(action) > 2 else 0.0
            break
        elif action.startswith('TR'):
            movement = 'TR'
            duration = float(action[2:]) if len(action) > 2 else 0.0
            break
    
    return has_skill, has_dash, movement, duration

# Load and parse JSONL
data = []
with open(dataset_output_path, 'r') as f:
    for line in f:
        if max_dataset > 0 and len(data) >= max_dataset:
            break
        record = json.loads(line)
        state_features = parse_state(record['state'])
        has_skill, has_dash, movement, duration = parse_action_enhanced(record['action'])
        
        data.append({
            **state_features,
            'has_skill': has_skill,
            'has_dash': has_dash,
            'movement': movement,
            'duration': duration
        })

df = pd.DataFrame(data)
print(f"Loaded {len(df)} samples")
print(f"\nSkill distribution:")
print(df['has_skill'].value_counts())
print(f"\nDash distribution:")
print(df['has_dash'].value_counts())
print(f"\nMovement distribution:")
print(df['movement'].value_counts())
print(f"\nDataFrame shape: {df.shape}")
print(f"\nSample:")
print(df.head())

# Training

In [None]:
import json
import tf2onnx

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.metrics import classification_report

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Dropout
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.losses import CategoricalCrossentropy

# Features from the parsed state
features = [
    "AngleToEnemy",
    "AngleToEnemyScore",
    "DistanceToEnemyScore",
    "NearBorderArenaScore",
    "FacingToArena"
]

X = df[features].values
imputer = SimpleImputer(strategy="mean")
X = imputer.fit_transform(X)

# Prepare labels for 4 outputs
y_skill = df["has_skill"].values.astype("float32")
y_dash = df["has_dash"].values.astype("float32")

# Encode movement actions
le_movement = LabelEncoder()
y_movement_encoded = le_movement.fit_transform(df["movement"])
y_movement_cat = to_categorical(y_movement_encoded)

y_duration = df["duration"].values.astype("float32")

# Split
X_train, X_test, y_skill_train, y_skill_test, y_dash_train, y_dash_test, y_movement_train, y_movement_test, y_duration_train, y_duration_test = train_test_split(
    X, y_skill, y_dash, y_movement_cat, y_duration, test_size=0.2, random_state=42
)

print(f"Training samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")
print(f"Movement classes: {le_movement.classes_}")
print(f"Number of movement classes: {y_movement_cat.shape[1]}")

# Build model with 4 outputs
inputs = Input(shape=(X.shape[1], ), name="input")

# Shared layers
# x = Dense(256, activation='relu')(inputs)
# x = BatchNormalization()(x)
# x = Dropout(0.3)(x)
# x = Dense(128, activation='relu')(x)
# x = BatchNormalization()(x)
# x = Dropout(0.2)(x)
# x = Dense(64, activation='relu')(x)
# x = Dense(32, activation='relu')(x)

x = Dense(256, activation='relu')(inputs)
x = BatchNormalization()(x)
x = Dense(128, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(32, activation='relu')(x)

# Output 1: Skill (binary)
output_skill = Dense(1, activation='sigmoid', name="skill")(x)

# Output 2: Dash (binary)
output_dash = Dense(1, activation='sigmoid', name="dash")(x)

# Output 3: Movement (multi-class)
output_movement = Dense(y_movement_cat.shape[1], activation='softmax', name="movement")(x)

# Output 4: Duration (regression)
output_duration = Dense(1, activation='linear', name="duration")(x)

# Compile model
model = Model(inputs=inputs, outputs=[output_skill, output_dash, output_movement, output_duration])
model.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss={
        "skill": "binary_crossentropy",
        "dash": "binary_crossentropy",
        "movement": CategoricalCrossentropy(label_smoothing=0.1),
        "duration": "mae"
    },
    metrics={
        'skill': 'accuracy',
        'dash': 'accuracy',
        'movement': 'accuracy',
        'duration': 'mae'
    }
)

model.summary()

# Early stopping
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=10,
    min_delta=0.001,
    mode='min',
    restore_best_weights=True,
    verbose=1
)

# Train
history = model.fit(
    X_train, 
    {
        "skill": y_skill_train, 
        "dash": y_dash_train,
        "movement": y_movement_train,
        "duration": y_duration_train
    },
    validation_data=(
        X_test, 
        {
            'skill': y_skill_test,
            'dash': y_dash_test,
            'movement': y_movement_test,
            'duration': y_duration_test
        }
    ),
    epochs=100,
    batch_size=512,
    callbacks=[early_stop],
)

# Predict
pred_skill, pred_dash, pred_movement_prob, pred_duration = model.predict(X_test)

# Convert predictions
pred_skill_binary = (pred_skill > 0.5).astype(int).flatten()
pred_dash_binary = (pred_dash > 0.5).astype(int).flatten()
pred_movement = np.argmax(pred_movement_prob, axis=1)
true_movement = np.argmax(y_movement_test, axis=1)

# Evaluation
print("\n=== Skill Classification ===")
print(f"Accuracy: {(pred_skill_binary == y_skill_test).mean():.4f}")
print(f"Skill usage - True: {y_skill_test.sum()}, Predicted: {pred_skill_binary.sum()}")

print("\n=== Dash Classification ===")
print(f"Accuracy: {(pred_dash_binary == y_dash_test).mean():.4f}")
print(f"Dash usage - True: {y_dash_test.sum()}, Predicted: {pred_dash_binary.sum()}")

print("\n=== Movement Classification Report ===")
print(classification_report(true_movement, pred_movement, target_names=le_movement.classes_))

print(f"\n=== Duration MAE ===")
print(f"MAE: {np.mean(np.abs(pred_duration.flatten() - y_duration_test)):.4f}")

# Convert the model to ONNX
spec = (tf.TensorSpec((None, X.shape[1]), tf.float32, name="input"),)
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13)

# Save to file
with open(output_onnx_name, "wb") as f:
    f.write(onnx_model.SerializeToString())

print(f"\nModel saved to {output_onnx_name}")

# Save metadata
metadata = {
    "movement_classes": le_movement.classes_.tolist(),
    "input_features": features,
    "outputs": ["skill", "dash", "movement", "duration"]
}

with open(output_labels_name, "w") as f:
    json.dump(metadata, f, indent=2)

print(f"Metadata saved to {output_labels_name}")
print(f"Movement classes: {metadata['movement_classes']}")

# Testing

In [None]:
import onnxruntime as ort
import numpy as np
import json

# Load ONNX session
session = ort.InferenceSession(output_onnx_name)

# Load metadata
with open(output_labels_name, 'r') as f:
    metadata = json.load(f)

movement_classes = metadata['movement_classes']
print(f"Movement classes: {movement_classes}")
print(f"Input features: {metadata['input_features']}")
print(f"Model outputs: {metadata['outputs']}")

# Sample input - features: [AngleToEnemy, AngleToEnemyScore, DistanceToEnemyScore, NearBorderArenaScore, FacingToArena]
sample = np.array([[
    0.0,    # AngleToEnemy
    1.0,    # AngleToEnemyScore (perfect alignment)
    0.9,    # DistanceToEnemyScore (very close)
    0.5,    # NearBorderArenaScore (mid arena)
    -0.5    # FacingToArena (facing inward)
]], dtype=np.float32)

# Get input & output names
input_name = session.get_inputs()[0].name
output_names = [o.name for o in session.get_outputs()]

print(f"\nONNX Input: {input_name}")
print(f"ONNX Outputs: {output_names}")

# Run inference
outputs = session.run(output_names, {input_name: sample})

# Unpack outputs (order: skill, dash, movement, duration)
pred_skill_prob = outputs[0][0][0]
pred_dash_prob = outputs[1][0][0]
pred_movement_probs = outputs[2][0]
pred_duration = outputs[3][0][0]

# Decode predictions
use_skill = pred_skill_prob > 0.5
use_dash = pred_dash_prob > 0.5
movement_index = np.argmax(pred_movement_probs)
movement_action = movement_classes[movement_index]

print("\n=== Inference Results ===")
print(f"Skill: {'Yes' if use_skill else 'No'} (confidence: {pred_skill_prob:.4f})")
print(f"Dash: {'Yes' if use_dash else 'No'} (confidence: {pred_dash_prob:.4f})")
print(f"Movement: {movement_action}")
print(f"  Movement probabilities:")
for i, cls in enumerate(movement_classes):
    print(f"    {cls}: {pred_movement_probs[i]:.4f}")
print(f"Duration: {pred_duration:.4f} seconds")

# Construct final action string
final_actions = []
if use_skill:
    final_actions.append("SK")
if use_dash:
    final_actions.append("DS")
if movement_action != "None":
    final_actions.append(f"{movement_action}{pred_duration:.2f}")

final_action_str = ", ".join(final_actions) if final_actions else "Idle"
print(f"\nFinal Action String: {final_action_str}")

# Test with different scenarios
print("\n=== Testing Multiple Scenarios ===")

test_scenarios = [
    {
        "name": "Perfect attack position",
        "state": [0.0, 1.0, 0.9, 0.3, -0.8]
    },
    {
        "name": "Need to turn left",
        "state": [45.0, 0.7, 0.5, 0.4, -0.5]
    },
    {
        "name": "Near border, turn right",
        "state": [-30.0, 0.8, 0.6, 0.8, 0.6]
    },
    {
        "name": "Enemy far away",
        "state": [0.0, 1.0, 0.2, 0.3, -0.9]
    }
]

for scenario in test_scenarios:
    sample = np.array([scenario["state"]], dtype=np.float32)
    outputs = session.run(output_names, {input_name: sample})
    
    pred_skill_prob = outputs[0][0][0]
    pred_dash_prob = outputs[1][0][0]
    pred_movement_probs = outputs[2][0]
    pred_duration = outputs[3][0][0]
    
    use_skill = pred_skill_prob > 0.5
    use_dash = pred_dash_prob > 0.5
    movement_index = np.argmax(pred_movement_probs)
    movement_action = movement_classes[movement_index]
    
    final_actions = []
    if use_skill:
        final_actions.append("SK")
    if use_dash:
        final_actions.append("DS")
    if movement_action != "None":
        final_actions.append(f"{movement_action}{pred_duration:.2f}")
    
    final_action_str = ", ".join(final_actions) if final_actions else "Idle"
    
    print(f"\n{scenario['name']}:")
    print(f"  State: {scenario['state']}")
    print(f"  Action: {final_action_str}")