In [None]:
import copy
from typing import Tuple, Any

import numpy as np
import pandas
import pandas as pd
import torch.nn as nn
import torch
import torchmetrics
import torch.nn.functional as F
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
from torch.utils.data import Dataset, DataLoader
import lightning as pl
from sklearn.model_selection import train_test_split


In [None]:
from jass.game.const import card_strings

In [None]:
def trump_to_one_hot(trump: int) -> np.ndarray:
    if trump == 10:
        trump = 6  # correct PUSH down to 6 (there are both cases in the dataset)
    one_hot = np.zeros(4 + 2 + 1, dtype=int)  # 4 suits + une_ufe & obe_abe + push
    one_hot[trump] = 1
    return one_hot

In [None]:
def trump_to_one_hot(trumps: np.ndarray) -> np.ndarray:
    # correct PUSH down to 6 (there are both cases in the dataset)
    trumps[trumps == 10] = 6
    one_hot = np.zeros((trumps.shape[0], 4 + 2 + 1), dtype=int)  # 4 suits + une_ufe & obe_abe + push
    np.put_along_axis(one_hot, np.expand_dims(trumps, 1), 1, axis=1)
    return one_hot

In [None]:
# trump_to_one_hot(5)

In [None]:
trump_to_one_hot(np.array([4, 5, 1, 10]))

In [None]:
class TrumpDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        assert X.shape[0] == y.shape[0], "X y dim mismatch"
        
    def __getitem__(self, item):
        assert False, "Inefficient __getitem__ called"
        # return torch.tensor(self.X[item]), torch.tensor(trump_to_one_hot(self.y[item]))
    
    def __getitems__(self, items):
        # linear layers and CrossEntropyLoss both need float tensors
        return torch.FloatTensor(self.X[items]), torch.FloatTensor(trump_to_one_hot(self.y[items]))
    
    def __len__(self):
        return self.X.shape[0]

In [None]:
class TrumpDataModule(pl.LightningDataModule):
    def __init__(self, csv_path: str, test_split: float, val_split: float, batch_size: int, num_workers: int, promising_users_path=None):
        super().__init__()
        self.csv_path = csv_path
        self.promising_users_path = promising_users_path
        self.test_split = test_split
        self.val_split = val_split
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.data = None
        self.promising_users = None
        self.features = None
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        
    def setup(self, stage: str):
        self.features = np.append(card_strings, ['FH'])
        cols = np.append(self.features, ['user', 'trump'])
        self.data = pd.read_csv(self.csv_path, header=None, names=cols)
        
        if self.promising_users_path:
            # only use the promising users (what is promising is decided in the player-selection notebook)
            # currently not using the scores, could maybe use them with something like this
            # https://stackoverflow.com/a/77300557/10883465
            # but definitely not just *sample_weights, maybe * (1 + sample_weights/2), to just give
            # the good players a tiny importance boost.
            self.promising_users = pd.read_csv(self.promising_users_path)
            self.data = self.data[self.data['user'] in self.promising_users['id']]
        
        X = self.data[self.features].values
        y = self.data['trump'].values
        
        # we need stratification, otherwise torch's random_split would work too
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=self.test_split, stratify=y, random_state=42)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=self.val_split, stratify=y_train, random_state=42)
        
        self.train_dataset = TrumpDataset(X_train, y_train)
        self.val_dataset = TrumpDataset(X_val, y_val)
        self.test_dataset = TrumpDataset(X_test, y_test)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=self.collate)

    def val_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=self.collate)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=self.collate)

    @staticmethod
    def collate(batch):
        # the default collate thinks __getitems__ returns a list of tuples that need to be stacked to get the batched values
        # just like you would need to if you invoked __getitem__ multiple times and tried to batch that together.
        # however, __getitems__ already returns the tensor with the items stacked so no need for additional processing.
        # see https://pytorch.org/docs/stable/data.html#torch.utils.data._utils.collate.collate potentially
        assert type(batch) == tuple and type(batch[0]) == torch.Tensor and type(batch[1]) == torch.Tensor, "Did not get tensor from dataset, investigate and update collate"
        
        return batch


In [None]:
class TrumpSelection(pl.LightningModule):
    def __init__(self, input_dim: int, hidden_dim: int, n_layers: int, learning_rate: float):
        super().__init__()

        self.save_hyperparameters()
        
        n_classes = 7
        self.ll = nn.ModuleList([nn.Linear(input_dim, hidden_dim)] + [nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers-1)])
        self.classifier = nn.Linear(hidden_dim, n_classes)
        self.criterion = nn.CrossEntropyLoss()

        self.metrics = nn.ModuleDict(dict(
            accuracy=torchmetrics.Accuracy('multiclass', num_classes=n_classes),
            precision=torchmetrics.Precision('multiclass', num_classes=n_classes),
            recall=torchmetrics.Recall('multiclass', num_classes=n_classes),
            f1=torchmetrics.F1Score('multiclass', num_classes=n_classes),
        ))
        
        self.learning_rate = learning_rate
        
    def forward(self, x):
        for l in self.ll:
            x = l(x)
            x = F.relu(x)
        x = self.classifier(x)
        
        # no softmax here because of CrossEntropyLoss does that internally for better numerical stability
        return x
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    
    def training_step(self, batch, _batch_idx):
        return self.step("train_", batch)

    def validation_step(self, batch, _batch_idx):
        return self.step("val_", batch)
    
    def test_step(self, batch, _batch_idx):
        return self.step("test_", batch)
        
    def step(self, prefix, batch):
        X, y = batch
        predictions = self(X)
        loss = self.criterion(predictions, y)
        self.log(prefix + "loss", loss)

        # remember, prediction is still the logits.
        # many of these metrics should be able to handle that
        # but for efficiency and to be sure, let's do the softmax ourselves.
        predictions = F.softmax(predictions, dim=-1)
        self._log_and_update_metrics(prefix, predictions, y)
        
        return loss

    def _log_and_update_metrics(self, prefix, prediction, y):
        for name, metric in self.metrics.items():
            metric(prediction, y)
            self.log(prefix + name, metric)


In [None]:
from lightning import seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

def train(hidden_dim: int, n_layers: int, learning_rate: float, batch_size: int, epochs: int):
    hparams = copy.deepcopy(locals())  # feels wrong but does the job
    seed_everything(42)
    dm = TrumpDataModule("./data/2018_10_18_trump.csv", test_split=.2, val_split=.2, num_workers=4, batch_size=batch_size, promising_users_path="./data/promising_players.csv")
    input_dim = 36 + 1  # all cards + forehand
    model = TrumpSelection(input_dim, hidden_dim, n_layers, learning_rate)

    checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")

    trainer = pl.Trainer(max_epochs=epochs, callbacks=[LearningRateMonitor(logging_interval='step'), checkpoint_callback], profiler='simple', log_every_n_steps=5)
    
    trainer.logger.log_hyperparams(hparams)
    
    trainer.fit(model, dm)

In [None]:
train(250, 10, 1e-3, 2000, 2500)