In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset, Dataset
import numpy as np

import pandas as pd

In [None]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

num_labels = 3
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels
)
model = model.to('cuda')

In [None]:
from quixo_env import *

In [None]:
train_df = pd.read_csv('./quixo_statuses_train.csv')
test_df = pd.read_csv('./quixo_statuses_test.csv')

In [None]:
train_df

Unnamed: 0.1,Unnamed: 0,prefix,source_state_depth,target
0,0,X O # #\nO O # X\n# # X X\nO O # #\n\nX\n,-1,d
1,1,X # # #\nO # O #\nO # # O\nO X # X\n\nX\n,10,l
2,2,# # # O\n# O # #\nO # X X\n# X X O\n\nX\n,5,w
3,3,# X X O\nO O # O\nO # X X\n# O O O\n\nX\n,2,l
4,4,# O X O\n# O X O\nO X X X\nO X X O\n\nO\n,0,l
...,...,...,...,...
299995,299995,O X # #\n# X # O\nO # # #\nX O X #\n\nO\n,13,w
299996,299996,O O X #\nO O O X\nO # X O\n# X O O\n\nX\n,2,l
299997,299997,X O X X\nO X X #\nX X # #\n# O O X\n\nO\n,2,l
299998,299998,X X O O\nX # O X\n# X # #\nX X # #\n\nO\n,2,l


In [None]:
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

In [None]:
target_mapping = {
    'l': 0,
    'w': 1,
    'd': 2,
}

# Tokenize the dataset
def preprocess_function(examples):
    model_inputs = tokenizer(examples["prefix"], max_length=128, truncation=True, padding="max_length",)
    model_inputs["labels"] = list(map(lambda l : target_mapping[l], examples["target"]))
    return model_inputs

tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_test_dataset = test_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/300000 [00:00<?, ? examples/s]

Map:   0%|          | 0/15000 [00:00<?, ? examples/s]

In [None]:
class LLMClassifierPlayer():
    negative_class_idx = 0
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def predict_classes(self, model, tokenizer, states, max_length=128):
            encodings = tokenizer(
                states,
                truncation=True,
                padding=True,
                max_length=max_length,
                return_tensors="pt"
            )

            device = model.device
            encodings = {k: v.to(device) for k, v in encodings.items()}



            with torch.no_grad():
                outputs = model(**encodings)
                scores = torch.nn.functional.softmax(outputs.logits, dim=-1)[:, self.negative_class_idx]


            return torch.argmax(scores).cpu().numpy()

    def get_action(self, env):
        states = env.get_possible_next_states()
        action_idx = self.predict_classes(self.model, self.tokenizer, [s[0] for s in states])
        return [s[1] for s in states][action_idx]

In [None]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to='none'
)




In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
)

trainer.train()

In [None]:
from tqdm import tqdm

In [None]:
def evaluate_model(model, test_dataset):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    predictions = []
    true_labels = []

    # Create a DataLoader with proper batch collation
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=64,
        collate_fn=lambda batch: {
            'input_ids': torch.stack([torch.tensor(item['input_ids']) for item in batch]),
            'attention_mask': torch.stack([torch.tensor(item['attention_mask']) for item in batch]),
            'labels': torch.tensor([item['labels'] for item in batch])
        }
    )

    for batch in tqdm(test_loader):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask']
            )

        logits = outputs.logits
        batch_predictions = torch.argmax(logits, dim=-1)
        predictions.extend(batch_predictions.cpu().numpy())
        true_labels.extend(batch['labels'].cpu().numpy())

    return predictions, true_labels

predictions, true_labels = evaluate_model(model, tokenized_test_dataset)


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 235/235 [00:25<00:00,  9.20it/s]


In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

accuracy = accuracy_score(true_labels, predictions)
print(f"\nTest Accuracy: {accuracy:.4f}")


Test Accuracy: 0.8636


In [None]:
env = QuixoEnv(None)
player1 = LLMClassifierPlayer(model, tokenizer)
player2 = RandomPlayer()

In [None]:
from collections import Counter

results = Counter()
for game_number in range(500):
    result = play_game(env, player1, player2)
    results.update(str(result.value if result is not None else "#"))

In [None]:
results

Counter({'X': 487, 'O': 13})

In [None]:
from collections import Counter

results = Counter()
for game_number in range(500):
    result = play_game(env, player2, player1)
    results.update(str(result.value if result is not None else "#"))

In [None]:
results

Counter({'O': 303, 'X': 197})