In [None]:
import networkx as nx
import nest_asyncio
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from optuna.visualization import plot_param_importances, plot_contour
import pandas as pd
from plotly.graph_objects import Figure
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from sklearn.preprocessing import LabelEncoder, MaxAbsScaler
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import from_networkx
import wandb

from modules import dataset, graph, model, paths

In [None]:
# Allows for asyncio to be run in notebooks
nest_asyncio.apply()

# Task 2

## Dataset

In [None]:
# Load the dataset

# Preprocessors
scaler: MaxAbsScaler = MaxAbsScaler()
label_encoder: LabelEncoder = LabelEncoder()

# Prepare the training set
train_set: pd.DataFrame = dataset.get_page_len_dataset('train')
train_x: pd.DataFrame = train_set.drop(columns=['label'])
train_x = pd.DataFrame(scaler.fit_transform(train_x), columns=train_x.columns)
train_y: torch.Tensor = torch.tensor(label_encoder.fit_transform(train_set['label']))

# Prepare the validation set
valid_set: pd.DataFrame = dataset.get_page_len_dataset('valid')
valid_x: pd.DataFrame = valid_set.drop(columns=['label']).reindex(columns=train_x.columns)
valid_x = pd.DataFrame(scaler.transform(valid_x), columns=valid_x.columns)
valid_y: torch.Tensor = torch.tensor(label_encoder.transform(valid_set['label']))

In [None]:
# Create the graphs

similarity_graph: graph.SimilarityGraph = graph.SimilarityGraph(train_x, k = 5, show = True)

training_graphs: list[nx.DiGraph] = similarity_graph.get_graphs(train_x.fillna(0))
validation_graphs: list[nx.DiGraph] = similarity_graph.get_graphs(valid_x.fillna(0))

In [None]:
# Make the graphs compatible with PyTorch Geometric

train_data: list[Data] = []
for digraph, label in zip(training_graphs, train_y):
    data: Data = from_networkx(digraph)
    data.y = label
    train_data.append(data)

valid_data: list[Data] = []
for digraph, label in zip(validation_graphs, valid_y):
    data: Data = from_networkx(digraph)
    data.y = label
    valid_data.append(data)

## Model

### Hyperparameter tuning

In [None]:
def objective(trial: optuna.Trial) -> float:
    """
    Objective function for Optuna to optimize the hyperparameters.
    """

    # Hyperparameters
    first_layer_channels: int = trial.suggest_int('first_layer_channels', 32, 256, step = 32)
    lr: float = trial.suggest_float('lr', 1e-5, 1e-1, log = True)
    dropout: float = trial.suggest_float('dropout', 0., 0.5)

    # Data loaders
    train_loader: DataLoader = DataLoader(train_data, batch_size = 256, shuffle = True)
    valid_loader: DataLoader = DataLoader(valid_data, batch_size = 256)

    # Model
    gcn: model.GCN = model.GCN(first_layer_channels, lr = lr, dropout = dropout)

    # Callbacks
    early_stopping: EarlyStopping = EarlyStopping(monitor = 'val_loss', patience = 10)
    pruning_callback: PyTorchLightningPruningCallback = PyTorchLightningPruningCallback(trial, monitor = 'val_loss')

    # Trainer
    trainer: pl.Trainer = pl.Trainer(max_epochs = -1,
                                     callbacks = [early_stopping, pruning_callback],
                                     logger = False,
                                     enable_progress_bar = False,
                                     enable_model_summary = False,
                                     enable_checkpointing = False,
                                     precision = '16-mixed'
                                     )

    # Train the model
    trainer.fit(gcn, train_loader, valid_loader)

    pruning_callback.check_pruned()

    # Evaluate the model
    f1: float = trainer.validate(gcn, valid_loader, verbose = False)[0]['val_f1']

    return f1

In [None]:
# Optuna study
study: optuna.Study = optuna.create_study(direction = 'maximize', pruner = optuna.pruners.MedianPruner(n_warmup_steps = 10))
study.optimize(objective, n_trials = 10, n_jobs = -1)

In [None]:
# Plot parameter importances
param_importances_fig: Figure = plot_param_importances(study)
param_fig: Figure = plot_param_importances(study)
param_fig.update_layout(autosize = False,
                        width = 1200,
                        height = 400
                        )
param_fig.show()

# Plot contour
contour_fig: Figure = plot_contour(study)
contour_fig.update_layout(autosize = False,
                          width = 1200,
                          height = 1200
                          )
contour_fig.show()

### Best Model

In [None]:
# Take the best hyperparameters and print them
best_params: dict[str, int|float] = study.best_trial.params

# Retrain the model with the best hyperparameters
best_channels: int = int(best_params['first_layer_channels'])
best_lr: float = best_params['lr']
best_dropout: float = best_params['dropout']

print(f"Best hyperparameters:\n\tfirst layer channels channels: {best_channels}\n\tlearning rate: {best_lr:.3e}\n\tdropout: {best_dropout:.3e}")

In [None]:
# Retrain the model with the best hyperparameters

# Data loaders
train_loader: DataLoader = DataLoader(train_data,
                                      batch_size = 256,
                                      shuffle = True
                                      )
valid_loader: DataLoader = DataLoader(valid_data,
                                      batch_size = 256
                                      )

# Model
gcn: model.GCN = model.GCN(best_channels, lr = best_lr, dropout = best_dropout)

# Wandb logger
wandb_logger: WandbLogger = WandbLogger(project = 'MNLP_HW_1', name = 'best_model', save_dir = paths.DATA_DIR)

# Callbacks
early_stopping: EarlyStopping = EarlyStopping(monitor='val_loss', patience = 10)
checkpoint: ModelCheckpoint = ModelCheckpoint(monitor='val_loss')

# Trainer
trainer: pl.Trainer = pl.Trainer(max_epochs = -1,
                                 callbacks = [early_stopping, checkpoint],
                                 logger = wandb_logger,
                                 log_every_n_steps = len(train_loader)
                                 )

# Train the model
trainer.fit(gcn, train_loader, valid_loader)

# Close wandb and remove the logger from the trainer
trainer.logger = None
wandb.finish()

In [None]:
# Load the best model
gcn: model.GCN = model.GCN.load_from_checkpoint(checkpoint.best_model_path)

# Evaluate the model on the validation set
trainer.validate(gcn, valid_loader)