In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import random

import pandas as pd

from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.nn as nn

from torch.optim import AdamW

import torch_geometric.transforms as T

from torch_geometric.data import Batch

from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import global_add_pool
from torch_geometric.nn import GraphConv
from torch.utils.data import DataLoader

from pathlib import Path

from tqdm import tqdm

In [3]:
import sys
import os
cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
parent_parent_dir = os.path.dirname(parent_dir)

sys.path.append(parent_dir)
sys.path.append(parent_parent_dir)

from DataPipeline.dataset import ZincSubgraphDatasetStep, custom_collate_GNN3
from Model.GNN3 import ModelWithEdgeFeatures
from Model.metrics import pseudo_accuracy_metric, pseudo_recall_for_each_class, pseudo_precision_for_each_class, MaskedCrossEntropyLoss, pseudo_accuracy_metric_gnn3

In [4]:
datapath = Path('..') / '../DataPipeline/data/preprocessed_graph_no_I_Br_P.pt'
dataset = ZincSubgraphDatasetStep(data_path = datapath, GNN_type=3)

Dataset encoded with size 7


In [5]:
loader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=custom_collate_GNN3)

In [6]:
encoding_size = 7

model = ModelWithEdgeFeatures(in_channels=encoding_size, hidden_channels_list=[64, 128, 128, 64, 32, 5], edge_channels=4, use_dropout=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


optimizer = AdamW(model.parameters(), lr=0.0001)

# criterion = MaskedCrossEntropyLoss() pas besoin y a plus simple
criterion = nn.CrossEntropyLoss()

name = 'GNN3'



In [7]:
from tqdm.notebook import tqdm as tqdm_notebook
import numpy as np

def train(loader, epoch):
    model.train()
    total_loss = 0
    
    progress_bar = tqdm_notebook(loader, desc="Training", unit="batch")

    num_output = torch.zeros(5)  # Already on CPU
    num_labels = torch.zeros(5)  # Already on CPU
    total_graphs_processed = 0
    global_cycles_created = 0
    global_well_placed_cycles = 0
    global_well_type_cycles =0
    global_cycles_missed = 0
    global_cycles_shouldnt_created = 0
    global_num_wanted_cycles = 0

    for batch_idx, batch in enumerate(progress_bar):
        data = batch[0].to(device)
        node_labels = batch[1].to(device)
        mask = batch[2].to(device)
        
        optimizer.zero_grad()
        out = model(data)

        # Convert node_labels to class indices
        
        node_labels = node_labels.to(device)
        mask = mask.to(device)
        

        # Use node_labels_indices with CrossEntropyLoss
        #loss = criterion(out, node_labels, mask)
        loss = criterion(out[mask], node_labels[mask])
   
    
        # Add softmax to out
        softmax_out = F.softmax(out, dim=1)

        cycles_created, well_placed_cycles , well_type_cycles, cycles_missed, cycles_shouldnt_created, num_wanted_cycles = pseudo_accuracy_metric_gnn3(data,softmax_out,node_labels,mask)        
        # Calculate metrics and move tensors to CPU
        num_output += torch.sum(softmax_out[mask], dim=0).detach().cpu()
        num_labels += torch.sum(node_labels[mask], dim=0).detach().cpu()
        global_cycles_created +=cycles_created
        global_well_placed_cycles += well_placed_cycles
        global_well_type_cycles += well_type_cycles
        global_cycles_missed += cycles_missed
        global_cycles_shouldnt_created += cycles_shouldnt_created
        global_num_wanted_cycles += num_wanted_cycles
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
        loss_value = total_loss / (data.num_graphs * (progress_bar.last_print_n + 1))
        total_graphs_processed += data.num_graphs
        
        progress_bar.set_postfix(loss=loss_value, avg_num_output=num_output / total_graphs_processed, avg_num_labels=num_labels / total_graphs_processed,
         pseudo_precision = global_cycles_created/(global_cycles_created+global_cycles_shouldnt_created),  pseudo_recall = global_cycles_created/global_num_wanted_cycles ,
          pseudo_recall_placed = global_well_placed_cycles/global_num_wanted_cycles, pseudo_recall_type = global_well_type_cycles/global_num_wanted_cycles, 
          f1_score = 2/(1/( global_cycles_created/(global_cycles_created+global_cycles_shouldnt_created))+1/( global_cycles_created/global_num_wanted_cycles)))

    return (
        total_loss / len(loader.dataset),
        num_output / total_graphs_processed,
        num_labels / total_graphs_processed, 
        global_cycles_created/(global_cycles_created+global_cycles_shouldnt_created), 
        global_cycles_created/global_num_wanted_cycles , 
        global_well_placed_cycles/global_num_wanted_cycles, 
        global_well_type_cycles/global_num_wanted_cycles,
        2/(1/( global_cycles_created/(global_cycles_created+global_cycles_shouldnt_created))+1/( global_cycles_created/global_num_wanted_cycles))
    )

In [8]:
# Create a dataframe to save the training history
training_history = pd.DataFrame(columns=['epoch', 'loss', 'avg_output_vector', 'avg_label_vector','pseudo_precision', 'pseudo_recall' , 'pseudo_recall_placed', 'pseudo_recall_type', 'f1_score' ])
n_epochs = 300
for epoch in range(1, n_epochs+1):
    loss, avg_output_vector, avg_label_vector,  pseudo_precision, pseudo_recall , pseudo_recall_placed, pseudo_recall_type, f1_score = train(loader, epoch)
    training_history.loc[epoch] = [epoch, loss, avg_output_vector, avg_label_vector, pseudo_precision, pseudo_recall , pseudo_recall_placed, pseudo_recall_type, f1_score]
    #save the model(all with optimizer step, the loss ) every 5 epochs

    save_every_n_epochs = 20
    if (epoch) % save_every_n_epochs == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            # Add any other relevant information you want to save here
        }
        torch.save(checkpoint,'./history_training/'+ f'checkpoint_epoch_{epoch+1}_{name}.pt')
        
    #save the training history every 10 epochs
    if epoch % 1 == 0:
        training_history.to_csv(f"training_history_{name}.csv", index=False)
    print(f'Epoch: {epoch}, Loss: {loss:.8f}')

#alway save at the end
checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            # Add any other relevant information you want to save here
        }
torch.save(checkpoint,'./history_training/'+ f'checkpoint_epoch_{epoch+1}_{name}.pt')

Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 1, Loss: 0.18322288


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 2, Loss: 0.06598326


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 3, Loss: 0.05588702


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 4, Loss: 0.05240611


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 5, Loss: 0.04994475


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 6, Loss: 0.04829497


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 7, Loss: 0.04825932


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 8, Loss: 0.04650305


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 9, Loss: 0.04675910


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 10, Loss: 0.04540644


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 11, Loss: 0.04576287


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 12, Loss: 0.04487950


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 13, Loss: 0.04489663


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 14, Loss: 0.04336099


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 15, Loss: 0.04410035


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 16, Loss: 0.04314968


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 17, Loss: 0.04343711


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 18, Loss: 0.04315982


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 19, Loss: 0.04266459


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 20, Loss: 0.04303292


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 21, Loss: 0.04258747


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 22, Loss: 0.04210604


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 23, Loss: 0.04181298


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 24, Loss: 0.04195220


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 25, Loss: 0.04125730


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 26, Loss: 0.04176331


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 27, Loss: 0.04140787


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 28, Loss: 0.04087946


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 29, Loss: 0.04086518


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 30, Loss: 0.04162320


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 31, Loss: 0.04125885


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 32, Loss: 0.04047946


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 33, Loss: 0.04114701


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 34, Loss: 0.04032207


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 35, Loss: 0.04046144


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 36, Loss: 0.04046370


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 37, Loss: 0.04049416


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 38, Loss: 0.04001467


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 39, Loss: 0.03995741


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 40, Loss: 0.03994932


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 41, Loss: 0.03960233


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 42, Loss: 0.04019780


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 43, Loss: 0.03997853


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 44, Loss: 0.04004659


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 45, Loss: 0.03972388


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 46, Loss: 0.03952745


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 47, Loss: 0.03956178


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 48, Loss: 0.03943227


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 49, Loss: 0.03930105


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 50, Loss: 0.03955404


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 51, Loss: 0.03977620


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 52, Loss: 0.03953062


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 53, Loss: 0.03912178


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 54, Loss: 0.03918123


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 55, Loss: 0.03949164


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 56, Loss: 0.03881284


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 57, Loss: 0.03787342


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 58, Loss: 0.03906821


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 59, Loss: 0.03866520


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 60, Loss: 0.03892916


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 61, Loss: 0.03831412


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 62, Loss: 0.03908969


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 63, Loss: 0.03839003


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 64, Loss: 0.03876386


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 65, Loss: 0.03890966


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 66, Loss: 0.03831052


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 67, Loss: 0.03890330


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 68, Loss: 0.03840472


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 69, Loss: 0.03882210


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 70, Loss: 0.03805419


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 71, Loss: 0.03866469


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 72, Loss: 0.03914735


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 73, Loss: 0.03761040


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 74, Loss: 0.03825606


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 75, Loss: 0.03822798


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 76, Loss: 0.03814391


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 77, Loss: 0.03841567


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 78, Loss: 0.03827727


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 79, Loss: 0.03832080


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 80, Loss: 0.03867932


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 81, Loss: 0.03776825


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 82, Loss: 0.03806107


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 83, Loss: 0.03783092


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 84, Loss: 0.03839565


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 85, Loss: 0.03841005


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 86, Loss: 0.03769400


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 87, Loss: 0.03808921


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 88, Loss: 0.03803125


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 89, Loss: 0.03805397


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 90, Loss: 0.03758237


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 91, Loss: 0.03752689


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 92, Loss: 0.03784851


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 93, Loss: 0.03771672


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 94, Loss: 0.03771747


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 95, Loss: 0.03789381


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 96, Loss: 0.03744129


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 97, Loss: 0.03723574


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 98, Loss: 0.03763801


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 99, Loss: 0.03728051


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 100, Loss: 0.03720540


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 101, Loss: 0.03726788


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]

Epoch: 102, Loss: 0.03745510


Training:   0%|          | 0/1845 [00:00<?, ?batch/s]