In [None]:
# link colab to google drive directory where this project data is placed
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

import numpy as np
import tensorflow as tf
print(tf.__version__)


# set project path
projectpath = "/content/gdrive/My Drive/GraphAttnProject/ErdosRanyiSubmission/"

print(projectpath)
#print(datareadpath)


!pip install dgl


import os
os.chdir(projectpath)
os.getcwd()

from CodeZip_ER import *

from tqdm.notebook import tqdm, trange
import networkx as nx
import pickle
import torch
print(tf.__version__)

# Load Data

In [None]:
name = 'Caveman'
walk_len = 4 # set walk length for GKAT 

In [None]:

num_classes = 2
num_features = 32
num_heads = 2
feature_drop = 0
atten_drop = 0
runtimes = 15

epsilon = 1e-4

start_tol = 499
tolerance = 80
max_epoch = 500
batch_size = 128
learning_rate = 0.001
h_size = 5
normalize = None



In [None]:
# load all train and validation graphs
train_graphs = pickle.load(open(f'graph_data/{name}/train_graphs.pkl', 'rb'))
val_graphs = pickle.load(open(f'graph_data/{name}/val_graphs.pkl', 'rb'))

# load all labels
train_labels = np.load(f'graph_data/{name}/train_labels.npy')
val_labels = np.load(f'graph_data/{name}/val_labels.npy')


# here we load the pre-calculated GKAT kernel
train_GKAT_kernel = pickle.load(open(f'graph_data/{name}/GKAT_dot_kernels_train_len={walk_len}.pkl', 'rb'))
val_GKAT_kernel = pickle.load(open(f'graph_data/{name}/GKAT_dot_kernels_val_len={walk_len}.pkl', 'rb'))

train_GAT_masking = pickle.load(open(f'graph_data/{name}/GAT_masking_train.pkl', 'rb'))
val_GAT_masking = pickle.load(open(f'graph_data/{name}/GAT_masking_val.pkl', 'rb'))

train_GKAT_kernel = [torch.from_numpy(g) for g in train_GKAT_kernel]
val_GKAT_kernel = [torch.from_numpy(g) for g in val_GKAT_kernel]



for bg in train_graphs:
  bg.remove_nodes_from(list(nx.isolates(bg)))
for bg in val_graphs:
  bg.remove_nodes_from(list(nx.isolates(bg)))


def generate_knn_degrees(bg, h_size):
  bg_h = np.zeros([bg.number_of_nodes(), h_size])
  degree_dict = bg.degree

  for node in bg.nodes():
      nbr_degrees = []
      nbrs = bg.neighbors(node)
      for nb in nbrs:
          nbr_degrees.append( degree_dict[nb] )
      nbr_degrees.sort(reverse = True)

      if len(nbr_degrees)==0:
        nbr_degrees.append(1e-3)

      bg_h[node] = (nbr_degrees + h_size*[0])[:h_size] 
      
  return bg_h


h_size = 5
train_h = [generate_knn_degrees(bg, h_size) for bg in train_graphs]
val_h = [generate_knn_degrees(bg, h_size) for bg in val_graphs]

train_graphs = [ dgl.from_networkx(g) for g in train_graphs]
val_graphs = [ dgl.from_networkx(g) for g in val_graphs]


GKAT_masking = [train_GKAT_kernel, val_GKAT_kernel]
GAT_masking = [train_GAT_masking, val_GAT_masking]

# GKAT and GAT

In [None]:

class GKATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, feat_drop=0., attn_drop=0., alpha=0.2, agg_activation=F.elu):
        super(GKATLayer, self).__init__()

        self.feat_drop = nn.Dropout(feat_drop)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        #torch.nn.init.xavier_uniform_(self.fc.weight)
        #torch.nn.init.zeros_(self.fc.bias)
        self.attn_l = nn.Parameter(torch.ones(size=(out_dim, 1)))
        self.attn_r = nn.Parameter(torch.ones(size=(out_dim, 1)))
        self.attn_drop = nn.Dropout(attn_drop)
        self.activation = nn.LeakyReLU(alpha)
        self.softmax = nn.Softmax(dim = 1)
        self.agg_activation=agg_activation

    def forward(self, feat, bg, counting_attn):
        self.g = bg
        h = self.feat_drop(feat)
        head_ft = self.fc(h).reshape((h.shape[0], -1))
        
        a1 = torch.mm(head_ft, self.attn_l)    # V x 1
        a2 = torch.mm(head_ft, self.attn_r)     # V x 1
        a = self.attn_drop(a1 + a2.transpose(0, 1))
        a = self.activation(a)

        a_ = a #- maxes
        a_nomi = torch.mul(torch.exp(a_), counting_attn.float())
        a_deno = torch.sum(a_nomi, 1, keepdim=True)
        a_nor = a_nomi/(a_deno+1e-9)

        ret = torch.mm(a_nor, head_ft)
        if self.agg_activation is not None:
            ret = self.agg_activation(ret)

        return ret



class GKATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, feat_drop=0., attn_drop=0., alpha=0.2, agg_activation=F.elu):
        super(GKATLayer, self).__init__()

        self.feat_drop = feat_drop  #nn.Dropout(feat_drop, training=self.training)
        self.attn_drop = attn_drop  #nn.Dropout(attn_drop)
        
        self.fc_Q = nn.Linear(in_dim, out_dim, bias=False)
        self.fc_K = nn.Linear(in_dim, out_dim, bias=False)
        self.fc_V = nn.Linear(in_dim, out_dim, bias=False)
        
        self.softmax = nn.Softmax(dim = 1)

        self.agg_activation=agg_activation

            
    def forward(self, feat, bg, counting_attn):
        h = F.dropout(feat, p=self.feat_drop, training=self.training)

        Q = self.fc_Q(h).reshape((h.shape[0], -1))
        K = self.fc_K(h).reshape((h.shape[0], -1))
        V = self.fc_V(h).reshape((h.shape[0], -1))
        
        logits = F.dropout( torch.matmul( Q, torch.transpose(K,0,1) ) , p=self.attn_drop, training=self.training) / np.sqrt(Q.shape[1])

        maxes = torch.max(logits, 1, keepdim=True)[0]
        logits =  logits - maxes
        
        a_nomi = torch.mul(torch.exp( logits  ), counting_attn.float())
        a_deno = torch.sum(a_nomi, 1, keepdim=True)
        a_nor = a_nomi/(a_deno+1e-9)

        ret = torch.mm(a_nor, V)
        if self.agg_activation is not None:
            ret = self.agg_activation(ret)

        return ret



class GKATClassifier_ER(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_heads, n_classes, feat_drop_=0., attn_drop_=0.,):
        super(GKATClassifier_ER, self).__init__()

        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.layers = nn.ModuleList([
            nn.ModuleList([GKATLayer(in_dim, hidden_dim, feat_drop = feat_drop_, attn_drop = attn_drop_, agg_activation=F.elu) for _ in range(num_heads)]),
            nn.ModuleList([GKATLayer(hidden_dim * num_heads, hidden_dim, feat_drop = feat_drop_, attn_drop = attn_drop_, agg_activation=F.elu) for _ in range(num_heads)]), ])
        self.classify = nn.Linear(hidden_dim * num_heads, n_classes)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, bg, bg_h, counting_attn, normalize = 'normal'):
        h = torch.tensor(bg_h).float()
        num_nodes = h.shape[0]
        
        if normalize == 'normal':
            features = h.numpy() #.flatten()
            mean_ = np.mean(features, -1).reshape(-1,1)
            std_ = np.std(features, -1).reshape(-1,1)
            h = (h - mean_)/std_

        for i, gnn in enumerate(self.layers):
            all_h = []
            for j, att_head in enumerate(gnn):
                all_h.append(att_head(h, bg, counting_attn))   
            h = torch.squeeze(torch.cat(all_h, dim=1))

        bg.ndata['h'] = h
        hg = dgl.mean_nodes(bg, 'h')

        return self.classify(hg)


# GCN

In [None]:

class GCNClassifier_ER(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(GCNClassifier_ER, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, bg_h, normalize = 'normal'):
       
        h = torch.tensor(bg_h).float()       
        num_nodes = h.shape[0]
        
        if normalize == 'normal':
            features = h.numpy()
            mean_ = np.mean(features, -1).reshape(-1,1)
            std_ = np.std(features, -1).reshape(-1,1)
            h = (h - mean_)/std_

        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)
    

# SGC

In [None]:
def cal_Laplacian(graph):

    N = nx.adjacency_matrix(graph).shape[0]
    D = np.sum(nx.adjacency_matrix(graph), 1)
    D_hat = np.diag((np.array(D).flatten()+1e-5)**(-0.5))
    return np.identity(N) - np.dot(D_hat, nx.to_numpy_matrix(graph)).dot(D_hat)  

def rescale_L(L, lmax=2):
    """Rescale Laplacian eigenvalues to [-1,1]"""
    M, M = L.shape
    I = torch.diag(torch.ones(M))
    L /= lmax * 2
    L = torch.tensor(L)
    L -= I
    return L 

def lmax_L(L):
    """Compute largest Laplacian eigenvalue"""
    return scipy.sparse.linalg.eigsh(L, k=1, which='LM', return_eigenvectors=False)[0]





train_L_original = [cal_Laplacian(bg) for bg in train_graphs]
val_L_original = [cal_Laplacian(bg) for bg in val_graphs]

train_L_max = [lmax_L(L) for L in train_L_original]
val_L_max = [lmax_L(L) for L in val_L_original]

train_L = []
for iter, L in tqdm(enumerate(train_L_original)):
  train_L.append(rescale_L(L, train_L_max[iter]))

val_L = []
for iter, L in tqdm(enumerate(val_L_original)):
  val_L.append(rescale_L(L, val_L_max[iter]))


class Graph_ConvNet_LeNet5(nn.Module):
    
    def __init__(self, net_parameters):
        
        print('Graph ConvNet: LeNet5')
        
        super(Graph_ConvNet_LeNet5, self).__init__()
        
        # parameters
        h_size, CL1_F, CL1_K, CL2_F, CL2_K, FC1_F = net_parameters
        
        # graph CL1
        self.cl1 = nn.Linear(h_size*CL1_K, CL1_F) 
        Fin = CL1_K; Fout = CL1_F;
        scale = np.sqrt( 2.0/ (Fin+Fout) )
        self.cl1.weight.data.uniform_(-scale, scale)
        self.cl1.bias.data.fill_(0.0)
        self.CL1_K = CL1_K; self.CL1_F = CL1_F; 
        
        # graph CL2
        self.cl2 = nn.Linear(CL2_K*CL1_F, CL2_F) 
        Fin = CL2_K*CL1_F; Fout = CL2_F;
        scale = np.sqrt( 2.0/ (Fin+Fout) )
        self.cl2.weight.data.uniform_(-scale, scale)
        self.cl2.bias.data.fill_(0.0)
        self.CL2_K = CL2_K; self.CL2_F = CL2_F;
        
        # FC1
        self.fc1 = nn.Linear(CL2_F, FC1_F) 
        Fin = CL2_F; Fout = FC1_F;
        scale = np.sqrt( 2.0/ (Fin+Fout) )
        self.fc1.weight.data.uniform_(-scale, scale)
        self.fc1.bias.data.fill_(0.0)

        # nb of parameters
        nb_param = h_size* CL1_K* CL1_F + CL1_F          # CL1
        nb_param += CL2_K* CL1_F* CL2_F + CL2_F  # CL2
        nb_param += CL2_F* FC1_F + FC1_F        # FC1
        print('nb of parameters=',nb_param,'\n')
        
        
    def init_weights(self, W, Fin, Fout):

        scale = np.sqrt( 2.0/ (Fin+Fout) )
        W.uniform_(-scale, scale)

        return W
        
    def graph_conv_cheby(self, x, cl, L, Fout, K):

        # parameters
        # B = batch size
        # V = nb vertices
        # Fin = nb input features
        # Fout = nb output features
        # K = Chebyshev order & support size
        B, V, Fin = x.size(); B, V, Fin = int(B), int(V), int(Fin) 

        # rescale Laplacian
        
        # transform to Chebyshev basis
        x0 = x.permute(1,2,0).contiguous().cuda()  # V x Fin x B
        x0 = x0.view([V, Fin*B])            # V x Fin*B
        x = x0.unsqueeze(0)                 # 1 x V x Fin*B
        
        def concat(x, x_):
            x_ = x_.unsqueeze(0)            # 1 x V x Fin*B
            return torch.cat((x, x_), 0)    # K x V x Fin*B  
   
        x1 = torch.mm(L.double().cuda(),x0.double())              # V x Fin*B
        x = torch.cat((x, x1.unsqueeze(0)),0)  # 2 x V x Fin*B

        for k in range(2, K):
            x2 = 2 * torch.mm(L.cuda(),x1) - x0  
            x = torch.cat((x, x2.unsqueeze(0)),0)  # M x Fin*B
            x0, x1 = x1, x2  
        
        x = x.view([K, V, Fin, B])           # K x V x Fin x B     
        x = x.permute(3,1,2,0).contiguous()  # B x V x Fin x K       
        x = x.view([B*V, Fin*K])             # B*V x Fin*K
        
        # Compose linearly Fin features to get Fout features
        #print(x.shape)
        x = cl(x.float())                            # B*V x Fout  
        x = x.view([B, V, Fout])             # B x V x Fout
        #print(x.shape)
        
        return x
        
    def forward(self, x, L):
        
        # graph CL1
        x = torch.tensor(x).unsqueeze(0) # B x V x Fin=1  
        x = self.graph_conv_cheby(x, self.cl1, L, self.CL1_F, self.CL1_K)
        x = F.relu(x)

        # graph CL2
        x = self.graph_conv_cheby(x, self.cl2, L, self.CL2_F, self.CL2_K)
        x = F.relu(x)
        
        # FC1
        x = self.fc1(x)
        x = torch.mean(x, axis = 1)
            
        return x  

# Start Training

In [None]:


all_GGG_train_losses = []
all_GGG_train_acc = []
all_GGG_val_losses = []
all_GGG_val_acc = []
GGG_test_acc_end = []
GGG_test_acc_ckpt = []




from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    




for runtime in trange(runtimes):

    for method in ['GAT', 'GKAT', 'GCN', 'ChebyGNN']:

      ckpt_file = f'results/{name}/ckpt/{method}__ckpt.pt'

      
      
      if method == 'GKAT':
          num_features = 9
          train_GGG_masking, val_GGG_masking = GKAT_masking
          model = GKATClassifier_ER(h_size, num_features, num_heads, num_classes, feat_drop_ = feature_drop, attn_drop_ = atten_drop)
      if method == 'GAT':
          num_features = 9
          train_GGG_masking, val_GGG_masking = GAT_masking
          model = GKATClassifier_ER(h_size, num_features, num_heads, num_classes, feat_drop_ = feature_drop, attn_drop_ = atten_drop)
      if method == 'GCN':
          num_features = 32
          model = GCNClassifier_ER(h_size, num_features, num_classes)
      if method == 'ChebyGNN':
          CL1_F = 32
          CL1_K = 2
          CL2_F = 32
          CL2_K = 2
          FC1_F = 2
          net_parameters = [h_size, CL1_F, CL1_K, CL2_F, CL2_K, FC1_F]
          # instantiate the object net of the class 
          model = Graph_ConvNet_LeNet5(net_parameters)



      for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform(p)

      count_parameters(model)

      #model.apply(init_weights)
      loss_func = nn.CrossEntropyLoss()
      optimizer = optim.Adam(model.parameters(), lr=learning_rate)
      model.train()

      epoch_train_losses_GGG = []
      epoch_train_acc_GGG = []
      epoch_val_losses_GGG = []
      epoch_val_acc_GGG = []

      num_batches = int(len(train_graphs)/batch_size)

      epoch = 0
      nan_found = 0
      tol = 0 

      while True:
          if nan_found:
            break
          
          epoch_loss = 0
          epoch_acc = 0

          ''' Training '''
          for iter in range(num_batches):
          #for iter in range(2): 
              predictions = []
              labels = torch.empty(batch_size)
              rand_indices = np.random.choice(len(train_graphs), batch_size, replace=False)

              for b in range(batch_size): 

                  if method == 'GCN':
                      predictions.append(model(train_graphs[rand_indices[b]], train_h[rand_indices[b]][:,:h_size], normalize = normalize ))
                  elif method == 'GAT':
                      predictions.append(model(train_graphs[rand_indices[b]], train_h[rand_indices[b]][:,:h_size], train_GGG_masking[rand_indices[b]], normalize = normalize ))
                  elif method == 'GKAT':
                      predictions.append(model(train_graphs[rand_indices[b]], train_h[rand_indices[b]][:,:h_size], train_GGG_masking[rand_indices[b]], normalize = normalize ))
                  elif method == 'ChebyGNN':
                      predictions.append(model(train_h[rand_indices[b]], train_L[rand_indices[b]]))
                
                                  

                  if torch.isnan(predictions[b][0])[0]:
                    print('NaN found.')
                    break
                  
                  labels[b] = train_labels[rand_indices[b]]
              
              acc = 0
              for k in range(len(predictions)):
                if predictions[k][0][0]>predictions[k][0][1] and labels[k]==0:
                  acc += 1
                elif predictions[k][0][0]<=predictions[k][0][1] and labels[k]==1:
                  acc += 1
              acc /= len(predictions)  
              epoch_acc += acc  
              
              predictions = torch.squeeze(torch.stack(predictions))
              if torch.any(torch.isnan(predictions)):
                    print('NaN found.')
                    nan_found = 1
                    break
              
              loss = loss_func(predictions, labels.long())
              optimizer.zero_grad()
              loss.backward()
              optimizer.step()
              epoch_loss += loss.detach().item()

          epoch_acc /= (iter + 1)
          epoch_loss /= (iter + 1)

          val_acc = 0
          val_loss = 0
          predictions_val = []

          for b in range(len(val_graphs)): 

              if method == 'GCN':
                  predictions_val.append(model(val_graphs[b], val_h[b][:,:h_size], normalize = normalize ))
              elif method == 'GAT':
                  predictions_val.append(model(val_graphs[b], val_h[b][:,:h_size], val_GGG_masking[b], normalize = normalize ))
              elif method == 'GKAT':
                  predictions_val.append(model(val_graphs[b], val_h[b][:,:h_size], val_GGG_masking[b], normalize = normalize ))
              elif method == 'ChebyGNN':
                  predictions_val.append(model(val_h[b], val_L[b]))
                                        
          
          for k in range(len(predictions_val)):
            if predictions_val[k][0][0]>predictions_val[k][0][1] and val_labels[k]==0:
              val_acc += 1
            elif predictions_val[k][0][0]<=predictions_val[k][0][1] and val_labels[k]==1:
              val_acc += 1
              
          val_acc /= len(val_graphs)

          predictions_val = torch.squeeze(torch.stack(predictions_val))    
          loss = loss_func(predictions_val, torch.tensor(val_labels).long())
          val_loss += loss.detach().item()

          

          if len(epoch_val_losses_GGG) ==0:
            try:
              os.remove(f'{projectpath}{ckpt_file}')
            except:
              pass
            torch.save(model, f'{projectpath}{ckpt_file}')
            print('Epoch {}, acc{:.2f}, loss {:.4f}, tol {}, val_acc{:.2f}, val_loss{:.4f} -- checkpoint saved'.format(epoch, epoch_acc, epoch_loss, tol, val_acc, val_loss))
          elif (np.min(epoch_val_losses_GGG) >= val_loss) and (np.max(epoch_val_acc_GGG) <= val_acc): 
            torch.save(model, f'{projectpath}{ckpt_file}')
            print('Epoch {}, acc{:.2f}, loss {:.4f}, tol {}, val_acc{:.2f}, val_loss{:.4f} -- checkpoint saved'.format(epoch, epoch_acc, epoch_loss, tol, val_acc, val_loss))
          else:
            print('Epoch {}, acc{:.2f}, loss {:.4f}, tol {}, val_acc{:.2f}, val_loss{:.4f}'.format(epoch, epoch_acc, epoch_loss, tol, val_acc, val_loss))


          if epoch > start_tol:
            if np.min(epoch_val_losses_GGG) <= val_loss: 
              tol += 1
              if tol == tolerance: 
                  print('Loss do not decrease')
                  break
            else:
              if np.abs(epoch_val_losses_GGG[-1] - val_loss)<epsilon:
                  print('Converge steadily')
                  break
              tol = 0

              
          if epoch > max_epoch:
              print("Reach Max Epoch Number")
              break            

          epoch += 1
          epoch_train_acc_GGG.append(epoch_acc)
          epoch_train_losses_GGG.append(epoch_loss)
          epoch_val_acc_GGG.append(val_acc)
          epoch_val_losses_GGG.append(val_loss)

      all_GGG_train_acc.append(epoch_train_acc_GGG)
      all_GGG_train_losses.append(epoch_train_losses_GGG)
      all_GGG_val_acc.append(epoch_val_acc_GGG)
      all_GGG_val_losses.append(epoch_val_losses_GGG)



      np.save(f'{projectpath}results/{name}/epoch_train_acc_{method}_run{runtime}.npy', epoch_train_acc_GGG)
      np.save(f'{projectpath}results/{name}/epoch_val_acc_{method}_run{runtime}.npy', epoch_val_acc_GGG)
      np.save(f'{projectpath}results/{name}/epoch_train_losses_{method}_run{runtime}.npy', epoch_train_losses_GGG)
      np.save(f'{projectpath}results/{name}/epoch_val_losses_{method}_run{runtime}.npy', epoch_val_losses_GGG)

