# Task 2

In [None]:
from pathlib import Path

import networkx as nx
import nest_asyncio
import numpy as np
from numpy.typing import NDArray
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from optuna.visualization import plot_param_importances, plot_contour
import pandas as pd
from pandas import Index
from plotly.graph_objects import Figure
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import from_networkx

from modules import dataset, graph, models_creation, paths
from modules.utils import model as model_utils

In [None]:
# Allows for asyncio to be run in notebooks
nest_asyncio.apply()

## Data

### Basic Operations

In [None]:
# Load the dataset
train_set: pd.DataFrame = dataset.extract_dataset('train')
val_set: pd.DataFrame = dataset.extract_dataset('validation')
test_set: pd.DataFrame = dataset.extract_dataset('test')

In [None]:
# Split x and y
train_x: pd.DataFrame = train_set.drop(columns = ['label'])
train_y: pd.Series = train_set['label']
val_x: pd.DataFrame = val_set.drop(columns = ['label'])
val_y: pd.Series = val_set['label']
test_x: pd.DataFrame = test_set.copy()

In [None]:
# Drop the columns that are not needed
cols_to_drop: list[str] = train_set.filter(regex = '_extract$').columns.tolist()    # Used in the previous task
cols_to_drop += ['item', 'name', 'description', 'category']
train_x.drop(columns = cols_to_drop, inplace = True)
val_x.drop(columns = cols_to_drop, inplace = True)
test_x.drop(columns = cols_to_drop, inplace = True)

### Preprocessing

In [None]:
# One hot encode categorical columns
one_hot_encoder: OneHotEncoder = OneHotEncoder(sparse_output = False, handle_unknown = 'ignore')
categorical_columns: Index = train_x.select_dtypes(exclude = 'number').columns
one_hot_encoder.fit(train_x[categorical_columns])
encoded_columns: NDArray[str] = one_hot_encoder.get_feature_names_out(categorical_columns)  # type: ignore
train_x_encoded: pd.DataFrame = pd.DataFrame(one_hot_encoder.transform(train_x[categorical_columns]),   # type: ignore
                                             columns = encoded_columns
                                             )
val_x_encoded: pd.DataFrame = pd.DataFrame(one_hot_encoder.transform(val_x[categorical_columns]),   # type: ignore
                                           columns = encoded_columns
                                           )
test_x_encoded: pd.DataFrame = pd.DataFrame(one_hot_encoder.transform(test_x[categorical_columns]),   # type: ignore
                                            columns = encoded_columns
                                            )

# Scale numerical columns
scaler: MinMaxScaler = MinMaxScaler()
numerical_columns: Index = train_x.select_dtypes(include = 'number').columns
train_x_encoded[numerical_columns] = scaler.fit_transform(train_x[numerical_columns])
val_x_encoded[numerical_columns] = scaler.transform(val_x[numerical_columns])
test_x_encoded[numerical_columns] = scaler.transform(test_x[numerical_columns])

# Encode labels
label_encoder: LabelEncoder = LabelEncoder()
train_y_encoded: torch.Tensor = torch.tensor(label_encoder.fit_transform(train_y))
val_y_encoded: torch.Tensor = torch.tensor(label_encoder.transform(val_y))

### Data Structuring

In [None]:
# Create the similarity graph
graph_creation_df: pd.DataFrame = train_x.filter(regex = '_length$')
graph_creation_df = graph_creation_df.rename(columns = lambda x: x.replace('_length', ''))
similarity_graph: graph.SimilarityGraph = graph.SimilarityGraph(graph_creation_df, threshold = 0.6, connected = True, show = True)

In [None]:
# Get the Dataloaders

def get_loader(x: pd.DataFrame, y: torch.Tensor|None, shuffle: bool = False) -> DataLoader:
    """
    Get the DataLoader for the given data.
    """

    # Split the different kinds of columns
    global_columns: list[str] = encoded_columns.tolist() + ['sitelinks_count']

    global_data: torch.Tensor = torch.from_numpy(x[global_columns].to_numpy(dtype = np.float32))
    graphs: list[nx.Graph] = similarity_graph.get_graphs(x.drop(columns = global_columns))

    # Create the DataLoader
    data_list: list[Data] = []
    for i in range(len(graphs)):
        data: Data = from_networkx(graphs[i])
        data.x_fc = global_data[i]
        if y is not None:
            data.y = y[i]
        data_list.append(data)

    return DataLoader(data_list, batch_size = 256, shuffle = shuffle)

train_loader: DataLoader = get_loader(train_x_encoded, train_y_encoded, shuffle = True)
val_loader: DataLoader = get_loader(val_x_encoded, val_y_encoded)
test_loader: DataLoader = get_loader(test_x_encoded, None)

In [None]:
# Get the number of features for the model

n_global_features: int = train_loader.dataset[0].x_fc.shape[0]
n_local_features: int = train_loader.dataset[0].x_graph.shape[1]
n_classes: int = len(label_encoder.classes_)

print(f"Number of global features: {n_global_features}")
print(f"Number features for each node: {n_local_features}")

## Model

### Tuning and Training

In [None]:
# Save the best model obtained during hyperparameter tuning
best_overall_checkpoint: ModelCheckpoint = ModelCheckpoint(monitor = 'val_f1', mode = 'max', dirpath = paths.GRAPH_MODEL_DIR, filename = 'graph_{epoch}')

In [None]:
def objective(trial: optuna.Trial) -> float:
    """
    Objective function for Optuna to optimize the hyperparameters.
    """

    # Hyperparameters
    inner_dim: int = trial.suggest_int('inner_dim', 32, 256, log = True)
    depth: int = trial.suggest_int('depth', 2, 5)
    lr: float = trial.suggest_float('lr', 1e-5, 1e-2, log = True)

    # Model
    model: models_creation.GraphNet = models_creation.GraphNet(fc_features = n_global_features,
                                                               node_features = n_local_features,
                                                               n_classes = n_classes,
                                                               inner_dim = inner_dim,
                                                               depth = depth,
                                                               lr = lr
                                                               )

    # Callbacks
    early_stopping: EarlyStopping = EarlyStopping(monitor = 'val_loss', patience = 10)
    pruning_callback: PyTorchLightningPruningCallback = PyTorchLightningPruningCallback(trial, monitor = 'val_f1')

    # Wandb logger
    wandb_logger: WandbLogger = model_utils.configure_wandb_logger(project = 'Cultural classification on graphs', name = f'trial_{trial.number}')

    # Trainer
    trainer: pl.Trainer = pl.Trainer(max_epochs = -1,
                                     callbacks = [early_stopping, pruning_callback, best_overall_checkpoint],
                                     precision = '16-mixed',
                                     logger = wandb_logger,
                                     log_every_n_steps = len(train_loader),
                                     enable_progress_bar = False,
                                     enable_model_summary = False
                                     )

    # Train the model
    trainer.fit(model, train_loader, val_loader)
    pruning_callback.check_pruned()

    # Evaluate the model on the best epoch
    f1: float
    try:
        f1 = wandb_logger.experiment.summary['val_f1']['max']
    except:
        raise optuna.TrialPruned("No f1 score found in wandb logger. Probably something went wrong during training.")
    finally:
        # Close the wandb run
        wandb_logger.experiment.finish()

    return f1

In [None]:
# Optuna study
study: optuna.Study = optuna.create_study(direction = 'maximize', pruner = optuna.pruners.MedianPruner(n_startup_trials = 10, n_warmup_steps = 5))
study.optimize(objective, n_trials = 50, show_progress_bar = True)

In [None]:
# Rename the best model checkpoint
best_model_path: Path = model_utils.rename_best_checkpoint(best_overall_checkpoint, study.best_trial.number)

In [None]:
# Plot parameter importances
param_importances_fig: Figure = plot_param_importances(study)
param_fig: Figure = plot_param_importances(study)
param_fig.update_layout(autosize = False,
                        width = 1200,
                        height = 400
                        )
param_fig.show()

# Plot contour
contour_fig: Figure = plot_contour(study)
contour_fig.update_layout(autosize = False,
                          width = 1200,
                          height = 1200
                          )
contour_fig.show()

In [None]:
# Print the best hyperparameters
best_params: dict[str, int|float] = study.best_trial.params
best_inner_dim: int = int(best_params['inner_dim'])
best_depth: int = int(best_params['depth'])
best_lr: float = best_params['lr']

print(f"""Best hyperparameters:
      \tlayer dimension: {best_inner_dim}
      \tdepth: {best_depth}
      \tlearning rate: {best_lr:.3e}"""
      )

### Results

#### Validation

In [None]:
# Load the best model
model: models_creation.GraphNet = models_creation.GraphNet.load_from_checkpoint(best_model_path)

# Evaluate the model on the validation set
trainer: pl.Trainer = pl.Trainer(max_epochs = -1, precision = '16-mixed', logger = False)
trainer.validate(model, val_loader)

In [None]:
# Confusion matrix
val_logits: torch.Tensor = torch.cat(trainer.predict(model, val_loader))    # type: ignore
val_predictions_encoded: torch.Tensor = torch.argmax(val_logits, dim = 1)
model_utils.plot_confusion_matrix(val_y_encoded, val_predictions_encoded, label_encoder)

#### Test

In [None]:
# Get the predictions for the test set
test_logits: torch.Tensor = torch.cat(trainer.predict(model, test_loader))    # type: ignore
test_predictions_encoded: torch.Tensor = torch.argmax(test_logits, dim = 1)
test_predictions: NDArray[str] = label_encoder.inverse_transform(test_predictions_encoded)    # type: ignore

# Save the predictions
test_predictions_df: pd.DataFrame = pd.DataFrame({'item': test_set['item'], 'name': test_set['name'], 'label': test_predictions}, index = test_set.index)
test_predictions_df.to_csv(paths.GRAPH_PREDICTIONS, index_label = 'id')
print(f"Saved the predictions on test set to {paths.GRAPH_PREDICTIONS}")