In [None]:
from typing import Literal

import networkx as nx
import pandas as pd
import pytorch_lightning as pl
from sklearn.preprocessing import LabelEncoder
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import from_networkx

from modules import dataset, graph, model

In [None]:
# Load the dataset

# Label encoder
label_encoder: LabelEncoder = LabelEncoder()

# Prepare the training set
train_set: pd.DataFrame = dataset.prepare_dataset('train')
train_x: pd.DataFrame = train_set.drop(columns=['label'])
label_encoder.fit(train_set['label'])
train_y: torch.Tensor = torch.tensor(label_encoder.transform(train_set['label']))

# Prepare the validation set
valid_set: pd.DataFrame = dataset.prepare_dataset('valid')
valid_x: pd.DataFrame = valid_set.drop(columns=['label'])
valid_y: torch.Tensor = torch.tensor(label_encoder.transform(valid_set['label']))

In [None]:
# Create the graphs

# Parameters
mode: Literal['iou', 'correlation', 'filtered correlation'] = 'iou'
treshold: float = 0.5

training_graph: nx.Graph = graph.get_similarity_graph(train_x, similarity_threshold=treshold, mode=mode, save_fig=True)
validation_graph: nx.Graph = graph.get_similarity_graph(valid_x, similarity_threshold=treshold, mode=mode)

In [None]:
# Create the dataloaders

train_data: Data = from_networkx(training_graph)
train_data.y = train_y
train_loader: DataLoader = DataLoader([train_data], shuffle=True)

valid_data: Data = from_networkx(validation_graph)
valid_data.y = valid_y
valid_loader: DataLoader = DataLoader([valid_data])

In [None]:
model = model.LightningGCN(input_feature= len(train_x.columns),
    hidden_feature= 64,
    output_feature=3
)

trainer = pl.Trainer(
    max_epochs=100,
    log_every_n_steps=1  # per log più frequenti (anche con 1 batch)
)

In [None]:

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)