In [None]:
import os

if os.path.basename(os.getcwd()) == 'notebooks':
    %cd ..

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.data.load_data import get_crossval_datasets, create_dataloaders, get_unique_num
from src.models.simple_regressor import SimpleRegressorModelV1
from src.models.train import train_rsm
from src.utils.common import seed_everything
from src.models.evaluation import compute_metrics, make_markdown_table

import torch
from torch import nn
import numpy as np

In [None]:
dataset_splits = get_crossval_datasets()
dataloaders = create_dataloaders(dataset_splits)

randomizer_seed = 42
seed_everything(randomizer_seed)

In [None]:
train_losses, val_losses = [], []
metrics = []

loss_fn = nn.MSELoss()

for i, (train_dataloader, val_dataloader) in enumerate(dataloaders, 1):
    print(f'Training on split #{i}...')
    model = SimpleRegressorModelV1(get_unique_num('user_id'), get_unique_num('item_id'), 128, 128)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    train_loss, val_loss = train_rsm(model, optimizer, loss_fn, train_dataloader, val_dataloader)
    train_losses.append(train_loss)
    val_losses.append(val_loss)

    metrics.append(compute_metrics(model, val_dataloader))
    metrics[-1]['Training loss'] = train_loss[-1]

last_tls = np.array(train_losses)[:,-1]
last_vls = np.array(val_losses)[:,-1]

print(f'\nTraining is over. Average training/validation RMSE: {np.mean(np.sqrt(last_tls)):0.2f} / {np.mean(np.sqrt(last_vls)):0.2f}')
print(f'Training losses: {last_tls}')
print(f'Validation losses: {last_vls}')

In [None]:
print(f'Random seed: {randomizer_seed}')
print('Summary on cross validation training:')

make_markdown_table(metrics)

In [None]:
model.save_model('models/model-128.pickle')