## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
from pathlib import Path

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
from torch import nn

os.environ["WANDB_SILENT"] = "true"

from owkin.training import train, RocLoss
from owkin.dataset import build_dataset
from owkin.models.mono_models import MLP
from owkin.models.aggregators import SmoothMaxAggregator

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

## Training

In [4]:
num_layers = 5
inside_dim = 512

mono_batch_size = 64
mono_nb_epochs = 2000

val_center = "C_1"
normalizer_type = "expo_lambda"
batch_size = 16
nb_epochs = 500

mono_lr = 3e-6
wd = 1e-1

In [5]:
list_val_center = ["C_1", "C_2", "C_5"]
list_normalizer_type = ["None", "mean", "expo_lambda"]
list_wd = [0, 1e-3, 3e-3, 1e-2]
list_lr = [3e-7, 1e-7]

In [6]:
for val_center in list_val_center:
    for normalizer_type in list_normalizer_type:
        X_train, X_train_mean, y_train, X_val, X_val_mean, y_val, X_test, X_test_mean, df_test = build_dataset(normalizer_type=normalizer_type, val_center=val_center)
        for wd in list_wd:
            mono_model = MLP(num_layers=num_layers, inside_dim=inside_dim)
            run_name = Path(f"./../saved_models/{mono_model.name}/BCELoss/nt_{normalizer_type}/vc_{val_center}/bs_{mono_batch_size}/wd_{'{:.0e}'.format(wd)}/lr_{'{:.0e}'.format(mono_lr)}")
            best_val_score = 0
            for model_path in Path(run_name).glob("*"):
                val_score = float(model_path.name[-8:-3])
                if val_score > best_val_score:
                    best_path = model_path
                    best_val_score = val_score
            mono_model.load_state_dict(torch.load(best_path))

            for lr in list_lr:
                model = SmoothMaxAggregator(mono_model)
                train(X_train,
                    y_train,
                    X_val,
                    y_val,
                    model,
                    learning_rate=lr,
                    weight_decay=wd,
                    val_center=val_center,
                    normalizer_type=normalizer_type,
                    batch_size=batch_size,
                    nb_epochs=nb_epochs,
                    use_wandb=True,
                    device=device,
                    exp_name=f"mono_lr_{mono_lr}"
                )

In [5]:
list_val_center = ["C_1", "C_2", "C_5"]
list_normalizer_type = ["None", "mean", "expo_lambda"]
list_wd = [0, 1e-3, 3e-3, 1e-2]

In [6]:
for val_center in list_val_center:
    for normalizer_type in list_normalizer_type:
        X_train, X_train_mean, y_train, X_val, X_val_mean, y_val, X_test, X_test_mean, df_test = build_dataset(normalizer_type=normalizer_type, val_center=val_center)
        for wd in list_wd:
            mono_model = MLP(num_layers=num_layers, inside_dim=inside_dim)
            train(X_train_mean,
                y_train,
                X_val_mean,
                y_val,
                mono_model,
                learning_rate=mono_lr,
                weight_decay=wd,
                val_center=val_center,
                normalizer_type=normalizer_type,
                batch_size=mono_batch_size,
                nb_epochs=mono_nb_epochs,
                use_wandb=True,
                device=device,
            )

            mono_model.load_state_dict(torch.load(mono_model.best_path))
            model = SmoothMaxAggregator(mono_model)
            train(X_train,
                y_train,
                X_val,
                y_val,
                model,
                learning_rate=lr,
                weight_decay=wd,
                val_center=val_center,
                normalizer_type=normalizer_type,
                batch_size=batch_size,
                nb_epochs=nb_epochs,
                use_wandb=True,
                device=device,
                exp_name=f"mono_lr_{mono_lr}"
            )

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…