In [None]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu116.html

import torch
import torch.nn.functional as F
from torch import nn
import tqdm
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import f1_score
from sklearn.multioutput import MultiOutputClassifier
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL

from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader, LinkNeighborLoader
from torch_geometric.nn import GraphSAGE
from torch_geometric.datasets import TUDataset
import torch_geometric.transforms as T

dataset = TUDataset(root='/tmp/NCI1', name='NCI1', transform=T.NormalizeFeatures())
torch.manual_seed(12315)
dataset = dataset.shuffle()
dataset_length = len(dataset)

# split the dataset into 3 parts
DA_train = dataset[0:int(0.4*dataset_length)]
D_aux = dataset[int(0.4*dataset_length):int(0.7*dataset_length)]
DA_test = dataset[int(0.7*dataset_length):]

buck_num = 8

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.12.0+cu116.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu116/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (8.0 MB)
[K     |████████████████████████████████| 8.0 MB 73.2 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu116/torch_sparse-0.6.15-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 338 kB/s 
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu116/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl (2.4 MB)
[K     |████████████████████████████████| 2.4 MB 302 kB/s 
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu116/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl (706 kB)
[K     |████████████████████████████████| 706 kB 334 kB/s 


Downloading https://www.chrsmrrs.com/graphkerneldatasets/NCI1.zip
Extracting /tmp/NCI1/NCI1/NCI1.zip
Processing...
Done!


# Load the model

In [None]:
model_path = "NCI_model.pt"

model_sage = torch.load(model_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_sage = model_sage.to(device)

# Get graph embedding for D_aux

# Get labels for D_aux

# for num of nodes:

# for num of edges:

In [None]:
# get graph embedding from node embedding for D aux using mean pool
D_aux_graph_embedding = []
D_aux_num_of_nodes_raw_data = []
D_aux_num_of_edges_raw_data = []

for data in D_aux:
    data = data.to(device)
    test_output = model_sage(data.x, data.edge_index)
    D_aux_graph_embedding.append(test_output.sum(dim=0).cpu())
    D_aux_num_of_nodes_raw_data.append(data.num_nodes)
    D_aux_num_of_edges_raw_data.append(data.num_edges)

del test_output
del data

# split data for train and validation

In [None]:
print(len(D_aux_num_of_nodes_raw_data))

D_aux_num_of_nodes_raw_data_train = D_aux_num_of_nodes_raw_data[0:int(0.7*len(D_aux_num_of_nodes_raw_data))]
D_aux_num_of_edges_raw_data_train = D_aux_num_of_edges_raw_data[0:int(0.7*len(D_aux_num_of_edges_raw_data))]
D_aux_graph_embedding_train = D_aux_graph_embedding[0:int(0.7*len(D_aux_graph_embedding))]
print(len(D_aux_num_of_nodes_raw_data_train))

D_aux_num_of_nodes_raw_data_test = D_aux_num_of_nodes_raw_data[int(0.7*len(D_aux_num_of_nodes_raw_data)):]
D_aux_num_of_edges_raw_data_test = D_aux_num_of_edges_raw_data[int(0.7*len(D_aux_num_of_edges_raw_data)):]
D_aux_graph_embedding_test = D_aux_graph_embedding[int(0.7*len(D_aux_graph_embedding)):]
print(len(D_aux_num_of_nodes_raw_data_test))

1233
863
370


# quantile bucketing


In [None]:
def split_buck(data, buck):
    # get the maximum number in data
    N_max = data[0]

    for i in data:        
        if i > N_max:
            N_max = i
            
    # sort the data w.r.t. data.num_nodes
    sorted_list = sorted(data, key=lambda x: x, reverse=False)
    # print('length of sorted list: ', len(sorted_list))
    
    # num of elements in each buck
    buck_num = round(len(data) / buck)
    
    # split point for data w.r.t. buck
    split_pt = []
    cnt = 0
    for i in sorted_list:
        cnt += 1
        if cnt == buck_num:
            cnt = 0
            split_pt.append(i)
    
    if len(split_pt) == buck:
        split_pt[-1] = N_max + 1
    else:
        split_pt.append(N_max + 1)
        
    res = []
    for i in data:
        for index, j in enumerate(split_pt):
            if i <= j:
                res.append(index)
                break
    
    return res

In [None]:
D_aux_nodes_train = split_buck(D_aux_num_of_nodes_raw_data_train, buck_num)
D_aux_nodes_test = split_buck(D_aux_num_of_nodes_raw_data_test, buck_num)

D_aux_edges_train = split_buck(D_aux_num_of_edges_raw_data_train, buck_num)
D_aux_edges_test = split_buck(D_aux_num_of_edges_raw_data_test, buck_num)

In [None]:
class TrainSet(Dataset):
    def __init__(self, X, num_nodes, num_edges):
        self.X = torch.stack(X)
        self.num_nodes = torch.tensor(num_nodes, dtype=torch.float)
        self.num_edges = torch.tensor(num_edges, dtype=torch.float)

    def __getitem__(self, index):
        return self.X[index], self.num_nodes[index], self.num_edges[index]

    def __len__(self):
        return len(self.num_nodes)

In [None]:
# mydataset = TrainSet(D_aux_graph_embedding_train, D_aux_nodes_train)
mydataset = TrainSet(D_aux_graph_embedding_train, D_aux_nodes_train, D_aux_edges_train)
train_loader = DL(mydataset, batch_size=10, shuffle=True)

# Define the multi-task learning model

# sharing the feature net, and has its own classification net at the end for each task

In [None]:
class Network(nn.Module):
    def __init__(self, buck_num):
        super().__init__()
        self.num_features = 192
        
        self.num_buck = buck_num
        
        self.featureNet = nn.Sequential(
            nn.Linear(self.num_features, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_features),
            nn.ReLU()
        )
        
        self.nodeNet = nn.Sequential(
            nn.Linear(self.num_features, self.num_features),
            nn.ReLU(),
            nn.Linear(self.num_features, self.num_buck)
        )
        
        self.edgeNet = nn.Sequential(
            nn.Linear(self.num_features, self.num_features),
            nn.ReLU(),
            nn.Linear(self.num_features, self.num_buck)
        )
        
    def forward(self, x):
        x = self.featureNet(x)
        pred_nodes = self.nodeNet(x)
        pred_edges = self.edgeNet(x)
        
        return pred_nodes, pred_edges

# class Network(nn.Module):
#     def __init__(self, buck_num):
#         super().__init__()
        
#         self.num_features = 192
        
#         self.linear1 = nn.Linear(self.num_features, 256)
#         self.linear2 = nn.Linear(256, 256)
#         self.linear3 = nn.Linear(256, self.num_features)
#         self.linear4 = nn.Linear(self.num_features, buck_num)
        
#     def forward(self, x):
#         x = F.relu(self.linear1(x))
#         x = F.relu(self.linear2(x))
#         x = F.relu(self.linear3(x))
#         x = self.linear4(x)
        
#         return x

In [None]:
model = Network(buck_num).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
def train():
    model.train()
    total_loss = 0
    
    for data in tqdm.tqdm(train_loader):
        inputs, num_nodes, num_edges = data
        optimizer.zero_grad()
        
        inputs = inputs.detach()
        num_nodes = num_nodes.detach()
        num_edges = num_edges.detach()
        
        pred_nodes, pred_edges = model(inputs.to(device))
        loss_1 = criterion(pred_nodes, num_nodes.type(torch.LongTensor).to(device))
        loss_2 = criterion(pred_edges, num_edges.type(torch.LongTensor).to(device))
        
        loss = loss_1 + loss_2
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    return total_loss / len(D_aux_nodes_train)
        

# def train():
#     model.train()
#     total_loss = 0
    
#     for data in tqdm.tqdm(train_loader):
#         inputs, num_node = data        
#         optimizer.zero_grad()
#         # !IMPORTANT
#         # need to extract the data first
#         inputs = inputs.detach()
#         num_node = num_node.detach()
        
#         pred_node = model(inputs.to(device))
#         loss = criterion(pred_node, num_node.type(torch.LongTensor).to(device))
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
        
#     return total_loss / len(D_aux_nodes_train)

In [None]:
# mytestdata = TrainSet(D_aux_graph_embedding_test, D_aux_nodes_test)
mytestdata = TrainSet(D_aux_graph_embedding_test, D_aux_nodes_test, D_aux_edges_test)
test_loader = DL(mytestdata, batch_size=1, shuffle=False)

In [None]:
@torch.no_grad()
def test():
    model.eval()
    total_loss = 0
    accu_nodes = 0
    accu_edges = 0
    
    for data in tqdm.tqdm(test_loader):
        inputs, num_nodes, num_edges = data
        inputs = inputs.detach()
        num_nodes = num_nodes.detach()
        num_edges = num_edges.detach()
        
        pred_nodes, pred_edges = model(inputs.to(device))
        
        pred_nodes_argmax = torch.argmax(pred_nodes, dim=1).type(torch.float).cpu()
        pred_edges_argmax = torch.argmax(pred_edges, dim=1).type(torch.float).cpu()
        
        if (torch.equal(pred_nodes_argmax, num_nodes)):
            accu_nodes += 1
        
        if (torch.equal(pred_edges_argmax, num_edges)):
            accu_edges += 1
        
        loss_1 = criterion(pred_nodes, num_nodes.type(torch.LongTensor).to(device))
        
        loss_2 = criterion(pred_edges, num_edges.type(torch.LongTensor).to(device))
        
        loss = loss_1 + loss_2

        total_loss += loss.item()
        
    return total_loss / len(D_aux_nodes_test), accu_nodes / len(D_aux_nodes_test), accu_edges / len(D_aux_nodes_test)

# @torch.no_grad()
# def test():
#     model.eval()
#     total_loss = 0
#     accu = 0
    
#     for data in tqdm.tqdm(test_loader):
#         inputs, num_node = data
#         inputs = inputs.detach()
#         num_node = num_node.detach()
        
#         pred_node = model(inputs.to(device))
#         # print(pred_node.shape)  1x8
#         # print(num_node)
#         argmax = torch.argmax(pred_node, dim=1).type(torch.float).cpu()
#         # print(argmax)
#         if (torch.equal(argmax, num_node)):
#             accu += 1
        
#         loss = criterion(pred_node, num_node.type(torch.LongTensor).to(device))
#         total_loss += loss.item()
    
#     return total_loss / len(D_aux_nodes_test), accu / len(D_aux_nodes_test)

In [None]:
for epoch in range(60):
    train_loss = train()
    test_loss, test_nodes_accu, test_edges_accu = test()
    print(f'Epoch: {epoch:4d}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Test node accu: {test_nodes_accu:.4f}, Test edge accu: {test_edges_accu:.4f}')

100%|██████████| 87/87 [00:00<00:00, 243.46it/s]
100%|██████████| 370/370 [00:00<00:00, 1149.89it/s]


Epoch:    0, Train loss: 0.4093, Test loss: 4.0236, Test node accu: 0.1432, Test edge accu: 0.1486


100%|██████████| 87/87 [00:00<00:00, 271.45it/s]
100%|██████████| 370/370 [00:00<00:00, 1152.06it/s]


Epoch:    1, Train loss: 0.3936, Test loss: 3.9772, Test node accu: 0.1730, Test edge accu: 0.1838


100%|██████████| 87/87 [00:00<00:00, 272.59it/s]
100%|██████████| 370/370 [00:00<00:00, 1192.11it/s]


Epoch:    2, Train loss: 0.3847, Test loss: 3.9640, Test node accu: 0.1838, Test edge accu: 0.2135


100%|██████████| 87/87 [00:00<00:00, 267.21it/s]
100%|██████████| 370/370 [00:00<00:00, 1202.36it/s]


Epoch:    3, Train loss: 0.3748, Test loss: 3.9312, Test node accu: 0.1946, Test edge accu: 0.2297


100%|██████████| 87/87 [00:00<00:00, 283.69it/s]
100%|██████████| 370/370 [00:00<00:00, 1233.78it/s]


Epoch:    4, Train loss: 0.3687, Test loss: 3.9375, Test node accu: 0.2162, Test edge accu: 0.2432


100%|██████████| 87/87 [00:00<00:00, 281.15it/s]
100%|██████████| 370/370 [00:00<00:00, 1227.89it/s]


Epoch:    5, Train loss: 0.3591, Test loss: 3.8794, Test node accu: 0.2162, Test edge accu: 0.2270


100%|██████████| 87/87 [00:00<00:00, 290.58it/s]
100%|██████████| 370/370 [00:00<00:00, 1297.72it/s]


Epoch:    6, Train loss: 0.3526, Test loss: 3.8384, Test node accu: 0.2270, Test edge accu: 0.2405


100%|██████████| 87/87 [00:00<00:00, 296.02it/s]
100%|██████████| 370/370 [00:00<00:00, 1305.18it/s]


Epoch:    7, Train loss: 0.3437, Test loss: 3.7968, Test node accu: 0.2432, Test edge accu: 0.2757


100%|██████████| 87/87 [00:00<00:00, 295.62it/s]
100%|██████████| 370/370 [00:00<00:00, 1291.59it/s]


Epoch:    8, Train loss: 0.3383, Test loss: 3.7745, Test node accu: 0.2486, Test edge accu: 0.2730


100%|██████████| 87/87 [00:00<00:00, 289.26it/s]
100%|██████████| 370/370 [00:00<00:00, 1248.86it/s]


Epoch:    9, Train loss: 0.3281, Test loss: 3.8040, Test node accu: 0.2811, Test edge accu: 0.3000


100%|██████████| 87/87 [00:00<00:00, 184.08it/s]
100%|██████████| 370/370 [00:00<00:00, 1280.14it/s]


Epoch:   10, Train loss: 0.3195, Test loss: 3.7587, Test node accu: 0.2757, Test edge accu: 0.2946


100%|██████████| 87/87 [00:00<00:00, 178.19it/s]
100%|██████████| 370/370 [00:00<00:00, 1282.26it/s]


Epoch:   11, Train loss: 0.3115, Test loss: 3.7066, Test node accu: 0.2514, Test edge accu: 0.2946


100%|██████████| 87/87 [00:00<00:00, 253.63it/s]
100%|██████████| 370/370 [00:00<00:00, 919.53it/s]


Epoch:   12, Train loss: 0.3023, Test loss: 3.6186, Test node accu: 0.2865, Test edge accu: 0.2946


100%|██████████| 87/87 [00:00<00:00, 291.87it/s]
100%|██████████| 370/370 [00:00<00:00, 1284.29it/s]


Epoch:   13, Train loss: 0.2974, Test loss: 3.5917, Test node accu: 0.3162, Test edge accu: 0.3405


100%|██████████| 87/87 [00:00<00:00, 270.60it/s]
100%|██████████| 370/370 [00:00<00:00, 1262.66it/s]


Epoch:   14, Train loss: 0.2869, Test loss: 3.6227, Test node accu: 0.3054, Test edge accu: 0.2865


100%|██████████| 87/87 [00:00<00:00, 290.85it/s]
100%|██████████| 370/370 [00:00<00:00, 1302.41it/s]


Epoch:   15, Train loss: 0.2806, Test loss: 3.4946, Test node accu: 0.3216, Test edge accu: 0.3676


100%|██████████| 87/87 [00:00<00:00, 287.77it/s]
100%|██████████| 370/370 [00:00<00:00, 1277.44it/s]


Epoch:   16, Train loss: 0.2689, Test loss: 3.4607, Test node accu: 0.3486, Test edge accu: 0.3514


100%|██████████| 87/87 [00:00<00:00, 298.43it/s]
100%|██████████| 370/370 [00:00<00:00, 1276.64it/s]


Epoch:   17, Train loss: 0.2605, Test loss: 3.4523, Test node accu: 0.3216, Test edge accu: 0.3486


100%|██████████| 87/87 [00:00<00:00, 285.01it/s]
100%|██████████| 370/370 [00:00<00:00, 1309.63it/s]


Epoch:   18, Train loss: 0.2537, Test loss: 3.6128, Test node accu: 0.3270, Test edge accu: 0.3297


100%|██████████| 87/87 [00:00<00:00, 275.82it/s]
100%|██████████| 370/370 [00:00<00:00, 1274.16it/s]


Epoch:   19, Train loss: 0.2486, Test loss: 3.2912, Test node accu: 0.3838, Test edge accu: 0.3622


100%|██████████| 87/87 [00:00<00:00, 307.89it/s]
100%|██████████| 370/370 [00:00<00:00, 1285.13it/s]


Epoch:   20, Train loss: 0.2346, Test loss: 3.2765, Test node accu: 0.3757, Test edge accu: 0.3541


100%|██████████| 87/87 [00:00<00:00, 286.21it/s]
100%|██████████| 370/370 [00:00<00:00, 1274.38it/s]


Epoch:   21, Train loss: 0.2305, Test loss: 3.1966, Test node accu: 0.3459, Test edge accu: 0.3595


100%|██████████| 87/87 [00:00<00:00, 288.26it/s]
100%|██████████| 370/370 [00:00<00:00, 1267.30it/s]


Epoch:   22, Train loss: 0.2230, Test loss: 3.0931, Test node accu: 0.3919, Test edge accu: 0.3784


100%|██████████| 87/87 [00:00<00:00, 290.91it/s]
100%|██████████| 370/370 [00:00<00:00, 1300.19it/s]


Epoch:   23, Train loss: 0.2098, Test loss: 3.2978, Test node accu: 0.3703, Test edge accu: 0.3541


100%|██████████| 87/87 [00:00<00:00, 282.34it/s]
100%|██████████| 370/370 [00:00<00:00, 1280.55it/s]


Epoch:   24, Train loss: 0.2044, Test loss: 3.2134, Test node accu: 0.3649, Test edge accu: 0.3541


100%|██████████| 87/87 [00:00<00:00, 293.38it/s]
100%|██████████| 370/370 [00:00<00:00, 1262.87it/s]


Epoch:   25, Train loss: 0.2011, Test loss: 3.1360, Test node accu: 0.3946, Test edge accu: 0.3973


100%|██████████| 87/87 [00:00<00:00, 268.81it/s]
100%|██████████| 370/370 [00:00<00:00, 1274.20it/s]


Epoch:   26, Train loss: 0.1957, Test loss: 3.0461, Test node accu: 0.3973, Test edge accu: 0.4162


100%|██████████| 87/87 [00:00<00:00, 296.26it/s]
100%|██████████| 370/370 [00:00<00:00, 1262.93it/s]


Epoch:   27, Train loss: 0.1854, Test loss: 2.9878, Test node accu: 0.4108, Test edge accu: 0.4135


100%|██████████| 87/87 [00:00<00:00, 292.58it/s]
100%|██████████| 370/370 [00:00<00:00, 1287.80it/s]


Epoch:   28, Train loss: 0.1805, Test loss: 3.0729, Test node accu: 0.3838, Test edge accu: 0.4189


100%|██████████| 87/87 [00:00<00:00, 284.96it/s]
100%|██████████| 370/370 [00:00<00:00, 1282.89it/s]


Epoch:   29, Train loss: 0.1848, Test loss: 2.9620, Test node accu: 0.4135, Test edge accu: 0.4514


100%|██████████| 87/87 [00:00<00:00, 273.05it/s]
100%|██████████| 370/370 [00:00<00:00, 1288.49it/s]


Epoch:   30, Train loss: 0.1690, Test loss: 2.9290, Test node accu: 0.4297, Test edge accu: 0.4595


100%|██████████| 87/87 [00:00<00:00, 275.95it/s]
100%|██████████| 370/370 [00:00<00:00, 1271.09it/s]


Epoch:   31, Train loss: 0.1653, Test loss: 2.9612, Test node accu: 0.4243, Test edge accu: 0.4243


100%|██████████| 87/87 [00:00<00:00, 289.07it/s]
100%|██████████| 370/370 [00:00<00:00, 1243.63it/s]


Epoch:   32, Train loss: 0.1627, Test loss: 3.0205, Test node accu: 0.4568, Test edge accu: 0.4568


100%|██████████| 87/87 [00:00<00:00, 294.50it/s]
100%|██████████| 370/370 [00:00<00:00, 1272.09it/s]


Epoch:   33, Train loss: 0.1567, Test loss: 3.0623, Test node accu: 0.4297, Test edge accu: 0.4351


100%|██████████| 87/87 [00:00<00:00, 284.17it/s]
100%|██████████| 370/370 [00:00<00:00, 1292.52it/s]


Epoch:   34, Train loss: 0.1513, Test loss: 2.9440, Test node accu: 0.4243, Test edge accu: 0.4432


100%|██████████| 87/87 [00:00<00:00, 280.31it/s]
100%|██████████| 370/370 [00:00<00:00, 1304.17it/s]


Epoch:   35, Train loss: 0.1474, Test loss: 3.0896, Test node accu: 0.4027, Test edge accu: 0.4027


100%|██████████| 87/87 [00:00<00:00, 274.25it/s]
100%|██████████| 370/370 [00:00<00:00, 1284.72it/s]


Epoch:   36, Train loss: 0.1450, Test loss: 3.0539, Test node accu: 0.4243, Test edge accu: 0.4243


100%|██████████| 87/87 [00:00<00:00, 292.70it/s]
100%|██████████| 370/370 [00:00<00:00, 1208.71it/s]


Epoch:   37, Train loss: 0.1379, Test loss: 3.2879, Test node accu: 0.3973, Test edge accu: 0.4243


100%|██████████| 87/87 [00:00<00:00, 285.09it/s]
100%|██████████| 370/370 [00:00<00:00, 1316.43it/s]


Epoch:   38, Train loss: 0.1434, Test loss: 3.1030, Test node accu: 0.4162, Test edge accu: 0.4541


100%|██████████| 87/87 [00:00<00:00, 290.63it/s]
100%|██████████| 370/370 [00:00<00:00, 1286.16it/s]


Epoch:   39, Train loss: 0.1421, Test loss: 3.1492, Test node accu: 0.4135, Test edge accu: 0.4514


100%|██████████| 87/87 [00:00<00:00, 290.36it/s]
100%|██████████| 370/370 [00:00<00:00, 1302.33it/s]


Epoch:   40, Train loss: 0.1323, Test loss: 3.3560, Test node accu: 0.3811, Test edge accu: 0.4189


100%|██████████| 87/87 [00:00<00:00, 274.36it/s]
100%|██████████| 370/370 [00:00<00:00, 1274.47it/s]


Epoch:   41, Train loss: 0.1331, Test loss: 3.0928, Test node accu: 0.4486, Test edge accu: 0.4405


100%|██████████| 87/87 [00:00<00:00, 294.74it/s]
100%|██████████| 370/370 [00:00<00:00, 1247.29it/s]


Epoch:   42, Train loss: 0.1227, Test loss: 3.0333, Test node accu: 0.4297, Test edge accu: 0.4622


100%|██████████| 87/87 [00:00<00:00, 293.55it/s]
100%|██████████| 370/370 [00:00<00:00, 1237.56it/s]


Epoch:   43, Train loss: 0.1242, Test loss: 3.0508, Test node accu: 0.4432, Test edge accu: 0.4730


100%|██████████| 87/87 [00:00<00:00, 285.88it/s]
100%|██████████| 370/370 [00:00<00:00, 1248.40it/s]


Epoch:   44, Train loss: 0.1178, Test loss: 3.4758, Test node accu: 0.3973, Test edge accu: 0.4135


100%|██████████| 87/87 [00:00<00:00, 281.27it/s]
100%|██████████| 370/370 [00:00<00:00, 1160.56it/s]


Epoch:   45, Train loss: 0.1185, Test loss: 3.0668, Test node accu: 0.4378, Test edge accu: 0.4486


100%|██████████| 87/87 [00:00<00:00, 292.87it/s]
100%|██████████| 370/370 [00:00<00:00, 1259.68it/s]


Epoch:   46, Train loss: 0.1127, Test loss: 3.0163, Test node accu: 0.4649, Test edge accu: 0.4811


100%|██████████| 87/87 [00:00<00:00, 282.76it/s]
100%|██████████| 370/370 [00:00<00:00, 1148.58it/s]


Epoch:   47, Train loss: 0.1093, Test loss: 3.0868, Test node accu: 0.4459, Test edge accu: 0.4730


100%|██████████| 87/87 [00:00<00:00, 297.03it/s]
100%|██████████| 370/370 [00:00<00:00, 1251.91it/s]


Epoch:   48, Train loss: 0.1145, Test loss: 3.4134, Test node accu: 0.4622, Test edge accu: 0.4541


100%|██████████| 87/87 [00:00<00:00, 289.22it/s]
100%|██████████| 370/370 [00:00<00:00, 1286.22it/s]


Epoch:   49, Train loss: 0.1201, Test loss: 3.4474, Test node accu: 0.4676, Test edge accu: 0.4541


100%|██████████| 87/87 [00:00<00:00, 296.22it/s]
100%|██████████| 370/370 [00:00<00:00, 1281.89it/s]


Epoch:   50, Train loss: 0.1082, Test loss: 3.3662, Test node accu: 0.4514, Test edge accu: 0.4649


100%|██████████| 87/87 [00:00<00:00, 293.35it/s]
100%|██████████| 370/370 [00:00<00:00, 1249.39it/s]


Epoch:   51, Train loss: 0.1042, Test loss: 3.3434, Test node accu: 0.4378, Test edge accu: 0.4378


100%|██████████| 87/87 [00:00<00:00, 289.91it/s]
100%|██████████| 370/370 [00:00<00:00, 1202.49it/s]


Epoch:   52, Train loss: 0.1104, Test loss: 3.4774, Test node accu: 0.4270, Test edge accu: 0.4405


100%|██████████| 87/87 [00:00<00:00, 298.94it/s]
100%|██████████| 370/370 [00:00<00:00, 1263.52it/s]


Epoch:   53, Train loss: 0.1002, Test loss: 3.2313, Test node accu: 0.4459, Test edge accu: 0.4649


100%|██████████| 87/87 [00:00<00:00, 272.40it/s]
100%|██████████| 370/370 [00:00<00:00, 1298.78it/s]


Epoch:   54, Train loss: 0.1152, Test loss: 3.3610, Test node accu: 0.4405, Test edge accu: 0.4351


100%|██████████| 87/87 [00:00<00:00, 303.55it/s]
100%|██████████| 370/370 [00:00<00:00, 1270.53it/s]


Epoch:   55, Train loss: 0.0947, Test loss: 3.2210, Test node accu: 0.4486, Test edge accu: 0.4297


100%|██████████| 87/87 [00:00<00:00, 303.72it/s]
100%|██████████| 370/370 [00:00<00:00, 1317.30it/s]


Epoch:   56, Train loss: 0.0974, Test loss: 3.3011, Test node accu: 0.4486, Test edge accu: 0.4595


100%|██████████| 87/87 [00:00<00:00, 301.40it/s]
100%|██████████| 370/370 [00:00<00:00, 1266.48it/s]


Epoch:   57, Train loss: 0.0908, Test loss: 3.6604, Test node accu: 0.4270, Test edge accu: 0.4622


100%|██████████| 87/87 [00:00<00:00, 304.04it/s]
100%|██████████| 370/370 [00:00<00:00, 1314.33it/s]


Epoch:   58, Train loss: 0.0954, Test loss: 4.4064, Test node accu: 0.3459, Test edge accu: 0.4000


100%|██████████| 87/87 [00:00<00:00, 296.41it/s]
100%|██████████| 370/370 [00:00<00:00, 1306.90it/s]

Epoch:   59, Train loss: 0.0993, Test loss: 3.4746, Test node accu: 0.4486, Test edge accu: 0.4649





In [None]:
a = torch.rand((1,8))
print(a)
b = torch.argmax(a, dim=1)

tensor([[0.0983, 0.7518, 0.1657, 0.4712, 0.6338, 0.2363, 0.5027, 0.2240]])


In [None]:
c = torch.tensor([6.])

In [None]:
c

tensor([6.])

In [None]:
b.type(torch.float)

tensor([1.])

In [None]:
(b == c) is True

False

In [None]:
torch.equal(b.type(torch.float), c)

False