This code is implememnted based on [Graph_Transformer_Networks](https://github.com/seongjunyun/Graph_Transformer_Networks/tree/master)

## Installation

In [None]:
# Install torch geometric
import os
import torch

In [None]:
!pip install torch==2.4.0

In [None]:
torch_version = str(torch.__version__)
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
!pip install torch-scatter -f $scatter_src
!pip install torch-sparse -f $sparse_src
!pip install torch-geometric
# !pip install -q git+https://github.com/snap-stanford/deepsnap.git

In [None]:
import torch_geometric
torch_geometric.__version__

## Data Processing


## Model


In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
from gcn import GCNConv
from torch_scatter import scatter_add
import torch_sparse

In [None]:
class GTConv(nn.Module):

    def __init__(self, in_channels, out_channels, num_nodes):
        super(GTConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weight = nn.Parameter(torch.Tensor(out_channels,in_channels))
        self.bias = None
        self.num_nodes = num_nodes
        self.reset_parameters()
    def reset_parameters(self):
        n = self.in_channels
        nn.init.normal_(self.weight, std=0.01)
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, A, num_nodes, eval=eval):
        filter = F.softmax(self.weight, dim=1)
        num_channels = filter.shape[0]
        results = []
        for i in range(num_channels):
            for j, (edge_index,edge_value) in enumerate(A):
                if j == 0:
                    total_edge_index = edge_index
                    total_edge_value = edge_value*filter[i][j]
                else:
                    total_edge_index = torch.cat((total_edge_index, edge_index), dim=1)
                    total_edge_value = torch.cat((total_edge_value, edge_value*filter[i][j]))

            index, value = torch_sparse.coalesce(total_edge_index.detach(), total_edge_value, m=num_nodes, n=num_nodes, op='add')
            results.append((index, value))
        return results

In [None]:
class GTLayer(nn.Module):

    def __init__(self, in_channels, out_channels, num_nodes, first=True):
        super(GTLayer, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.first = first
        self.num_nodes = num_nodes
        if self.first == True:
            self.conv1 = GTConv(in_channels, out_channels, num_nodes)
            self.conv2 = GTConv(in_channels, out_channels, num_nodes)
        else:
            self.conv1 = GTConv(in_channels, out_channels, num_nodes)

    def forward(self, A, num_nodes, H_=None, eval=False):
        if self.first == True:
            result_A = self.conv1(A, num_nodes, eval=eval)
            result_B = self.conv2(A, num_nodes, eval=eval)
            W = [(F.softmax(self.conv1.weight, dim=1)),(F.softmax(self.conv2.weight, dim=1))]
        else:
            result_A = H_
            result_B = self.conv1(A, num_nodes, eval=eval)
            W = [(F.softmax(self.conv1.weight, dim=1))]
        H = []
        for i in range(len(result_A)):
            a_edge, a_value = result_A[i]
            b_edge, b_value = result_B[i]
            mat_a = torch.sparse_coo_tensor(a_edge, a_value, (num_nodes, num_nodes)).to(a_edge.device)
            mat_b = torch.sparse_coo_tensor(b_edge, b_value, (num_nodes, num_nodes)).to(a_edge.device)
            mat = torch.sparse.mm(mat_a, mat_b).coalesce()
            edges, values = mat.indices(), mat.values()
            # edges, values = torch_sparse.spspmm(a_edge, a_value, b_edge, b_value, num_nodes, num_nodes, num_nodes)
            H.append((edges, values))
        return H, W

In [None]:
class GTN(nn.Module):

    def __init__(self, num_edge, num_channels, w_in, w_out, num_class, num_nodes, num_layers, args=None):
        super(GTN, self).__init__()
        self.num_edge = num_edge
        self.num_channels = num_channels
        self.num_nodes = num_nodes
        self.w_in = w_in
        self.w_out = w_out
        self.num_class = num_class
        self.num_layers = num_layers
        self.args = args
        layers = []
        for i in range(num_layers):
            if i == 0:
                layers.append(GTLayer(num_edge, num_channels, num_nodes, first=True))
            else:
                layers.append(GTLayer(num_edge, num_channels, num_nodes, first=False))
        self.layers = nn.ModuleList(layers)
        if args.dataset in ["PPI", "BOOK", "MUSIC"]:
            self.m = nn.Sigmoid()
            self.loss = nn.BCELoss()
        else:
            self.loss = nn.CrossEntropyLoss()
        self.gcn = GCNConv(in_channels=self.w_in, out_channels=w_out, args=args)
        self.linear = nn.Linear(self.w_out*self.num_channels, self.num_class)

    def normalization(self, H, num_nodes):
        norm_H = []
        for i in range(self.num_channels):
            edge, value=H[i]
            deg_row, deg_col = self.norm(edge.detach(), num_nodes, value)
            value = (deg_row) * value
            norm_H.append((edge, value))
        return norm_H

    def norm(self, edge_index, num_nodes, edge_weight, improved=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                    dtype=dtype,
                                    device=edge_index.device)
        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)
        row, col = edge_index
        deg = scatter_add(edge_weight.clone(), row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-1)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return deg_inv_sqrt[row], deg_inv_sqrt[col]

    def forward(self, A, X, target_x, target, num_nodes=None, eval=False, node_labels=None):
        if num_nodes is None:
            num_nodes = self.num_nodes
        Ws = []
        for i in range(self.num_layers):
            if i == 0:
                H, W = self.layers[i](A, num_nodes, eval=eval)
            else:
                H, W = self.layers[i](A, num_nodes, H, eval=eval)
            H = self.normalization(H, num_nodes)
            Ws.append(W)
        for i in range(self.num_channels):
            edge_index, edge_weight = H[i][0], H[i][1]
            if i==0:
                X_ = self.gcn(X,edge_index=edge_index.detach(), edge_weight=edge_weight)
                X_ = F.relu(X_)
            else:
                X_tmp = F.relu(self.gcn(X,edge_index=edge_index.detach(), edge_weight=edge_weight))
                X_ = torch.cat((X_,X_tmp), dim=1)

        y = self.linear(X_[target_x])
        if eval:
            return y
        else:
            if self.args.dataset == 'PPI':
                loss = self.loss(self.m(y), target)
            else:
                loss = self.loss(y, target)
        return loss, y, Ws

## Train - NEED FURTHER MODIFICATIONS!!!

In [None]:
epochs = args.epoch
node_dim = args.node_dim
num_channels = args.num_channels
lr = args.lr
weight_decay = args.weight_decay
num_layers = args.num_layers

In [None]:
runs = args.runs
if args.pre_train:
    runs += 1
    pre_trained_fastGTNs = None
for l in range(runs):
    # initialize a model
    if args.model == 'GTN':
        model = GTN(num_edge=len(A),
                            num_channels=num_channels,
                            w_in = node_features.shape[1],
                            w_out = node_dim,
                            num_class=num_classes,
                            num_layers=num_layers,
                            num_nodes=num_nodes,
                            args=args)
    elif args.model == 'FastGTN':
        if args.pre_train and l == 1:
            pre_trained_fastGTNs = []
            for layer in range(args.num_FastGTN_layers):
                pre_trained_fastGTNs.append(copy.deepcopy(model.fastGTNs[layer].layers))
        while len(A) > num_edge_type:
            del A[-1]
        model = FastGTNs(num_edge_type=len(A),
                        w_in = node_features.shape[1],
                        num_class=num_classes,
                        num_nodes = node_features.shape[0],
                        args = args)
        if args.pre_train and l > 0:
            for layer in range(args.num_FastGTN_layers):
                model.fastGTNs[layer].layers = pre_trained_fastGTNs[layer]

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    model.cuda()
    if args.dataset == 'PPI':
        loss = nn.BCELoss()
    else:
        loss = nn.CrossEntropyLoss()
    Ws = []

    best_val_loss = 10000
    best_test_loss = 10000
    best_train_loss = 10000
    best_train_f1, best_micro_train_f1 = 0, 0
    best_val_f1, best_micro_val_f1 = 0, 0
    best_test_f1, best_micro_test_f1 = 0, 0

    for i in range(epochs):
        # print('Epoch ',i)
        model.zero_grad()
        model.train()
        if args.model == 'FastGTN':
            loss,y_train,W = model(A, node_features, train_node, train_target, epoch=i)
        else:
            loss,y_train,W = model(A, node_features, train_node, train_target)
        if args.dataset == 'PPI':
            y_train = (y_train > 0).detach().float().cpu()
            train_f1 = 0.0
            sk_train_f1 = sk_f1_score(train_target.detach().cpu().numpy(), y_train.numpy(), average='micro')
        else:
            train_f1 = torch.mean(f1_score(torch.argmax(y_train.detach(),dim=1), train_target, num_classes=num_classes)).cpu().numpy()
            sk_train_f1 = sk_f1_score(train_target.detach().cpu(), np.argmax(y_train.detach().cpu(), axis=1), average='micro')
        # print(W)
        # print('Train - Loss: {}, Macro_F1: {}, Micro_F1: {}'.format(loss.detach().cpu().numpy(), train_f1, sk_train_f1))

        loss.backward()
        optimizer.step()
        model.eval()
        # Valid
        with torch.no_grad():
            if args.model == 'FastGTN':
                val_loss, y_valid,_ = model.forward(A, node_features, valid_node, valid_target, epoch=i)
            else:
                val_loss, y_valid,_ = model.forward(A, node_features, valid_node, valid_target)
            if args.dataset == 'PPI':
                val_f1 = 0.0
                y_valid = (y_valid > 0).detach().float().cpu()
                sk_val_f1 = sk_f1_score(valid_target.detach().cpu().numpy(), y_valid.numpy(), average='micro')
            else:
                val_f1 = torch.mean(f1_score(torch.argmax(y_valid,dim=1), valid_target, num_classes=num_classes)).cpu().numpy()
                sk_val_f1 = sk_f1_score(valid_target.detach().cpu(), np.argmax(y_valid.detach().cpu(), axis=1), average='micro')
            # print('Valid - Loss: {}, Macro_F1: {}, Micro_F1: {}'.format(val_loss.detach().cpu().numpy(), val_f1, sk_val_f1))

            if args.model == 'FastGTN':
                test_loss, y_test,W = model.forward(A, node_features, test_node, test_target, epoch=i)
            else:
                test_loss, y_test,W = model.forward(A, node_features, test_node, test_target)
            if args.dataset == 'PPI':
                test_f1 = 0.0
                y_test = (y_test > 0).detach().float().cpu()
                sk_test_f1 = sk_f1_score(test_target.detach().cpu().numpy(), y_test.numpy(), average='micro')
            else:
                test_f1 = torch.mean(f1_score(torch.argmax(y_test,dim=1), test_target, num_classes=num_classes)).cpu().numpy()
                sk_test_f1 = sk_f1_score(test_target.detach().cpu(), np.argmax(y_test.detach().cpu(), axis=1), average='micro')
            # print('Test - Loss: {}, Macro_F1: {}, Micro_F1:{} \n'.format(test_loss.detach().cpu().numpy(), test_f1, sk_test_f1))
        if sk_val_f1 > best_micro_val_f1:
            best_val_loss = val_loss.detach().cpu().numpy()
            best_test_loss = test_loss.detach().cpu().numpy()
            best_train_loss = loss.detach().cpu().numpy()
            best_train_f1 = train_f1
            best_val_f1 = val_f1
            best_test_f1 = test_f1
            best_micro_train_f1 = sk_train_f1
            best_micro_val_f1 = sk_val_f1
            best_micro_test_f1 = sk_test_f1
    if l == 0 and args.pre_train:
        continue
    print('Run {}'.format(l))
    print('--------------------Best Result-------------------------')
    print('Train - Loss: {:.4f}, Macro_F1: {:.4f}, Micro_F1: {:.4f}'.format(best_test_loss, best_train_f1, best_micro_train_f1))
    print('Valid - Loss: {:.4f}, Macro_F1: {:.4f}, Micro_F1: {:.4f}'.format(best_val_loss, best_val_f1, best_micro_val_f1))
    print('Test - Loss: {:.4f}, Macro_F1: {:.4f}, Micro_F1: {:.4f}'.format(best_test_loss, best_test_f1, best_micro_test_f1))
    final_f1.append(best_test_f1)
    final_micro_f1.append(best_micro_test_f1)

print('--------------------Final Result-------------------------')
print('Test - Macro_F1: {:.4f}+{:.4f}, Micro_F1:{:.4f}+{:.4f}'.format(np.mean(final_f1), np.std(final_f1), np.mean(final_micro_f1), np.std(final_micro_f1)))