## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import glob
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import copy
import pandas as pd
import numpy as np
import torch

from sklearn.metrics import roc_auc_score

os.environ["WANDB_SILENT"] = "true"

from owkin.training import train, RocLoss
from owkin.dataset import build_dataset
from owkin.models.mono_models import MLP, SVM
from owkin.models.aggregators import SmoothMaxAggregator, MaxAggregator

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

## Get the Data

In [4]:
X_train, X_train_mean, y_train, X_val, X_val_mean, y_val, X_test, X_test_mean, df_test = build_dataset()

## Training

In [5]:
num_layers = 5
inside_dim = 512

In [5]:
mono_batch_size = 64
mono_nb_epochs = 5000

batch_size = 16
nb_epochs = 1000

In [7]:
list_wd = [0, 1e-3, 3e-3, 1e-2]
list_mono_lr = [1e-6, 3e-6, 1e-5]
list_lr = [1e-6, 3e-6]

In [9]:
for wd in list_wd:
    for mono_lr in list_mono_lr:
        mono_model = MLP(num_layers=num_layers, inside_dim=inside_dim)
        train(X_train_mean,
            y_train,
            X_val_mean,
            y_val,
            mono_model,
            learning_rate=mono_lr,
            weight_decay=wd,
            batch_size=mono_batch_size,
            nb_epochs=mono_nb_epochs,
            use_wandb=True,
            device=device)
        
        mono_model.load_state_dict(torch.load(mono_model.best_path))
        for lr in list_lr:
            model = SmoothMaxAggregator(mono_model)
            train(X_train,
                y_train,
                X_val,
                y_val,
                model,
                learning_rate=lr,
                weight_decay=wd,
                batch_size=batch_size,
                nb_epochs=nb_epochs,
                use_wandb=True,
                device=device,
                exp_name=f"mono_lr_{mono_lr}")

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…