# This file is based on https://github.com/dmlc/dgl/tree/master/examples/pytorch/hgt

In [12]:
#Import the right libraries
import pandas as pd
import dgl
import numpy as np
import networkx as nx
import torch
from dgl.data.utils import save_graphs
from dgl.data.utils import load_graphs
import sklearn.metrics as sk

In [13]:
#Import the data
processed_company_data = pd.read_csv('processed_company_data.csv', sep=",")
processed_investor_data = pd.read_csv('processed_investor_data.csv', sep=",")
processed_relation_data = pd.read_csv('processed_relation_data.csv', sep=",")

# Start of the github file

In [14]:
import scipy.io
import urllib.request
import dgl
import math
import numpy as np
from hgt_model import *
import argparse

In [15]:
torch.manual_seed(0)

n_epoch = 200
n_hid = 256
n_inp = 256
clip = 1.0
max_lr = 1e-3

In [16]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [17]:
def train(model, G):
    best_val_acc = torch.tensor(0)
    best_test_acc = torch.tensor(0)
    
    best_fbeta = torch.tensor(0)
    best_precision = torch.tensor(0)
    best_recall = torch.tensor(0)
    
    
    train_step = torch.tensor(0)
    for epoch in np.arange(n_epoch) + 1:
        model.train()
        logits, _ = model(G, 'investor/company')
        # The loss is computed only for labeled nodes.
#         loss = F.cross_entropy(logits[train_idx], labels[train_idx].to(device))
        loss = F.cross_entropy(logits[train_idx], labels[train_idx])
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        train_step += 1
        scheduler.step(train_step)
        if epoch % 5 == 0:
            model.eval()
            logits, _ = model(G, 'investor/company')
            pred   = logits.argmax(1).cpu()
            train_acc = (pred[train_idx] == labels[train_idx]).float().mean()
            val_acc   = (pred[val_idx]   == labels[val_idx]).float().mean()
            test_acc  = (pred[test_idx]  == labels[test_idx]).float().mean()
            
            fbeta = sk.fbeta_score(labels[test_idx], pred[test_idx], beta=0.2, average="macro")
            precision = sk.precision_score(labels[test_idx], pred[test_idx], average="macro")
            recall = sk.recall_score(labels[test_idx], pred[test_idx], average="macro")
            
            if best_val_acc < val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
                
            if best_fbeta < fbeta:
                best_fbeta = fbeta
        
            if best_precision < precision:
                best_precision = precision
                
            if best_recall < recall:
                best_recall = recall
                
            print('Epoch: %d LR: %.5f Loss %.4f, Train Acc %.4f, Val Acc %.4f (Best %.4f), Test Acc %.4f (Best %.4f)' % (
                epoch,
                optimizer.param_groups[0]['lr'], 
                loss.item(),
                train_acc.item(),
                val_acc.item(),
                best_val_acc.item(),
                test_acc.item(),
                best_test_acc.item()
            ))
            print('Epoch: %d LR: %.5f Loss %.4f, fbeta %.4f (Best %.4f), precision %.4f (Best %.4f), recall %.4f (Best %.4f)' % (
                epoch,
                optimizer.param_groups[0]['lr'], 
                loss.item(),
                fbeta.item(),
                best_fbeta.item(),
                precision.item(),
                best_precision.item(),
                recall.item(),
                best_recall.item()
            ))

In [18]:
g = load_graphs("dgl_graph")
G = g[0][0]
print(G)

Graph(num_nodes={'company': 9779, 'investor/company': 7883},
      num_edges={('company', 'different_invested_by', 'investor/company'): 50087, ('investor/company', 'different_invests_in', 'company'): 50087, ('investor/company', 'same_invested_by', 'investor/company'): 5102, ('investor/company', 'same_invests_in', 'investor/company'): 5102},
      metagraph=[('company', 'investor/company', 'different_invested_by'), ('investor/company', 'company', 'different_invests_in'), ('investor/company', 'investor/company', 'same_invested_by'), ('investor/company', 'investor/company', 'same_invests_in')])


In [19]:
different_edge = ('investor/company', 'different_invests_in', 'company')
same_edge = ('investor/company', 'same_invests_in', 'investor/company')

different_labels = G[different_edge].edges()
same_labels = G[same_edge].edges()

pid = G[different_edge].edges()[0]
labels = G[different_edge].edges()[1]

print(pid)
print(labels)

tensor([   0,    0,    0,  ..., 7882, 7882, 7882])
tensor([9181, 9701, 9676,  ..., 7817, 7818, 7819])


In [20]:
# generate train/val/test split
shuffle = np.random.permutation(pid)
train_idx = torch.tensor(shuffle[0:800]).long()
val_idx = torch.tensor(shuffle[800:900]).long()
test_idx = torch.tensor(shuffle[900:]).long()

In [21]:
node_dict = {}
edge_dict = {}

for ntype in G.ntypes:
    node_dict[ntype] = len(node_dict)
for etype in G.etypes:
    edge_dict[etype] = len(edge_dict)
    G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * edge_dict[etype] 

#     Random initialize input feature
for ntype in G.ntypes:
    emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 256), requires_grad = False)
    nn.init.xavier_uniform_(emb)
    G.nodes[ntype].data['inp'] = emb
    

# G = G.to(device)

In [22]:
model = HGT(G,
            node_dict, edge_dict,
            n_inp=n_inp,
            n_hid=n_hid,
            n_out=labels.max().item()+1,
            n_layers=2,
            n_heads=4,
            use_norm = True)#.to(device)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=n_epoch, max_lr = max_lr)
print('Training HGT with #param: %d' % (get_n_params(model)))
train(model, G)

Training HGT with #param: 3961687


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 5 LR: 0.00006 Loss 9.2663, Train Acc 0.0012, Val Acc 0.0000 (Best 0.0000), Test Acc 0.0002 (Best 0.0000)
Epoch: 5 LR: 0.00006 Loss 9.2663, fbeta 0.0000 (Best 0.0000), precision 0.0000 (Best 0.0000), recall 0.0005 (Best 0.0005)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 10 LR: 0.00011 Loss 9.1559, Train Acc 0.0012, Val Acc 0.0000 (Best 0.0000), Test Acc 0.0001 (Best 0.0000)
Epoch: 10 LR: 0.00011 Loss 9.1559, fbeta 0.0000 (Best 0.0000), precision 0.0000 (Best 0.0000), recall 0.0003 (Best 0.0005)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 15 LR: 0.00019 Loss 8.9594, Train Acc 0.0012, Val Acc 0.0000 (Best 0.0000), Test Acc 0.0001 (Best 0.0000)
Epoch: 15 LR: 0.00019 Loss 8.9594, fbeta 0.0000 (Best 0.0000), precision 0.0000 (Best 0.0000), recall 0.0003 (Best 0.0005)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 20 LR: 0.00029 Loss 8.5726, Train Acc 0.0025, Val Acc 0.0000 (Best 0.0000), Test Acc 0.0010 (Best 0.0000)
Epoch: 20 LR: 0.00029 Loss 8.5726, fbeta 0.0000 (Best 0.0000), precision 0.0000 (Best 0.0000), recall 0.0005 (Best 0.0005)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 25 LR: 0.00041 Loss 8.0230, Train Acc 0.0088, Val Acc 0.0100 (Best 0.0100), Test Acc 0.0057 (Best 0.0057)
Epoch: 25 LR: 0.00041 Loss 8.0230, fbeta 0.0000 (Best 0.0000), precision 0.0000 (Best 0.0000), recall 0.0005 (Best 0.0005)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 30 LR: 0.00053 Loss 7.3633, Train Acc 0.0088, Val Acc 0.0100 (Best 0.0100), Test Acc 0.0050 (Best 0.0057)
Epoch: 30 LR: 0.00053 Loss 7.3633, fbeta 0.0000 (Best 0.0000), precision 0.0000 (Best 0.0000), recall 0.0003 (Best 0.0005)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 35 LR: 0.00066 Loss 6.6787, Train Acc 0.0213, Val Acc 0.0300 (Best 0.0300), Test Acc 0.0145 (Best 0.0145)
Epoch: 35 LR: 0.00066 Loss 6.6787, fbeta 0.0002 (Best 0.0002), precision 0.0002 (Best 0.0002), recall 0.0009 (Best 0.0009)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 40 LR: 0.00077 Loss 6.2567, Train Acc 0.0325, Val Acc 0.0400 (Best 0.0400), Test Acc 0.0256 (Best 0.0256)
Epoch: 40 LR: 0.00077 Loss 6.2567, fbeta 0.0001 (Best 0.0002), precision 0.0001 (Best 0.0002), recall 0.0009 (Best 0.0009)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 45 LR: 0.00087 Loss 5.9613, Train Acc 0.0350, Val Acc 0.0500 (Best 0.0500), Test Acc 0.0285 (Best 0.0285)
Epoch: 45 LR: 0.00087 Loss 5.9613, fbeta 0.0001 (Best 0.0002), precision 0.0001 (Best 0.0002), recall 0.0011 (Best 0.0011)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 50 LR: 0.00095 Loss 5.7155, Train Acc 0.0375, Val Acc 0.0500 (Best 0.0500), Test Acc 0.0293 (Best 0.0285)
Epoch: 50 LR: 0.00095 Loss 5.7155, fbeta 0.0004 (Best 0.0004), precision 0.0004 (Best 0.0004), recall 0.0013 (Best 0.0013)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 55 LR: 0.00099 Loss 5.4621, Train Acc 0.0562, Val Acc 0.0500 (Best 0.0500), Test Acc 0.0400 (Best 0.0285)
Epoch: 55 LR: 0.00099 Loss 5.4621, fbeta 0.0014 (Best 0.0014), precision 0.0015 (Best 0.0015), recall 0.0024 (Best 0.0024)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 60 LR: 0.00100 Loss 5.0523, Train Acc 0.2288, Val Acc 0.1200 (Best 0.1200), Test Acc 0.1035 (Best 0.1035)
Epoch: 60 LR: 0.00100 Loss 5.0523, fbeta 0.0177 (Best 0.0177), precision 0.0186 (Best 0.0186), recall 0.0185 (Best 0.0185)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 65 LR: 0.00100 Loss 4.6988, Train Acc 0.2612, Val Acc 0.1200 (Best 0.1200), Test Acc 0.1132 (Best 0.1035)
Epoch: 65 LR: 0.00100 Loss 4.6988, fbeta 0.0163 (Best 0.0177), precision 0.0170 (Best 0.0186), recall 0.0182 (Best 0.0185)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 70 LR: 0.00098 Loss 4.2536, Train Acc 0.5788, Val Acc 0.2100 (Best 0.2100), Test Acc 0.1923 (Best 0.1923)
Epoch: 70 LR: 0.00098 Loss 4.2536, fbeta 0.0645 (Best 0.0645), precision 0.0687 (Best 0.0687), recall 0.0619 (Best 0.0619)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 75 LR: 0.00097 Loss 3.8273, Train Acc 0.7075, Val Acc 0.2200 (Best 0.2200), Test Acc 0.2155 (Best 0.2155)
Epoch: 75 LR: 0.00097 Loss 3.8273, fbeta 0.0707 (Best 0.0707), precision 0.0748 (Best 0.0748), recall 0.0756 (Best 0.0756)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 80 LR: 0.00095 Loss 4.2980, Train Acc 0.3950, Val Acc 0.1500 (Best 0.2200), Test Acc 0.1473 (Best 0.2155)
Epoch: 80 LR: 0.00095 Loss 4.2980, fbeta 0.0324 (Best 0.0707), precision 0.0345 (Best 0.0748), recall 0.0321 (Best 0.0756)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 85 LR: 0.00092 Loss 3.5791, Train Acc 0.8050, Val Acc 0.1800 (Best 0.2200), Test Acc 0.2383 (Best 0.2155)
Epoch: 85 LR: 0.00092 Loss 3.5791, fbeta 0.1032 (Best 0.1032), precision 0.1108 (Best 0.1108), recall 0.0890 (Best 0.0890)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 90 LR: 0.00088 Loss 3.1699, Train Acc 0.9150, Val Acc 0.2300 (Best 0.2300), Test Acc 0.2605 (Best 0.2605)
Epoch: 90 LR: 0.00088 Loss 3.1699, fbeta 0.0876 (Best 0.1032), precision 0.0920 (Best 0.1108), recall 0.1021 (Best 0.1021)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 95 LR: 0.00085 Loss 3.2760, Train Acc 0.9325, Val Acc 0.2300 (Best 0.2300), Test Acc 0.2641 (Best 0.2605)
Epoch: 95 LR: 0.00085 Loss 3.2760, fbeta 0.0837 (Best 0.1032), precision 0.0868 (Best 0.1108), recall 0.1046 (Best 0.1046)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 100 LR: 0.00080 Loss 3.5633, Train Acc 0.9325, Val Acc 0.2400 (Best 0.2400), Test Acc 0.2645 (Best 0.2645)
Epoch: 100 LR: 0.00080 Loss 3.5633, fbeta 0.0677 (Best 0.1032), precision 0.0696 (Best 0.1108), recall 0.0990 (Best 0.1046)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 105 LR: 0.00076 Loss 2.5093, Train Acc 0.9362, Val Acc 0.2400 (Best 0.2400), Test Acc 0.2650 (Best 0.2645)
Epoch: 105 LR: 0.00076 Loss 2.5093, fbeta 0.0669 (Best 0.1032), precision 0.0689 (Best 0.1108), recall 0.1021 (Best 0.1046)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 110 LR: 0.00071 Loss 2.5027, Train Acc 0.9350, Val Acc 0.2300 (Best 0.2400), Test Acc 0.2653 (Best 0.2645)
Epoch: 110 LR: 0.00071 Loss 2.5027, fbeta 0.1033 (Best 0.1033), precision 0.1089 (Best 0.1108), recall 0.1059 (Best 0.1059)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 115 LR: 0.00065 Loss 2.2479, Train Acc 0.9413, Val Acc 0.2300 (Best 0.2400), Test Acc 0.2673 (Best 0.2645)
Epoch: 115 LR: 0.00065 Loss 2.2479, fbeta 0.0985 (Best 0.1033), precision 0.1035 (Best 0.1108), recall 0.1063 (Best 0.1063)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 120 LR: 0.00060 Loss 2.0784, Train Acc 0.9475, Val Acc 0.2400 (Best 0.2400), Test Acc 0.2704 (Best 0.2645)
Epoch: 120 LR: 0.00060 Loss 2.0784, fbeta 0.0930 (Best 0.1033), precision 0.0974 (Best 0.1108), recall 0.1068 (Best 0.1068)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 125 LR: 0.00054 Loss 1.9300, Train Acc 0.9525, Val Acc 0.2400 (Best 0.2400), Test Acc 0.2729 (Best 0.2645)
Epoch: 125 LR: 0.00054 Loss 1.9300, fbeta 0.0898 (Best 0.1033), precision 0.0937 (Best 0.1108), recall 0.1074 (Best 0.1074)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 130 LR: 0.00049 Loss 1.8491, Train Acc 0.9575, Val Acc 0.2400 (Best 0.2400), Test Acc 0.2763 (Best 0.2645)
Epoch: 130 LR: 0.00049 Loss 1.8491, fbeta 0.0852 (Best 0.1033), precision 0.0880 (Best 0.1108), recall 0.1081 (Best 0.1081)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 135 LR: 0.00043 Loss 1.6498, Train Acc 0.9625, Val Acc 0.2400 (Best 0.2400), Test Acc 0.2779 (Best 0.2645)
Epoch: 135 LR: 0.00043 Loss 1.6498, fbeta 0.0919 (Best 0.1033), precision 0.0952 (Best 0.1108), recall 0.1086 (Best 0.1086)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 140 LR: 0.00038 Loss 1.5395, Train Acc 0.9650, Val Acc 0.2400 (Best 0.2400), Test Acc 0.2798 (Best 0.2645)
Epoch: 140 LR: 0.00038 Loss 1.5395, fbeta 0.0876 (Best 0.1033), precision 0.0912 (Best 0.1108), recall 0.1089 (Best 0.1089)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 145 LR: 0.00032 Loss 1.4262, Train Acc 0.9688, Val Acc 0.2400 (Best 0.2400), Test Acc 0.2793 (Best 0.2645)
Epoch: 145 LR: 0.00032 Loss 1.4262, fbeta 0.0763 (Best 0.1033), precision 0.0783 (Best 0.1108), recall 0.1090 (Best 0.1090)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 150 LR: 0.00027 Loss 1.3207, Train Acc 0.9725, Val Acc 0.2400 (Best 0.2400), Test Acc 0.2804 (Best 0.2645)
Epoch: 150 LR: 0.00027 Loss 1.3207, fbeta 0.0806 (Best 0.1033), precision 0.0830 (Best 0.1108), recall 0.1095 (Best 0.1095)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 155 LR: 0.00022 Loss 1.2257, Train Acc 0.9787, Val Acc 0.2500 (Best 0.2500), Test Acc 0.2803 (Best 0.2803)
Epoch: 155 LR: 0.00022 Loss 1.2257, fbeta 0.0757 (Best 0.1033), precision 0.0777 (Best 0.1108), recall 0.1097 (Best 0.1097)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 160 LR: 0.00018 Loss 1.1345, Train Acc 0.9875, Val Acc 0.2600 (Best 0.2600), Test Acc 0.2821 (Best 0.2821)
Epoch: 160 LR: 0.00018 Loss 1.1345, fbeta 0.0776 (Best 0.1033), precision 0.0796 (Best 0.1108), recall 0.1112 (Best 0.1112)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 165 LR: 0.00014 Loss 1.0661, Train Acc 0.9887, Val Acc 0.2700 (Best 0.2700), Test Acc 0.2831 (Best 0.2831)
Epoch: 165 LR: 0.00014 Loss 1.0661, fbeta 0.0770 (Best 0.1033), precision 0.0791 (Best 0.1108), recall 0.1113 (Best 0.1113)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 170 LR: 0.00010 Loss 1.0085, Train Acc 0.9912, Val Acc 0.2700 (Best 0.2700), Test Acc 0.2831 (Best 0.2831)
Epoch: 170 LR: 0.00010 Loss 1.0085, fbeta 0.0744 (Best 0.1033), precision 0.0762 (Best 0.1108), recall 0.1115 (Best 0.1115)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 175 LR: 0.00007 Loss 0.9665, Train Acc 0.9937, Val Acc 0.2700 (Best 0.2700), Test Acc 0.2843 (Best 0.2831)
Epoch: 175 LR: 0.00007 Loss 0.9665, fbeta 0.0759 (Best 0.1033), precision 0.0779 (Best 0.1108), recall 0.1117 (Best 0.1117)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 180 LR: 0.00004 Loss 0.9391, Train Acc 0.9937, Val Acc 0.2700 (Best 0.2700), Test Acc 0.2852 (Best 0.2831)
Epoch: 180 LR: 0.00004 Loss 0.9391, fbeta 0.0745 (Best 0.1033), precision 0.0762 (Best 0.1108), recall 0.1119 (Best 0.1119)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 185 LR: 0.00002 Loss 0.9141, Train Acc 0.9950, Val Acc 0.2700 (Best 0.2700), Test Acc 0.2856 (Best 0.2831)
Epoch: 185 LR: 0.00002 Loss 0.9141, fbeta 0.0736 (Best 0.1033), precision 0.0751 (Best 0.1108), recall 0.1120 (Best 0.1120)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 190 LR: 0.00001 Loss 0.9051, Train Acc 0.9950, Val Acc 0.2700 (Best 0.2700), Test Acc 0.2856 (Best 0.2831)
Epoch: 190 LR: 0.00001 Loss 0.9051, fbeta 0.0739 (Best 0.1033), precision 0.0755 (Best 0.1108), recall 0.1120 (Best 0.1120)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 195 LR: 0.00000 Loss 0.9006, Train Acc 0.9950, Val Acc 0.2700 (Best 0.2700), Test Acc 0.2856 (Best 0.2831)
Epoch: 195 LR: 0.00000 Loss 0.9006, fbeta 0.0738 (Best 0.1033), precision 0.0754 (Best 0.1108), recall 0.1120 (Best 0.1120)




Epoch: 200 LR: 0.00000 Loss 0.9020, Train Acc 0.9950, Val Acc 0.2700 (Best 0.2700), Test Acc 0.2856 (Best 0.2831)
Epoch: 200 LR: 0.00000 Loss 0.9020, fbeta 0.0738 (Best 0.1033), precision 0.0754 (Best 0.1108), recall 0.1120 (Best 0.1120)


  _warn_prf(average, modifier, msg_start, len(result))


In [23]:
# torch.save({
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             }, "model.pth")

In [24]:
# model = HeteroRGCN(G,
#                    in_size=args.n_inp,
#                    hidden_size=args.n_hid,
#                    out_size=labels.max().item()+1)#.to(device)
# optimizer = torch.optim.AdamW(model.parameters())
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
# print('Training RGCN with #param: %d' % (get_n_params(model)))
# train(model, G)

In [25]:
# model = HGT(G,
#             node_dict, edge_dict,
#             n_inp=args.n_inp,
#             n_hid=args.n_hid,
#             n_out=labels.max().item()+1,
#             n_layers=0,
#             n_heads=4)#.to(device)
# optimizer = torch.optim.AdamW(model.parameters())
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=args.n_epoch, max_lr = args.max_lr)
# print('Training MLP with #param: %d' % (get_n_params(model)))
# train(model, G)