In [None]:
# link colab to google drive directory where this project data is placed
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)


Mounted at /content/gdrive


In [None]:
################ Need to set project path here !!  #################
projectpath = # "/content/gdrive/MyDrive/GraphAttnProject/SpanTree [with start node]_[walklen=3]_[p=1,q=1]_[num_walks=50]/NIPS_Submission/"

In [None]:
import os
os.chdir(projectpath)
os.getcwd()

'/content/gdrive/MyDrive/GraphAttnProject/SpanTree [with start node]_[walklen=3]_[p=1,q=1]_[num_walks=50]/NIPS_Submission'

In [None]:
! pip install dgl
import dgl



Using backend: pytorch


# Load data

In [None]:

from tqdm.notebook import tqdm, trange
import networkx as nx
import pickle
import numpy as np
import tensorflow as tf
import torch
print(tf.__version__)

2.4.1


In [None]:
# load all train and validation graphs
train_graphs = pickle.load(open(f'graph_data/train_graphs.pkl', 'rb'))
val_graphs = pickle.load(open(f'graph_data/val_graphs.pkl', 'rb'))

# load all labels
train_labels = np.load('graph_data/train_labels.npy')
val_labels = np.load('graph_data/val_labels.npy')

In [None]:
#################. NEED TO SPECIFY THE RANDOM WALK LENGTH WE WANT TO USE ################
walk_len = 6 # we use GKAT with random walk length of 6 in this code file 
# we could also change this parameter to load GKAT kernel generated from random walks with different lengths from 2 to 10.
#########################################################################################

In [None]:
# here we load the frequency matriies (we could use this as raw data to do random feature mapping )

train_freq_mat = pickle.load(open(f'graph_data/GKAT_freq_mats_train_len={walk_len}.pkl', 'rb'))
val_freq_mat = pickle.load(open(f'graph_data/GKAT_freq_mats_val_len={walk_len}.pkl', 'rb'))

In [None]:
# here we load the pre-calculated GKAT kernel

train_GKAT_kernel = pickle.load(open(f'graph_data/GKAT_dot_kernels_train_len={walk_len}.pkl', 'rb'))
val_GKAT_kernel = pickle.load(open(f'graph_data/GKAT_dot_kernels_val_len={walk_len}.pkl', 'rb'))

In [None]:
GKAT_masking = [train_GKAT_kernel, val_GKAT_kernel]


In [None]:
train_graphs = [ dgl.from_networkx(g) for g in train_graphs]
val_graphs = [ dgl.from_networkx(g) for g in val_graphs]
info = [train_graphs, train_labels, val_graphs, val_labels, GKAT_masking]

# START Training

In [None]:
import networkx as nx
import matplotlib.pyplot as plt 
import time
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm, trange
import seaborn as sns

from random import shuffle
from multiprocessing import Pool
import multiprocessing
from functools import partial
from networkx.generators.classic import cycle_graph

import sys
import scipy
import scipy.sparse

#from CodeZip_ST import *

In [None]:
from prettytable import PrettyTable

# this function will count the number of parameters in GKAT (will be used later in this code file)
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    


# GKAT Testing

## GKAT model

In [None]:
class GKATLayer(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 feat_drop=0.,
                 attn_drop=0.,
                 alpha=0.2,
                 agg_activation=F.elu):
        super(GKATLayer, self).__init__()

        self.feat_drop = nn.Dropout(feat_drop)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        #torch.nn.init.xavier_uniform_(self.fc.weight)
        #torch.nn.init.zeros_(self.fc.bias)
        self.attn_l = nn.Parameter(torch.ones(size=(out_dim, 1)))
        self.attn_r = nn.Parameter(torch.ones(size=(out_dim, 1)))
        self.attn_drop = nn.Dropout(attn_drop)
        self.activation = nn.LeakyReLU(alpha)
        self.softmax = nn.Softmax(dim = 1)

        self.agg_activation=agg_activation

    def clean_data(self):
        ndata_names = ['ft', 'a1', 'a2']
        edata_names = ['a_drop']
        for name in ndata_names:
            self.g.ndata.pop(name)
        for name in edata_names:
            self.g.edata.pop(name)

    def forward(self, feat, bg, counting_attn):

        self.g = bg
        h = self.feat_drop(feat)
        head_ft = self.fc(h).reshape((h.shape[0], -1))
        
        a1 = torch.mm(head_ft, self.attn_l)    # V x 1
        a2 = torch.mm(head_ft, self.attn_r)     # V x 1
        a = self.attn_drop(a1 + a2.transpose(0, 1))
        a = self.activation(a)

        maxes = torch.max(a, 1, keepdim=True)[0]
        a_ = a - maxes # we could subtract max to make the attention matrix bounded. (not feasible for random feature mapping decomposition)
        a_nomi = torch.mul(torch.exp(a_), counting_attn.float())
        a_deno = torch.sum(a_nomi, 1, keepdim=True)
        a_nor = a_nomi/(a_deno+1e-9)

        ret = torch.mm(a_nor, head_ft)

        if self.agg_activation is not None:
            ret = self.agg_activation(ret)

        return ret

In [None]:
class GKATClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_heads, n_classes, feat_drop_=0.,
                 attn_drop_=0.,):
        super(GKATClassifier, self).__init__()

        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.layers = nn.ModuleList([
            nn.ModuleList([GKATLayer(in_dim, hidden_dim, feat_drop = feat_drop_, attn_drop = attn_drop_, agg_activation=F.elu) for _ in range(num_heads)]),
            nn.ModuleList([GKATLayer(hidden_dim * num_heads, hidden_dim, feat_drop = feat_drop_, attn_drop = attn_drop_, agg_activation=F.elu) for _ in range(1)])
        ])
        self.classify = nn.Linear(hidden_dim * 1, n_classes)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, bg, counting_attn, normalize = 'normal'):

        h = bg.in_degrees().view(-1, 1).float() # use degree as features 
        num_nodes = h.shape[0]
        features = h.numpy().flatten()
        
        if normalize == 'normal':
            mean_ = np.mean(features)
            std_ = np.std(features)
            h = (h - mean_)/(std_+1e-9)

        for i, gnn in enumerate(self.layers):
            all_h = []
            for j, att_head in enumerate(gnn):
                all_h.append(att_head(h, bg, counting_attn))   
            h = torch.squeeze(torch.cat(all_h, dim=1))

        bg.ndata['h'] = h
        hg = dgl.mean_nodes(bg, 'h')
        return self.classify(hg)

In [None]:
# the following are the parameters we used in GKAT

method = 'GKAT'

runtimes = 15 # the number of repeats 

num_classes = 2 
num_features = 4 # use hidden dimension of 4 in each attention head
num_heads = 8 # use 8 heads
num_layers = 2 # use a two layer GKAT model

feature_drop = 0
atten_drop = 0

epsilon = 1e-4

start_tol = 499
tolerance = 80
max_epoch = 500
batch_size = 128
learning_rate = 0.005

In [23]:
all_GKAT_train_losses = []
all_GKAT_train_acc = []
all_GKAT_val_losses = []
all_GKAT_val_acc = []

ckpt_file = f'results_{num_layers}layers/{method}/{method}_ckpt.pt'

for runtime in trange(runtimes):
    

    train_graphs, train_labels, val_graphs, val_labels, GKAT_masking = info
    train_GKAT_masking, val_GKAT_masking = GKAT_masking
                                                 
    # Create model
    model = GKATClassifier(1, num_features, num_heads, num_classes, feat_drop_ = feature_drop, attn_drop_ = atten_drop)

    for p in model.parameters():
      if p.dim() > 1:
          nn.init.xavier_uniform(p)

    count_parameters(model)


    #model.apply(init_weights)
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    model.train()

    epoch_train_losses_GKAT = []
    epoch_train_acc_GKAT = []
    epoch_val_losses_GKAT = []
    epoch_val_acc_GKAT = []

    num_batches = int(len(train_graphs)/batch_size)

    epoch = 0
    nan_found = 0
    tol = 0 

    while True:
        if nan_found:
          break
        
        epoch_loss = 0
        epoch_acc = 0

        ''' Training '''
        for iter in range(num_batches):
        #for iter in range(2): 
            predictions = []
            labels = torch.empty(batch_size)
            rand_indices = np.random.choice(len(train_graphs), batch_size, replace=False)

            for b in range(batch_size): 
                predictions.append(model(train_graphs[rand_indices[b]], torch.Tensor(train_GKAT_masking[rand_indices[b]])))
                
                if torch.isnan(predictions[b][0])[0]:
                  print('NaN found.')
                  break
                #print(predictions[b].detach().numpy())
                
                labels[b] = train_labels[rand_indices[b]]
            
            acc = 0
            for k in range(len(predictions)):
              if predictions[k][0][0]>predictions[k][0][1] and labels[k]==0:
                acc += 1
              elif predictions[k][0][0]<=predictions[k][0][1] and labels[k]==1:
                acc += 1
            acc /= len(predictions)  
            epoch_acc += acc  
            
            predictions = torch.squeeze(torch.stack(predictions))
            if torch.any(torch.isnan(predictions)):
                  print('NaN found.')
                  nan_found = 1
                  break
            
            loss = loss_func(predictions, labels.long())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.detach().item()

        epoch_acc /= (iter + 1)
        epoch_loss /= (iter + 1)

        val_acc = 0
        val_loss = 0
        predictions_val = []

        for b in range(len(val_graphs)): 

            predictions_val.append(model(val_graphs[b], torch.Tensor(val_GKAT_masking[b])))
        
        for k in range(len(predictions_val)):
          if predictions_val[k][0][0]>predictions_val[k][0][1] and val_labels[k]==0:
            val_acc += 1
          elif predictions_val[k][0][0]<=predictions_val[k][0][1] and val_labels[k]==1:
            val_acc += 1
              
        val_acc /= len(val_graphs)

        predictions_val = torch.squeeze(torch.stack(predictions_val))    
        loss = loss_func(predictions_val, torch.tensor(val_labels).long())
        val_loss += loss.detach().item()


        if len(epoch_val_losses_GKAT) ==0:
          try:
            os.remove(f'{projectpath}{ckpt_file}')
          except:
            pass
          torch.save(model, f'{projectpath}{ckpt_file}')
          print('Epoch {}, acc{:.2f}, loss {:.4f}, tol {}, val_acc{:.2f}, val_loss{:.4f} -- checkpoint saved'.format(epoch, epoch_acc, epoch_loss, tol, val_acc, val_loss))
        elif (np.min(epoch_val_losses_GKAT) >= val_loss) and (np.max(epoch_val_acc_GKAT) <= val_acc): 
          torch.save(model, f'{projectpath}{ckpt_file}')
          print('Epoch {}, acc{:.2f}, loss {:.4f}, tol {}, val_acc{:.2f}, val_loss{:.4f} -- checkpoint saved'.format(epoch, epoch_acc, epoch_loss, tol, val_acc, val_loss))
        else:
          print('Epoch {}, acc{:.2f}, loss {:.4f}, tol {}, val_acc{:.2f}, val_loss{:.4f}'.format(epoch, epoch_acc, epoch_loss, tol, val_acc, val_loss))


        if epoch > start_tol:
          if np.min(epoch_val_losses_GKAT) <= val_loss: 
            tol += 1
            if tol == tolerance: 
                print('Loss do not decrease')
                break
          else:
            if np.abs(epoch_val_losses_GKAT[-1] - val_loss)<epsilon:
                print('Converge steadily')
                break
            tol = 0

                    
        if epoch > max_epoch:
            print("Reach Max Epoch Number")
            break

        epoch += 1
        epoch_train_acc_GKAT.append(epoch_acc)
        epoch_train_losses_GKAT.append(epoch_loss)
        epoch_val_acc_GKAT.append(val_acc)
        epoch_val_losses_GKAT.append(val_loss)

    all_GKAT_train_acc.append(epoch_train_acc_GKAT)
    all_GKAT_train_losses.append(epoch_train_losses_GKAT)
    all_GKAT_val_acc.append(epoch_val_acc_GKAT)
    all_GKAT_val_losses.append(epoch_val_losses_GKAT)

    # save results from current repeat to the following file
    np.save(f'{projectpath}results_{num_layers}layers/epoch_train_acc_{method}_walklen{walk_len}_run{runtime}.npy', epoch_train_acc_GKAT)
    np.save(f'{projectpath}results_{num_layers}layers/epoch_val_acc_{method}_walklen{walk_len}_run{runtime}.npy', epoch_val_acc_GKAT)
    np.save(f'{projectpath}results_{num_layers}layers/epoch_train_losses_{method}_walklen{walk_len}_run{runtime}.npy', epoch_train_losses_GKAT)
    np.save(f'{projectpath}results_{num_layers}layers/epoch_val_losses_{method}_walklen{walk_len}_run{runtime}.npy', epoch_val_losses_GKAT)

# all all results to the following file
np.save(f'{projectpath}results_{num_layers}layers/all_{method}_walklen{walk_len}_train_losses.npy', all_GKAT_train_losses)
np.save(f'{projectpath}results_{num_layers}layers/all_{method}_walklen{walk_len}_train_acc.npy', all_GKAT_train_acc)
np.save(f'{projectpath}results_{num_layers}layers/all_{method}_walklen{walk_len}_val_losses.npy', all_GKAT_val_losses)
np.save(f'{projectpath}results_{num_layers}layers/all_{method}_walklen{walk_len}_val_acc.npy', all_GKAT_val_acc)

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



+----------------------+------------+
|       Modules        | Parameters |
+----------------------+------------+
|  layers.0.0.attn_l   |     4      |
|  layers.0.0.attn_r   |     4      |
| layers.0.0.fc.weight |     4      |
|  layers.0.1.attn_l   |     4      |
|  layers.0.1.attn_r   |     4      |
| layers.0.1.fc.weight |     4      |
|  layers.0.2.attn_l   |     4      |
|  layers.0.2.attn_r   |     4      |
| layers.0.2.fc.weight |     4      |
|  layers.0.3.attn_l   |     4      |
|  layers.0.3.attn_r   |     4      |
| layers.0.3.fc.weight |     4      |
|  layers.0.4.attn_l   |     4      |
|  layers.0.4.attn_r   |     4      |
| layers.0.4.fc.weight |     4      |
|  layers.0.5.attn_l   |     4      |
|  layers.0.5.attn_r   |     4      |
| layers.0.5.fc.weight |     4      |
|  layers.0.6.attn_l   |     4      |
|  layers.0.6.attn_r   |     4      |
| layers.0.6.fc.weight |     4      |
|  layers.0.7.attn_l   |     4      |
|  layers.0.7.attn_r   |     4      |
| layers.0.7

KeyboardInterrupt: ignored