In [None]:
from typing import Literal

import networkx as nx
import nest_asyncio
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from optuna.visualization import plot_param_importances, plot_contour
import pandas as pd
from plotly.graph_objects import Figure
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from sklearn.preprocessing import LabelEncoder, StandardScaler
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import from_networkx
import wandb

from modules import dataset, graph, model

In [None]:
# Allows for asyncio to be run in notebooks
nest_asyncio.apply()

# Task 2

## Dataset

In [None]:
# Load the dataset

# Preprocessors
scaler: StandardScaler = StandardScaler()
label_encoder: LabelEncoder = LabelEncoder()

# Prepare the training set
train_set: pd.DataFrame = dataset.get_page_len_dataset('train')
train_x: pd.DataFrame = train_set.drop(columns=['label'])
train_x = pd.DataFrame(scaler.fit_transform(train_x), columns=train_x.columns)
train_y: torch.Tensor = torch.tensor(label_encoder.fit_transform(train_set['label']))

# Prepare the validation set
valid_set: pd.DataFrame = dataset.get_page_len_dataset('valid')
valid_x: pd.DataFrame = valid_set.drop(columns=['label']).reindex(columns=train_x.columns)
valid_x: pd.DataFrame = pd.DataFrame(scaler.transform(valid_x), columns=valid_x.columns)
valid_y: torch.Tensor = torch.tensor(label_encoder.transform(valid_set['label']))

In [None]:
# Create the graphs

# Parameters
mode: Literal['iou', 'correlation'] = 'correlation'
treshold: float = 0.5

training_graphs: list[nx.Graph] = graph.get_similarity_graphs(train_x, similarity_threshold=treshold, mode=mode, show=True)
validation_graphs: list[nx.Graph] = graph.get_similarity_graphs(valid_x, similarity_threshold=treshold, mode=mode)

In [None]:
# Make the graphs compatible with PyTorch Geometric

train_data: list[Data] = []
for graph, label in zip(training_graphs, train_y):
    data: Data = from_networkx(graph)
    data.y = label
    train_data.append(data)

valid_data: list[Data] = []
for graph, label in zip(validation_graphs, valid_y):
    data: Data = from_networkx(graph)
    data.y = label
    valid_data.append(data)

## Model

### Hyperparameter tuning

In [None]:
def objective(trial: optuna.Trial) -> float:
    """
    Objective function for Optuna to optimize the hyperparameters.
    """

    # Hyperparameters
    batch_size: int = 2 ** trial.suggest_int('batch_size_exp', 4, 7)
    hidden_dim: int = trial.suggest_int('hidden_dim', 16, 128, step = 16)
    lr: float = trial.suggest_float('lr', 1e-5, 1e-1, log = True)
    dropout: float = trial.suggest_float('dropout', 0., 0.5)
    smoothing: float = trial.suggest_float('smoothing', 0., 0.4)

    # Data loaders
    train_loader: DataLoader = DataLoader(train_data,
                                          batch_size = batch_size,
                                          shuffle = True
                                          )
    valid_loader: DataLoader = DataLoader(valid_data,
                                          batch_size = batch_size,
                                          )

    # Model
    gcn: model.GCN = model.GCN(1, hidden_dim, lr = lr, dropout = dropout, smoothing = smoothing)

    # Callbacks
    early_stopping: EarlyStopping = EarlyStopping(monitor='val_loss', patience=5)
    pruning_callback: PyTorchLightningPruningCallback = PyTorchLightningPruningCallback(trial, monitor='val_loss')

    # Trainer
    trainer: pl.Trainer = pl.Trainer(callbacks = [early_stopping, pruning_callback],
                                     logger = False,
                                     enable_progress_bar = False,
                                     enable_model_summary = False
                                     )

    # Train the model
    trainer.fit(gcn, train_loader, valid_loader)

    pruning_callback.check_pruned()

    # Evaluate the model
    loss: float = trainer.validate(gcn, valid_loader, verbose=False)[0]['val_loss']

    return loss

In [None]:
# Optuna study
study: optuna.Study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=3, n_jobs=-1)

In [None]:
# Plot parameter importances
param_importances_fig: Figure = plot_param_importances(study)
param_fig: Figure = plot_param_importances(study)
param_fig.update_layout(autosize=False,
                        width=1200,
                        height=400
                        )
param_fig.show()

# Plot contour
contour_fig: Figure = plot_contour(study)
contour_fig.update_layout(autosize=False,
                          width=1200,
                          height=1200
                          )
contour_fig.show()

### Best Model

In [None]:
# Take the best hyperparameters and print them
best_params: dict[str, int|float] = study.best_trial.params

# Retrain the model with the best hyperparameters
best_batch_size: int = int(2 ** best_params['batch_size_exp'])
best_hidden_dim: int = int(best_params['hidden_dim'])
best_lr: float = best_params['lr']
best_dropout: float = best_params['dropout']
best_smoothing: float = best_params['smoothing']

print(f"Best hyperparameters:\n\tbatch_size: {best_batch_size}\n\thidden_dim: {best_hidden_dim}\n\tlr: {best_lr:.3e}\n\tdropout: {best_dropout:.3e}\n\tsmoothing: {best_smoothing:.3e}")

In [None]:
# Retrain the model with the best hyperparameters

# Data loaders
train_loader: DataLoader = DataLoader(train_data,
                                      batch_size = best_batch_size,
                                      shuffle = True
                                      )
valid_loader: DataLoader = DataLoader(valid_data,
                                      batch_size = best_batch_size
                                      )

# Model
gcn: model.GCN = model.GCN(1, best_hidden_dim, lr = best_lr, dropout = best_dropout, smoothing = best_smoothing)

# Wandb logger
wandb_logger: WandbLogger = WandbLogger(project='MNLP_HW_1', name='best_model', dir='wandb')

# Callbacks
checkpoint: ModelCheckpoint = ModelCheckpoint(monitor='val_loss', filename=f'best_model.ckpt')
early_stopping: EarlyStopping = EarlyStopping(monitor='val_loss', patience=5)

# Trainer
trainer: pl.Trainer = pl.Trainer(callbacks = [early_stopping, checkpoint],
                                 logger = wandb_logger,
                                 log_every_n_steps = len(train_loader)
                                 )

# Train the model
trainer.fit(gcn, train_loader, valid_loader)

# Close wandb and remove the logger from the trainer
trainer.logger = None
wandb.finish()

In [None]:
# Load the best model
gcn: model.GCN = model.GCN.load_from_checkpoint(checkpoint.best_model_path)

# Evaluate the model on the validation set
trainer.validate(gcn, valid_loader)