In [13]:
import os
import sys
if not os.getcwd().endswith("Submodular"):
    sys.path.append('../../Submodular')    

In [14]:
import DeviceDir

DIR, RESULTS_DIR = DeviceDir.get_directory()
device, NUM_PROCESSORS = DeviceDir.get_device()

In [15]:
DIR

'/scratch/gilbreth/das90/Dataset/'

In [16]:
from ipynb.fs.full.Dataset import get_data
from ipynb.fs.full.Dataset import datasets as available_datasets
from ipynb.fs.full.Utils import save_plot

In [17]:
import argparse
import sys
import os
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_undirected, sort_edge_index
from torch_geometric.data import NeighborSampler, ClusterData, ClusterLoader, Data, GraphSAINTNodeSampler, GraphSAINTEdgeSampler, GraphSAINTRandomWalkSampler, RandomNodeSampler
from torch_scatter import scatter

from logger import Logger, SimpleLogger
from dataset import load_nc_dataset, NCDataset
from data_utils import normalize, gen_normalized_adjs, evaluate, eval_acc, eval_rocauc, to_sparse_tensor
from parse import parse_method, parser_add_main_args
from batch_utils import nc_dataset_to_torch_geo, torch_geo_to_nc_dataset, AdjRowLoader, make_loader

In [18]:
import argparse
from argparse import ArgumentParser

#set default arguments here
def get_configuration():
    
    parser = ArgumentParser()
    
    ### Parse args ###
    parser = argparse.ArgumentParser(description='General Training Pipeline')
    parser_add_main_args(parser)
    parser.add_argument('--train_batch', type=str, default='cluster', help='type of mini batch loading scheme for training GNN')
    parser.add_argument('--no_mini_batch_test', action='store_true', help='whether to test on mini batches as well')
    parser.add_argument('--batch_size', type=int, default=10000)
    parser.add_argument('--num_parts', type=int, default=100, help='number of partitions for partition batching')
    parser.add_argument('--cluster_batch_size', type=int, default=1, help='number of clusters to use per cluster-gcn step')
    parser.add_argument('--saint_num_steps', type=int, default=5, help='number of steps for graphsaint')
    parser.add_argument('--test_num_parts', type=int, default=10, help='number of partitions for testing')
    
    #parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--log_info', type=bool, default=True)
    parser.add_argument('--pbar', type=bool, default=False)
    #parser.add_argument('--batch_size', type=int, default=2048)
    parser.add_argument('--learning_rate', type=float, default=0.01)
    parser.add_argument('--num_gpus', type=int, default=-1)
    parser.add_argument('--parallel_mode', type=str, default="dp", choices=['dp', 'ddp', 'ddp2'])
    #parser.add_argument('--dataset', type=str, default="Cora", choices=available_datasets)
    #parser.add_argument('--use_normalization', action='store_false', default=True)
    parser.add_argument('--use_normalization', action='store_true')    
    parser.add_argument('-f') ##dummy for jupyternotebook
    
    args = parser.parse_args()
    
    dict_args = vars(args)
    
    return args, dict_args

args, dict_args = get_configuration()

In [19]:
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch_geometric.datasets import LINKXDataset
# from torch_geometric.nn import LINKX
import numpy as np
from tqdm import tqdm
from torch_geometric.loader import NeighborSampler, NeighborLoader
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn import GCNConv, SGConv, GATConv, JumpingKnowledge, APPNP, GCN2Conv, MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import scipy.sparse
import time

# LINKX model

### Available models
LINK, GCN, MLP, SGC, GAT, SGCMem, MultiLP, MixHop, 

GCNJK, GATJK, H2GCN, APPNP_Net, LINK_Concat, LINKX, GPRGNN, GCNII

### Available Sampler

-- NeighborSampler, 

-- ClusterData, ClusterLoader, 

-- GraphSAINTNodeSampler, GraphSAINTEdgeSampler, GraphSAINTRandomWalkSampler

-- RandomNodeSampler

### AGS-GNN: our work
-- KNNsampler

-- SubmodularSampler

-- LinkSampler

-- JointSampler

-- DisjointEdgeSampler

# Train

In [20]:
def test(model, loader, mask, name='Train'):
    
    total_correct=0
    total_examples=0
    
    model.eval()            
    with torch.no_grad():
        if args.log_info:
            pbar = tqdm(total=sum(mask).item())
            pbar.set_description(f'Evaluating {name}')
        
        for tg_batch in loader:
            node_ids = tg_batch.node_ids
            batch_dataset = torch_geo_to_nc_dataset(tg_batch, device=device)
            out = model(batch_dataset)
            
            batch_test_idx = tg_batch.mask.to(torch.bool)
            pred = out[batch_test_idx].argmax(dim=-1)
            correct = pred.eq(batch_dataset.label[batch_test_idx])
            
            total_correct+=correct.sum().cpu().item()
            total_examples+=sum(tg_batch.mask).item()            
            
#             pred = out[tg_batch.mask].argmax(dim=-1).cpu()
#             correct = pred.eq(tg_batch.y[tg_batch.mask])

#             total_correct+=correct.sum().item()
#             total_examples+=sum(tg_batch.mask).item()
    
            if args.log_info:
                pbar.update(sum(tg_batch.mask).item())
    
        if args.log_info:
            pbar.close()
    
#     print(total_correct)
#     print(total_examples)

    return total_correct/total_examples

In [21]:
def eval_accuracy(y_true, y_pred):    
    
    y_true = y_true.detach().cpu()
    y_pred = y_pred.argmax(dim=-1).detach().cpu()
    
    correct = y_pred.eq(y_true)
    acc = correct.sum().item()/len(y_true)
    
    return acc

In [22]:
def linkxtest(model, data, nc_dataset, dataset, loader, num_classes):   
    # needs a loader that includes every node in the graph
    model.eval()
    
    full_out = torch.zeros(data.num_nodes, num_classes, device=device)
    
    with torch.no_grad():
        for tg_batch in loader:
            node_ids = tg_batch.node_ids
            batch_dataset = torch_geo_to_nc_dataset(tg_batch, device=device)
            out = model(batch_dataset)
            full_out[node_ids] = out
    

    train_acc = eval_accuracy(data.y[data.train_mask], out[data.train_mask])
    valid_acc = eval_accuracy(data.y[data.val_mask], out[data.val_mask])
    test_acc = eval_accuracy(data.y[data.test_mask], out[data.test_mask])
        
    return train_acc, valid_acc, test_acc

In [23]:
def train(model, data, nc_dataset, dataset, epochs, num_classes):
    
    #model.reset_parameters()
    
    if args.rocauc or args.dataset in ('yelp-chi', 'twitch-e', 'ogbn-proteins'):
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.NLLLoss()        
        
    if args.log_info:    
        print('making train loader')    
    
    train_idx = torch.nonzero(data.train_mask).squeeze()
    val_idx = torch.nonzero(data.val_mask).squeeze()
    test_idx = torch.nonzero(data.test_mask).squeeze()
    
    worker = 0
    if data.num_nodes>100000:
        worker=8
    
    train_loader = make_loader(args, nc_dataset, train_idx, mini_batch = True, num_workers=worker)
    t_loader = make_loader(args, nc_dataset, train_idx, mini_batch = True, test=True, num_workers=worker)
    val_loader = make_loader(args, nc_dataset, val_idx, mini_batch = True, test=True, num_workers=worker)
    test_loader = make_loader(args, nc_dataset, test_idx, mini_batch = True, test=True, num_workers=worker)
    
    model.reset_parameters()
    if args.adam:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
    if args.log_info:
        print(train_loader)
        print(val_loader)
    
    
    train_losses=[]
    best_acc = 0 
    num_iteration = epochs    
    print_once = True
    
    val_accuracies = []; train_accuracies = []; test_accuracies = [];training_times = [];
    
    for epoch in range(1,epochs+1):
        
        total_loss = 0
        num_examples = 0
        
        epoch_start = time.time()
        
        if args.log_info:
            pbar = tqdm(total=len(train_loader))
            pbar.set_description(f'Epoch {epoch:02d}')
        
        for i in range(args.num_parts):
            for tg_batch in train_loader:

    #             if args.log_info and print_once:
    #                 print_once = False
    #                 print(tg_batch)
    #                 print(tg_batch.mask)
    #             print(tg_batch.mask.sum())

                if int(tg_batch.mask.sum().item()) == 0:
                    continue

                batch_train_idx = tg_batch.mask.to(torch.bool)
                batch_dataset = torch_geo_to_nc_dataset(tg_batch, device=device)
                optimizer.zero_grad()

                out = model(batch_dataset)
                if args.rocauc or args.dataset in ('yelp-chi', 'twitch-e', 'ogbn-proteins'):
                    if dataset.label.shape[1] == 1:
                        # change -1 instances to 0 for one-hot transform
                        # dataset.label[dataset.label==-1] = 0
                        true_label = F.one_hot(batch_dataset.label, batch_dataset.label.max() + 1).squeeze(1)
                    else:
                        true_label = batch_dataset.label

                    loss = criterion(out[batch_train_idx], true_label[batch_train_idx].to(out.dtype))
                else:                                            
    #                 loss = F.cross_entropy(out[batch_train_idx], batch_dataset.label[batch_train_idx])

                    out = F.log_softmax(out, dim=1)
    #                 #loss = criterion(out[batch_train_idx], batch_dataset.label.squeeze(1)[batch_train_idx])
                    loss = criterion(out[batch_train_idx], batch_dataset.label[batch_train_idx])
    #                 loss = F.nll_loss(out[batch_train_idx], batch_dataset.label[batch_train_idx])

                loss.backward()
                optimizer.step()

                total_loss+=loss.item()
                num_examples+= len(batch_train_idx)
        
                if args.log_info:
                    pbar.update(1)
        
        if args.log_info:
            pbar.close() 
        
        epoch_end = time.time()
        training_times.append(epoch_end-epoch_start)
        
        
#         train_acc, val_acc, test_acc = linkxtest(model, data, nc_dataset, dataset, val_loader, num_classes)
        
        train_acc=0; val_acc=0; test_acc = 0;loss = 0;
        if epoch%10 == 0:
            
            if num_examples == 0:
                print("Something went wrong, training examples cannot be zero")
                continue
            else:
                loss = total_loss/num_examples
            
            train_losses.append(loss)
                
            if args.log_info:   
                train_acc=test(model, t_loader,data.train_mask,'Train')
                val_acc = test(model, val_loader, data.val_mask,'Validation')            
            
            test_acc = test(model, test_loader, data.test_mask,'Test')                

            train_accuracies.append(train_acc)
            val_accuracies.append(val_acc)
            test_accuracies.append(test_acc)

            if test_acc>best_acc:
                best_acc=test_acc

            std_dev = np.std(train_losses[-5:])        

            if args.log_info:
                print(f'Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}, Std dev: {std_dev:.4f}')

    #         if epoch>=5 and std_dev<=1e-4:
    #             num_iteration = epoch

    #             if args.log_info:                
    #                 print("Iteration for convergence: ", epoch)
    #             break
            
    if args.log_info:
        save_plot([train_losses, train_accuracies, val_accuracies, test_accuracies], labels=['Loss','Train','Validation','Test'], name='../Results/AGSGSValidation', yname='Accuracy', xname='Epoch')
        print ("Best Validation Accuracy, ",max(val_accuracies))
        print ("Best Test Accuracy, ",max(test_accuracies))
    
    acc_file = open("../Runtime/LINKX.txt",'a+') 
    acc_file.write(str(train_losses))
    acc_file.write(str(train_accuracies))
    acc_file.write(str(val_accuracies))
    acc_file.write(str(test_accuracies))
    acc_file.write(str(training_times))
    acc_file.write(str(np.mean(training_times)))
    acc_file.write(f'\nworker {worker:1d} avg epoch runtime {np.mean(training_times):0.8f}')
    acc_file.close()     
                
    return best_acc, num_iteration

In [24]:
def AllperformanceSampler(data, dataset, num_classes, epochs=1, train_neighbors=[8,4], test_neighbors=[8,4]):        
    
    n = data.num_nodes
    c = num_classes
    d = data.x.shape[1]
        
    nc_dataset = torch_geo_to_nc_dataset(data)    
    model = parse_method(args, nc_dataset, n, c, d, device)
    
    if args.log_info:
        print(model) 
    
    best_acc, num_iteration = train(model, data, nc_dataset, dataset, epochs, num_classes)        
    return best_acc, num_iteration, model

In [25]:
methods = ['link', 'gcn', 'mlp', 'cs', 'sgc', 
           'gprgnn', 'appnp', 'gat', 'lp', 
           'mixhop','gcnjk','gatjk','h2gcn',
           'link_concat','linkx','gcn2']

#others = ['gsage','gsaint','acmgcn','clustergcn','gcn','gin','gat','linkx']

In [26]:
sampling_methods = ['cluster', 'graphsaint-node', 'graphsaint-edge', 'graphsaint-rw',
            'random','full-batch', 'row']

# train_batch cluster
# batch_size 10000
# num_parts 100
# cluster_batch_size 1
# saint_num_steps 5
# test_num_parts 10

In [27]:
# #!/bin/bash

# dataset=$1
# sub_dataset=${2:-''}

# hidden_channels_lst=(16 32 128 256)
# num_layers_lst=(1 2 3)


# for num_layers in "${num_layers_lst[@]}"; do
#     for hidden_channels in "${hidden_channels_lst[@]}"; do
#         if [ "$dataset" = "snap-patents" ] || [ "$dataset" = "arxiv-year" ]; then
#             echo "Running $dataset "
#             python main_scalable.py --dataset $dataset --sub_dataset ${sub_dataset:-''} --method linkx  --num_layers $num_layers --hidden_channels $hidden_channels --display_step 25 --runs 5 --directed --train_batch row --num_parts 10
#         else
#             python main_scalable.py --dataset $dataset --sub_dataset ${sub_dataset:-''} --method linkx  --num_layers $num_layers --hidden_channels $hidden_channels --display_step 25 --runs 5 --train_batch row --num_parts 10
#         fi
#     done
# done


## LINKX

In [28]:
# args.log_info = True
# DATASET_NAME = 'Reddit0.525'
# args.dataset = DATASET_NAME
# gnn_name = 'linkx'

# args.method = gnn_name
# args.train_batch = 'row'
# args.num_parts = 10

# data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=True, split_no=0); print("")
# print(data)

# # best_acc, num_iteration, _ = AllperformanceSampler(data, dataset, dataset.num_classes, epochs=50)
# # print(best_acc, num_iteration)

## Other methods

In [33]:
args.log_info = True
DATASET_NAME = 'karate'
args.dataset = DATASET_NAME
gnn_name = 'link'

args.method = gnn_name
args.train_batch = 'full-batch'

data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=True, split_no=0); print("")
print(data)

# best_acc, num_iteration, _ = AllperformanceSampler(data, dataset, dataset.num_classes, epochs=50)
# print(best_acc, num_iteration)

N  34  E  156  d  4.588235294117647 0.8020520210266113 0.7564102411270142 0.6170591711997986 -0.4756128787994385 
Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34], val_mask=[34], test_mask=[34])


# Batch Experiments

In [30]:
def batch_experiments(num_run=1):
    
    ALL_DATASETs= [
#         "Roman-empire","Texas","Squirrel","Chameleon",
#         "Cornell","Actor","Wisconsin","Flickr","Amazon-ratings","reed98","amherst41",
#         "genius",
#         "AmazonProducts",
#         "cornell5",
#         "penn94",
#         "johnshopkins55",
#         "Yelp",
#         "cora","Tolokers","Minesweeper",
#         "CiteSeer","Computers","PubMed","pubmed",
#         "Reddit",
#         "cora_ml","dblp",
#         "Reddit2",
#         "Cora","CS","Photo","Questions","Physics","citeseer",     
#         "Cora",
#         "Reddit","genius","Yelp",
#         'pokec','twitch-gamer',
#         'wiki', 
#         'arxiv-year','snap-patents'
    ]
 
    
    #ALL_DATASETs= ["karate"]
    ALL_DATASETs= ["Reddit0.525","Reddit0.425","Reddit0.325"]
    
    gnn_name = 'linkx'

    args.method = gnn_name
    args.train_batch = 'row'
    args.num_parts = 10
    
    
    result_file = open("../Results/LINKXscale.txt",'a+')        
    result_file.write(f'{gnn_name} ')
    result_file.write(f'{args.train_batch} ')
    result_file.write(f'{args.num_parts} ')
    result_file.close()
    
    
    runtime_filename = "../Runtime/LINKX.txt"
    
    args.log_info = False
    
    for DATASET_NAME in ALL_DATASETs: 
        
#         if DATASET_NAME in ['arxiv-year','snap-patents']:
#             args.directed = True     
            
        if DATASET_NAME in ['wiki','snap-patents','AmazonProducts']:
            args.num_parts = 50
        else:
            args.num_parts = 10
        
        print(DATASET_NAME, end=' ')
        
        args.dataset = DATASET_NAME
        
        result_file = open("../Results/LINKXscale.txt",'a+')        
        result_file.write(f'{DATASET_NAME} ')
        
        acc_file = open(runtime_filename,'a+') 
        acc_file.write(f'{DATASET_NAME}\n')
        acc_file.close()     
                
        accs = []
        itrs = []
        
        for i in range(num_run):
            data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=False, split_no=i, random_state=5)   
            
#             if data.num_nodes<100000:
#                 accs.append(-1)
#                 itrs.append(-1)
#                 break
            
            if len(data.y.shape) > 1:
                data.y = data.y.argmax(dim=1)        
                num_classes = torch.max(data.y).item()+1
            else:
                num_classes = dataset.num_classes
            
            if num_classes!= torch.max(data.y)+1:
                num_classes = torch.max(data.y).item()+1
            
            max_epochs = 50
                              
            accuracy, itr, _ = AllperformanceSampler(data, dataset, num_classes, epochs=max_epochs)
            
            accs.append(accuracy)
            itrs.append(itr)
            #print(itr, accuracy)
                        
        #print(accs, itrs)
        print(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}')
        result_file.write(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}\n')
        result_file.close()
                
# batch_experiments(num_run=5)

## Batch processing of all methods

In [31]:
def batch_experiments(num_run=1):
    
    ALL_DATASETs= [
            "Cornell",
            "Texas",
            "Wisconsin",
            "reed98",
            "amherst41",
            "penn94",
            "Roman-empire",
            "cornell5",
            "Squirrel",
            "johnshopkins55",
            "Actor",
            "Minesweeper",
            "Questions",
            "Chameleon",
            "Tolokers",
            "Flickr",
            "Amazon-ratings",
    ]
 
    #'lp' error
    #'h2gcn'
    methods = [
        'link', 'mlp', 'cs', 'sgc', 'gprgnn', 'appnp',
        'mixhop','gcnjk','gatjk',
        'link_concat','linkx','gcn2']
    
    
#     ALL_DATASETs= ["karate"]
#     methods=['linkx']
    
    
    for gnn_name in methods:
        #gnn_name = 'linkx'
        
        print(gnn_name)
        print("-"*100)

        args.method = gnn_name
        args.train_batch = 'full-batch'
    
    
        filename = "../Results/"+gnn_name+"scale.txt"

        result_file = open(filename,'a+')        
        result_file.write(f'{gnn_name} ')
        result_file.write(f'{args.train_batch} ')
        result_file.write(f'{args.num_parts} ')
        result_file.close()


        runtime_filename = "../Runtime/"+gnn_name+".txt"

        args.log_info = False

        for DATASET_NAME in ALL_DATASETs: 

    #         if DATASET_NAME in ['arxiv-year','snap-patents']:
    #             args.directed = True     

            if DATASET_NAME in ['wiki','snap-patents','AmazonProducts']:
                args.num_parts = 50
            else:
                args.num_parts = 10

            print(DATASET_NAME, end=' ')

            args.dataset = DATASET_NAME

            result_file = open(filename,'a+')        
            result_file.write(f'{DATASET_NAME} ')

            acc_file = open(runtime_filename,'a+') 
            acc_file.write(f'{DATASET_NAME}\n')
            acc_file.close()     

            accs = []
            itrs = []

            for i in range(num_run):
                data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=False, split_no=i, random_state=5)   

    #             if data.num_nodes<100000:
    #                 accs.append(-1)
    #                 itrs.append(-1)
    #                 break

                if len(data.y.shape) > 1:
                    data.y = data.y.argmax(dim=1)        
                    num_classes = torch.max(data.y).item()+1
                else:
                    num_classes = dataset.num_classes

                if num_classes!= torch.max(data.y)+1:
                    num_classes = torch.max(data.y).item()+1

                max_epochs = 150

                accuracy, itr, _ = AllperformanceSampler(data, dataset, num_classes, epochs=max_epochs)

                accs.append(accuracy)
                itrs.append(itr)
                #print(itr, accuracy)

            #print(accs, itrs)
            print(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}')
            result_file.write(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}\n')
            result_file.close()
                
# batch_experiments(num_run=5)