# AGS-GNN Graph Sampling

In [1]:
import sys
sys.path.append('../Submodular')

import DeviceDir

DIR, RESULTS_DIR = DeviceDir.get_directory()
device, NUM_PROCESSORS = DeviceDir.get_device()

In [2]:
from ipynb.fs.full.Dataset import get_data
from ipynb.fs.full.Dataset import datasets as available_datasets
from ipynb.fs.full.Utils import save_plot

In [3]:
import argparse
from argparse import ArgumentParser

#set default arguments here
def get_configuration():
    parser = ArgumentParser()
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--log_info', type=bool, default=True)
    parser.add_argument('--recompute', type=bool, default=False)
    parser.add_argument('--pbar', type=bool, default=False)
    parser.add_argument('--batch_size', type=int, default=2048)
    parser.add_argument('--learning_rate', type=float, default=0.01)
    parser.add_argument('--num_gpus', type=int, default=-1)
    parser.add_argument('--parallel_mode', type=str, default="dp", choices=['dp', 'ddp', 'ddp2'])
    parser.add_argument('--dataset', type=str, default="Cora", choices=available_datasets)
    parser.add_argument('--use_normalization', action='store_false', default=True)
    parser.add_argument('-f') ##dummy for jupyternotebook
    args = parser.parse_args()
    
    dict_args = vars(args)
    
    return args, dict_args

args, dict_args = get_configuration()

In [4]:
import random
import numpy as np
import torch

seed = 123

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
None

## GNN model

In [5]:
import os
import math
import time
import torch_geometric
from torch_geometric.data import Data, Dataset
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv, ChebConv
from torch_geometric.nn import GraphConv, TransformerConv
from torch_geometric.utils import degree
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from ipynb.fs.full.SpatialConv import SpatialConv
from tqdm import tqdm
import torch.nn as nn

## Homophilic GNN

In [6]:
class HomophilicNet(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels, end_hidden):
        super().__init__()
        in_channels = num_features
        out_channels = num_classes
        self.conv1 = GraphConv(in_channels, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
#         self.conv3 = GraphConv(hidden_channels, hidden_channels)
#         self.lin = torch.nn.Linear(3 * hidden_channels, end_hidden)
#         self.lin2 = torch.nn.Linear(3 * hidden_channels, out_channels)
        self.lin = torch.nn.Linear(2 * hidden_channels, end_hidden)
        self.lin2 = torch.nn.Linear(2 * hidden_channels, out_channels)


    def set_aggr(self, aggr):
        self.conv1.aggr = aggr
        self.conv2.aggr = aggr
#         self.conv3.aggr = aggr

    def forward(self, x0, edge_index, edge_weight=None):
        x1 = F.relu(self.conv1(x0, edge_index, edge_weight))
        x1 = F.dropout(x1, p=0.2, training=self.training)
        x2 = F.relu(self.conv2(x1, edge_index, edge_weight))
        x2 = F.dropout(x2, p=0.2, training=self.training)
#         x3 = F.relu(self.conv3(x2, edge_index, edge_weight))
#         x3 = F.dropout(x3, p=0.2, training=self.training)
#         x = torch.cat([x1, x2, x3], dim=-1)

        x = torch.cat([x1, x2], dim=-1)
        
        c1 = self.lin(x)
        c2 = self.lin2(x)
        
        return F.relu(c1), c2.log_softmax(dim=-1)

## Combination Network

In [7]:
class AGSGNN(torch.nn.Module):
    def __init__(self, num_features,num_classes, hidden_channels=256, dropout=0.5, N = 0):
        super().__init__()
        self.num_classes = num_classes
        
        hidden = int(hidden_channels/2)        
        self.gnn1 = HomophilicNet(num_features, num_classes, hidden_channels, hidden)
        self.gnn2 = HomophilicNet(num_features, num_classes, hidden_channels, hidden)
        self.p = dropout
        self.com_lin = nn.Linear(hidden*2, num_classes)      
        
    def forward(self, batch_data):
        
        #out = model(batch_data.x, batch_data.edge_index, batch_data.edge_weight)
        #out = model(batch_data.x, batch_data.edge_index)
        
        x1, x1c1 = self.gnn1(batch_data[0].x, batch_data[0].edge_index, batch_data[0].weight)
        x2, x2c2 = self.gnn2(batch_data[1].x, batch_data[1].edge_index, batch_data[1].weight)

#         x1 = self.gnn1(batch_data[0].x, batch_data[0].edge_index)
#         x2 = self.gnn2(batch_data[1].x, batch_data[1].edge_index)

        a1 = F.relu(x1)
        a1 = F.dropout(a1, p=self.p, training=self.training)
        
        s1 = F.relu(x2)        
        s1 = F.dropout(s1, p=self.p, training=self.training)
        
        batch_size = batch_data[0].batch_size        
        x = torch.cat([a1[:batch_size,:], s1[:batch_size,:]], dim=-1)
        x = self.com_lin(x)
        
        #return x
    
        return x.log_softmax(dim=-1), x1c1, x2c2

## GNN Training and Testing

In [8]:
from ipynb.fs.full.AGSGraphSampler import AGSGraphSampler
from torch_geometric.loader import NeighborSampler, NeighborLoader

In [9]:
from collections import Counter
import random

def prediction(y_pred_seed, y_pred_hm, y_pred_ht):
    
    all_tensors = [y_pred_seed, y_pred_hm, y_pred_ht]
    final_predictions = []

    for i in range(len(y_pred_seed)):
        values_at_index = [tensor[i].item() for tensor in all_tensors]
        counter = Counter(values_at_index)
        most_common_values = counter.most_common()

        # Check if there's a tie
        if len(most_common_values) > 1 and most_common_values[0][1] == most_common_values[1][1]:
            #selected_value = random.choice([value for value, count in most_common_values[:2]])
            selected_value = y_pred_seed[i].item()
        else:
            selected_value = most_common_values[0][0]

        final_predictions.append(selected_value)
            
    return torch.LongTensor(final_predictions)

y_pred_seed = torch.tensor([0, 1, 2, 2, 1])
y_pred_hm = torch.tensor([2, 1, 1, 0, 2])
y_pred_ht = torch.tensor([0, 2, 1, 1, 0])


# prediction(y_pred_seed,y_pred_hm,y_pred_ht)

In [10]:
def test(model, loader, mask, name='Train', channel='all'):
    if args.log_info:
        pbar = tqdm(total=sum(mask).item())
        pbar.set_description(f'Evaluating {name}')

    model.eval()
    model.gnn1.set_aggr('add' if args.use_normalization else 'mean')
    model.gnn2.set_aggr('add' if args.use_normalization else 'mean')
    
    total_correct=0
    total_examples=0
    
    with torch.no_grad():                  
    
        for i,batch_data in enumerate(loader):
            batch_data = [batch_data,batch_data]
            
            batch_data = [b_data.to(device) for b_data in batch_data]
            batch_size = batch_data[0].batch_size
            
            
#             if args.use_normalization:
#                 batch_data[0].weight = batch_data[0].edge_norm * batch_data[0].edge_weight
#                 batch_data[1].weight = batch_data[1].edge_norm * batch_data[1].edge_weight
#             else:
#                 batch_data[0].weight = batch_data[0].edge_weight
#                 batch_data[1].weight = batch_data[1].edge_weight
            
            batch_data[0].weight = batch_data[0].edge_weight
            batch_data[1].weight = batch_data[1].edge_weight
        
            
            out, out1, out2 = model(batch_data)
                
            #print(out.shape, out1.shape, out2.shape)
            
            y_pred_seed = out[:batch_size].argmax(dim=-1).cpu()
            y_pred_hm = out1[:batch_size].argmax(dim=-1).cpu()
            y_pred_ht = out2[:batch_size].argmax(dim=-1).cpu()
            
            #print(y_pred_seed.shape,y_pred_hm.shape, y_pred_ht.shape)
            
            
            y_true = batch_data[0].y[:batch_size]
            
            if name == 'Train':
                t_mask = batch_data[0].train_mask[:batch_size]
            elif name == 'Validation':
                t_mask = batch_data[0].val_mask[:batch_size]
            else:
                t_mask = batch_data[0].test_mask[:batch_size]

            if channel=='sd':
                y_pred = y_pred_seed
            elif channel=='hm':
                y_pred = y_pred_hm
            elif channel=='ht':
                y_pred = y_pred_ht
            else:
                y_pred = prediction(y_pred_seed, y_pred_hm, y_pred_ht)

            #print(y_pred)                
            correct = y_pred.eq(y_true.cpu())
            total_correct+= correct[t_mask].sum().item()
            items = t_mask.sum().item()
            total_examples+= items           
            if args.log_info:
                pbar.update(items)
    
    if args.log_info:
        pbar.close()

    return total_correct/total_examples


In [11]:
#https://github.com/GraphSAINT/GraphSAINT/issues/11
    
def train(DATASET_NAME,model, data, dataset, epochs=1, channel='all', BATCH_SIZE=1024): #'all', 'hm', 'ht', 'sd'
        
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.001)    
    optimizer = torch.optim.Adam(model.parameters())    
    
    if data.y.ndim == 1:
        criterion = torch.nn.CrossEntropyLoss() #regular logits as output
#         criterion = torch.nn.NLLLoss() ## if log softmax used as activation
    else:
        criterion = torch.nn.BCEWithLogitsLoss()     #multillabel
    
    row, col = data.edge_index
    data.edge_weight = 1. / degree(col, data.num_nodes)[col]  # Norm by in-degree.
    
    #minibatch_size= 2048
    minibatch_size = BATCH_SIZE
    
    sampler_dir = DIR+'AGSGSAINTII/'+DATASET_NAME
    if not os.path.exists(sampler_dir):
        os.makedirs(sampler_dir)
        
    #batch_size= min(1024, data.num_nodes)
    #num_steps=math.ceil(sum (data.train_mask)/batch_size)    
    num_steps=math.ceil(data.num_nodes/minibatch_size)
    num_workers = 8  if data.num_nodes>50000 else 0
    
    
    sample_func =['wrw', 'wrw']
    weight_func =[
        {'exact':False,'weight':'fastlink'}, #exact for exact size to the batch
        {'exact':False,'weight':'fastlink'}
    ]

    params={'knn':{'metric':'cosine'},
            'submodular':{'metric':'cosine'},
            'link-nn':{'value':'min'},
            'link-sub':{'value':'max'},
            'disjoint':{'value':'mst'},
           }
    
    loader = AGSGraphSampler(
        data, batch_size=minibatch_size, walk_length=2, num_steps=num_steps, sample_coverage=100,
        num_workers=num_workers,log=args.log_info,save_dir=sampler_dir,recompute = args.recompute, shuffle = False,
        sample_func = sample_func, weight_func=weight_func, params=params)
        
    
    #for evaluation
    train_num_steps = int(torch.ceil(sum(data.train_mask)/minibatch_size))
    val_num_steps = int(torch.ceil(sum(data.val_mask)/minibatch_size))
    test_num_steps = int(torch.ceil(sum(data.test_mask)/minibatch_size))
 
    
    sample_batch_size=2048
    train_loader = NeighborLoader(data, input_nodes=data.train_mask,num_neighbors=[8,4], 
                            batch_size=sample_batch_size, shuffle=False, num_workers=num_workers)
    val_loader = NeighborLoader(data,input_nodes=data.val_mask,num_neighbors=[8,4], 
                                batch_size=sample_batch_size,shuffle=False, num_workers=num_workers)
    test_loader = NeighborLoader(data, input_nodes=data.test_mask,num_neighbors=[8,4], 
                                 batch_size=sample_batch_size,shuffle=False, num_workers=num_workers)

            
    best_acc=0    
    train_losses = []; val_accuracies = []; train_accuracies = []; test_accuracies = []    
    max_iteration = epochs    
    
    
    th_node = 10000
    
    if data.num_nodes<th_node:
        test_data = data.clone()
        test_data.weight = loader.edge_norm * test_data.edge_weight
        test_data.seed_node=torch.arange(test_data.num_nodes)  
        test_data.batch_size=test_data.seed_node.shape[0]
        all_data = [test_data.to(device), test_data.to(device)]    
        
        if args.log_info:
            print(all_data)
            
    for epoch in range(1,epochs+1):
        
        if args.log_info:
            pbar = tqdm(total=num_steps)
            pbar.set_description(f'Epoch {epoch:02d}')
        
        model.train()
        model.gnn1.set_aggr('add' if args.use_normalization else 'mean')
        model.gnn2.set_aggr('add' if args.use_normalization else 'mean')

        total_loss = total_examples = 0
        total_loss_seed = total_loss_hm = total_loss_ht = 0
        total_seed_example = total_hm_example = total_ht_example = 0
                    
        for i,batch_data in enumerate(loader):
            
#             print(batch_data);
#             print("*"*50)            
            
            batch_data = [b_data.to(device) for b_data in batch_data]
            batch_size = batch_data[0].batch_size
            mask = batch_data[0].train_mask[:batch_size]
            y_true = batch_data[0].y[:batch_size]
            
            if torch.sum(mask) == 0:
                print("no training mask in seed node")
#                 continue

            optimizer.zero_grad()
            if args.use_normalization:                 
                
                batch_data[0].weight = batch_data[0].edge_norm * batch_data[0].edge_weight
                batch_data[1].weight = batch_data[1].edge_norm * batch_data[1].edge_weight
                
                out, out1, out2 = model(batch_data)                            
                #print(out.shape, out1.shape, out2.shape)
                
                                
#                 print(out[:batch_size][mask].shape, y_true[mask].shape)                
#                 print(loss.shape)
#                 print(batch_data[0].node_norm.shape)
#                 print(batch_data[0].node_norm[:batch_size].shape)
#                 loss = criterion(out, batch_data.y, reduction='none')

                #--------- loss computation seed nodes --------- #                
                loss_seed = F.nll_loss(out[:batch_size][mask], y_true[mask], reduction='none')
                loss_seed = (loss_seed * batch_data[0].node_norm[:batch_size][mask]).sum()
        
                #--------- loss computation homophily nodes --------- #
            
                loss_hm = F.nll_loss(out1[batch_data[0].train_mask], batch_data[0].y[batch_data[0].train_mask], reduction='none')
                loss_hm = (loss_hm * batch_data[0].node_norm[batch_data[0].train_mask]).sum()
                
                #--------- loss computation homophily nodes --------- #
                loss_ht = F.nll_loss(out2[batch_data[1].train_mask], batch_data[1].y[batch_data[1].train_mask], reduction='none')
                loss_ht = (loss_ht * batch_data[1].node_norm[batch_data[1].train_mask]).sum()
                
                if channel=='sd':loss = loss_seed
                elif channel=='hm':loss = loss_hm
                elif channel=='ht':loss = loss_ht
                else:loss = loss_seed + loss_hm + loss_ht                

            else:
                batch_data[0].weight = batch_data[0].edge_weight
                batch_data[1].weight = batch_data[1].edge_weight
                    
                out, out1, out2 = model(batch_data)                            
                #print(out.shape, out1.shape, out2.shape)
                
                mask = batch_data[0].train_mask[:batch_size]
                y_true = batch_data[0].y[:batch_size]
                
                #--------- loss computation seed nodes --------- #                
                loss_seed = F.nll_loss(out[:batch_size][mask], y_true[mask])
                
                #--------- loss computation homophily nodes --------- #
                loss_hm = F.nll_loss(out1[batch_data[0].train_mask], batch_data[0].y[batch_data[0].train_mask])
                
                #--------- loss computation homophily nodes --------- #
                loss_ht = F.nll_loss(out2[batch_data[1].train_mask], batch_data[1].y[batch_data[1].train_mask])
                #loss = criterion(out[batch_data.train_mask], batch_data.y[batch_data.train_mask])
                
                if channel=='sd':loss = loss_seed
                elif channel=='hm':loss = loss_hm
                elif channel=='ht':loss = loss_ht
                else:loss = loss_seed + loss_hm + loss_ht
                                                        
            loss.backward()
            optimizer.step()
            
            total_loss_seed+= loss_seed.item()*batch_size
            total_loss_hm+= loss_hm.item()*sum(batch_data[0].train_mask).item()
            total_loss_ht+= loss_ht.item()*sum(batch_data[1].train_mask).item()
            
            total_seed_example+= sum(mask).item()
            total_hm_example+=sum(batch_data[0].train_mask).item()
            total_ht_example+=sum(batch_data[1].train_mask).item()
            
            if args.log_info:
                pbar.update(1)
        
        if args.log_info:
            pbar.close()
            
        total_loss_seed /= max(total_seed_example,1)
        total_loss_hm /= max(total_hm_example,1)
        total_loss_ht /= max(total_ht_example,1)
        
        total_loss=total_loss_seed+total_loss_hm+total_loss_ht
        total_examples = total_seed_example + total_hm_example + total_ht_example
        
        
        if epoch%10==0:
        
            if args.log_info:
                print(f'Loss:{total_loss:0.4f} seed:{total_loss_seed:0.4f} hm:{total_loss_hm:0.4f} ht:{total_loss_ht:0.4f}')
                print(f'Example:{total_examples:1d} seed:{total_seed_example:1d} hm:{total_hm_example:1d} ht:{total_ht_example:1d}')

            if channel=='sd':train_losses.append(total_loss_seed)
            elif channel=='hm':train_losses.append(total_loss_hm)
            elif channel=='ht':train_losses.append(total_loss_ht)
            else:train_losses.append(total_loss)


            model.eval()       
            model.gnn1.set_aggr('add' if args.use_normalization else 'mean')
    #         model.gnn2.set_aggr('add' if args.use_normalization else 'mean')
            accs = [0,0,0]

            if data.num_nodes<th_node:

                with torch.no_grad():
                    out, out1, out2 = model(all_data)

                    #print(out.shape, out1.shape, out2.shape)
                    y_pred_seed = out.argmax(dim=-1).cpu()
                    y_pred_hm = out1.argmax(dim=-1).cpu()
                    y_pred_ht = out2.argmax(dim=-1).cpu()

                    if channel=='sd':
                        y_pred = y_pred_seed
                    elif channel=='hm':
                        y_pred = y_pred_hm
                    elif channel=='ht':
                        y_pred = y_pred_ht
                    else:
                        y_pred = prediction(y_pred_seed, y_pred_hm, y_pred_ht)

                    #print(y_pred)                
                    correct = y_pred.eq(data.y.cpu())
                    accs = []
                    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
                        accs.append(correct[mask].sum().item() / mask.sum().item()) 

            else:
                accs[0] = test(model, train_loader, data.train_mask, name='Train',channel=channel)
                accs[1] = test(model, val_loader, data.val_mask, name='Validation',channel=channel)
                accs[2] = test(model, test_loader, data.test_mask, name='Test',channel=channel)

            train_accuracies.append(accs[0])
            val_accuracies.append(accs[1])
            test_accuracies.append(accs[2])
            std_dev = np.std(train_losses[-5:])

            if args.log_info:
                print(f'Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Train: {accs[0]:.4f}, Val: {accs[1]:.4f}, Test: {accs[2]:.4f}, Std dev: {std_dev:.4f}')

#             if epoch>=5 and std_dev<=1e-3:
#                 if args.log_info:
#                     print("Iteration for convergence: ", epoch)
#                 max_iteration = epoch
#                 break
    
    if args.log_info:
        save_plot([train_losses, train_accuracies, val_accuracies, test_accuracies], labels=['Loss','Train','Validation','Test'], name='Results/AGSGSValidation', yname='Accuracy', xname='Epoch')
        print ("Best Validation Accuracy, ",max(val_accuracies))
        print ("Best Test Accuracy, ",max(test_accuracies))
    
    if (minibatch_size == data.num_nodes) and data.num_nodes<th_node:
        del all_data
    
    return max(test_accuracies), max_iteration

In [12]:
def AGSGSperformanceSampler(DATASET_NAME,data, dataset, num_classes, epochs=1,channel='all'):
    
    ###
    #BATCH_SIZE = min(data.num_nodes, 1024)
    BATCH_SIZE = 6000
    
    if args.log_info:
        print("BATCH SIZE: ", BATCH_SIZE)
        
    model = AGSGNN(data.x.shape[1], num_classes, hidden_channels=256, N = BATCH_SIZE).to(device)            
    
    if args.log_info:
        print(model)
    
    itr, accuracy = train(DATASET_NAME, model, data, dataset, epochs, channel, BATCH_SIZE)
    
    return itr, accuracy, model

In [13]:
from sklearn.decomposition import TruncatedSVD
import numpy as np

def adj_feature(data):    
    adj_mat = torch.zeros((data.num_nodes,data.num_nodes))
    edges = data.edge_index.t()
    adj_mat[edges[:,0], edges[:,1]] = 1
    adj_mat[edges[:,1], edges[:,0]] = 1
    
#     n_components = data.x.shape[1]
    n_components = min(256, data.x.shape[1], data.num_nodes)

    svd = TruncatedSVD(n_components=n_components)
    x = svd.fit_transform(adj_mat)
    
    x = torch.Tensor(x)
    x.shape    
    
    return x

# x = adj_feature(data)
# x.shape

In [14]:
args.log_info = False
args.recompute = False

# DATASET_NAME = 'Cora'
# data, dataset = get_data(DATASET_NAME,DIR=None, log=False, h_score=True, split_no=0)
# print(data)

# channel = 'all' #'all', 'hm', 'ht', 'sd'

# # if DATASET_NAME in ['Squirrel', 'Chameleon']:
# #     data.x = torch.cat((data.x, adj_feature(data)), dim=1)
# #     if args.log_info == True:
# #         print(data.x.shape)
    
# best_acc, num_iteration, _ =  AGSGSperformanceSampler(DATASET_NAME, data, dataset, dataset.num_classes,
#                                                       epochs=150, channel=channel)
# print(best_acc, num_iteration)

# Batch Experiments

In [None]:
def batch_experiments(num_run=1, channel='all'):
    
    ALL_DATASETs= [
        'genius',
        'pokec',
        'arxiv-year',
        'snap-patents',
        'twitch-gamers',
        'wiki',
        'AmazonProducts',
        'Yelp',
        'Reddit',
        'Reddit2',        
    ]
    
    #ALL_DATASETs= ["Cora"]
#     ALL_DATASETs= ["Squirrel"]
#     ALL_DATASETs= ["Texas", "Cornell", "Wisconsin"]
#     ALL_DATASETs= ["cornell5", "penn94", "johnshopkins55"]

    args.log_info = False
    
    filename = "Results/AGSGNN-SAINT-II-GS-"+channel+".txt"
    
    for DATASET_NAME in ALL_DATASETs:  
        print(DATASET_NAME, end=' ')
                
        result_file = open(filename,'a+')        
        result_file.write(f'{DATASET_NAME} ')
        result_file.close()
                
        accs = []
        itrs = []
                
        for i in range(num_run):
            data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=False, split_no=i,random_state=i)  

            
            if len(data.y.shape) > 1:
                data.y = data.y.argmax(dim=1)        
                num_classes = torch.max(data.y).item()+1
            else:
                num_classes = dataset.num_classes
            
            if num_classes!= torch.max(data.y)+1:
                num_classes = torch.max(data.y).item()+1
                
            if data.num_nodes<100000:
                max_epochs = 150
            else:
                max_epochs = 50
                
#             if DATASET_NAME in ['Squirrel', 'Chameleon', 
#                                 #'cornell5','penn94','johnshopkins55'
#                                ]:
#                 data.x = torch.cat((data.x, adj_feature(data)), dim=1)
#                 if args.log_info == True:
#                     print(data.x.shape)
                              
            accuracy, itr, _ = AGSGSperformanceSampler(DATASET_NAME, data, dataset, num_classes,epochs=max_epochs, channel=channel)
            
            accs.append(accuracy)
            itrs.append(itr)
            #print(itr, accuracy)
                        
        #print(accs, itrs)
        print(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}')
        result_file = open(filename,'a+')
        result_file.write(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}\n')
        result_file.close()
                
start = time.time()
batch_experiments(num_run=1, channel='all')
end = time.time()
print("Time spent:", end-start)

genius acc 0.8003 sd 0.0000 itr 50 sd 0
pokec 

## Visualize representation

In [None]:
if __name__ == '__main__':    
    
#     n=7
#     x = torch.Tensor([[1,0],[1,0],[1,0],[0,1],[0,1],[0,1],[0,1]])
#     y = torch.LongTensor([0,0,0, 1, 1, 1, 1])
#     edge_index = torch.LongTensor([[1,2],[1,4],[1,5],[2,1],[3,6],[3,7],[4,5],[4,1],[4,6],[4,7],[5,1],[5,4],[5,6],[6,3],[6,4],[6,5],[6,7],[7,3],[7,4],[7,6]]).T
#     edge_index = edge_index-1
    
#     mask = torch.zeros(n, dtype=torch.bool)
#     mask[[1,3]] = True
    
#     test_data = Data(x = x, y = y, edge_index = edge_index, train_mask = mask, test_mask = mask, val_mask = mask)    
#     print(test_data)
    
    
    None

In [None]:
# class AGS_layer(torch.nn.Module):
#     def __init__(self, input_channels, output_channels, dropout=0.2):
#         super().__init__()
#         self.T = 3
#         self.p = dropout
#         #self.Aconv1 = GCNConv(input_channels, output_channels)        
#         #self.Sconv1 = SpatialConv(input_channels, output_channels)
        
#         self.Aconv1 = GCNConv(input_channels, output_channels)
#         self.Sconv1 = SpatialConv(input_channels, output_channels)
        
#         self.I1 = nn.Linear(input_channels, output_channels)
        
#         self.layer_norm_a1 =  nn.LayerNorm(output_channels)
#         self.layer_norm_s1 =  nn.LayerNorm(output_channels)
#         self.layer_norm_i1 =  nn.LayerNorm(output_channels)
        
#         self.alpha_a1 = nn.Linear(output_channels, 1)
#         self.alpha_s1 = nn.Linear(output_channels, 1)
#         self.alpha_i1 = nn.Linear(output_channels, 1)
#         self.w1 = nn.Linear(3, 3)
        
#         #self.reset_parameters()
            
#     def reset_parameters(self):
        
#         stdv = 1. / math.sqrt(self.I1.weight.size(1))
#         std_att = 1. / math.sqrt(self.w1.weight.size(1))
#         std_att_vec = 1. / math.sqrt( self.alpha_a1.weight.size(1))
        
#         self.I1.weight.data.uniform_(-stdv, stdv)
        
#         self.alpha_a1.weight.data.uniform_(-std_att, std_att)
#         self.alpha_s1.weight.data.uniform_(-std_att, std_att)
#         self.alpha_i1.weight.data.uniform_(-std_att, std_att)
        
#         self.w1.weight.data.uniform_(-std_att_vec, std_att_vec)
        
#         self.layer_norm_a1.reset_parameters()
#         self.layer_norm_s1.reset_parameters()
#         self.layer_norm_i1.reset_parameters()
        

#     def forward(self, x0, edge_index, edge_weight=None):
#         a1 = F.relu(self.Aconv1(x0, edge_index, edge_weight))
#         a1 = self.layer_norm_a1(a1)
#         a1 = F.dropout(a1, p=self.p, training=self.training)
        
#         s1 = F.relu(self.Sconv1(x0, edge_index, edge_weight))
#         s1 = self.layer_norm_s1(s1)
#         s1 = F.dropout(s1, p=self.p, training=self.training)

#         i1 = F.relu(self.I1(x0))
#         i1 = self.layer_norm_i1(i1)
#         i1 = F.dropout(i1, p=self.p, training=self.training)
        
#         ala1 = torch.sigmoid(self.alpha_a1(a1))
#         als1 = torch.sigmoid(self.alpha_s1(s1))
#         ali1 = torch.sigmoid(self.alpha_i1(i1))        
#         alpha1 = F.softmax(self.w1(torch.cat([ala1, als1, ali1],dim=-1)/self.T), dim=1)        
        
#         x1 = torch.mm(torch.diag(alpha1[:,0]),a1) + torch.mm(torch.diag(alpha1[:,1]),s1) + torch.mm(torch.diag(alpha1[:,2]),i1)                
        
#         return x1
        
# class AGS_GCN(torch.nn.Module):
#     def __init__(self, num_features, num_classes, hidden_channels=16, dropout=0.2):
#         super().__init__()        
#         self.num_classes = num_classes
#         self.p = dropout
        
#         self.ags_layer1 = AGS_layer(num_features, hidden_channels)
#         self.ags_layer2 = AGS_layer(hidden_channels, hidden_channels)
#         #self.ags_layer2 = AGS_layer(hidden_channels, num_classes)
                
#         self.CombineW = nn.Linear(2 * hidden_channels, hidden_channels)
#         self.PredW = nn.Linear(1*hidden_channels, num_classes)
        
    
#     def forward(self, x0, edge_index, edge_weight=None):
        
#         #x0 = F.dropout(x0, p=self.p, training=self.training)
#         x1 = self.ags_layer1(x0, edge_index, edge_weight)
#         x1 = F.dropout(x1, p=self.p, training=self.training)
        
#         x2 = self.ags_layer2(x1, edge_index, edge_weight)
#         x2 = F.dropout(x2, p=self.p, training=self.training)        
        
#         #x = self.PredW(torch.cat([x1, x2], dim=-1))
#         x = self.PredW(x2)
         
#         #return x
#         return x.log_softmax(dim=-1)