In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data
import sys
sys.path.extend([ '../', '../data'])
from data import Cora
import dataloader as dl
from model import Sage_En, Sage_Classifier, EdgePredictor
from smote import smote
from train import train_graph, test_graph, train_smote, test_smote, train_smote2, test_smote2

# Set device to GPU if available, else use CPU
device = torch.device("cuda")
print(f"Current device: {torch.cuda.get_device_name(torch.cuda.current_device())}" if torch.cuda.is_available() else "Current device: CPU")
torch.cuda.empty_cache()

Current device: NVIDIA RTX A6000


In [2]:
data_dir = '../data/cora'
data_obj = Cora(data_dir).load_data()
print(data_obj.validate(raise_on_error=True))
print(data_obj['x'], " Size = ", data_obj['x'].size())
print(data_obj.edge_index, " Size = ", data_obj.edge_index.size())
print(data_obj.y, " Size = ", data_obj.y.size())

True
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])  Size =  torch.Size([2708, 1433])
tensor([[   0,    0,    0,  ..., 1874, 1876, 1897],
        [  21,  905,  906,  ..., 2586, 1874, 2707]])  Size =  torch.Size([2, 5429])
tensor([6, 6, 1,  ..., 5, 5, 5])  Size =  torch.Size([2708])


In [3]:
lr = 0.001
num_epochs = 1000
weight_decay = 5e-4
hdim = 64
dropout = 0.1
im_class_num = 3
im_ratio = [0.5, 0.5, 0.5]
class_sample_num = 20
nclass = 7

In [4]:
c_train_num = dl.train_num(data_obj.y, im_class_num, class_sample_num, im_ratio)
print(c_train_num, sum(c_train_num))
train_idx, val_idx, test_idx, c_num_mat = dl.segregate(data_obj.y, c_train_num)
print("train_idx: ", train_idx, len(train_idx))
print("val_idx: ", val_idx, len(val_idx))
print("test_idx: ", test_idx, len(test_idx))
# print(c_num_mat)

[20, 20, 20, 20, 10, 10, 10] 110
0 818
1 217
2 426
3 298
4 351
5 180
6 418
train_idx:  [2179, 2442, 2097, 1133, 1684, 254, 2537, 1901, 2418, 1451, 1382, 1359, 435, 2320, 2099, 248, 2597, 1791, 2529, 2128, 1560, 2, 89, 1681, 1630, 295, 2076, 1689, 2625, 1145, 207, 1192, 1427, 2206, 408, 1128, 2494, 939, 2601, 175, 69, 557, 119, 2098, 516, 597, 251, 368, 1515, 1431, 816, 1682, 964, 278, 1932, 2520, 2511, 1735, 1201, 2203, 1497, 1096, 462, 140, 1312, 1719, 1548, 1749, 2299, 2202, 2432, 37, 1170, 1042, 1646, 1399, 345, 575, 136, 2335, 2016, 357, 1361, 347, 1147, 2349, 263, 710, 1487, 643, 2208, 1984, 330, 1729, 2575, 2568, 947, 471, 821, 1411, 1799, 1609, 1039, 1525, 7, 429, 478, 1094, 517, 2082] 110
val_idx:  [2491, 813, 887, 1352, 2170, 162, 1174, 2683, 1755, 1758, 741, 1840, 404, 786, 1125, 1463, 2078, 1202, 1227, 2042, 2286, 448, 1337, 2140, 1853, 655, 2656, 2317, 378, 2092, 1289, 297, 301, 2244, 846, 1947, 2626, 167, 434, 865, 1033, 1933, 324, 160, 1371, 2121, 4, 172, 428, 250, 366, 1

In [5]:
train_data = dl.dataloader(data_obj, train_idx)
val_data = dl.dataloader(data_obj, val_idx)
test_data = dl.dataloader(data_obj, test_idx)
print(train_data)
print(val_data)
print(test_data)

Data(x=[110, 1433], edge_index=[2, 8], y=[110])
Data(x=[175, 1433], edge_index=[2, 31], y=[175])
Data(x=[385, 1433], edge_index=[2, 172], y=[385])


In [6]:
encoder = Sage_En(train_data.x.shape[-1], hdim, dropout)
decoder = EdgePredictor(hdim)
#features = encoder(train_data)
classifier = Sage_Classifier(hdim, hdim, nclass, dropout)
#print(features.shape)

In [7]:
train_smote2(data_obj, encoder, classifier, decoder, num_epochs, lr, weight_decay, train_idx, val_idx, portion = 0, im_class_num = im_class_num)

torch.Size([2723, 64])
Epoch [1/1000], Loss: 1.9440, Accuracy: 0.1600, Edge Accuracy: 0.3962
Class 0:AUC-ROC- 0.4967, F1 Score- 0.0000; Class 1:AUC-ROC- 0.5976, F1 Score- 0.0000; Class 2:AUC-ROC- 0.4176, F1 Score- 0.0000; Class 3:AUC-ROC- 0.4233, F1 Score- 0.0000; Class 4:AUC-ROC- 0.4830, F1 Score- 0.0000; Class 5:AUC-ROC- 0.3267, F1 Score- 0.0000; Class 6:AUC-ROC- 0.4545, F1 Score- 0.0000; Macro-Average AUC-ROC: 0.4571,Macro-Average F1 Score: 0.0000
torch.Size([2708, 64])
Validation Loss: 1.9503, Validation Accuracy: 0.1486, Validation Edge Accuracy: 0.3983
Class 0:AUC-ROC- 0.5583, F1 Score- 0.0000; Class 1:AUC-ROC- 0.5009, F1 Score- 0.0000; Class 2:AUC-ROC- 0.5291, F1 Score- 0.0000; Class 3:AUC-ROC- 0.5584, F1 Score- 0.0000; Class 4:AUC-ROC- 0.5925, F1 Score- 0.0000; Class 5:AUC-ROC- 0.5427, F1 Score- 0.0000; Class 6:AUC-ROC- 0.6442, F1 Score- 0.0000; Macro-Average AUC-ROC: 0.5609,Macro-Average F1 Score: 0.0000
torch.Size([2723, 64])
Epoch [2/1000], Loss: 1.9326, Accuracy: 0.2240, Ed

In [8]:
test_smote2(data_obj, encoder, classifier, decoder, test_idx)

torch.Size([2708, 64])
Test Loss: 1.0895, Test Accuracy: 0.6545, Test Edge Accuracy: 0.4978
Class 0:AUC-ROC- 0.8861, F1 Score- 0.6320; Class 1:AUC-ROC- 0.9392, F1 Score- 0.6350; Class 2:AUC-ROC- 0.9120, F1 Score- 0.5565; Class 3:AUC-ROC- 0.9286, F1 Score- 0.6702; Class 4:AUC-ROC- 0.8446, F1 Score- 0.3214; Class 5:AUC-ROC- 0.9191, F1 Score- 0.5921; Class 6:AUC-ROC- 0.9662, F1 Score- 0.7978; Macro-Average AUC-ROC: 0.9137,Macro-Average F1 Score: 0.6007
