In [19]:
import torch
import random

import pandas as pd

from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.nn as nn

from torch.optim import AdamW

import torch_geometric.transforms as T

from torch_geometric.data import Batch

from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import global_add_pool
from torch_geometric.nn import GraphConv
from torch.utils.data import DataLoader

from pathlib import Path

from tqdm import tqdm


In [20]:
import sys
import os
cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
sys.path.append(parent_dir)
from DataPipeline.dataset import ZincSubgraphDatasetStep, custom_collate_passive_add_feature
from Model.GNN1 import ModelWithEdgeFeatures
from Model.metrics import pseudo_accuracy_metric, pseudo_recall_for_each_class, pseudo_precision_for_each_class

In [21]:
datapath = Path('..') / 'DataPipeline/data/preprocessed_graph_no_I_Br_P.pt'
dataset = ZincSubgraphDatasetStep(data_path = datapath, GNN_type=1)

Dataset encoded with size 7


In [22]:
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0, collate_fn=custom_collate_passive_add_feature)

In [23]:
#load checkpoint
encoding_size = 7

model = ModelWithEdgeFeatures(num_classes = encoding_size , in_channels=encoding_size+1, hidden_channels_list=[64, 128, 256, 512, 512], mlp_hidden_channels=512, edge_channels=4, use_dropout=False, size_info=True)
optimizer = AdamW(model.parameters(), lr=0.0001)

checkpoint_path = Path('..') / 'Train' / 'checkpoint_epoch_41_new_balanced_10-4.pt'
checkpoint = torch.load(checkpoint_path)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

model.eval()

ModelWithEdgeFeatures(
  (message_passing_layers): ModuleList(
    (0): CustomMessagePassingLayer()
    (1): CustomMessagePassingLayer()
    (2): CustomMessagePassingLayer()
    (3): CustomMessagePassingLayer()
    (4): CustomMessagePassingLayer()
  )
  (batch_norm_layers): ModuleList(
    (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fc1): Linear(in_features=513, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=7, bias=True)
)

In [24]:
with torch.no_grad():
    avg_output_vector = torch.zeros(encoding_size)
    avg_label_vector = torch.zeros(encoding_size)
    num_correct = 0
    num_correct_recall = torch.zeros(encoding_size)
    num_correct_precision = torch.zeros(encoding_size)
    count_per_class_recall = torch.zeros(encoding_size)
    count_per_class_precision = torch.zeros(encoding_size)
    total_graphs_processed = 0
    for batch in tqdm(dataloader):
        terminal_node_infos = batch[1]
        output = model(batch[0])
        avg_output_vector += output.detach().cpu().numpy().mean(axis=0) * batch[0].num_graphs
        avg_label_vector += terminal_node_infos.detach().cpu().numpy().mean(axis=0) * batch[0].num_graphs
        num_correct += pseudo_accuracy_metric(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)

        recall_output = pseudo_recall_for_each_class(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)
        precision_output = pseudo_precision_for_each_class(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)
        num_correct_recall += recall_output[0]
        num_correct_precision += precision_output[0]
        count_per_class_recall += recall_output[1]
        count_per_class_precision += precision_output[1]
        total_graphs_processed += batch[0].num_graphs
    avg_correct = num_correct / total_graphs_processed
    avg_correct_recall = num_correct_recall / count_per_class_recall
    avg_correct_precision = num_correct_precision / count_per_class_precision
    avg_f1 = 2 * (avg_correct_recall * avg_correct_precision) / (avg_correct_recall + avg_correct_precision)
    avg_output_vector /= len(dataset)
    avg_label_vector /= len(dataset)

100%|██████████| 1845/1845 [08:28<00:00,  3.63it/s]


In [25]:
print(f'Accuracy: {avg_correct}')
print(f'F1: {avg_f1}')
print(f'Output vector: {avg_output_vector}')
print(f'Label vector: {avg_label_vector}')
print(f'Recall: {avg_correct_recall}')
print(f'Precision: {avg_correct_precision}')

Accuracy: 0.6716246606555225
F1: tensor([0.7294, 0.4679, 0.5010, 0.1938, 0.1258, 0.0671, 0.7999])
Output vector: tensor([0.2673, 0.1013, 0.0738, 0.0392, 0.0684, 0.0627, 0.3873])
Label vector: tensor([0.3504, 0.0591, 0.0489, 0.0068, 0.0086, 0.0037, 0.5225])
Recall: tensor([0.6338, 0.6411, 0.6872, 0.7365, 0.6858, 0.8075, 0.6964])
Precision: tensor([0.8590, 0.3684, 0.3942, 0.1116, 0.0693, 0.0350, 0.9396])


In [12]:
mask = avg_output_vector / avg_label_vector

In [13]:
mask

tensor([ 0.7804,  1.7018,  1.5616,  5.4898,  6.7228, 18.5459,  0.7393])

In [11]:
mask**10

tensor([8.3821e-02, 2.0373e+02, 8.6224e+01, 2.4864e+07, 1.8858e+08, 4.8136e+12,
        4.8788e-02])

In [16]:
with torch.no_grad():
    avg_output_vector = torch.zeros(encoding_size)
    avg_label_vector = torch.zeros(encoding_size)
    num_correct = 0
    num_correct_recall = torch.zeros(encoding_size)
    num_correct_precision = torch.zeros(encoding_size)
    count_per_class_recall = torch.zeros(encoding_size)
    count_per_class_precision = torch.zeros(encoding_size)
    total_graphs_processed = 0
    for batch in tqdm(dataloader):
        terminal_node_infos = batch[1]
        output = model(batch[0])
        output = output / (mask**2)
        # normalize
        output = output / output.sum(axis=1).reshape(-1, 1)
        avg_output_vector += output.detach().cpu().numpy().mean(axis=0) * batch[0].num_graphs
        avg_label_vector += terminal_node_infos.detach().cpu().numpy().mean(axis=0) * batch[0].num_graphs
        num_correct += pseudo_accuracy_metric(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)

        recall_output = pseudo_recall_for_each_class(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)
        precision_output = pseudo_precision_for_each_class(output.detach().cpu(), terminal_node_infos.detach().cpu(), random=True)
        num_correct_recall += recall_output[0]
        num_correct_precision += precision_output[0]
        count_per_class_recall += recall_output[1]
        count_per_class_precision += precision_output[1]
        total_graphs_processed += batch[0].num_graphs
    avg_correct = num_correct / total_graphs_processed
    avg_correct_recall = num_correct_recall / count_per_class_recall
    avg_correct_precision = num_correct_precision / count_per_class_precision
    avg_f1 = 2 * (avg_correct_recall * avg_correct_precision) / (avg_correct_recall + avg_correct_precision)
    avg_output_vector /= len(dataset)
    avg_label_vector /= len(dataset)

100%|██████████| 1845/1845 [08:33<00:00,  3.59it/s]


In [18]:
print(f'Accuracy: {avg_correct}')
print(f'F1: {avg_f1}')
print(f'Output vector: {avg_output_vector}')
print(f'Label vector: {avg_label_vector}')
print(f'Recall: {avg_correct_recall}')
print(f'Precision: {avg_correct_precision}')

Accuracy: 0.7089197304726047
F1: tensor([0.7639, 0.4746, 0.4868, 0.1844, 0.1586, 0.0774, 0.8177])
Output vector: tensor([0.3045, 0.0947, 0.0686, 0.0417, 0.0354, 0.0460, 0.4091])
Label vector: tensor([0.3484, 0.0592, 0.0487, 0.0071, 0.0087, 0.0037, 0.5241])
Recall: tensor([0.7022, 0.6293, 0.6623, 0.7178, 0.5195, 0.6830, 0.7280])
Precision: tensor([0.8376, 0.3810, 0.3848, 0.1058, 0.0936, 0.0410, 0.9326])


In [None]:
Accuracy: 0.6895988005946204
F1: tensor([0.7499, 0.4692, 0.4869, 0.1846, 0.1447, 0.0667, 0.8065])
Output vector: tensor([0.2891, 0.0968, 0.0713, 0.0400, 0.0453, 0.0603, 0.3971])
Label vector: tensor([0.3508, 0.0585, 0.0475, 0.0069, 0.0085, 0.0038, 0.5240])
Recall: tensor([0.6728, 0.6324, 0.6764, 0.7225, 0.5676, 0.7804, 0.7087])
Precision: tensor([0.8470, 0.3729, 0.3803, 0.1058, 0.0829, 0.0348, 0.9356])


Accuracy: 0.6735304954747011
F1: tensor([0.7337, 0.4651, 0.4941, 0.1986, 0.1343, 0.0642, 0.7982])
Output vector: tensor([0.2725, 0.0998, 0.0752, 0.0367, 0.0566, 0.0710, 0.3882])
Label vector: tensor([0.3492, 0.0587, 0.0482, 0.0067, 0.0084, 0.0038, 0.5250])
Recall: tensor([0.6438, 0.6368, 0.6942, 0.7361, 0.6389, 0.8444, 0.6942])
Precision: tensor([0.8528, 0.3664, 0.3835, 0.1148, 0.0751, 0.0334, 0.9390])

Mask normal 
Accuracy: 0.6895988005946204
F1: tensor([0.7499, 0.4692, 0.4869, 0.1846, 0.1447, 0.0667, 0.8065])
Output vector: tensor([0.2891, 0.0968, 0.0713, 0.0400, 0.0453, 0.0603, 0.3971])
Label vector: tensor([0.3508, 0.0585, 0.0475, 0.0069, 0.0085, 0.0038, 0.5240])
Recall: tensor([0.6728, 0.6324, 0.6764, 0.7225, 0.5676, 0.7804, 0.7087])
Precision: tensor([0.8470, 0.3729, 0.3803, 0.1058, 0.0829, 0.0348, 0.9356])

Mask**2
Accuracy: 0.6732679137885031
F1: tensor([0.5766, 0.4773, 0.4954, 0.2790, 0.1419, 0.0420, 0.8070])
Output vector: tensor([0.1796, 0.0639, 0.0854, 0.0297, 0.0330, 0.0866, 0.5217])
Label vector: tensor([0.1744, 0.0371, 0.0430, 0.0061, 0.0048, 0.0023, 0.7323])
Recall: tensor([0.5839, 0.6558, 0.7405, 0.8126, 0.5696, 0.8214, 0.6909])
Precision: tensor([0.5694, 0.3752, 0.3722, 0.1684, 0.0811, 0.0216, 0.9699])

Mask**5
Accuracy: 0.7199989835547631
F1: tensor([0.5928, 0.4800, 0.4933, 0.3029, 0.1850, 0.0450, 0.8344])
Output vector: tensor([0.2226, 0.0667, 0.0787, 0.0286, 0.0120, 0.0295, 0.5618])
Label vector: tensor([0.1749, 0.0374, 0.0425, 0.0066, 0.0049, 0.0021, 0.7315])
Recall: tensor([0.6713, 0.6734, 0.7118, 0.8031, 0.3182, 0.3393, 0.7375])
Precision: tensor([0.5307, 0.3729, 0.3774, 0.1866, 0.1304, 0.0241, 0.9606])

Mask**8
Accuracy: 0.7583189689857147
F1: tensor([0.6022, 0.4964, 0.5111, 0.4067, 0.2126, 0.0704, 0.8599])
Output vector: tensor([0.2364, 0.0612, 0.0732, 0.0168, 0.0041, 0.0054, 0.6030])
Label vector: tensor([0.1745, 0.0371, 0.0437, 0.0063, 0.0048, 0.0024, 0.7313])
Recall: tensor([0.7058, 0.6595, 0.6959, 0.7429, 0.1959, 0.1168, 0.7842])
Precision: tensor([0.5251, 0.3979, 0.4039, 0.2800, 0.2324, 0.0504, 0.9517])

Mask**10
Accuracy: 0.7701139689221869
F1: tensor([0.6046, 0.5083, 0.5108, 0.5206, 0.1421, 0.0286, 0.8693])
Output vector: tensor([0.2395, 0.0562, 0.0678, 0.0111, 0.0023, 0.0014, 0.6217])
Label vector: tensor([0.1754, 0.0374, 0.0424, 0.0063, 0.0048, 0.0022, 0.7315])
Recall: tensor([0.7114, 0.6383, 0.6786, 0.7198, 0.1029, 0.0208, 0.8040])
Precision: tensor([0.5257, 0.4223, 0.4096, 0.4078, 0.2295, 0.0457, 0.9461])