# Configurations

In [None]:
import sys, os
sys.path.append(os.path.abspath(".."))

import torch
from dataset_helper import export_dataset, get_dataset, find_project_root, get_dataset_dir

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Amount of dataset lines that will be compiled and converted to dataset.jsonl. 
# If -1, use all lines.
max_dataset=100
# max_dataset=-1

# Model parameters
block_size = 128
batch_size = 256
n_embed = 192
n_heads = 4
n_layers = 3
lr = 1e-4
max_iters = 5000
eval_interval = 100

onnx_output_name="slm"
dataset_output_path = f"{get_dataset_dir()}/slm_dataset.txt"
tokenizer_output_path = f"{find_project_root()}/Resources/ML/Tokenizer/action_tokenizer.json"

# Load Data

In [None]:
# Load data

# To use existing dataset, use dataset_dir param
df, dir = get_dataset(prefer_local=False)

if max_dataset>-1:
    df = df.sample(max_dataset)
    
export_dataset(df, dataset_output_path, format="txt", completion_mode="short")

print(f"Saved {len(df)} samples to {dataset_output_path}")

# Tokenization

In [None]:
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers

# Tokenization
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
tokenizer.decoder = decoders.ByteLevel()
specials = ["<PAD>"]
# Tokenization - Train
trainer = trainers.BpeTrainer(special_tokens=specials)
tokenizer.train([dataset_output_path], trainer)
for tok in specials:
    tokenizer.add_special_tokens([tok])
# Tokenization - Save to file
tokenizer.save(tokenizer_output_path)
print(f"âœ… Tokenizer saved to {tokenizer_output_path}")


# Load Tokenizer
data_file = open(dataset_output_path, "r", encoding="utf-8")

tokenizer = Tokenizer.from_file(tokenizer_output_path)
vocab_size = tokenizer.get_vocab_size()

# Prepare Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from tokenizers import Tokenizer

def line_token_stream(file):
    for line in file:
        tokens = tokenizer.encode(line).ids
        yield tokens

token_stream = line_token_stream(data_file)

def get_batch_linear():
    global token_stream
    x_batch, y_batch = [], []

    while len(x_batch) < batch_size:
        try:
            tokens = next(token_stream)
        except StopIteration:
            # Restart from beginning
            data_file.seek(0)
            token_stream = line_token_stream(data_file)
            tokens = next(token_stream)

        # Pad or trim
        if len(tokens) < block_size + 1:
            tokens += [0] * (block_size + 1 - len(tokens))
        else:
            tokens = tokens[:block_size + 1]

        x_batch.append(tokens[:-1])
        y_batch.append(tokens[1:])

    return (
        torch.tensor(x_batch, dtype=torch.long, device=device),
        torch.tensor(y_batch, dtype=torch.long, device=device)
    )

# === Model ===
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) / (C ** 0.5)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        v = self.value(x)
        return wei @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.proj(out)

class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embed, n_heads):
        super().__init__()
        head_size = n_embed // n_heads
        self.sa = MultiHeadAttention(n_heads, head_size)
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embed)
        self.pos_embedding = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_heads) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding(idx)
        pos_emb = self.pos_embedding(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)

        if targets is None:
            return logits, None

        loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
        return logits, loss

    def generate(self, idx, max_new_tokens=50):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            probs = F.softmax(logits[:, -1, :], dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)
        return idx

# Training and save to onnx

In [None]:
# Training
model = GPT().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

for step in tqdm(range(max_iters)):
    xb, yb = get_batch_linear()  # sequential batching
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % eval_interval == 0:
        val_x, val_y = get_batch_linear()  # sequential validation
        _, val_loss = model(val_x, val_y)
        print(f"Step {step}: train loss {loss.item():.4f}, val loss {val_loss.item():.4f}")

# Generation
# Encode example prompt
context_ids = tokenizer.encode("BotPos=[2.23,2.25], BotRot=228, EnemyPos=[2.87,0.39], EnemyRot=87, AngleToEnemy=-29.68, AngleToEnemyScore=0.87, DistanceToEnemyScore=0.79, NearBorderArenaScore=0.42, FacingToArena=0.65.").ids
context = torch.tensor(context_ids, dtype=torch.long, device=device).unsqueeze(0)

# Generate
output_ids = model.generate(context, max_new_tokens=20)[0].tolist()

# Decode generated IDs back to text
output_text = tokenizer.decode(output_ids)
print(output_text)

model.eval()

# Dummy input: batch=1, variable sequence length (start small for export)
dummy_input = torch.randint(0, vocab_size, (1, 8), dtype=torch.long, device=device)

torch.onnx.export(
    model,
    dummy_input,
    f"{onnx_output_name}.onnx",
    input_names=["input_ids"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "logits": {0: "batch_size", 1: "sequence_length"}
    },
    opset_version=13
)

print(f"âœ… Exported GPT model to {onnx_output_name}.onnx")

# Quantize the model - OPTIONAL
import onnx
from onnxconverter_common import float16
model = onnx.load(f"{onnx_output_name}.onnx")
fp16_model = float16.convert_float_to_float16(model)
onnx.save(fp16_model, f"{onnx_output_name}_fp16.onnx")

# Testing

In [None]:
from time import sleep
from tokenizers import Tokenizer
import numpy as np
import onnxruntime as ort

# Load trained BPE tokenizer
tokenizer = Tokenizer.from_file(tokenizer_output_path)
vocab_size = tokenizer.get_vocab_size()

# Load ONNX model
session = ort.InferenceSession(f"{onnx_output_name}.onnx", providers=['CPUExecutionProvider'])

def generate_onnx(prompt, max_new_tokens=20, block_size=128):
    # Encode with BPE tokenizer
    input_ids = tokenizer.encode(prompt).ids
    for _ in range(max_new_tokens):
        # Keep only last block_size tokens
        input_slice = input_ids[-block_size:]
        input_array = np.array([input_slice], dtype=np.int64)

        # Run inference
        outputs = session.run(None, {"input_ids": input_array})
        logits = outputs[0]  # shape: (1, seq_len, vocab_size)

        # Get last token logits
        next_token_logits = logits[0, -1]
        next_token_id = int(np.argmax(next_token_logits))

        pred = tokenizer.decode([next_token_id])

        # Optional stop condition (EOS token index)
        if pred == "\n":
            break
        
        # sleep(0.1)
        input_ids.append(next_token_id)
        print(f"[{pred}]", end="", flush=True)



    # Decode IDs back to string
    return tokenizer.decode(input_ids)

# Test run
prompt = (
    "BotPos=[2.23,2.25], BotRot=228, EnemyPos=[2.87,0.39], EnemyRot=87, AngleToEnemy=-29.68, AngleToEnemyScore=0.87, DistanceToEnemyScore=0.79, NearBorderArenaScore=0.42, FacingToArena=0.65. Suggested Action:"
)

output = generate_onnx(prompt, max_new_tokens=300)
print(f"\n\nðŸ§  Output {len(output)}:", output)