In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import random

import pandas as pd

from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.nn as nn

from torch.optim import AdamW

import torch_geometric.transforms as T

from torch_geometric.data import Batch

from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import global_add_pool
from torch_geometric.nn import GraphConv
from torch.utils.data import DataLoader

from pathlib import Path

from tqdm import tqdm

In [3]:
import sys
import os
cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
sys.path.append(parent_dir)
from DataPipeline.dataset import ZincSubgraphDatasetStep, custom_collate_GNN1
from Model.GNN1 import ModelWithEdgeFeatures
from Model.metrics import pseudo_accuracy_metric, pseudo_recall_for_each_class, pseudo_precision_for_each_class


In [4]:
datapath = Path('..') / 'DataPipeline/data/preprocessed_graph.pt'
dataset = ZincSubgraphDatasetStep(data_path = datapath)

In [5]:
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0, collate_fn=custom_collate_GNN1)

In [6]:
avg_label_vector=[1.74596621e-01, 3.70247139e-02, 4.33084123e-02, 6.24962418e-03, 8.01747810e-06, 
                  4.66416789e-03, 2.08654868e-03, 6.59437574e-04, 4.00873905e-05, 7.31362370e-01]
class_weights_tensor = torch.FloatTensor(avg_label_vector)

# create weight tensor from frenquency  of each class
class_weights_tensor = torch.FloatTensor(avg_label_vector)
class_weights_tensor = 1/class_weights_tensor
class_weights_tensor = class_weights_tensor/class_weights_tensor.sum()


In [13]:
custom_weights = torch.Tensor([2.2688e-03, 1.0699e-02, 9.1468e-03, 6.3385e-02, 1.9231e-02, 8.4932e-02,
        1.8986e-01, 6.0071e-01, 1.9231e-02, 5.4163e-04]).to('cuda')

In [23]:
import torch
import torch.nn.functional as F

class FocalLoss(torch.nn.Module):
    def __init__(self, num_classes, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.num_classes = num_classes
        self.gamma = gamma
        self.reduction = reduction
        
        if alpha is None:
            self.alpha = torch.ones(self.num_classes, 1) / num_classes
        else:
            self.alpha = torch.tensor(alpha).view(self.num_classes, 1)

    def forward(self, inputs, targets):
        log_softmax = F.log_softmax(inputs, dim=1)
        targets = targets.to(torch.float32)  # Convert targets to float32
        logpt = (log_softmax * targets).sum(dim=1, keepdim=True)
        pt = torch.exp(logpt)
        
        alpha_t = self.alpha.to(inputs.device).view(1, -1)
        alpha_t = (alpha_t * targets).sum(dim=1, keepdim=True)

        loss = -alpha_t * (1 - pt) ** self.gamma * logpt

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

In [25]:
import numpy as np 
from sklearn.metrics import mean_squared_error
from tqdm.notebook import tqdm as tqdm_notebook


model = ModelWithEdgeFeatures(in_channels=10, hidden_channels_list=[32, 128, 128], mlp_hidden_channels=512, edge_channels=4, use_dropout=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)



from sklearn.utils import class_weight
import numpy as np


# Set up the optimizer and loss function
optimizer = AdamW(model.parameters(), lr=0.001)
#crossentropy
criterion = Cr

name = 'focal_non_weighted'

# Training function

from tqdm.notebook import tqdm as tqdm_notebook

def train(loader, epoch):
    model.train()
    total_loss = 0
    mse_sum = 0
    num_correct = 0
    num_correct_recall = torch.zeros(10)
    num_correct_precision = torch.zeros(10)
    count_per_class_recall = torch.zeros(10)
    count_per_class_precision = torch.zeros(10)
    progress_bar = tqdm_notebook(loader, desc="Training", unit="batch")

    avg_output_vector = np.zeros(10)  # Initialize the average output vector
    avg_label_vector = np.zeros(10)  # Initialize the average label vector
    total_graphs_processed = 0

    

    for batch_idx, batch in enumerate(progress_bar):
        data = batch[0]
        terminal_node_infos = batch[1]
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        terminal_node_infos = terminal_node_infos.to(device)

        loss = criterion(out, terminal_node_infos)
        num_correct += pseudo_accuracy_metric(out.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)

        recall_output = pseudo_recall_for_each_class(out.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)
        precision_output = pseudo_precision_for_each_class(out.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)
        num_correct_recall += recall_output[0]
        num_correct_precision += precision_output[0]
        count_per_class_recall += recall_output[1]
        count_per_class_precision += precision_output[1]
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
        loss_value = total_loss / (data.num_graphs * (progress_bar.last_print_n + 1))

        # Compute MSE
        mse = mean_squared_error(terminal_node_infos.detach().cpu(), out.detach().cpu())
        mse_sum += mse * data.num_graphs
        mse_value = mse_sum / (data.num_graphs * (progress_bar.last_print_n + 1))

        # Update the average output vector
        avg_output_vector += out.detach().cpu().numpy().mean(axis=0) * data.num_graphs
        avg_label_vector += terminal_node_infos.detach().cpu().numpy().mean(axis=0) * data.num_graphs
        total_graphs_processed += data.num_graphs
        current_avg_output_vector = avg_output_vector / total_graphs_processed
        current_avg_label_vector = avg_label_vector / total_graphs_processed
        avg_correct = num_correct / total_graphs_processed
        avg_correct_recall = num_correct_recall / count_per_class_recall
        avg_correct_precision = num_correct_precision / count_per_class_precision
        progress_bar.set_postfix(loss=loss_value, mse=mse_value, avg_output_vector=current_avg_output_vector, 
                                 avg_label_vector=current_avg_label_vector, 
                                 avg_correct=avg_correct, num_correct=num_correct, 
                                 total_graphs_processed=total_graphs_processed, 
                                 avg_correct_precision=avg_correct_precision, 
                                 avg_correct_recall=avg_correct_recall, 
                                 count_per_class_precision=count_per_class_precision,
                                 count_per_class_recall=count_per_class_recall)


    return total_loss / len(loader.dataset), current_avg_label_vector, current_avg_output_vector, avg_correct

# Train the model

# Create a dataframe to save the training history
training_history = pd.DataFrame(columns=['epoch', 'loss', 'mse', 'avg_output_vector', 'avg_label_vector'])


n_epochs = 100
for epoch in range(1, n_epochs+1):
    loss, avg_label_vector, avg_output_vector, avg_correct = train(dataloader, epoch)
    training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)
    #save the model(all with optimizer step, the loss ) every 5 epochs

    save_every_n_epochs = 5
    if (epoch) % save_every_n_epochs == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            # Add any other relevant information you want to save here
        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}_{name}.pt')
        
    #save the training history every 10 epochs
    if epoch % 1 == 0:
        training_history.to_csv(f"training_history_{name}.csv", index=False)
    print(f'Epoch: {epoch}, Loss: {loss:.8f}')

Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 1, Loss: 0.10927753


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 2, Loss: 0.10863678


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 3, Loss: 0.10833022


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 4, Loss: 0.10723583


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 5, Loss: 0.10686403


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 6, Loss: 0.10702270


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 7, Loss: 0.10694779


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 8, Loss: 0.10707904


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 9, Loss: 0.10688002


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 10, Loss: 0.10668473


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 11, Loss: 0.10689466


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 12, Loss: 0.10706576


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 13, Loss: 0.10662932


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 14, Loss: 0.10683687


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 15, Loss: 0.10682390


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 16, Loss: 0.10669981


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 17, Loss: 0.10681152


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 18, Loss: 0.10663911


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 19, Loss: 0.10685808


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 20, Loss: 0.10705748


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

Epoch: 21, Loss: 0.10821025


  training_history = training_history.append({'epoch': epoch, 'loss': loss, 'mse': mean_squared_error(avg_label_vector, avg_output_vector), 'avg_output_vector': avg_output_vector, 'avg_label_vector': avg_label_vector, 'avg_correct': avg_correct}, ignore_index=True)


Training:   0%|          | 0/1949 [00:00<?, ?batch/s]

KeyboardInterrupt: 