In [None]:
import os
import sys

# Do be able to import reinforcement_yatzy
lib_path = os.path.abspath('..')
if lib_path not in sys.path:
    sys.path.append(lib_path)

In [None]:
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from reinforcement_yatzy.nn_models.autoencoders.scoreboard_autoencoder import  ScoreboardAutoencoder
from reinforcement_yatzy.scoreboard_dataset.scoreboard_dataset import ScoreboardDataset
from reinforcement_yatzy.yatzy.base_player import ABCYatzyPlayer

In [None]:
def train_autoencoder(
    autoencoder: ScoreboardAutoencoder,
    train_loader: DataLoader,
    val_loader: DataLoader,
    n_epochs: int,
    loss_path: Path,
    n_rising_until_break: int,
    save_interval: int,
    save_path: Path,
):
    '''Train the autoencoder with holdout validation'''
    criterion = nn.MSELoss()
    optimizer = optim.Adam(autoencoder.parameters(), lr=1e-3)

    n_epochs_val_rising = 0
    train_losses = np.zeros([n_epochs])
    val_losses = np.zeros([n_epochs])

    for i_epoch in range(n_epochs):
        # Training
        with tqdm(train_loader, unit='batch') as tepoch:
            sum_loss = 0
            for i_batch, batch in enumerate(tepoch):
                tepoch.set_description(f'Epoch {i_epoch}')

                optimizer.zero_grad()
                outputs = autoencoder(batch)
                loss = criterion(outputs, batch)
                loss.backward()
                optimizer.step()
                sum_loss += loss.item()

                tepoch.set_postfix(loss=sum_loss / (i_batch + 1))

        avg_train_loss = sum_loss / len(train_loader)
        train_losses[i_epoch] = avg_train_loss

        # Validation
        with torch.no_grad():
            sum_val_loss = 0
            for batch in val_loader:
                outputs = autoencoder(batch)
                val_loss = criterion(outputs, batch)
                sum_val_loss += val_loss

        avg_val_loss = sum_val_loss / len(val_loader)
        val_losses[i_epoch] = avg_val_loss

        tqdm.write(
            f'\nEpoch {i_epoch} - Training Loss: {avg_train_loss:.2e} - Validation Loss: {avg_val_loss:.2e}\n')

        
        if i_epoch % save_interval == 0:
            torch.save(autoencoder.encoder.state_dict(), save_path)
            decoder_path = save_path.with_stem(save_path.stem + '_decoder')
            torch.save(autoencoder.decoder.state_dict(), decoder_path)

            np.savetxt(
                loss_path,
                np.column_stack([train_losses, val_losses]),
                delimiter=',',
                header='train_loss,val_loss',
            )

        # Holdout validation
        if i_epoch > 0 and val_losses[i_epoch] > val_losses[i_epoch - 1]:
            n_epochs_val_rising += 1

        if n_epochs_val_rising == n_rising_until_break:
            break

    print(f'Saved weights to {save_path}')

    return autoencoder

In [None]:
epochs = 10

encoder_dims = [64, 64, 32, 16, 8]
latent_dim = 4

autoencoder = ScoreboardAutoencoder(
    n_entries=ABCYatzyPlayer.NUM_ENTRIES,
    latent_dim=latent_dim,
    mlp_dims=encoder_dims,
)

dataset_path = os.path.join('..', 'datasets', '512k_scoreboards.csv')
dataset_df = pd.read_csv(dataset_path)
dataset = ScoreboardDataset(dataset_df)

train_set, val_set = random_split(dataset, [.95, 0.05])
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=True)

param_specifier = ''.join([
    'auto_encoder_',
    '[',
    *[str(dim) + '_' for dim in encoder_dims],
    ']_',
    f'{latent_dim}',
    ])

now = datetime.now()
formatted_time = now.strftime("%d-%m-%H-%M")

loss_file_path = Path(''.join([
    'loss_log_',
    param_specifier,
    '__',
    formatted_time,
    '.csv'
]))
loss_log_dir_path = os.path.join('..', 'loss_logs')
loss_path = Path(os.path.join(loss_log_dir_path, loss_file_path))

weights_dir_path = os.path.join('..', 'weights', 'autoencoder')
weights_path = param_specifier + '.pth'
save_path = Path(os.path.join(weights_dir_path, weights_path))

In [None]:
train_autoencoder(
    autoencoder,
    train_loader,
    val_loader,
    epochs,
    loss_path=loss_path,
    n_rising_until_break=1,
    save_interval=5,
    save_path=save_path,
)