In [1]:
import os
import sys
if not os.getcwd().endswith("Submodular"):
    sys.path.append('../Submodular')    

In [2]:
import DeviceDir

DIR, RESULTS_DIR = DeviceDir.get_directory()
device, NUM_PROCESSORS = DeviceDir.get_device()

In [3]:
from ipynb.fs.full.Dataset import get_data
from ipynb.fs.full.Dataset import datasets as available_datasets
from ipynb.fs.full.Utils import save_plot

In [4]:
import argparse
from argparse import ArgumentParser

#set default arguments here
def get_configuration():
    parser = ArgumentParser()
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--log_info', type=bool, default=True)
    parser.add_argument('--pbar', type=bool, default=False)
    parser.add_argument('--batch_size', type=int, default=2048)
    parser.add_argument('--learning_rate', type=float, default=0.01)
    parser.add_argument('--num_gpus', type=int, default=-1)
    parser.add_argument('--parallel_mode', type=str, default="dp", choices=['dp', 'ddp', 'ddp2'])
    parser.add_argument('--dataset', type=str, default="Cora", choices=available_datasets)
    #parser.add_argument('--use_normalization', action='store_false', default=True)
    parser.add_argument('--use_normalization', action='store_true')    
    parser.add_argument('-f') ##dummy for jupyternotebook
    
    args = parser.parse_args()
    
    dict_args = vars(args)
    
    return args, dict_args

args, dict_args = get_configuration()

## Packages

In [5]:
import torch.nn as nn
import numpy as np
from torch.nn import init
from random import shuffle, randint
import torch.nn.functional as F
from itertools import combinations, combinations_with_replacement
from sklearn.metrics import f1_score, accuracy_score
from sklearn.decomposition import TruncatedSVD
import matplotlib.pyplot as plt
import sys
from torch_geometric.data import Data
import logging
import time

import argparse
import os.path as osp
import math

In [6]:
import random
import numpy as np
import torch

seed = 123

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
None

### GSAINT model

In [7]:
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm

from torch_geometric.loader import NeighborSampler, NeighborLoader
from torch_geometric.loader import GraphSAINTRandomWalkSampler, GraphSAINTNodeSampler, GraphSAINTEdgeSampler, GraphSAINTSampler
from torch_geometric.nn import GraphConv
from torch_geometric.utils import degree

In [8]:
class Net(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels):
        super().__init__()        
        in_channels = num_features
        out_channels = num_classes
        self.conv1 = GraphConv(in_channels, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(3 * hidden_channels, out_channels)

    def set_aggr(self, aggr):
        self.conv1.aggr = aggr
        self.conv2.aggr = aggr
        self.conv3.aggr = aggr

    def forward(self, x0, edge_index, edge_weight=None):
        x1 = F.relu(self.conv1(x0, edge_index, edge_weight))
        x1 = F.dropout(x1, p=0.2, training=self.training)
        x2 = F.relu(self.conv2(x1, edge_index, edge_weight))
        x2 = F.dropout(x2, p=0.2, training=self.training)
        x3 = F.relu(self.conv3(x2, edge_index, edge_weight))
        x3 = F.dropout(x3, p=0.2, training=self.training)
        x = torch.cat([x1, x2, x3], dim=-1)
        x = self.lin(x)
        return x.log_softmax(dim=-1)

    
    #graphsage
    @torch.no_grad()
    def inference(self, x_all, device, subgraph_loader):
        if args.log_info:
            pbar = tqdm(total=x_all.size(0))
            pbar.set_description('Evaluating')

        xs = []
        for batch_size, n_id, adj in subgraph_loader:
            edge_index, _, size = adj.to(device)
            x = x_all[n_id].to(device)                
            x = self.forward(x, edge_index)
            x_target = x[:size[1]]

            xs.append(x_target.cpu())
            
            if args.log_info:
                pbar.update(batch_size)

        x_all = torch.cat(xs, dim=0)
        
        if args.log_info:
            pbar.close()

        return x_all

In [9]:
def test(model, loader, mask, name='Train'):
    if args.log_info:
        pbar = tqdm(total=sum(mask).item())
        pbar.set_description(f'Evaluating {name}')

    model.eval()
    model.set_aggr('add' if args.use_normalization else 'mean')
    
    total_correct=0
    total_examples=0
    
    with torch.no_grad():                  
    
        for i,batch_data in enumerate(loader):
            out = model(batch_data.x.to(device), batch_data.edge_index.to(device))
            out=out[:batch_data.batch_size,:]
            pred = out.argmax(dim=-1)            
            correct = pred.eq(batch_data.y[:batch_data.batch_size].to(device))

            total_correct+=correct.sum()
            total_examples+=batch_data.batch_size

            if args.log_info:                
                pbar.update(batch_data.batch_size)
    if args.log_info:
        pbar.close()

    return total_correct.item()/total_examples

In [10]:
def train(DATASET_NAME, model, data, dataset, epochs=10,train_neighbors=[8,4],test_neighbors=[8,4]):
        
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()
    
    row, col = data.edge_index
    data.edge_weight = 1. / degree(col, data.num_nodes)[col]  # Norm by in-degree.

    
    sampler_dir = DIR+'GSAINT/'+DATASET_NAME
    if not os.path.exists(sampler_dir):
        os.makedirs(sampler_dir)
        
    batch_size=4096
    num_steps=math.ceil(data.num_nodes/batch_size) #num_steps=5
    
    worker = 0
    if data.num_nodes>100000:
        worker = 8
    
    
    norm_start_time = time.time()
    loader = GraphSAINTRandomWalkSampler(data, batch_size=batch_size, walk_length=2,
                                         num_steps=num_steps, sample_coverage=100,
                                         save_dir=sampler_dir,
                                         num_workers=worker)

#         loader = GraphSAINTNodeSampler(data, batch_size=batch_size,
#                                    num_steps=num_steps, sample_coverage=100,
#                                    save_dir=dataset.processed_dir,
#                                    num_workers=worker)

#         loader = GraphSAINTEdgeSampler(data, batch_size=batch_size,
#                                        num_steps=num_steps, sample_coverage=100,
#                                        save_dir=dataset.processed_dir,
#                                        num_workers=worker)    
    
    
    norm_end_time = time.time()
    
    inital_time = norm_end_time-norm_start_time
    
    if args.log_info:
        print("Norm time: ",inital_time)
    
    if args.log_info:
        print("Train neighbors: ", train_neighbors)
        print("Test neighbors: ", test_neighbors)

    sample_batch_size=2048
    train_loader = NeighborLoader(data, 
                            input_nodes=data.train_mask,
                            num_neighbors=train_neighbors, 
                            batch_size=sample_batch_size, shuffle=True, num_workers=worker)

    val_loader = NeighborLoader(data,input_nodes=data.val_mask,num_neighbors=test_neighbors, 
                                batch_size=sample_batch_size,shuffle=False, num_workers=worker)
    test_loader = NeighborLoader(data, input_nodes=data.test_mask,num_neighbors=test_neighbors, 
                                 batch_size=sample_batch_size,shuffle=False, num_workers=worker)

#         subgraph_loader = NeighborSampler(data.edge_index, node_idx=None,
#                                       sizes=[-1], batch_size=2048,
#                                       shuffle=False, num_workers=4)

    
    
    best_acc=0    
    num_iteration = epochs
    train_losses = []
    
    val_accuracies=[]
    train_accuracies=[]
    test_accuracies=[]
    training_times = []
    
    
    
    for epoch in range(1,epochs+1):
        if args.log_info:
            #pbar = tqdm(total=int(sum(data.train_mask)))
            pbar = tqdm(total=batch_size*num_steps)
            pbar.set_description(f'Epoch {epoch:02d}')
        
        model.train()
        model.set_aggr('add' if args.use_normalization else 'mean')

        total_loss = total_examples = 0
        
        epoch_start = time.time()
        
        for i,batch_data in enumerate(loader):
            
            #print(batch_data);print("*"*50)            
            batch_data = batch_data.to(device)
            optimizer.zero_grad()

            if args.use_normalization:
                edge_weight = batch_data.edge_norm * batch_data.edge_weight
                out = model(batch_data.x, batch_data.edge_index, edge_weight)
                loss = F.nll_loss(out, batch_data.y, reduction='none')
                #loss = criterion(out, batch_data.y, reduction='none')
                loss = (loss * batch_data.node_norm)[batch_data.train_mask].sum()
            else:
                out = model(batch_data.x, batch_data.edge_index)
                loss = F.nll_loss(out[batch_data.train_mask], batch_data.y[batch_data.train_mask])
                #loss = criterion(out[batch_data.train_mask], batch_data.y[batch_data.train_mask])

            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch_data.num_nodes
            total_examples += batch_data.num_nodes
            
            if args.log_info:
                pbar.update(batch_size)
        
        if args.log_info:
            pbar.close()        
            
        epoch_end = time.time()
        training_times.append(epoch_end-epoch_start)
        
        if epoch%10==0:
        
            loss=total_loss / total_examples
            train_losses.append(loss)

            if args.log_info:
                print("Training Loss: ",loss)                             

            if data.num_nodes<10000:
                model.eval()
                #model.set_aggr('mean')
                model.set_aggr('add' if args.use_normalization else 'mean')

                with torch.no_grad():
                    out = model(data.x.to(device), data.edge_index.to(device))
                    pred = out.argmax(dim=-1)
                    correct = pred.eq(data.y.to(device))

                accs = []
                for _, mask in data('train_mask', 'val_mask', 'test_mask'):
                    accs.append(correct[mask].sum().item() / mask.sum().item())

                if args.log_info:                
                    print(accs)

                if accs[2]>best_acc:
                    best_acc=accs[2]

                train_acc = accs[0]
                val_acc = accs[1]
                test_acc = accs[2]

            else:
                if args.log_info:
                    train_acc=test(model, train_loader,data.train_mask,'Train')
                    val_acc = test(model, val_loader,data.val_mask,'Validation')
                else:
                    train_acc=0
                    val_acc =0

                test_acc = test(model, test_loader,data.test_mask,'Test')

                if args.log_info:
                    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

                if test_acc>best_acc:
                    best_acc=test_acc

            std_dev = np.std(train_losses[-5:])
            if args.log_info:
                print('std_dev: ', std_dev)

            train_accuracies.append(train_acc)
            val_accuracies.append(val_acc)
            test_accuracies.append(test_acc)

#         if epoch>=5 and std_dev<=1e-3:
#             num_iteration = epoch            
#             if args.log_info:                
#                 print("Iteration for convergence: ", epoch)
#             break
        
    if args.log_info:
        #save_plot([val_accuracies], labels=['Validation'], name='Plots/Validation', yname='Accuracy', xname='Epoch')    
        save_plot([train_losses, train_accuracies, val_accuracies, test_accuracies], labels=['Loss','Train','Validation','Test'], name='Results/AGSNSVal', yname='Accuracy', xname='Epoch')
        
        print ("Best Validation Accuracy, ",max(val_accuracies))
        print ("Best Test Accuracy, ",max(test_accuracies))
        
    acc_file = open("Runtime/GSAINT.txt",'a+')     
    acc_file.write(f'\n norm_time {inital_time:0.4f}\n')
    acc_file.write(str(train_losses))
    acc_file.write(str(train_accuracies))
    acc_file.write(str(val_accuracies))
    acc_file.write(str(test_accuracies))
    acc_file.write(str(training_times))
    acc_file.write(str(np.mean(training_times)))
    acc_file.write(f'\nworker {worker:1d} avg epoch runtime {np.mean(training_times):0.8f}')
    acc_file.close()     
                
    return best_acc, num_iteration

In [11]:
def GSAINTperformance(DATASET_NAME, data, dataset, num_classes, epochs=20, train_neighbors=[8,4],test_neighbors=[8,4]):
    model = Net(dataset.num_features, num_classes, hidden_channels=256).to(device)        
    
    if args.log_info:
        print(model)
    
    best_acc, num_iteration = train(DATASET_NAME, model, data, dataset, epochs, train_neighbors, test_neighbors)
    
    return best_acc, num_iteration, model

In [12]:
from torch_geometric.utils import add_self_loops

In [13]:
# args.log_info = True
# DATASET_NAME = 'Cora'
# data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=True, split_no=0); print("")

# if DATASET_NAME in ['Cornell', 'cornell5']:
#     data.edge_index, _ = add_self_loops(data.edge_index)            



# best_acc, num_iteration, _ = GSAINTperformance(DATASET_NAME, data, dataset, dataset.num_classes, epochs=1,
#                              train_neighbors=[8,4],test_neighbors=[8,4])    
# print(best_acc, num_iteration)

# Batch Experiments

In [14]:
def batch_experiments(num_run=1):
    
    ALL_DATASETs= [
#         "Cornell",
#         "Texas",
#         "Wisconsin",
#         "reed98",
#         "amherst41",
#         "penn94",
#         "Roman-empire",
#         "cornell5",
#         "Squirrel",
#         "johnshopkins55",
#         "AmazonProducts",
#         "Actor",
#         "Minesweeper",
#         "Questions",
#         "Chameleon",
#         "Tolokers",
#         "Flickr",
#         "Yelp",
#         "Amazon-ratings",
#         "genius",
#         "cora",
#         "CiteSeer",
#         "dblp",
#         "Computers",
#         "PubMed",
#         "pubmed",
#         "Reddit",
#         "cora_ml",
#         "Cora",
#         "Reddit2",
#         "CS",
#         "Photo",
#         "Physics",
#         "citeseer"
#         'pokec',
#         'arxiv-year',
#         'snap-patents',
#         'twitch-gamer',
#         'wiki'
    ]     
    
#     ALL_DATASETs= [
#         'karate'
#     ]

    ALL_DATASETs= ["Reddit0.525","Reddit0.425","Reddit0.325"]
    
    args.log_info = False
    
    runtime_filename = "Runtime/GSAINT.txt"
    
    for DATASET_NAME in ALL_DATASETs:  
        print(DATASET_NAME, end=' ')        
        result_file = open("Results/GSAINT.txt",'a+')                
        result_file.write(f'{DATASET_NAME} ')
        
        acc_file = open(runtime_filename,'a+') 
        acc_file.write(f'{DATASET_NAME}\n')
        acc_file.close()     

        
        accs = []
        itrs = []
        
        for i in range(num_run):
            data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=False, split_no=i, random_state = i)            
            if len(data.y.shape) > 1:
                data.y = data.y.argmax(dim=1)        
                num_classes = torch.max(data.y).item()+1
            else:
                num_classes = dataset.num_classes
            
            if num_classes!= torch.max(data.y)+1:
                num_classes = torch.max(data.y).item()+1
                
            if data.num_nodes<100000:
                max_epochs = 150
            else:
                max_epochs = 50
                
            if DATASET_NAME in ['Cornell', 'cornell5']:
                data.edge_index, _ = add_self_loops(data.edge_index)            

                                
            accuracy, itr, _ = GSAINTperformance(DATASET_NAME, data, dataset, num_classes, epochs=max_epochs,train_neighbors=[8,4],test_neighbors=[8,4])
            accs.append(accuracy)
            itrs.append(itr)
            #print(itr, accuracy)
                        
        #print(accs, itrs)
        print(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}')
        result_file.write(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}\n')
        result_file.close()
                
batch_experiments(num_run = 5)

Reddit0.525 loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.525


Compute GraphSAINT normalization: : 23720142it [01:39, 238414.19it/s]                            


loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.525
loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.525
loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.525
loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.525
acc 0.9046 sd 0.0005 itr 50 sd 0
Reddit0.425 loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.425


Compute GraphSAINT normalization: : 23584405it [01:49, 216245.57it/s]                            


loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.425
loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.425
loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.425
loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.425
acc 0.8913 sd 0.0006 itr 50 sd 0
Reddit0.325 loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.325


Compute GraphSAINT normalization: : 23424493it [01:55, 202926.20it/s]                            


loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.325
loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.325
loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.325
loaded from:  /scratch/gilbreth/das90/Dataset/RedditSynthetic/Reddit0.325
acc 0.8769 sd 0.0015 itr 50 sd 0
