In [1]:
import os
import sys
if not os.getcwd().endswith("Submodular"):
    sys.path.append('../Submodular')    

In [2]:
import DeviceDir

DIR, RESULTS_DIR = DeviceDir.get_directory()
device, NUM_PROCESSORS = DeviceDir.get_device()

In [3]:
from ipynb.fs.full.Dataset import get_data
from ipynb.fs.full.Dataset import datasets as available_datasets
from ipynb.fs.full.Utils import save_plot

In [4]:
import argparse
from argparse import ArgumentParser

#set default arguments here
def get_configuration():
    parser = ArgumentParser()
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--log_info', type=bool, default=True)
    parser.add_argument('--pbar', type=bool, default=False)
    parser.add_argument('--batch_size', type=int, default=2048)
    parser.add_argument('--learning_rate', type=float, default=0.01)
    parser.add_argument('--num_gpus', type=int, default=-1)
    parser.add_argument('--parallel_mode', type=str, default="dp", choices=['dp', 'ddp', 'ddp2'])
    parser.add_argument('--dataset', type=str, default="Cora", choices=available_datasets)
    #parser.add_argument('--use_normalization', action='store_false', default=True)
    parser.add_argument('--use_normalization', action='store_true')    
    parser.add_argument('-f') ##dummy for jupyternotebook
    
    args = parser.parse_args()
    
    dict_args = vars(args)
    
    return args, dict_args

args, dict_args = get_configuration()

## Packages

In [5]:
import torch.nn as nn
import numpy as np
from torch.nn import init
from random import shuffle, randint
import torch.nn.functional as F
from itertools import combinations, combinations_with_replacement
from sklearn.metrics import f1_score, accuracy_score
from sklearn.decomposition import TruncatedSVD
import matplotlib.pyplot as plt
import sys
from torch_geometric.data import Data
import logging
import time

import argparse
import os.path as osp
import math

In [6]:
import random
import numpy as np
import torch

seed = 123

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
None

### GSAINT model

In [7]:
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm

from torch_geometric.loader import NeighborSampler, NeighborLoader
from torch_geometric.loader import GraphSAINTRandomWalkSampler, GraphSAINTNodeSampler, GraphSAINTEdgeSampler, GraphSAINTSampler
from ipynb.fs.full.AGSGraphSampler import AGSGraphSampler

from torch_geometric.nn import GraphConv
from torch_geometric.utils import degree

In [8]:
from torch_geometric.nn import GCNConv, GATConv, GINConv, SAGEConv
GNNconv = GINConv
class GCN(torch.nn.Module):
    def __init__(self, num_features,num_classes, hidden_channels=16, GNNconv = GCNConv):
        super().__init__()        
        ##GNN layer
        if(GNNconv==GINConv):
            self.MLP1 = nn.Linear(num_features,hidden_channels)
            self.MLP2 = nn.Linear(hidden_channels,num_classes)
            self.conv1 = GNNconv(self.MLP1)
            self.conv2 = GNNconv(self.MLP2)                
        else:        
            self.conv1 = GNNconv(num_features, hidden_channels)
            self.conv2 = GNNconv(hidden_channels,num_classes)

    def forward(self, x, edge_index, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        
        x = x.log_softmax(dim=-1)
        
        return x

In [9]:
def test(model, loader, mask, name='Train'):
    if args.log_info:
        pbar = tqdm(total=sum(mask).item())
        pbar.set_description(f'Evaluating {name}')

    model.eval()
    #model.set_aggr('add' if args.use_normalization else 'mean')
    
    total_correct=0
    total_examples=0
    
    with torch.no_grad():                  
    
        for i,batch_data in enumerate(loader):
            out = model(batch_data.x.to(device), batch_data.edge_index.to(device))
            out=out[:batch_data.batch_size,:]
            pred = out.argmax(dim=-1)            
            correct = pred.eq(batch_data.y[:batch_data.batch_size].to(device))

            total_correct+=correct.sum()
            total_examples+=batch_data.batch_size

            if args.log_info:                
                pbar.update(batch_data.batch_size)
    if args.log_info:
        pbar.close()

    return total_correct.item()/total_examples

In [10]:
def train(DATASET_NAME, model, data, dataset, epochs=10,train_neighbors=[8,4],test_neighbors=[8,4]):
        
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()
    
    row, col = data.edge_index
    data.edge_weight = 1. / degree(col, data.num_nodes)[col]  # Norm by in-degree.

    
    sampler_dir = DIR+'AGSGIN/'+DATASET_NAME
    if not os.path.exists(sampler_dir):
        os.makedirs(sampler_dir)
        
    batch_size=min(data.num_nodes,4096)
    #batch_size= 6000
    
    num_steps=math.ceil(data.num_nodes/batch_size) #num_steps=5    
    num_workers = 0 if data.num_nodes <100000 else 8
    
    worker = num_workers
    
    sample_func =['wrw']
    weight_func =[
        {'exact':False,'weight':'knn'}, #exact for exact size to the batch
       #{'exact':False,'weight':'knn'}
    ]

    params={'knn':{'metric':'cosine'},
            'submodular':{'metric':'cosine'},
            'link-nn':{'value':'min'},
            'link-sub':{'value':'max'},
            'disjoint':{'value':'mst'},
           }
    
    loader = AGSGraphSampler(
        data, batch_size=batch_size, walk_length=2, num_steps=num_steps, sample_coverage=100,
        num_workers=num_workers,log=args.log_info,save_dir=sampler_dir,recompute = False, shuffle = False,
        sample_func = sample_func, weight_func=weight_func, params=params)
        
#     #### original loader
#     loader = GraphSAINTRandomWalkSampler(data, batch_size=batch_size, walk_length=2,
#                                      num_steps=num_steps, sample_coverage=100,
#                                      save_dir=sampler_dir,num_workers=num_workers)
    
#     #----
    
    if args.log_info:
        print("Train neighbors: ", train_neighbors)
        print("Test neighbors: ", test_neighbors)

    sample_batch_size=512
    train_loader = NeighborLoader(data, input_nodes=data.train_mask,num_neighbors=train_neighbors, 
                            batch_size=sample_batch_size, shuffle=False, num_workers=num_workers)
    val_loader = NeighborLoader(data,input_nodes=data.val_mask,num_neighbors=test_neighbors, 
                                batch_size=sample_batch_size,shuffle=False, num_workers=num_workers)
    test_loader = NeighborLoader(data, input_nodes=data.test_mask,num_neighbors=test_neighbors, 
                                 batch_size=sample_batch_size,shuffle=False, num_workers=num_workers)

#         subgraph_loader = NeighborSampler(data.edge_index, node_idx=None,
#                                       sizes=[-1], batch_size=2048,
#                                       shuffle=False, num_workers=4)

    
    best_acc=0    
    num_iteration = epochs
    train_losses = []; val_accuracies = []; train_accuracies = []; test_accuracies = [];
    training_times = []
    
    for epoch in range(1,epochs+1):
        if args.log_info:
            #pbar = tqdm(total=int(sum(data.train_mask)))
            pbar = tqdm(total=batch_size*num_steps)
            pbar.set_description(f'Epoch {epoch:02d}')
        
        model.train()
        #model.set_aggr('add' if args.use_normalization else 'mean')

        total_loss = total_examples = 0
        
        epoch_start = time.time()
        
        for i,batch_data in enumerate(loader):
            
            
            #print(batch_data);print("*"*50)            
            batch_data = batch_data.to(device)
            optimizer.zero_grad()
            
            #print(batch_data.y)

            if args.use_normalization:
                edge_weight = batch_data.edge_norm * batch_data.edge_weight
                out = model(batch_data.x, batch_data.edge_index, edge_weight)
                loss = F.nll_loss(out, batch_data.y, reduction='none')
                #loss = criterion(out, batch_data.y, reduction='none')
                loss = (loss * batch_data.node_norm)[batch_data.train_mask].sum()
            else:
                out = model(batch_data.x, batch_data.edge_index)
                loss = F.nll_loss(out[batch_data.train_mask], batch_data.y[batch_data.train_mask])
                #loss = criterion(out[batch_data.train_mask], batch_data.y[batch_data.train_mask])

            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch_data.num_nodes
            total_examples += batch_data.num_nodes
            
            if args.log_info:
                pbar.update(batch_size)
        
        if args.log_info:
            pbar.close()
        
        epoch_end = time.time()
        training_times.append(epoch_end-epoch_start)
        
        loss=total_loss / total_examples
        train_losses.append(loss)
        
        if args.log_info:
            print("Training Loss: ",loss)                             
        
        if data.num_nodes<10000:
            model.eval()
            
            with torch.no_grad():
                out = model(data.x.to(device), data.edge_index.to(device))
                pred = out.argmax(dim=-1)
                correct = pred.eq(data.y.to(device))

            accs = []
            for _, mask in data('train_mask', 'val_mask', 'test_mask'):
                accs.append(correct[mask].sum().item() / mask.sum().item())
            
            if args.log_info:                
                print(accs)

            if accs[2]>best_acc:
                best_acc=accs[2]
                
            train_acc=accs[0]
            val_acc=accs[1]
            test_acc=accs[2]

        else:
            if args.log_info==True:
                train_acc=test(model, train_loader,data.train_mask,'Train')
                val_acc = test(model, val_loader,data.val_mask,'Validation')
            else:
                train_acc=0
                val_acc = 0            
            test_acc = test(model, test_loader,data.test_mask,'Test')
            
            accs=[train_acc,val_acc,test_acc]
            
            if args.log_info:
                print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

            if test_acc>best_acc:
                best_acc=test_acc
                
        train_accuracies.append(accs[0])
        val_accuracies.append(accs[1])
        test_accuracies.append(accs[2])
        std_dev = np.std(train_losses[-5:])
        
                
        std_dev = np.std(train_losses[-5:])
        if args.log_info:
            print('std_dev: ', std_dev)
        
        if epoch>=5 and std_dev<=1e-3:
            num_iteration = epoch            
            if args.log_info:                
                print("Iteration for convergence: ", epoch)
            break
            
    if args.log_info:
        save_plot([train_losses, train_accuracies, val_accuracies, test_accuracies], labels=['Loss','Train','Validation','Test'], name='Results/AGSGSValidation', yname='Accuracy', xname='Epoch')
        print ("Best Validation Accuracy, ",max(val_accuracies))
        print ("Best Test Accuracy, ",max(test_accuracies))
    
    acc_file = open("Runtime/AGSGSAINT.txt",'a+') 
    acc_file.write(str(train_losses))
    acc_file.write(str(train_accuracies))
    acc_file.write(str(val_accuracies))
    acc_file.write(str(test_accuracies))
    acc_file.write(str(training_times))
    acc_file.write(str(np.mean(training_times)))
    acc_file.write(f'\nworker {worker:1d} avg epoch runtime {np.mean(training_times):0.8f}')
    acc_file.close()     
 
                
    return best_acc, num_iteration

In [11]:
def GSAINTperformance(DATASET_NAME, data, dataset, num_classes, epochs=20, train_neighbors=[8,4],test_neighbors=[8,4]):
    model = GCN(data.x.shape[1], num_classes, hidden_channels=256).to(device)        
    
    if args.log_info:
        print(model)
    
    best_acc, num_iteration = train(DATASET_NAME, model, data, dataset, epochs, train_neighbors, test_neighbors)
    
    return best_acc, num_iteration, model

In [12]:
from sklearn.decomposition import TruncatedSVD
import numpy as np

def adj_feature(data):    
    adj_mat = torch.zeros((data.num_nodes,data.num_nodes))
    edges = data.edge_index.t()
    adj_mat[edges[:,0], edges[:,1]] = 1
    adj_mat[edges[:,1], edges[:,0]] = 1
    
#     n_components = data.x.shape[1]
    n_components = min(256, data.x.shape[1], data.num_nodes)

    svd = TruncatedSVD(n_components=n_components)
    x = svd.fit_transform(adj_mat)
    
    x = torch.Tensor(x)
    x.shape    
    
    return x

# x = adj_feature(data)
# x.shape

In [13]:
from torch_geometric.utils import add_self_loops

In [14]:
# # Define your adjacency matrix (replace this with your actual adjacency matrix)
# adj = np.array([[0, 1, 0, 1],
#                 [1, 0, 1, 0],
#                 [0, 1, 0, 1],
#                 [1, 0, 1, 0]], dtype=np.float32)

# n_components = 2

# # Perform SVD dimensionality reduction
# svd = TruncatedSVD(n_components=n_components)
# low_dimensional_matrix = svd.fit_transform(adj_feature(data))

# low_dimensional_matrix.shape


In [15]:
# args.log_info = True
# DATASET_NAME = 'reed98'
# data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=True, split_no=1); print("")

# if len(data.y.shape) > 1:
#     data.y = data.y.argmax(dim=1)        
#     num_classes = torch.max(data.y).item()+1
# else:
#     num_classes = dataset.num_classes

# if num_classes!= torch.max(data.y)+1:
#     num_classes = torch.max(data.y).item()+1

# if DATASET_NAME in ['Cornell', 'cornell5']:
#     data.edge_index, _ = add_self_loops(data.edge_index)


# if DATASET_NAME in ['Squirrel', 'Chameleon', 'amherst41',
#                     'Cornell','cornell5']:
#     data.x = torch.cat((data.x, adj_feature(data)), dim=1)
#     if args.log_info == True:
#         print(data.x.shape)
    
# best_acc, num_iteration, _ = GSAINTperformance(DATASET_NAME, data, dataset, num_classes, epochs=150,
#                              train_neighbors=[8,4],test_neighbors=[8,4])    
# print(best_acc, num_iteration)

# # data.y

# Batch Experiments

In [16]:
from ipynb.fs.full.Dataset import generate_synthetic2homophily
import torch_geometric.utils.homophily as homophily
import torch_geometric

In [17]:
def ablation(num_run = 1):
    
    #SYN_NAME = random.randint(0,1000)

    ALL_DATASETs= [
        'Wisconsin','reed98','amherst41','Roman-empire','cornell5','Squirrel','johnshopkins55','Actor','Minesweeper','Chameleon','Tolokers']
    
#     ALL_DATASETs= ["Cora"]
    
    args.log_info = False
    
    filename = "Results/AGSGS-Ablation.txt"
    
    for DATASET_NAME in ALL_DATASETs:  
        
        random_state = 10
        
        print(DATASET_NAME,"-",random_state, end=' ')
        
        
        result_file = open(filename,'a+')        
        result_file.write(f'{DATASET_NAME} ')
        result_file.close()
                
        accs = []
        itrs = []
                
        for i in range(num_run):
            data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=False, split_no=i)   
            
#             d = 100
#             h =0.05
#             train=0.3
#             balance=True
#             h2 = 0.25
#             ratio = 0.50
                                    
# #             global data_filename_extension
# #             data_filename_extension = str(d)+str(h)+str(train)+str(random_state)+str(balance)+'.weight'            
# #             data_filename = DIR+'AGSGNNstruc/'+DATASET_NAME+str(d)+str(h)+str(train)+str(random_state)+str(balance)
            
# #             if os.path.exists(data_filename):
# #                 data = torch.load(data_filename)                
# #                 print("loaded "+data_filename)
# #             else:
# #                 data = generate_synthetic(data, d=d, h=h, train=train, random_state=random_state, log=False, balance=balance)
# # #                 data = generate_synthetic(data, d=d, h=h, train=train, random_state=random_state, log=False)
# #                 torch.save(data,data_filename)
# #                 print("saved "+data_filename)
        
#             global data_filename_extension
#             data_filename_extension = str(d)+str(h)+str(h2)+str(ratio)+str(train)+str(random_state)+str(balance)+'.weight'            
#             data_filename = DIR+'AGSGNNstruc/'+DATASET_NAME+str(d)+str(h)+str(h2)+str(ratio)+str(train)+str(random_state)+str(balance)
            
#             if os.path.exists(data_filename):
#                 data = torch.load(data_filename)                
#                 print("loaded "+data_filename)
#             else:
#                 data = generate_synthetic2homophily(data, d=d, h1=h, h2=h2, ratio=ratio, train=train, random_state=random_state, log=False, balance=balance)                 
#                 torch.save(data,data_filename)
#                 print("saved "+data_filename)
    
            ##Sparsifiy
            #data = random_sparsify(data, 13, log = True)
#             data = sparsify(data, log = True, method = 'submodular', metric= 'cosine')
                        
#             data1 = sparsify(copy.deepcopy(data), log = True, method = 'submodular', metric= 'cosine')
#             data = sparsify(data, log = True, method = 'nn', metric= 'cosine')                         
#             data.edge_index = torch.cat((data.edge_index, data1.edge_index), dim=1)
            
            #optional for making undirected graph
            (row, col) = data.edge_index
            data.edge_index = torch.stack((torch.cat((row, col),dim=0),torch.cat((col, row),dim=0)),dim=0)
            data.edge_index = torch_geometric.utils.coalesce(data.edge_index)
            
            if args.log_info:
                print("Node Homophily:", homophily(data.edge_index, data.y, method='node'))
                print("Edge Homophily:", homophily(data.edge_index, data.y, method='edge'))
                print("Edge_insensitive Homophily:", homophily(data.edge_index, data.y, method='edge_insensitive'))    
                print("Degree: ", data.num_edges / data.num_nodes)

            
#             if data.num_nodes>100000:
#                 accs.append(-1)
#                 itrs.append(-1)
#                 break
            
            if len(data.y.shape) > 1:
                data.y = data.y.argmax(dim=1)        
                num_classes = torch.max(data.y).item()+1
            else:
                num_classes = dataset.num_classes
            
            if num_classes!= torch.max(data.y)+1:
                num_classes = torch.max(data.y).item()+1
                
            if data.num_nodes<100000:
                max_epochs = 150
            else:
                max_epochs = 20
                
            if DATASET_NAME in ['Squirrel', 'Chameleon','cornell5','penn94','johnshopkins55','amherst41']:
                data.x = torch.cat((data.x, adj_feature(data)), dim=1)
                if args.log_info == True:
                    print(data.x.shape)

            accuracy, itr = 0,0
            accuracy, itr, _ = GSAINTperformance(DATASET_NAME, data, dataset, num_classes, epochs=150,train_neighbors=[8,4],test_neighbors=[8,4])    
            
            accs.append(accuracy)
            itrs.append(itr)
            #print(itr, accuracy)
                        
        print(accs, itrs)
        print(DATASET_NAME,"-",random_state, end=' ')
        print(f'acc {np.mean(accs)*100:0.4f} \pm {np.std(accs)*100:0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}')
        result_file = open(filename,'a+')
        result_file.write(f'acc {np.mean(accs)*100:0.4f} \pm {np.std(accs)*10:0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}\n')
        result_file.close()
                
    return 

# st_time = time.time()
# ablation(num_run=5)
# en_time = time.time()

# print("Runtime: ", en_time-st_time)

Wisconsin - 10 [0.5490196078431373, 0.7058823529411765, 0.5686274509803921, 0.5294117647058824, 0.5882352941176471] [150, 150, 150, 150, 150]
Wisconsin - 10 acc 58.8235 \pm 6.2005 itr 150 sd 0
reed98 - 10 [0.6269430051813472, 0.616580310880829, 0.6217616580310881, 0.616580310880829, 0.6062176165803109] [150, 150, 150, 150, 150]
reed98 - 10 acc 61.7617 \pm 0.6874 itr 150 sd 0
amherst41 - 10 [0.6420581655480985, 0.6554809843400448, 0.6442953020134228, 0.6353467561521253, 0.6532438478747203] [150, 150, 150, 150, 150]
amherst41 - 10 acc 64.6085 \pm 0.7406 itr 150 sd 0
Roman-empire - 10 [0.42728556300741266, 0.39198729262266147, 0.41651959054006354, 0.4177550300035298, 0.4151076597246735] [150, 80, 150, 150, 150]
Roman-empire - 10 acc 41.3731 \pm 1.1682 itr 136 sd 28
cornell5 - 10 [0.587352625937835, 0.5897642015005359, 0.5865487674169346, 0.587352625937835, 0.5908360128617364] [150, 72, 57, 50, 119]
cornell5 - 10 acc 58.8371 \pm 0.1638 itr 89 sd 38
Squirrel - 10 [0.6872586872586872, 0.6882

In [18]:
a= [0.6872586872586872, 0.6882239382239382, 0.6766409266409267, 0.6911196911196911, 0.666023166023166]

print(np.mean(a), "\pm", np.std(a))

0.6818532818532818 \pm 0.009310544827635114
