In [None]:
import copy
from typing import Tuple, Any

import numpy as np
import pandas
import pandas as pd
import torch.nn as nn
import torch
import torchmetrics
import torch.nn.functional as F
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
from torch.utils.data import Dataset, DataLoader
import lightning as pl
from sklearn.model_selection import train_test_split


In [None]:
from jass.game.const import card_strings

In [None]:
class TrumpDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        assert X.shape[0] == y.shape[0], "X y dim mismatch"
        
    def __getitem__(self, item):
        assert False, "Inefficient __getitem__ called"
        # return torch.tensor(self.X[item]), torch.tensor(trump_to_one_hot(self.y[item]))
    
    def __getitems__(self, items):
        # linear layers and CrossEntropyLoss both need float tensors (in case of class probabilities).
        # The CrossEntropyLoss apparently is more efficient if given the class indices instead of the class probabilities
        # so no need to one-hot encode
        return torch.FloatTensor(self.X[items]), torch.LongTensor(self.y[items])
    
    def __len__(self):
        return self.X.shape[0]

In [None]:
class TrumpDataModule(pl.LightningDataModule):
    def __init__(self, csv_path: str, test_split: float, val_split: float, batch_size: int, num_workers: int):
        super().__init__()
        self.csv_path = csv_path
        self.test_split = test_split
        self.val_split = val_split
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.data = None
        self.promising_users = None
        self.features = None
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        
    def setup(self, stage: str):
        self.features = np.append(card_strings, ['FH'])
        self.data = pd.read_csv(self.csv_path)
        
        X = self.data[self.features].values
        fh = self.data['FH'].values
        y = self.data['trump'].values
        
        # we need stratification, otherwise torch's random_split would work too
        X_train, X_test, y_train, y_test, fh_train, _ = train_test_split(X, y, fh, test_size=self.test_split, stratify=(fh * 10 + y), random_state=42)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=self.val_split, stratify=(fh_train * 10 + y_train), random_state=42)
        
        self.train_dataset = TrumpDataset(X_train, y_train)
        self.val_dataset = TrumpDataset(X_val, y_val)
        self.test_dataset = TrumpDataset(X_test, y_test)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=self.collate)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=self.collate)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=self.collate)

    @staticmethod
    def collate(batch):
        # the default collate thinks __getitems__ returns a list of tuples that need to be stacked to get the batched values
        # just like you would need to if you invoked __getitem__ multiple times and tried to batch that together.
        # however, __getitems__ already returns the tensor with the items stacked so no need for additional processing.
        # see https://pytorch.org/docs/stable/data.html#torch.utils.data._utils.collate.collate potentially
        assert type(batch) == tuple and type(batch[0]) == torch.Tensor and type(batch[1]) == torch.Tensor, "Did not get tensor from dataset, investigate and update collate"
        
        return batch


In [None]:
from jass.game.game_util import get_cards_encoded

In [None]:
class TrumpGrafDataModule(pl.LightningDataModule):
    # uses the generated and then balanced graf dataset
    # could've just generated samples on the fly but then balancing and train/val splitting would be more complex
    def __init__(self, train_path: str, val_path: str, batch_size: int, num_workers: int):
        super().__init__()
        self.train_path = train_path
        self.val_path = val_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.features = None
        self.train_dataset = None
        self.val_dataset = None
        
    def setup(self, stage: str):
        cols = [f"c{i}" for i in range(1, 10)] + ['fh', 'trump']
        train_df = pd.read_parquet(self.train_path, columns=cols)
        val_df = pd.read_parquet(self.val_path, columns=cols)
        X_train, y_train = self.to_one_hot(train_df)
        X_val, y_val = self.to_one_hot(val_df)
        
        self.train_dataset = TrumpDataset(X_train, y_train)
        self.val_dataset = TrumpDataset(X_val, y_val)
        
    def to_one_hot(self, df: pd.DataFrame):
        non_card_cols = ['fh', 'trump']
        fh_trump = df[non_card_cols]
        cards = df.drop(non_card_cols, axis=1).values
        # takes a while, but dataset isn't that big anymore, this is doable
        one_hot = np.apply_along_axis(get_cards_encoded, 1, cards)
        return np.append(one_hot, np.expand_dims(fh_trump['fh'].values, axis=1), axis=1), fh_trump['trump'].values

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=self.collate)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=self.collate)
    
    @staticmethod
    def collate(batch):
        # the default collate thinks __getitems__ returns a list of tuples that need to be stacked to get the batched values
        # just like you would need to if you invoked __getitem__ multiple times and tried to batch that together.
        # however, __getitems__ already returns the tensor with the items stacked so no need for additional processing.
        # see https://pytorch.org/docs/stable/data.html#torch.utils.data._utils.collate.collate potentially
        assert type(batch) == tuple and type(batch[0]) == torch.Tensor and type(batch[1]) == torch.Tensor, "Did not get tensor from dataset, investigate and update collate"
        
        return batch


In [None]:
class TrumpSelection(pl.LightningModule):
    def __init__(self, input_dim: int, hidden_dim: int, n_layers: int, learning_rate: float):
        super().__init__()

        self.save_hyperparameters()
        
        n_classes = 7
        self.ll = nn.ModuleList([nn.Linear(input_dim, hidden_dim)] + [nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers-1)])
        self.classifier = nn.Linear(hidden_dim, n_classes)
        self.criterion = nn.CrossEntropyLoss()

        self.metrics = nn.ModuleDict(dict(
            accuracy=torchmetrics.Accuracy('multiclass', num_classes=n_classes),
            precision=torchmetrics.Precision('multiclass', num_classes=n_classes),
            recall=torchmetrics.Recall('multiclass', num_classes=n_classes),
            f1=torchmetrics.F1Score('multiclass', num_classes=n_classes),
        ))
        
        self.learning_rate = learning_rate
        
    def forward(self, x):
        for l in self.ll:
            x = l(x)
            x = F.relu(x)
        x = self.classifier(x)
        
        # no softmax here because of CrossEntropyLoss does that internally for better numerical stability
        return x
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    
    def training_step(self, batch, _batch_idx):
        return self.step("train_", batch)

    def validation_step(self, batch, _batch_idx):
        return self.step("val_", batch)
    
    def test_step(self, batch, _batch_idx):
        return self.step("test_", batch)
        
    def step(self, prefix, batch):
        X, y = batch
        predictions = self(X)
        loss = self.criterion(predictions, y)
        self.log(prefix + "loss", loss)

        # remember, prediction is still the logits.
        # many of these metrics should be able to handle that
        # but for efficiency and to be sure, let's do the softmax ourselves.
        predictions = F.softmax(predictions, dim=-1)
        self._log_and_update_metrics(prefix, predictions, y)
        
        return loss

    def _log_and_update_metrics(self, prefix, prediction, y):
        for name, metric in self.metrics.items():
            metric(prediction, y)
            self.log(prefix + name, metric)


In [None]:
input_dim = 36 + 1  # all cards + forehand

In [None]:
from lightning import seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

def train_dm(dm: pl.LightningDataModule, hidden_dim: int, n_layers: int, learning_rate: float, batch_size: int, epochs: int):
    hparams = copy.deepcopy(locals())  # feels wrong but does the job
    seed_everything(42)
    model = TrumpSelection(input_dim, hidden_dim, n_layers, learning_rate)

    checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
    # patience is number of epochs of being worse before stopping. actually it's number of val checks, but we check val once per epoch.
    # with this we slightly overtrain the model as the val_accuracy continues to rise for a bit while the val_loss is already increasing again.
    # see https://stats.stackexchange.com/questions/282160/how-is-it-possible-that-validation-loss-is-increasing-while-validation-accuracy
    early_stopping = EarlyStopping(monitor="val_accuracy", mode="max", patience=5)

    trainer = pl.Trainer(max_epochs=epochs, callbacks=[LearningRateMonitor(logging_interval='step'), checkpoint_callback, early_stopping], profiler='simple', log_every_n_steps=5)
    
    trainer.logger.log_hyperparams(hparams)
    trainer.logger.log_graph(model)
    
    trainer.fit(model, dm)

In [None]:
def finetune_model(model: pl.LightningModule, dm: pl.LightningDataModule, learning_rate: float, batch_size: int, epochs: int, early_stop_patience: int):
    hparams = copy.deepcopy(locals())  # feels wrong but does the job
    del hparams['model']
    del hparams['dm']
    seed_everything(42)

    # if fine-tuning on imbalanced dataset, use f1 instead of accuracy
    checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
    early_stopping = EarlyStopping(monitor="val_accuracy", mode="max", patience=early_stop_patience)

    trainer = pl.Trainer(max_epochs=epochs, callbacks=[LearningRateMonitor(logging_interval='step'), checkpoint_callback, early_stopping], profiler='simple', log_every_n_steps=1)
    
    trainer.logger.log_hyperparams(hparams)
    # trainer.logger.log_graph(model)
    
    trainer.fit(model, dm)

In [None]:
def pre_train(hidden_dim: int, n_layers: int, learning_rate: float, batch_size: int, epochs: int):
    dm = TrumpGrafDataModule("./data/graf-dataset-balanced/train/", "./data/graf-dataset-balanced/val/", num_workers=4, batch_size=batch_size)
    train_dm(dm, hidden_dim, n_layers, learning_rate, batch_size, epochs)

In [None]:
def fine_tune(checkpoint_path: str, learning_rate: float, batch_size: int, epochs: int, early_stop_patience: int):
    model = TrumpSelection.load_from_checkpoint(checkpoint_path)
    dm = TrumpDataModule("./data/trump_top250_balanced.csv", test_split=.2, val_split=.2, num_workers=4, batch_size=batch_size)
    finetune_model(model, dm, learning_rate, batch_size, epochs, early_stop_patience)


In [None]:
hparams = dict(
    hidden_dim = 50,
    n_layers = 2,
    learning_rate = 1e-3,
    batch_size = 15000,
    epochs = 100,
)

pre_train(**hparams)

_The_ model was pre-trained on the graf dataset for 50 epochs and got a val_accuracy of 1. lr=1e-4. Checkpoint version_36

It was then fine-tuned for 3894 epochs, before early stopping kicked in with patience = 250. lr was 1e-6. It reached a val_accuracy of 77%, but it was clear that progress has not yet stagnated. Checkpoint version_45

I then continued fine-tuning from that new checkpoint with lr 1e-7, max_epoch 10000 and early_stopping_patience=1000. The learning rate was probably a bit small because it progressed very slowly, doing the entire 10k epochs in 1.2h and getting a val_accuracy of 79%. Checkpoint version_46.

Progress still hasn't stagnated, so I'll start it again from here but with a slightly larger learning rate (5 times as large, logarithmic middle of 1e-6 and 1e-7). Also upped the epochs to 20k and decreased the early stopping to 750. Okay that didn't do what I'd hoped for as it didn't stagnate again and early stopping kicked in pretty soon.

One last try, from 46 again, lr 1e-7 and patience on 1000. Checkpoint version_48. This got 79.6% Validation accuracy and seems to not be fully stagnated but fact of the matter is, that 1k epoch patience wasn't enough to keep going, so I'll call it here.

Next steps:

- Try with a more complex model, but I suspect it won't be very beneficial.
- Once done with tuning, run the testing. Note, this will only give a score of how well it imitates the top 250 players, not how good it's playing (the best chess bots also don't play like humans; I'd love to use this argument but if it isn't learning trump selection from self-play later on, this is an invalid argument).
- Extract the model definition to a python file, so it can be imported and used elsewhere
- Write an agent that uses the model definition and a checkpoint path to predict the best trump. Fallback can be graf if invalid but that hopefully doesn't happen.
- Compare this agent with the graf agent and potentially with another model agent, where the model learned only on the training data, without pre-training. The card playing stategy has to be the same of course.

In [None]:
hparams = dict(
    # checkpoint_path = "./lightning_logs/version_36/checkpoints/epoch=50-step=65535.ckpt",
    # checkpoint_path = "./lightning_logs/version_45/checkpoints/epoch=3644-step=10935.ckpt",
    checkpoint_path = "./lightning_logs/version_46/checkpoints/epoch=9674-step=29025.ckpt",
    learning_rate = 1e-7,
    batch_size = 3000,
    epochs = 20000,
    early_stop_patience=1000,
)

fine_tune(**hparams)

The Graf heuristic is a regression from 36 values to 1 with a linear transformation (there are just 36 values that are multiplied with the 36 hand values and then summed to get the score). You could also say it's a linear transformation from 36 values to 6 values where there's one heuristic score per trump. At the end there's just an argmax, which could also just be a softmax. The only additional thing is the threshold that needs to be crossed otherwise it's Schieben.

What I'm trying to say is, a network that can model this heuristic would be very simple and one that should perform better, would not need to be much larger. However, if the model is supposed to imitate human players, like it would be if we trained it supervised on this data, then it will learn different things.

One approach could also be to pre-train a network with the graf heuristic (all possible card combinations with their respective graf heuristic choice can be generated). Then this network already has a good basis. It could then be fine-tuned on the historical data from swisslos, but maybe only on the very best performing players, because a lot less data is required since the performance is already good with just the heuristic. A big advantage here would also be that we can downsample to get balanced data, both for pre-training (there's a lot more data) and for fine-tuning (a lot less data is needed).

The only way to ensure one method actual outperforms another, you need to test the same card-playing-bot playing against each other with different trump selection methods, e.g. one team ISMCTS with Graf and one team ISMCTS with deep learning purely from historical data.

Sidenote: With a complex architecture, this model will overfit quickly but as you'll see, the accuracy, recall, etc. are still increasing. This could be a side effect of the class imbalance or a sign that the model is getting more unsure. See also [this SO question](https://stats.stackexchange.com/questions/282160/how-is-it-possible-that-validation-loss-is-increasing-while-validation-accuracy).