In [1]:
import torch
import random

import pandas as pd

from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.nn as nn

from torch.optim import AdamW

import torch_geometric.transforms as T

from torch_geometric.data import Batch

from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import global_add_pool
from torch_geometric.nn import GraphConv
from torch.utils.data import DataLoader

from pathlib import Path

from tqdm import tqdm


In [2]:
import sys
import os
cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
sys.path.append(parent_dir)
from DataPipeline.dataset import ZincSubgraphDataset, custom_collate
from Model.GNN1 import ModelWithEdgeFeatures
from Model.metrics import pseudo_accuracy_metric, pseudo_recall_for_each_class, pseudo_precision_for_each_class

In [3]:
datapath = Path('..') / 'DataPipeline/data/preprocessed_graph_no_I_Br_P.pt'
dataset = ZincSubgraphDataset(data_path = datapath)

Dataset encoded with size 7


In [4]:
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0, collate_fn=custom_collate)

In [5]:
#load checkpoint
encoding_size = 7

model = ModelWithEdgeFeatures(num_classes = encoding_size, in_channels=encoding_size, hidden_channels_list=[64, 128, 256, 512, 512], mlp_hidden_channels=512, edge_channels=4, use_dropout=False)
optimizer = AdamW(model.parameters(), lr=0.0001)

checkpoint_path = Path('..') / 'Train' / 'checkpoint_epoch_96_balanced_10-4.pt'
checkpoint = torch.load(checkpoint_path)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

model.eval()

ModelWithEdgeFeatures(
  (message_passing_layers): ModuleList(
    (0): CustomMessagePassingLayer()
    (1): CustomMessagePassingLayer()
    (2): CustomMessagePassingLayer()
    (3): CustomMessagePassingLayer()
    (4): CustomMessagePassingLayer()
  )
  (batch_norm_layers): ModuleList(
    (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fc1): Linear(in_features=512, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=7, bias=True)
)

In [8]:
with torch.no_grad():
    avg_output_vector = torch.zeros(encoding_size)
    avg_label_vector = torch.zeros(encoding_size)
    num_correct = 0
    num_correct_recall = torch.zeros(encoding_size)
    num_correct_precision = torch.zeros(encoding_size)
    count_per_class_recall = torch.zeros(encoding_size)
    count_per_class_precision = torch.zeros(encoding_size)
    total_graphs_processed = 0
    for batch in tqdm(dataloader):
        terminal_node_infos = batch[1]
        output = model(batch[0])
        avg_output_vector += output.detach().cpu().numpy().mean(axis=0) * batch[0].num_graphs
        avg_label_vector += terminal_node_infos.detach().cpu().numpy().mean(axis=0) * batch[0].num_graphs
        num_correct += pseudo_accuracy_metric(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)

        recall_output = pseudo_recall_for_each_class(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)
        precision_output = pseudo_precision_for_each_class(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)
        num_correct_recall += recall_output[0]
        num_correct_precision += precision_output[0]
        count_per_class_recall += recall_output[1]
        count_per_class_precision += precision_output[1]
        total_graphs_processed += batch[0].num_graphs
    avg_correct = num_correct / total_graphs_processed
    avg_correct_recall = num_correct_recall / count_per_class_recall
    avg_correct_precision = num_correct_precision / count_per_class_precision
    avg_f1 = 2 * (avg_correct_recall * avg_correct_precision) / (avg_correct_recall + avg_correct_precision)
    avg_output_vector /= len(dataset)
    avg_label_vector /= len(dataset)

100%|██████████| 1845/1845 [08:07<00:00,  3.78it/s]


In [9]:
print(f'Accuracy: {avg_correct}')
print(f'F1: {avg_f1}')
print(f'Output vector: {avg_output_vector}')
print(f'Label vector: {avg_label_vector}')
print(f'Recall: {avg_correct_recall}')
print(f'Precision: {avg_correct_precision}')

Accuracy: 0.6439731150234841
F1: tensor([0.5336, 0.4781, 0.4869, 0.3198, 0.1136, 0.0380, 0.7923])
Output vector: tensor([0.1577, 0.0630, 0.0918, 0.0257, 0.0530, 0.1071, 0.5017])
Label vector: tensor([0.1741, 0.0370, 0.0438, 0.0064, 0.0047, 0.0023, 0.7317])
Recall: tensor([0.5089, 0.6525, 0.7483, 0.7870, 0.6922, 0.9255, 0.6678])
Precision: tensor([0.5610, 0.3773, 0.3608, 0.2007, 0.0618, 0.0194, 0.9738])


In [10]:
mask = avg_output_vector / avg_label_vector

In [18]:
mask

tensor([ 0.9056,  1.7045,  2.0966,  4.0114, 11.2091, 47.1996,  0.6856])

In [19]:
mask**10

tensor([3.7096e-01, 2.0698e+02, 1.6414e+03, 1.0788e+06, 3.1311e+10, 5.4876e+16,
        2.2953e-02])

In [26]:
with torch.no_grad():
    avg_output_vector = torch.zeros(encoding_size)
    avg_label_vector = torch.zeros(encoding_size)
    num_correct = 0
    num_correct_recall = torch.zeros(encoding_size)
    num_correct_precision = torch.zeros(encoding_size)
    count_per_class_recall = torch.zeros(encoding_size)
    count_per_class_precision = torch.zeros(encoding_size)
    total_graphs_processed = 0
    for batch in tqdm(dataloader):
        terminal_node_infos = batch[1]
        output = model(batch[0])
        output = output / (mask**8)
        # normalize
        output = output / output.sum(axis=1).reshape(-1, 1)
        avg_output_vector += output.detach().cpu().numpy().mean(axis=0) * batch[0].num_graphs
        avg_label_vector += terminal_node_infos.detach().cpu().numpy().mean(axis=0) * batch[0].num_graphs
        num_correct += pseudo_accuracy_metric(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)

        recall_output = pseudo_recall_for_each_class(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)
        precision_output = pseudo_precision_for_each_class(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)
        num_correct_recall += recall_output[0]
        num_correct_precision += precision_output[0]
        count_per_class_recall += recall_output[1]
        count_per_class_precision += precision_output[1]
        total_graphs_processed += batch[0].num_graphs
    avg_correct = num_correct / total_graphs_processed
    avg_correct_recall = num_correct_recall / count_per_class_recall
    avg_correct_precision = num_correct_precision / count_per_class_precision
    avg_f1 = 2 * (avg_correct_recall * avg_correct_precision) / (avg_correct_recall + avg_correct_precision)
    avg_output_vector /= len(dataset)
    avg_label_vector /= len(dataset)

100%|██████████| 1845/1845 [08:27<00:00,  3.63it/s]


In [27]:
print(f'Accuracy: {avg_correct}')
print(f'F1: {avg_f1}')
print(f'Output vector: {avg_output_vector}')
print(f'Label vector: {avg_label_vector}')
print(f'Recall: {avg_correct_recall}')
print(f'Precision: {avg_correct_precision}')

Accuracy: 0.7583189689857147
F1: tensor([0.6022, 0.4964, 0.5111, 0.4067, 0.2126, 0.0704, 0.8599])
Output vector: tensor([0.2364, 0.0612, 0.0732, 0.0168, 0.0041, 0.0054, 0.6030])
Label vector: tensor([0.1745, 0.0371, 0.0437, 0.0063, 0.0048, 0.0024, 0.7313])
Recall: tensor([0.7058, 0.6595, 0.6959, 0.7429, 0.1959, 0.1168, 0.7842])
Precision: tensor([0.5251, 0.3979, 0.4039, 0.2800, 0.2324, 0.0504, 0.9517])


In [None]:
Accuracy: 0.6439731150234841
F1: tensor([0.5336, 0.4781, 0.4869, 0.3198, 0.1136, 0.0380, 0.7923])
Output vector: tensor([0.1577, 0.0630, 0.0918, 0.0257, 0.0530, 0.1071, 0.5017])
Label vector: tensor([0.1741, 0.0370, 0.0438, 0.0064, 0.0047, 0.0023, 0.7317])
Recall: tensor([0.5089, 0.6525, 0.7483, 0.7870, 0.6922, 0.9255, 0.6678])
Precision: tensor([0.5610, 0.3773, 0.3608, 0.2007, 0.0618, 0.0194, 0.9738])


Accuracy: 0.6439731150234841
F1: tensor([0.5336, 0.4781, 0.4869, 0.3198, 0.1136, 0.0380, 0.7923])
Output vector: tensor([0.1577, 0.0630, 0.0918, 0.0257, 0.0530, 0.1071, 0.5017])
Label vector: tensor([0.1741, 0.0370, 0.0438, 0.0064, 0.0047, 0.0023, 0.7317])
Recall: tensor([0.5089, 0.6525, 0.7483, 0.7870, 0.6922, 0.9255, 0.6678])
Precision: tensor([0.5610, 0.3773, 0.3608, 0.2007, 0.0618, 0.0194, 0.9738])


Mask normal
Accuracy: 0.6587285117124138
F1: tensor([0.5589, 0.4849, 0.4899, 0.3087, 0.1364, 0.0416, 0.7986])
Output vector: tensor([0.1706, 0.0635, 0.0879, 0.0284, 0.0419, 0.0975, 0.5103])
Label vector: tensor([0.1740, 0.0376, 0.0436, 0.0066, 0.0050, 0.0024, 0.7309])
Recall: tensor([0.5529, 0.6564, 0.7357, 0.8058, 0.6369, 0.8689, 0.6781])
Precision: tensor([0.5650, 0.3845, 0.3672, 0.1909, 0.0764, 0.0213, 0.9711])

Mask**2
Accuracy: 0.6732679137885031
F1: tensor([0.5766, 0.4773, 0.4954, 0.2790, 0.1419, 0.0420, 0.8070])
Output vector: tensor([0.1796, 0.0639, 0.0854, 0.0297, 0.0330, 0.0866, 0.5217])
Label vector: tensor([0.1744, 0.0371, 0.0430, 0.0061, 0.0048, 0.0023, 0.7323])
Recall: tensor([0.5839, 0.6558, 0.7405, 0.8126, 0.5696, 0.8214, 0.6909])
Precision: tensor([0.5694, 0.3752, 0.3722, 0.1684, 0.0811, 0.0216, 0.9699])

Mask**5
Accuracy: 0.7199989835547631
F1: tensor([0.5928, 0.4800, 0.4933, 0.3029, 0.1850, 0.0450, 0.8344])
Output vector: tensor([0.2226, 0.0667, 0.0787, 0.0286, 0.0120, 0.0295, 0.5618])
Label vector: tensor([0.1749, 0.0374, 0.0425, 0.0066, 0.0049, 0.0021, 0.7315])
Recall: tensor([0.6713, 0.6734, 0.7118, 0.8031, 0.3182, 0.3393, 0.7375])
Precision: tensor([0.5307, 0.3729, 0.3774, 0.1866, 0.1304, 0.0241, 0.9606])

Mask**8
Accuracy: 0.7583189689857147
F1: tensor([0.6022, 0.4964, 0.5111, 0.4067, 0.2126, 0.0704, 0.8599])
Output vector: tensor([0.2364, 0.0612, 0.0732, 0.0168, 0.0041, 0.0054, 0.6030])
Label vector: tensor([0.1745, 0.0371, 0.0437, 0.0063, 0.0048, 0.0024, 0.7313])
Recall: tensor([0.7058, 0.6595, 0.6959, 0.7429, 0.1959, 0.1168, 0.7842])
Precision: tensor([0.5251, 0.3979, 0.4039, 0.2800, 0.2324, 0.0504, 0.9517])

Mask**10
Accuracy: 0.7701139689221869
F1: tensor([0.6046, 0.5083, 0.5108, 0.5206, 0.1421, 0.0286, 0.8693])
Output vector: tensor([0.2395, 0.0562, 0.0678, 0.0111, 0.0023, 0.0014, 0.6217])
Label vector: tensor([0.1754, 0.0374, 0.0424, 0.0063, 0.0048, 0.0022, 0.7315])
Recall: tensor([0.7114, 0.6383, 0.6786, 0.7198, 0.1029, 0.0208, 0.8040])
Precision: tensor([0.5257, 0.4223, 0.4096, 0.4078, 0.2295, 0.0457, 0.9461])