In [None]:
from pathlib import Path

import networkx as nx
import nest_asyncio
import numpy as np
from numpy.typing import NDArray
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from optuna.visualization import plot_param_importances, plot_contour
import pandas as pd
from pandas import Index
from plotly.graph_objects import Figure
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import from_networkx

from modules import dataset, graph, models, paths, utils

In [None]:
# Allows for asyncio to be run in notebooks
nest_asyncio.apply()

## Task 2

### Data

In [None]:
# Load the dataset
train_set: pd.DataFrame = dataset.prepare_dataset('train')
valid_set: pd.DataFrame = dataset.prepare_dataset('valid')

# Split x and y
train_x: pd.DataFrame = train_set.drop(columns=['label'])
train_y: pd.Series = train_set['label']
valid_x: pd.DataFrame = valid_set.drop(columns=['label'])
valid_y: pd.Series = valid_set['label']

In [None]:
# Create the similarity graph
graph_creation_df: pd.DataFrame = train_x.filter(regex = '_length$')
graph_creation_df = graph_creation_df.rename(columns = lambda x: x.replace('_length', ''))
similarity_graph: graph.SimilarityGraph = graph.SimilarityGraph(graph_creation_df, threshold = 0.6, connected = True, show = True)

In [None]:
# Preprocessing

# One hot encode categorical columns
one_hot_encoder: OneHotEncoder = OneHotEncoder(sparse_output = False, handle_unknown = 'ignore')
categorical_columns: Index = train_x.select_dtypes(exclude = 'number').columns
one_hot_encoder.fit(train_x[categorical_columns])
encoded_columns: NDArray[str] = one_hot_encoder.get_feature_names_out(categorical_columns)  # type: ignore
train_x_encoded: pd.DataFrame = pd.DataFrame(one_hot_encoder.transform(train_x[categorical_columns]),   # type: ignore
                                             columns = encoded_columns
                                             )
valid_x_encoded: pd.DataFrame = pd.DataFrame(one_hot_encoder.transform(valid_x[categorical_columns]),   # type: ignore
                                             columns = encoded_columns
                                             )

# Scale numerical columns
scaler: MinMaxScaler = MinMaxScaler()
numerical_columns: Index = train_x.select_dtypes(include = 'number').columns
train_x_encoded[numerical_columns] = scaler.fit_transform(train_x[numerical_columns])
valid_x_encoded[numerical_columns] = scaler.transform(valid_x[numerical_columns])

# Encode labels
label_encoder: LabelEncoder = LabelEncoder()
train_y_encoded: torch.Tensor = torch.tensor(label_encoder.fit_transform(train_y))
valid_y_encoded: torch.Tensor = torch.tensor(label_encoder.transform(valid_y))

In [None]:
# Split the data in global and local features
global_columns: list[str] = encoded_columns.tolist() + ['sitelinks_count']

global_train_x: pd.DataFrame = train_x_encoded[global_columns]
global_valid_x: pd.DataFrame = valid_x_encoded[global_columns]

local_train_x: pd.DataFrame = train_x_encoded.drop(columns = global_columns)
local_valid_x: pd.DataFrame = valid_x_encoded.drop(columns = global_columns)

In [None]:
# Give the data a suitable format for PyTorch Geometric

training_graphs: list[nx.Graph] = similarity_graph.get_graphs(local_train_x)
validation_graphs: list[nx.Graph] = similarity_graph.get_graphs(local_valid_x)

train_data: list[Data] = []
global_train_x_tensor: torch.Tensor = torch.from_numpy(global_train_x.to_numpy(dtype = np.float32))
for graph_item, global_features, label in zip(training_graphs, global_train_x_tensor, train_y_encoded):
    data: Data = from_networkx(graph_item)
    data.x_fc = global_features
    data.y = label
    train_data.append(data)

valid_data: list[Data] = []
global_valid_x_tensor: torch.Tensor = torch.from_numpy(global_valid_x.to_numpy(dtype = np.float32))
for graph_item, global_features, label in zip(validation_graphs, global_valid_x_tensor, valid_y_encoded):
    data: Data = from_networkx(graph_item)
    data.x_fc = global_features
    data.y = label
    valid_data.append(data)

# Create the dataloaders
train_loader: DataLoader = DataLoader(train_data, batch_size = 256, shuffle = True)
valid_loader: DataLoader = DataLoader(valid_data, batch_size = 256)

In [None]:
# Get the number of features for the model

n_global_features: int = train_data[0].x_fc.shape[0]
n_local_features: int = train_data[0].x_graph.shape[1]
n_classes: int = len(label_encoder.classes_)

### Model

#### Hyperparameter Tuning

In [None]:
# Save the best model obtained during hyperparameter tuning
best_overall_checkpoint: ModelCheckpoint = ModelCheckpoint(monitor = 'val_f1', mode = 'max', dirpath = paths.GRAPH_MODEL_DIR, filename = 'graph_{epoch}')

In [None]:
def objective(trial: optuna.Trial) -> float:
    """
    Objective function for Optuna to optimize the hyperparameters.
    """

    # Hyperparameters
    inner_dim: int = trial.suggest_int('inner_dim', 8, 512, log = True)
    depth: int = trial.suggest_int('depth', 1, 5)
    lr: float = trial.suggest_float('lr', 1e-5, 1e-1, log = True)

    # Model
    model: models.GraphNet = models.GraphNet(fc_features = n_global_features,
                                             node_features = n_local_features,
                                             n_classes = n_classes,
                                             inner_dim = inner_dim,
                                             depth = depth,
                                             lr = lr
                                             )

    # Callbacks
    early_stopping: EarlyStopping = EarlyStopping(monitor = 'val_f1', mode = 'max', patience = 5)
    pruning_callback: PyTorchLightningPruningCallback = PyTorchLightningPruningCallback(trial, monitor = 'val_f1')

    # Wandb logger
    wandb_logger: WandbLogger = utils.configure_wandb_logger(project = 'Cultural classification on graphs',
                                                            name = f'trial_{trial.number}',
                                                            config = {'inner_dim': inner_dim,
                                                                      'depth': depth,
                                                                      'lr': lr
                                                                      }
                                                            )

    # Trainer
    trainer: pl.Trainer = pl.Trainer(max_epochs = -1,
                                     callbacks = [early_stopping, pruning_callback, best_overall_checkpoint],
                                     precision = '16-mixed',
                                     logger = wandb_logger,
                                     log_every_n_steps = len(train_loader),
                                     enable_progress_bar = False,
                                     enable_model_summary = False
                                     )

    # Train the model
    trainer.fit(model, train_loader, valid_loader)
    pruning_callback.check_pruned()

    # Evaluate the model on the best epoch
    f1: float = wandb_logger.experiment.summary['val_f1']['max']

    # Close the wandb run
    wandb_logger.experiment.finish()

    return f1

In [None]:
# Optuna study
study: optuna.Study = optuna.create_study(direction = 'maximize', pruner = optuna.pruners.MedianPruner(n_warmup_steps = 5))
study.optimize(objective, n_trials = 2)

In [None]:
# Rename the best model checkpoint
best_model_path: Path = utils.rename_best_checkpoint(best_overall_checkpoint, study.best_trial.number)

In [None]:
# Plot parameter importances
param_importances_fig: Figure = plot_param_importances(study)
param_fig: Figure = plot_param_importances(study)
param_fig.update_layout(autosize = False,
                        width = 1200,
                        height = 400
                        )
param_fig.show()

# Plot contour
contour_fig: Figure = plot_contour(study)
contour_fig.update_layout(autosize = False,
                          width = 1200,
                          height = 1200
                          )
contour_fig.show()

##### Best Model

In [None]:
# Take the best hyperparameters and print them
best_params: dict[str, int|float] = study.best_trial.params

# Retrain the model with the best hyperparameters
best_inner_dim: int = int(best_params['inner_dim'])
best_depth: int = int(best_params['depth'])
best_lr: float = best_params['lr']

print(f"""Best hyperparameters:
      \tlayer dimension: {best_inner_dim}
      \tdepth: {best_depth}
      \tlearning rate: {best_lr:.3e}"""
      )

In [None]:
# Load the best model
models: models.GraphNet = models.GraphNet.load_from_checkpoint(best_model_path)

# Evaluate the model on the validation set
trainer: pl.Trainer = pl.Trainer(max_epochs = -1, precision = '16-mixed', logger = False)
trainer.validate(models, valid_loader)

In [None]:
# Confusion matrix
logits: torch.Tensor = torch.cat(trainer.predict(models, valid_loader))    # type: ignore
predictions_encoded: torch.Tensor = torch.argmax(logits, dim = 1)
utils.plot_confusion_matrix(valid_y_encoded, predictions_encoded, label_encoder)

### Test