In [1]:
import os
import sys
if not os.getcwd().endswith("Submodular"):
    sys.path.append('../Submodular')    

import DeviceDir

DIR, RESULTS_DIR = DeviceDir.get_directory()
device, NUM_PROCESSORS = DeviceDir.get_device()

In [2]:
from ipynb.fs.full.Dataset import get_data
from ipynb.fs.full.Dataset import datasets as available_datasets
from ipynb.fs.full.Utils import save_plot

In [3]:
import argparse
from argparse import ArgumentParser

#set default arguments here
def get_configuration():
    parser = ArgumentParser()
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--log_info', type=bool, default=True)
    parser.add_argument('--pbar', type=bool, default=False)
    parser.add_argument('--batch_size', type=int, default=2048)
    parser.add_argument('--learning_rate', type=float, default=0.01)
    parser.add_argument('--num_gpus', type=int, default=-1)
    parser.add_argument('--parallel_mode', type=str, default="dp", choices=['dp', 'ddp', 'ddp2'])
    parser.add_argument('--dataset', type=str, default="Cora", choices=available_datasets)
    parser.add_argument('--use_normalization', action='store_false', default=True)
    parser.add_argument('-f') ##dummy for jupyternotebook
    args = parser.parse_args()
    
    dict_args = vars(args)
    
    return args, dict_args

args, dict_args = get_configuration()

## Packages

In [4]:
import torch.nn as nn
import numpy as np
from torch.nn import init
from random import shuffle, randint
import torch.nn.functional as F
from itertools import combinations, combinations_with_replacement
from sklearn.metrics import f1_score, accuracy_score
from sklearn.decomposition import TruncatedSVD
import matplotlib.pyplot as plt
import sys
from torch_geometric.data import Data
import logging
import time

### GCN model

In [5]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn import SAGEConv
from torch_geometric.nn import GCNConv

In [6]:
import random
import numpy as np
import torch

seed = 123

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
None

In [7]:
#https://www.arangodb.com/2021/08/a-comprehensive-case-study-of-graphsage-using-pytorchgeometric/

class SAGE(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, num_layers=2):
        super().__init__()
        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(self.num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            if i != self.num_layers - 1:
                x = F.relu(x)
                #x = F.dropout(x, p=0.5, training=self.training)
                x = F.dropout(x, p=0.2, training=self.training)
        return x.log_softmax(dim=-1)

    @torch.no_grad()
    def inference(self, x_all, device, subgraph_loader):
        
        if args.log_info:
            pbar = tqdm(total=x_all.size(0) * self.num_layers)
            pbar.set_description('Evaluating')

        for i in range(self.num_layers):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                if i != self.num_layers - 1:
                    x = F.relu(x)
                xs.append(x.cpu())
                
                if args.log_info:
                    pbar.update(batch_size)

            x_all = torch.cat(xs, dim=0)
        
        if args.log_info:
            pbar.close()

        return x_all

In [8]:
def train(model, data, epochs=100, train_neighbors=[25,10]):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    #criterion = torch.nn.CrossEntropyLoss()
    criterion = torch.nn.NLLLoss()
    
    if args.log_info:
        print("Train neighbors: ", train_neighbors)
    
    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
    
    
    num_workers = 0 
    if data.num_nodes>100000:
        num_workers=8
    
    train_loader = NeighborSampler(data.edge_index, node_idx=train_idx,
                                   sizes=train_neighbors, batch_size=1024,
                                   shuffle=True, num_workers=0)    
    
    subgraph_loader = NeighborSampler(data.edge_index, node_idx=None,
                                      sizes=[25], batch_size=2048,                                      
                                      shuffle=False, num_workers=0)
    
    
    x, y = data.x.to(device), data.y.to(device)
    data.train_mask.to(device)
    data.val_mask.to(device)
    data.test_mask.to(device)
    
    best_acc=0    
    num_iteration = epochs
    train_losses = []
    
    for epoch in range(1,epochs+1):
        
        if args.log_info:
            pbar = tqdm(total=train_idx.size(0))
            pbar.set_description(f'Epoch {epoch:02d}')
        
        total_loss = total_correct = 0
        model.train()
        
        for batch_size, n_id, adjs in train_loader:
            adjs = [adj.to(device) for adj in adjs]

            optimizer.zero_grad()
            out = model(x[n_id], adjs)
            #loss = F.nll_loss(out, y[n_id[:batch_size]])
            loss = criterion(out, y[n_id[:batch_size]])
            loss.backward()
            optimizer.step()
            
        
            total_loss += float(loss)
            total_correct += int(out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum())
            
            if args.log_info:
                pbar.update(batch_size)
        
        if args.log_info:
            pbar.close()

        loss = total_loss / len(train_loader)
        approx_acc = total_correct / train_idx.size(0)
        train_losses.append(loss)
        
        if args.log_info:
            print(f'Epoch: {epoch:03d}, Training Loss: {loss:.4f}, Training Accuracy: {approx_acc:.4f}')
                
        ####EVALUATION        
        model.eval()
        with torch.no_grad():
            out = model.inference(x, device, subgraph_loader)

        res = out.argmax(dim=-1) == data.y
        train_acc = int(res[data.train_mask].sum()) / int(data.train_mask.sum())
        val_acc = int(res[data.val_mask].sum()) / int(data.val_mask.sum())
        test_acc = int(res[data.test_mask].sum()) / int(data.test_mask.sum())

        if test_acc>best_acc:
            best_acc=test_acc
            
        std_dev = np.std(train_losses[-5:])
        
        if args.log_info:
            print(f'Epoch: {epoch:03d}, Train Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}, Std dev: {std_dev:.4f}')

        if epoch>=5 and std_dev<=1e-3:
            num_iteration = epoch
            
            if args.log_info:                
                print("Iteration for convergence: ", epoch)
            break
                
    return best_acc, num_iteration


In [9]:
def GSAGEperformance(data, dataset, num_classes, epochs=20, train_neighbors=[25,10]):
    model = SAGE(dataset.num_features, num_classes, hidden_channels=256).to(device)        
    if args.log_info: 
        print(model)    
    
    best_acc, num_iteration = train(model, data, epochs, train_neighbors)
    return best_acc, num_iteration, model

### Main function

In [10]:
# args.log_info = True
# DATASET_NAME = 'karate'
# data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=True, split_no=0); print("")
# print(data)
# best_acc, num_iteration, _ = GSAGEperformance(data, dataset, dataset.num_classes, epochs=25, train_neighbors=[25,10])
# print(best_acc, num_iteration)

## Batch Experiments

In [11]:
def batch_experiments(num_run=1):
    
    ALL_DATASETs= [
#         "Roman-empire","Texas","Squirrel","Chameleon",
#         "Cornell","Actor","Wisconsin","Flickr","Amazon-ratings","reed98","amherst41","genius",
        #"AmazonProducts",
#         "cornell5","penn94",
#         "johnshopkins55",
#         #"Yelp",
#         "cora","Tolokers","Minesweeper",
#         "CiteSeer","Computers","PubMed","pubmed",
#         #"Reddit",
#         "cora_ml","dblp",
#         #"Reddit2",
#         "Cora","CS","Photo","Questions","Physics","citeseer",
#         "Reddit", #remove this later
#         "Reddit2",#remove this later
#         "Yelp", #remove this later
#         "AmazonProducts",#remove this later
#         'pokec','arxiv-year','snap-patents','twitch-gamer'
    ]
    
    ALL_DATASETs= ["Reddit0.525","Reddit0.425","Reddit0.325"]
    
    args.log_info = False
    
    for DATASET_NAME in ALL_DATASETs:  
        
        result_file = open("Results/GSAGE.txt",'a+')        
        print(DATASET_NAME, end=' ')
        result_file.write(f'{DATASET_NAME} ')
        
        accs = []
        itrs = []
        
        for i in range(num_run):
            data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=False, split_no=i, random_state=i)            
            if len(data.y.shape) > 1:
                data.y = data.y.argmax(dim=1)        
                num_classes = torch.max(data.y).item()+1
            else:
                num_classes = dataset.num_classes
            
            if num_classes!= torch.max(data.y)+1:
                num_classes = torch.max(data.y).item()+1
            
            if data.num_nodes<100000:
                max_epochs = 150
            else:
                max_epochs = 50
            
            accuracy, itr, _ =  GSAGEperformance(data, dataset, num_classes, epochs=max_epochs, train_neighbors=[8,4])
            accs.append(accuracy)
            itrs.append(itr)
            #print(itr, accuracy)
                        
        #print(accs, itrs)        
        print(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}')
        result_file.write(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f}, itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}\n')
        result_file.close()
                
# batch_experiments(num_run=5)

In [12]:
if __name__ == '__main__':    
    None    

In [13]:
def ablation(num_run = 1):
    
#     ALL_DATASETs= [
#         'Wisconsin',
#         'reed98',        
#         'Roman-empire',
#         'Actor',
#         'Minesweeper',        
#         'Tolokers'
#     ]
    
    ALL_DATASETs= [
        "reed98",
        "amherst41",
        "penn94",
        "cornell5",
        "Squirrel",
        "johnshopkins55",
        "Chameleon",
        "Tolokers",
        "Flickr",
        
        "Computers",
        "Photo",
        "Physics",
        
#         "AmazonProducts",
#         "Yelp",
#         'pokec',
#         'twitch-gamer',
#         'wiki',        
        
#         "Reddit",
#         "Reddit2",
    ]
    
    #ALL_DATASETs= ["karate"]
    
    args.log_info = False
    
    for DATASET_NAME in ALL_DATASETs:  
        
        result_file = open("Results/GSAGEablation.txt",'a+')        
        print(DATASET_NAME, end=' ')
        result_file.write(f'{DATASET_NAME} ')
        
        accs = []
        itrs = []
        
        for i in range(num_run):
            data, dataset = get_data(DATASET_NAME, DIR=None, log=False, h_score=False, split_no=i, random_state=i)            
            if len(data.y.shape) > 1:
                data.y = data.y.argmax(dim=1)        
                num_classes = torch.max(data.y).item()+1
            else:
                num_classes = dataset.num_classes
            
            if num_classes!= torch.max(data.y)+1:
                num_classes = torch.max(data.y).item()+1
            
            if data.num_nodes<100000:
                max_epochs = 500
            else:
                max_epochs = 50
            
            accuracy, itr, _ =  GSAGEperformance(data, dataset, num_classes, epochs=max_epochs, train_neighbors=[8,4])
            accs.append(accuracy)
            itrs.append(itr)
            #print(itr, accuracy)
                        
        #print(accs, itrs)        
        print(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f} itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}')
        result_file.write(f'acc {np.mean(accs):0.4f} sd {np.std(accs):0.4f}, itr {int(np.mean(itrs)):d} sd {int(np.std(itrs)):d}\n')
        result_file.close()            
    
    return 

ablation(num_run=5)

reed98 acc 0.6218 sd 0.0314 itr 81 sd 22
amherst41 acc 0.6559 sd 0.0114 itr 151 sd 84
penn94 acc 0.7570 sd 0.0030 itr 83 sd 28
cornell5 acc 0.7022 sd 0.0052 itr 92 sd 27
Squirrel acc 0.3812 sd 0.0094 itr 175 sd 68
johnshopkins55 acc 0.6612 sd 0.0044 itr 500 sd 0
Chameleon acc 0.5118 sd 0.0161 itr 500 sd 0
Tolokers acc 0.7912 sd 0.0029 itr 34 sd 21
Flickr acc 0.5083 sd 0.0023 itr 198 sd 83
Computers acc 0.9112 sd 0.0076 itr 500 sd 0
Photo acc 0.9612 sd 0.0043 itr 401 sd 166
Physics acc 0.9674 sd 0.0028 itr 41 sd 15
