# Python Package Install
- Please use the command below to down these Python package if they are not installed in your virtual environment.

In [None]:
!pip install torch==1.8.0
!pip install numpy==1.22.3

# Package Import

In [1]:
import pickle
import torch
from torch import nn
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math

  from .autonotebook import tqdm as notebook_tqdm


# Argument Settings

In [2]:
dataset = "DBLP" # DBLP ACM
num_layers = 3
epochs = 4
node_dim = 64
num_channels = 2
norm=True
lr = 0.005
weight_decay = 0.001
adaptive_lr = False

# Dataset Download
Please download datasets (DBLP, ACM, IMDB) from this [link](https://drive.google.com/file/d/13eC9gz8b9mLCPC_V1iHXJSKNHCTrTOYa/view?usp=sharing) and extract data.zip into data folder.

# Data Loading

In [3]:
with open('data/' + dataset + '/node_features.pkl', 'rb') as f:
    node_features = pickle.load(f)
with open('data/' + dataset+'/edges.pkl','rb') as f:
    edges = pickle.load(f)
with open('data/' + dataset+'/labels.pkl','rb') as f:
    labels = pickle.load(f)
num_nodes = edges[0].shape[0] # num_nodes

for i, edge in enumerate(edges):
    if i ==0:
        A = torch.from_numpy(edge.todense()).type(torch.FloatTensor).unsqueeze(-1)
    else:
        A = torch.cat([A, torch.from_numpy(edge.todense()).type(torch.FloatTensor).unsqueeze(-1)], dim=-1)
# A: [num_nodes, num_nodes, num_edges]

A = torch.cat([A, torch.eye(num_nodes).type(torch.FloatTensor).unsqueeze(-1)], dim=-1) # A: [num_nodes, num_nodes, num_edges + 1]

node_features = torch.from_numpy(node_features).type(torch.FloatTensor) # node_features: [num_nodes, node_embedding_input]
train_node = torch.from_numpy(np.array(labels[0])[:,0]).type(torch.LongTensor) # 800
train_target = torch.from_numpy(np.array(labels[0])[:,1]).type(torch.LongTensor) # 800
valid_node = torch.from_numpy(np.array(labels[1])[:,0]).type(torch.LongTensor) # 400
valid_target = torch.from_numpy(np.array(labels[1])[:,1]).type(torch.LongTensor) # 400
test_node = torch.from_numpy(np.array(labels[2])[:,0]).type(torch.LongTensor) # 2857
test_target = torch.from_numpy(np.array(labels[2])[:,1]).type(torch.LongTensor) # 2857

  edges = pickle.load(f)
  edges = pickle.load(f)


# Graph Transformers Neyworks Module

In [4]:

import torch
from torch import nn
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math

class GTN(nn.Module):

    """
    The module to Compose Graph Transformers Neural Network .

    Input shape:
        A [num_nodes, num_nodes, num_edges + 1]
        X [num_nodes, node_embedding_input]
        target_x [num_target_node] DBLP 400
        target [num_target_node] DBLP 400

    Output shape:
        y: [num_target_nodes, num_node_classes]
        Ws: A set of [num_channels, num_edges + 1, 1, 1]
    """
    
    def __init__(self, num_edge, num_channels, w_in, w_out, num_class ,num_layers, norm):
        super(GTN, self).__init__()
        self.num_edge = num_edge # num_edge_classes 4
        self.num_channels = num_channels # num_channelsn 2
        self.w_in = w_in # node_embedding_input 334
        self.w_out = w_out # node_embedding_output 64
        self.num_class = num_class # num_node_classes 4
        self.num_layers = num_layers # num_layers 3
        self.is_norm = norm
        layers = []
        for i in range(num_layers):
            if i == 0:
                layers.append(GraphTransformersLayer(num_edge, num_channels, first=True))
            else:
                layers.append(GraphTransformersLayer(num_edge, num_channels, first=False))
        self.layers = nn.ModuleList(layers)
        self.weight = nn.Parameter(torch.Tensor(w_in, w_out)) # [node_embedding_input, node_embedding_output]
        nn.init.xavier_uniform_(self.weight)
        self.bias = nn.Parameter(torch.Tensor(w_out)) # [node_embedding_output]
        nn.init.zeros_(self.bias)
        self.loss = nn.CrossEntropyLoss()
        self.linear1 = nn.Linear(self.w_out*num_channels, self.w_out) # [node_embedding_output * num_channels, node_embedding_output]
        self.linear2 = nn.Linear(self.w_out, self.num_class) # [node_embedding_output, num_node_classes]

    def gcn_conv(self, X, H):
        """
        Input shape:
            X: [num_nodes, node_embedding_input]
            H: [num_nodes, num_nodes]

        Middle variant:
            X: [num_nodes, node_embedding_output]
            H: [num_nodes, num_nodes]

        Output shape:
            [num_nodes, node_embedding_output]
        """
        
        X = torch.mm(X, self.weight) # [num_nodes, node_embedding_output]
        H = self.norm(H, add=True)  # [num_nodes, num_nodes]
        return torch.mm(H.t(),X)

    def normalization(self, H):
        """
        Input shape:
            H: [num_channels, num_nodes, num_nodes]

        Output shape:
            H_: [num_channels, num_nodes, num_nodes]
        """

        for i in range(self.num_channels):
            if i==0:
                H_ = self.norm(H[i, :, :]).unsqueeze(0)
            else:
                H_ = torch.cat((H_, self.norm(H[i, :, :]).unsqueeze(0)), dim=0)
        return H_

    def norm(self, H, add=False):
        """
        Input shape:
            H: [num_nodes, num_nodes]

        Output shape:
            H_: [num_channels, num_nodes, num_nodes]
        """

        H = H.t()

        if add == False:
            H = H * ((torch.eye(H.shape[0])==0).type(torch.FloatTensor)) # H: [num_nodes, num_nodes]
        else:
            H = H * ((torch.eye(H.shape[0])==0).type(torch.FloatTensor)) + torch.eye(H.shape[0]).type(torch.FloatTensor) # H: [num_nodes, num_nodes]

        deg = torch.sum(H, dim=1) # deg: [num_nodes]
        deg_inv = deg.pow(-1) # deg_inv: [num_nodes]
        deg_inv[deg_inv == float('inf')] = 0 # deg_inv: [num_nodes]
        deg_inv = deg_inv * torch.eye(H.shape[0]).type(torch.FloatTensor) # deg_inv: [num_nodes, num_nodes]
        H = torch.mm(deg_inv, H) # deg_inv: [num_nodes, num_nodes]
        H = H.t() # H: [num_nodes, num_nodes]
        return H

    def forward(self, A, X, target_x, target):
        """
        Input shape:
            A [num_nodes, num_nodes, num_edges + 1]
            X [num_nodes, node_embedding_input]
            target_x [num_target_node]
            target [num_target_node]

        Output shape:
            y: [num_target_nodes, num_node_classes]
            Ws: A set of [num_channels, num_edges + 1, 1, 1]
        """

        A = A.unsqueeze(0).permute(0,3,1,2)  # [1, num_edges + 1, num_nodes, num_nodes]
        Ws = []
        for i in range(self.num_layers):
            if i == 0:
                H, W = self.layers[i](A) # H: [num_channels, num_nodes, num_nodes]
            else:
                H = self.normalization(H) # H: [num_channels, num_nodes, num_nodes]
                H, W = self.layers[i](A, H) # H: [num_channels, num_nodes, num_nodes]
            Ws.append(W)
        
        for i in range(self.num_channels):
            if i==0:
                # X: [num_nodes, node_embedding_input] H[i]: [num_nodes, num_nodes]
                X_ = F.relu(self.gcn_conv(X, H[i])) # X_: [num_nodes, node_embedding_output]
            else:
                # X: [num_nodes, node_embedding_input] H[i]: [num_nodes, num_nodes]
                X_tmp = F.relu(self.gcn_conv(X, H[i])) # X_tmp: [num_nodes, node_embedding_output]
                X_ = torch.cat((X_,X_tmp), dim=1) # X_: [num_nodes, node_embedding_output * num_channels]
        X_ = self.linear1(X_) # X_: [num_nodes, node_embedding_output]
        X_ = F.relu(X_) # X_: [num_nodes, node_embedding_output]
        y = self.linear2(X_[target_x]) # [num_target_nodes, num_node_classes]
        loss = self.loss(y, target)
        return loss, y, Ws


class GraphTransformersLayer(nn.Module):
    '''
    The module to Compose Graph Transformers Convolutional Neural Network Layer.
    This module is to finsh convilution of two layers.
    After being sent to this module, the meta-path will be extended by one unit。

    Input shape:
        A: [1, num_edges + 1, num_nodes, num_nodes]
        H: [num_channels, num_nodes, num_nodes]

    Output shape:
        H: [num_channels, num_nodes, num_nodes]
        W: [num_channels, num_edges + 1, 1, 1]
    '''

    def __init__(self, in_channels, out_channels, first=True):
        super(GraphTransformersLayer, self).__init__()

        # Parameter Setting
        self.in_channels = in_channels # num_edges
        self.out_channels = out_channels # num_channels
        self.first = first

        if self.first == True:
            self.conv1 = GraphTransformersConvModule(in_channels, out_channels)
            self.conv2 = GraphTransformersConvModule(in_channels, out_channels)
        else:
            self.conv1 = GraphTransformersConvModule(in_channels, out_channels)
    
    def forward(self, A, H_=None): # A: [1, num_edges + 1, num_nodes, num_nodes] H: [num_channels, num_nodes, num_nodes]
        if self.first == True:
            a = self.conv1(A) # [num_channels, num_nodes, num_nodes]
            b = self.conv2(A) # [num_channels, num_nodes, num_nodes]
            H = torch.bmm(a, b) # [num_channels, num_nodes, num_nodes]
            W = [(F.softmax(self.conv1.weight, dim=1)).detach(),(F.softmax(self.conv2.weight, dim=1)).detach()]
        else:
            a = self.conv1(A) # [num_channels, num_nodes, num_nodes]
            H = torch.bmm(H_, a) # [num_channels, num_nodes, num_nodes]
            W = [(F.softmax(self.conv1.weight, dim=1)).detach()]
        return H, W

class GraphTransformersConvModule(nn.Module):
    '''
    The module to do Graph Transformers Convolutional Neural Network.
    This is a backbone module to finish edge classes to channels.

    Input shape:
        A: [1, num_edges + 1, num_nodes, num_nodes]

    Output shape:
        A: [num_channels, num_nodes, num_nodes]
    '''

    def __init__(self, in_channels, out_channels):
        super(GraphTransformersConvModule, self).__init__()

        # Module Parameters Setting
        self.in_channels = in_channels # num_edges + 1
        self.out_channels = out_channels # num_channels
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels).unsqueeze(2).unsqueeze(3)) # [num_channels, num_edges + 1, 1, 1]
        nn.init.constant_(self.weight, 0.1)
        self.bias = None
        # print("GraphTransformersConvModule Start")

    def forward(self, A): # A: [1, num_edges + 1, num_nodes, num_nodes]
        A = torch.sum(A * F.softmax(self.weight, dim=1), dim=1) # A: [num_channels, num_nodes, num_nodes]
        return A


# Training Utils

In [5]:
def accuracy(pred, target):
    return (pred == target).sum().item() / target.numel()

def true_positive(pred, target, num_classes):
    out = []
    for i in range(num_classes):
        out.append(((pred == i) & (target == i)).sum())
    return torch.tensor(out)

def true_negative(pred, target, num_classes):
    out = []
    for i in range(num_classes):
        out.append(((pred != i) & (target != i)).sum())
    return torch.tensor(out)

def false_positive(pred, target, num_classes):
    out = []
    for i in range(num_classes):
        out.append(((pred == i) & (target != i)).sum())
    return torch.tensor(out)

def false_negative(pred, target, num_classes):
    out = []
    for i in range(num_classes):
        out.append(((pred != i) & (target == i)).sum())
    return torch.tensor(out)

def precision(pred, target, num_classes):
    tp = true_positive(pred, target, num_classes).to(torch.float)
    fp = false_positive(pred, target, num_classes).to(torch.float)
    out = tp / (tp + fp)
    out[torch.isnan(out)] = 0
    return out

def recall(pred, target, num_classes):
    tp = true_positive(pred, target, num_classes).to(torch.float)
    fn = false_negative(pred, target, num_classes).to(torch.float)
    out = tp / (tp + fn)
    out[torch.isnan(out)] = 0
    return out

def f1_score(pred, target, num_classes):
    prec = precision(pred, target, num_classes)
    rec = recall(pred, target, num_classes)
    score = 2 * (prec * rec) / (prec + rec)
    score[torch.isnan(score)] = 0
    return score

# Training

In [6]:
num_classes = torch.max(train_target).item() + 1
final_f1 = 0
for l in range(1):
    model = GTN(num_edge=A.shape[-1], # num_edges 4
                        num_channels=num_channels, # num_channels 2
                        w_in = node_features.shape[1], # node_embedding_input 334
                        w_out = node_dim,  # node_embedding_output 64
                        num_class=num_classes, # num_classes 4
                        num_layers=num_layers, # num_layers 3
                        norm=norm)
    if adaptive_lr == 'false':
        optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)
    else:
        optimizer = torch.optim.Adam([{'params':model.weight},
                                    {'params':model.linear1.parameters()},
                                    {'params':model.linear2.parameters()},
                                    {"params":model.layers.parameters(), "lr":0.5}
                                    ], lr=0.005, weight_decay=0.001)
    loss = nn.CrossEntropyLoss()
    # Train & Valid & Test
    best_val_loss = 2000
    best_test_loss = 20000
    best_train_loss = 20000
    best_train_f1 = 0
    best_val_f1 = 0
    best_test_f1 = 0
    
    for i in range(epochs):

        # To make sure that the learning rate is limited in a rellatively small range.
        for param_group in optimizer.param_groups:
            if param_group['lr'] > 0.005:
                param_group['lr'] = param_group['lr'] * 0.9

        print('Epoch: {}'.format(i))

        # Clean the gradient
        model.zero_grad()
        model.train()

        # Training
        loss, y_train, Ws = model(A, node_features, train_node, train_target) # 
        train_f1 = torch.mean(f1_score(torch.argmax(y_train.detach(), dim=1), train_target, num_classes=num_classes)).cpu().numpy()
        print('Train - Loss: {}, Macro_F1: {}'.format(loss.detach().cpu().numpy(), train_f1))

        # gradient backward
        loss.backward()
        optimizer.step()
        model.eval()

        # Validation
        with torch.no_grad():
            loss_val, y_valid, W_val = model.forward(A, node_features, valid_node, valid_target)
            val_f1 = torch.mean(f1_score(torch.argmax(y_valid, dim=1), valid_target, num_classes=num_classes)).cpu().numpy()
            print('Valid - Loss: {}, Macro_F1: {}'.format(loss_val.detach().cpu().numpy(), val_f1))

            loss_test, y_test, W_test = model.forward(A, node_features, test_node, test_target)
            test_f1 = torch.mean(f1_score(torch.argmax(y_test,dim=1), test_target, num_classes=num_classes)).cpu().numpy()
            print('Test - Loss: {}, Macro_F1: {}\n'.format(loss_test.detach().cpu().numpy(), test_f1))

        if val_f1 > best_val_f1:
            best_val_loss = loss_val.detach().cpu().numpy()
            best_test_loss = loss_test.detach().cpu().numpy()
            best_train_loss = loss.detach().cpu().numpy()
            best_train_f1 = train_f1
            best_val_f1 = val_f1
            best_test_f1 = test_f1 
            
    print('Train - Loss: {}, Macro_F1: {}'.format(best_train_loss, best_train_f1))
    print('Valid - Loss: {}, Macro_F1: {}'.format(best_val_loss, best_val_f1))
    print('Test - Loss: {}, Macro_F1: {}'.format(best_test_loss, best_test_f1))
    final_f1 += best_test_f1


Epoch: 0
Train - Loss: 1.386662483215332, Macro_F1: 0.19890807569026947
Valid - Loss: 1.351372241973877, Macro_F1: 0.29231950640678406
Test - Loss: 1.3450151681900024, Macro_F1: 0.2843223214149475

Epoch: 1
Train - Loss: 1.3486584424972534, Macro_F1: 0.29405418038368225
Valid - Loss: 1.2783275842666626, Macro_F1: 0.8466583490371704
Test - Loss: 1.2830886840820312, Macro_F1: 0.8195434212684631

Epoch: 2
Train - Loss: 1.274963140487671, Macro_F1: 0.8568823933601379
Valid - Loss: 1.1668626070022583, Macro_F1: 0.7784095406532288
Test - Loss: 1.1480616331100464, Macro_F1: 0.7526983618736267

Epoch: 3
Train - Loss: 1.1701091527938843, Macro_F1: 0.741525411605835
Valid - Loss: 1.06666898727417, Macro_F1: 0.671305239200592
Test - Loss: 1.1062723398208618, Macro_F1: 0.6496049761772156

Train - Loss: 1.3486584424972534, Macro_F1: 0.29405418038368225
Valid - Loss: 1.2783275842666626, Macro_F1: 0.8466583490371704
Test - Loss: 1.2830886840820312, Macro_F1: 0.8195434212684631


# Reference
- https://arxiv.org/pdf/1911.06455v2.pdf
- https://github.com/seongjunyun/Graph_Transformer_Networks
- https://www.youtube.com/watch?v=91yBCJIkpLc&t=304s
- https://distill.pub/2021/gnn-intro/
- https://distill.pub/2021/understanding-gnns/