In [None]:
from utils import *
import argparse
import multiprocessing as mp

calculate_grad_vars = False

In [None]:
import sys; sys.argv=['']; del sys
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
"""
Dataset arguments
"""
parser = argparse.ArgumentParser(
    description='Training GCN on Large-scale Graph Datasets')
parser.add_argument('--dataset', type=str, default='reddit',
                    help='Dataset name: pubmed/flickr/reddit/ppi-large')
parser.add_argument('--sample_method', type=str, default='full',
                    help='Sampled Algorithms: full/ladies/fastgcn/graphsage/exact/graphsaint/vrgcn')
parser.add_argument('--nhid', type=int, default=256,
                    help='Hidden state dimension')
parser.add_argument('--epoch_num', type=int, default=200,
                    help='Number of Epoch')
parser.add_argument('--pool_num', type=int, default=10,
                    help='Number of Pool')
parser.add_argument('--batch_num', type=int, default=10,
                    help='Maximum Batch Number')
parser.add_argument('--batch_size', type=int, default=512,
                    help='size of output node in a batch')
parser.add_argument('--large_batch_size', type=int, default=81920,
                    help='size of output node in a batch')
parser.add_argument('--n_layers', type=int, default=2,
                    help='Number of GCN layers')
parser.add_argument('--n_stops', type=int, default=200,
                    help='Early stops')
parser.add_argument('--dropout', type=float, default=0,
                    help='Dropout rate')
parser.add_argument('--cuda', type=int, default=1,
                    help='Avaiable GPU ID')
parser.add_argument('--save_prefix', type=str, default='exps',
                    help='Save file prefix')
parser.add_argument('--run_options', type=str, default='True-True-True',
                    help='Run Vanilla? Zeroth? Doubly?')
parser.add_argument('--dist_bound', type=float, default=0.1, # sometime 0.01 also work well
                    help='Restart if the different is large')
parser.add_argument('--use_SGD', type=str, default='False', 
                    help='Whether use SGD?')                 
args = parser.parse_args()
print(args)

vanilla, zeroth_order, doubly_order = [option=='True' for option in args.run_options.split('-')]
sample_method_list = args.sample_method.split('/')
dist_bound = args.dist_bound
print(vanilla, zeroth_order, doubly_order, dist_bound)

In [None]:
"""
Prepare devices
"""
if args.cuda != -1:
    device = torch.device("cuda:" + str(args.cuda))
else:
    device = torch.device("cpu")
    
"""
Prepare data using multi-process
"""
def prepare_data(pool, sampler, process_ids, candidate_nodes, samp_num_list, num_nodes, lap_matrix, lap_matrix_sq, depth):
    jobs = []
    for _ in process_ids:
        batch_nodes = np.random.permutation(candidate_nodes)[:args.batch_size]
        p = pool.apply_async(sampler, args=(np.random.randint(2**32 - 1), batch_nodes,
                                            samp_num_list, num_nodes, lap_matrix, lap_matrix_sq, depth))
        jobs.append(p)
    return jobs

lap_matrix, labels, feat_data, train_nodes, valid_nodes, test_nodes = preprocess_data(args.dataset, False)

print("Dataset information")
print(lap_matrix.shape, labels.shape, feat_data.shape,
      train_nodes.shape, valid_nodes.shape, test_nodes.shape)
if type(feat_data) == sp.lil.lil_matrix:
    feat_data = torch.FloatTensor(feat_data.todense()).to(device)
else:
    feat_data = torch.FloatTensor(feat_data).to(device)

In [None]:
"""
Setup datasets and models for training (multi-class use sigmoid+binary_cross_entropy, use softmax+nll_loss otherwise)
"""

if args.dataset in ['cora', 'citeseer', 'pubmed', 'flickr', 'reddit']:
    from model import GCN
    from optimizers import sgcn_first, sgcn_zeroth, sgcn_doubly, sgd_step, full_step, VRGCN_step, VRGCN_doubly
    from optimizers import ForwardWrapper, VRGCNWrapper, package_mxl
    from samplers import fastgcn_sampler, ladies_sampler, graphsage_sampler, exact_sampler, full_batch_sampler, graphsaint_sampler, vrgcn_sampler
    labels = torch.LongTensor(labels).to(device)
    num_classes = labels.max().item()+1
elif args.dataset in ['ppi', 'ppi-large', 'amazon', 'yelp']:
    from model_mc import GCN
    from optimizers_mc import sgcn_first, sgcn_zeroth, sgcn_doubly, sgd_step, full_step, VRGCN_step, VRGCN_doubly
    from optimizers_mc import ForwardWrapper, VRGCNWrapper, package_mxl
    from samplers_sage_support import fastgcn_sampler, ladies_sampler, graphsage_sampler, exact_sampler, full_batch_sampler, graphsaint_sampler, vrgcn_sampler
    labels = torch.FloatTensor(labels).to(device)
    num_classes = labels.shape[1]
    

def calculate_grad_variance(net, feat_data, labels, train_nodes, adjs_full):
    net_grads = []
    for p_net in net.parameters():
        net_grads.append(p_net.grad.data)
    clone_net = copy.deepcopy(net)
    _, _ = clone_net.calculate_loss_grad(
        feat_data, adjs_full, labels, train_nodes)

    clone_net_grad = []
    for p_net in clone_net.parameters():
        clone_net_grad.append(p_net.grad.data)
    del clone_net

    variance = 0.0
    for g1, g2 in zip(net_grads, clone_net_grad):
        variance += (g1-g2).norm(2) ** 2
    variance = torch.sqrt(variance)
    return variance

def sample_large_batch(args, train_nodes, samp_num_list, num_nodes, lap_matrix, lap_matrix_sq, depth):
    batch_nodes = np.random.permutation(train_nodes)[:args.large_batch_size]
    adjs, input_nodes, output_nodes, sampled_nodes = exact_sampler(np.random.randint(2**32 - 1), 
                                                             batch_nodes, samp_num_list, num_nodes, lap_matrix, lap_matrix_sq, depth)
    return adjs, input_nodes, output_nodes, sampled_nodes

In [None]:
"""
This is a zeroth-order and first-order variance reduced version of SGCN++
"""

def algorithm_sgcn_doubly(feat_data, labels, lap_matrix,
                          train_nodes, valid_nodes, test_nodes,  
                          args, device, calculate_grad_vars=False):
    memory_allocated, max_memory_allocated = [], []
    lap_matrix_sq = lap_matrix.multiply(lap_matrix)
    
    # use multiprocess sample data
    process_ids = np.arange(args.batch_num)

    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,
                 layers=args.n_layers, dropout=args.dropout).to(device)
    susage.to(device)
    
    print(susage)

    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(
        train_nodes, len(feat_data), lap_matrix, args.n_layers)
    adjs_full = package_mxl(adjs_full, device)

    forward_wrapper = ForwardWrapper(
        len(feat_data), args.nhid, args.n_layers, num_classes)

    optimizer = optim.Adam(susage.parameters()) if args.use_SGD=='False' else optim.SGD(susage.parameters(), lr=0.7)

    best_model = copy.deepcopy(susage)
    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)
    cnt = 0
    
    wall_clock_time = [0]
    loss_train = [best_val_loss]
    loss_test = [best_val_loss]
    grad_variance_all = []
    loss_train_all = [best_val_loss]
    
    

    for epoch in np.arange(args.epoch_num):
        # create large batch
        large_batch_sample_time_start = time.perf_counter()
        
        large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes =  sample_large_batch(args, train_nodes, 
                                                                                                     large_samp_num_list, len(feat_data),
                                                                                                     lap_matrix, lap_matrix_sq, args.n_layers)
        large_batch_transfer_time_start = time.perf_counter()
        large_adjs = package_mxl(large_adjs, device)
        large_batch_transfer_time = time.perf_counter() - large_batch_transfer_time_start
        
        large_batch_sample_time = time.perf_counter() - large_batch_sample_time_start
        
        
        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        
        # prepare next epoch train data
        mini_batch_sample_time_start = time.perf_counter()
        pool = mp.Pool(args.pool_num)
        jobs = prepare_data(pool, sampler, process_ids, large_output_nodes, samp_num_list, len(feat_data),
                            lap_matrix, lap_matrix_sq, args.n_layers)
        # fetch train data
        train_data = [job.get() for job in jobs]
        pool.close()
        pool.join()
        mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start

        inner_loop_num = args.batch_num
        calculate_grad_vars = calculate_grad_vars and epoch<20
        cur_train_loss, cur_train_loss_all, grad_variance, time_counter = sgcn_doubly(susage, optimizer, feat_data, labels,
                                                                        train_nodes, valid_nodes, 
                                                                        large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes,
                                                                        train_data, inner_loop_num, forward_wrapper, device, dist_bound=dist_bound, #2e-4
                                                                        calculate_grad_vars=calculate_grad_vars)
        compute_time = time_counter['compute_time']
        transfer_time = time_counter['transfer_time']
        
        epoch_time_counter = {
            'large_batch_sample_time': large_batch_sample_time,
            'large_batch_transfer_time': large_batch_transfer_time,
            'mini_batch_sample_time': mini_batch_sample_time,
            'compute_time': time_counter['compute_time'],
            'mini_batch_transfer_time': time_counter['transfer_time'],
        }
        wall_clock_time.append(epoch_time_counter)
        loss_train_all.extend(cur_train_loss_all)
        grad_variance_all.extend(grad_variance)
        
        # calculate validate loss
        susage.eval()

        susage.zero_grad()
        val_loss, _ = susage.calculate_loss_grad(
            feat_data, adjs_full, labels, valid_nodes)

        if val_loss < best_val_loss:
            best_model = copy.deepcopy(susage)
            best_val_loss = val_loss
            cnt = 0
        else:
            cnt += 1
            
        if cnt == args.n_stops//args.batch_num:
            break

        cur_test_loss = val_loss

        loss_train.append(cur_train_loss)
        loss_test.append(cur_test_loss)
        
        
        # print progress
        print('Epoch: ', epoch,
              '| train loss: %.8f' % cur_train_loss,
              '| test loss: %.8f' % cur_test_loss)
    
    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)
    print('f1_score_test', f1_score_test)
    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time





"""
This is just an unchanged SGCN 
"""


def sgcn(feat_data, labels, lap_matrix, 
         train_nodes, valid_nodes, test_nodes,  
         args, device, calculate_grad_vars=False, full_batch=False):
    memory_allocated, max_memory_allocated = [], []
    
    # use multiprocess sample data
    process_ids = np.arange(args.batch_num)
    lap_matrix_sq = lap_matrix.multiply(lap_matrix)

    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,
                 layers=args.n_layers, dropout=args.dropout).to(device)
    susage.to(device)

    print(susage)

    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(
        train_nodes, len(feat_data), lap_matrix, args.n_layers)
    adjs_full = package_mxl(adjs_full, device)

    optimizer = optim.Adam(susage.parameters()) if args.use_SGD=='False' else optim.SGD(susage.parameters(), lr=0.7)

    best_model = copy.deepcopy(susage)
    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)
    cnt = 0
    
    wall_clock_time = [0]
    loss_train = [best_val_loss]
    loss_test = [best_val_loss]
    grad_variance_all = []
    loss_train_all = [best_val_loss]


    for epoch in np.arange(args.epoch_num):
        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        
        inner_loop_num = args.batch_num

        # it can also run full-batch GD by ignoring all the samplings
        
        if full_batch:
            mini_batch_sample_time = 0
            cur_train_loss, cur_train_loss_all, grad_variance, time_counter = full_step(susage, optimizer, feat_data, labels,
                                              train_nodes, valid_nodes,
                                              adjs_full, inner_loop_num, device, 
                                              calculate_grad_vars=calculate_grad_vars)
        else:
            # prepare next epoch train data
            mini_batch_sample_time_start = time.perf_counter()
            pool = mp.Pool(args.pool_num)
            jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list, len(feat_data),
                                lap_matrix, lap_matrix_sq, args.n_layers)
            # fetch train data
            train_data = [job.get() for job in jobs]
            pool.close()
            pool.join()
            mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start
            
            cur_train_loss, cur_train_loss_all, grad_variance, time_counter = sgd_step(susage, optimizer, feat_data, labels,
                                              train_nodes, valid_nodes,
                                              adjs_full, train_data, inner_loop_num, device, 
                                              calculate_grad_vars=calculate_grad_vars)
        
        epoch_time_counter = {
            'mini_batch_sample_time': mini_batch_sample_time,
            'compute_time': time_counter['compute_time'],
            'mini_batch_transfer_time': time_counter['transfer_time'],
        }
        
        wall_clock_time.append(epoch_time_counter)
        
        loss_train_all.extend(cur_train_loss_all)
        grad_variance_all.extend(grad_variance)
        
        # calculate test loss
        susage.eval()

        susage.zero_grad()
        val_loss, _ = susage.calculate_loss_grad(
            feat_data, adjs_full, labels, valid_nodes)

        if val_loss < best_val_loss:
            best_model = copy.deepcopy(susage)
            best_val_loss = val_loss
            cnt = 0
        else:
            cnt += 1
            
        if cnt == args.n_stops//args.batch_num:
            break
            
        cur_test_loss = val_loss

        loss_train.append(cur_train_loss)
        loss_test.append(cur_test_loss)
        
        # print progress
        print('Epoch: ', epoch,
              '| train loss: %.8f' % cur_train_loss,
              '| test loss: %.8f' % cur_test_loss)
    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)
    print('f1_score_test', f1_score_test)
    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time

def algorithm_vrgcn(feat_data, labels, lap_matrix, 
                    train_nodes, valid_nodes, test_nodes,  
                    args, device, calculate_grad_vars=False):
    memory_allocated, max_memory_allocated = [], []
    
    # use multiprocess sample data
    process_ids = np.arange(args.batch_num)
    lap_matrix_sq = lap_matrix.multiply(lap_matrix)
    
    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,
                 layers=args.n_layers, dropout=args.dropout).to(device)
    susage.to(device)

    print(susage)

    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(
        train_nodes, len(feat_data), lap_matrix, args.n_layers)
    adjs_full = package_mxl(adjs_full, device)

    forward_wrapper = VRGCNWrapper(
        len(feat_data), args.nhid, args.n_layers, num_classes)

    optimizer = optim.Adam(susage.parameters()) if args.use_SGD=='False' else optim.SGD(susage.parameters(), lr=0.7)

    best_model = copy.deepcopy(susage)
    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)
    cnt = 0
    
    wall_clock_time = [0]
    loss_train = [best_val_loss]
    loss_test = [best_val_loss]
    grad_variance_all = []
    loss_train_all = [best_val_loss]

    for epoch in np.arange(args.epoch_num):
        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        
        # prepare next epoch train data
        mini_batch_sample_time_start = time.perf_counter()
        start_time = time.time()
        pool = mp.Pool(args.pool_num)
        jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list, len(feat_data),
                            lap_matrix, lap_matrix_sq, args.n_layers)
        # fetch train data
        train_data = [job.get() for job in jobs]
        pool.close()
        pool.join()
        mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start

        inner_loop_num = args.batch_num
        calculate_grad_vars = calculate_grad_vars and epoch<20
        
        cur_train_loss, cur_train_loss_all, grad_variance, time_counter = VRGCN_step(susage, optimizer, feat_data, labels,
                                                                       train_nodes, valid_nodes, adjs_full,
                                                                       train_data, inner_loop_num, forward_wrapper, device,
                                                                       calculate_grad_vars=calculate_grad_vars)
        compute_time = time_counter['compute_time']
        transfer_time = time_counter['transfer_time']
        
        epoch_time_counter = {
            'mini_batch_sample_time': mini_batch_sample_time,
            'compute_time': time_counter['compute_time'],
            'mini_batch_transfer_time': time_counter['transfer_time'],
        }
        wall_clock_time.append(epoch_time_counter)
        
        loss_train_all.extend(cur_train_loss_all)
        grad_variance_all.extend(grad_variance)
        # calculate validate loss
        susage.eval()

        susage.zero_grad()
        val_loss, _ = susage.calculate_loss_grad(
            feat_data, adjs_full, labels, valid_nodes)

        if val_loss < best_val_loss:
            best_model = copy.deepcopy(susage)
            best_val_loss = val_loss
            cnt = 0
        else:
            cnt += 1
            
        if cnt == args.n_stops//args.batch_num:
            break

        cur_test_loss = val_loss

        loss_train.append(cur_train_loss)
        loss_test.append(cur_test_loss)
        # print progress
        print('Epoch: ', epoch,
              '| train loss: %.8f' % cur_train_loss,
              '| test loss: %.8f' % cur_test_loss)

    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)
    print('f1_score_test', f1_score_test)
    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time

def algorithm_vrgcn_doubly(feat_data, labels, lap_matrix, 
                          train_nodes, valid_nodes, test_nodes,  
                          args, device, calculate_grad_vars=False):
    memory_allocated, max_memory_allocated = [], []
    
    # use multiprocess sample data
    process_ids = np.arange(args.batch_num)
    lap_matrix_sq = lap_matrix.multiply(lap_matrix)

    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,
                 layers=args.n_layers, dropout=args.dropout).to(device)
    susage.to(device)

    print(susage)

    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(
        train_nodes, len(feat_data), lap_matrix, args.n_layers)
    adjs_full = package_mxl(adjs_full, device)
    
    forward_wrapper = VRGCNWrapper(len(feat_data), args.nhid, args.n_layers, num_classes)
    
    optimizer = optim.Adam(susage.parameters()) if args.use_SGD=='False' else optim.SGD(susage.parameters(), lr=0.7)

    best_model = copy.deepcopy(susage)
    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)
    cnt = 0
    
    wall_clock_time = [0]
    loss_train = [best_val_loss]
    loss_test = [best_val_loss]
    grad_variance_all = []
    loss_train_all = [best_val_loss]

    for epoch in np.arange(args.epoch_num):
        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        
        # prepare next epoch train data
        mini_batch_sample_time_start = time.perf_counter()
        pool = mp.Pool(args.pool_num)
        jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list, len(feat_data),
                            lap_matrix, lap_matrix_sq, args.n_layers)
        # fetch train data
        train_data = [job.get() for job in jobs]
        pool.close()
        pool.join()
        mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start

        inner_loop_num = args.batch_num
        calculate_grad_vars = calculate_grad_vars and epoch<20
        
        cur_train_loss, cur_train_loss_all, grad_variance, time_counter = VRGCN_doubly(susage, optimizer, feat_data, labels,
                                                         train_nodes, valid_nodes,
                                                         adjs_full, train_data, inner_loop_num, forward_wrapper, device,
                                                         calculate_grad_vars=calculate_grad_vars)
        compute_time = time_counter['compute_time']
        transfer_time = time_counter['transfer_time']

        epoch_time_counter = {
            'mini_batch_sample_time': mini_batch_sample_time,
            'compute_time': time_counter['compute_time'],
            'mini_batch_transfer_time': time_counter['transfer_time'],
        }

        wall_clock_time.append(epoch_time_counter)
        
        loss_train_all.extend(cur_train_loss_all)
        grad_variance_all.extend(grad_variance)
        # calculate test loss
        susage.eval()

        susage.zero_grad()
        val_loss, _ = susage.calculate_loss_grad(
            feat_data, adjs_full, labels, valid_nodes)

        if val_loss < best_val_loss:
            best_model = copy.deepcopy(susage)
            best_val_loss = val_loss
            cnt = 0
        else:
            cnt += 1
            
        if cnt == args.n_stops//args.batch_num:
            break

        cur_test_loss = val_loss

        loss_train.append(cur_train_loss)
        loss_test.append(cur_test_loss)

        # print progress
        print('Epoch: ', epoch,
              '| train loss: %.8f' % cur_train_loss,
              '| test loss: %.8f' % cur_test_loss)

    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)
    print('f1_score_test', f1_score_test)
    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time

"""
This is a zeroth-order variance reduced version of SGCN+
"""

def algorithm_sgcn_zeroth(feat_data, labels, lap_matrix, 
                          train_nodes, valid_nodes, test_nodes,  
                          args, device, calculate_grad_vars=False):
    memory_allocated, max_memory_allocated = [], []
    
    # use multiprocess sample data
    process_ids = np.arange(args.batch_num)
    lap_matrix_sq = lap_matrix.multiply(lap_matrix)

    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,
                 layers=args.n_layers, dropout=args.dropout).to(device)
    susage.to(device)

    print(susage)

    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(
        train_nodes, len(feat_data), lap_matrix, args.n_layers)
    adjs_full = package_mxl(adjs_full, device)

    # this stupid wrapper is only used for sgcn++
    forward_wrapper = ForwardWrapper(
        len(feat_data), args.nhid, args.n_layers, num_classes)

    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, susage.parameters()))

    best_model = copy.deepcopy(susage)
    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)
    cnt = 0
    
    wall_clock_time = [0]
    loss_train = [best_val_loss]
    loss_test = [best_val_loss]
    grad_variance_all = []
    loss_train_all = [best_val_loss]

    for epoch in np.arange(args.epoch_num):
        # create large batch
        large_batch_sample_time_start = time.perf_counter()
        large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes =  sample_large_batch(args, train_nodes, 
                                                                                                     large_samp_num_list, len(feat_data),
                                                                                                     lap_matrix, lap_matrix_sq, args.n_layers)
        large_batch_sample_time = time.perf_counter() - large_batch_sample_time_start

        large_batch_transfer_time_start = time.perf_counter()
        large_adjs = package_mxl(large_adjs, device)
        large_batch_transfer_time = time.perf_counter() - large_batch_transfer_time_start
        
        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        
        # prepare next epoch train data
        mini_batch_sample_time_start = time.perf_counter()
        start_time = time.time()
        pool = mp.Pool(args.pool_num)
        jobs = prepare_data(pool, sampler, process_ids, large_output_nodes, samp_num_list, len(feat_data),
                            lap_matrix, lap_matrix_sq, args.n_layers)
        # fetch train data
        train_data = [job.get() for job in jobs]
        pool.close()
        pool.join()
        mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start

        inner_loop_num = args.batch_num
        calculate_grad_vars = calculate_grad_vars and epoch<20
        cur_train_loss, cur_train_loss_all, grad_variance, time_counter = sgcn_zeroth(susage, optimizer, feat_data, labels,
                                                                        train_nodes, valid_nodes, 
                                                                        large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes,
                                                                        train_data, inner_loop_num, forward_wrapper, device, dist_bound=dist_bound,
                                                                        calculate_grad_vars=calculate_grad_vars)
        compute_time = time_counter['compute_time']
        transfer_time = time_counter['transfer_time']
        
        epoch_time_counter = {
            'large_batch_sample_time': large_batch_sample_time,
            'large_batch_transfer_time': large_batch_transfer_time,
            'mini_batch_sample_time': mini_batch_sample_time,
            'compute_time': time_counter['compute_time'],
            'mini_batch_transfer_time': time_counter['transfer_time'],
        }
        wall_clock_time.append(epoch_time_counter)
        
        loss_train_all.extend(cur_train_loss_all)
        grad_variance_all.extend(grad_variance)
        # calculate validate loss
        susage.eval()

        susage.zero_grad()
        val_loss, _ = susage.calculate_loss_grad(
            feat_data, adjs_full, labels, valid_nodes)

        if val_loss < best_val_loss:
            best_model = copy.deepcopy(susage)
            best_val_loss = val_loss
            cnt = 0
        else:
            cnt += 1
            
        if cnt == args.n_stops//args.batch_num:
            break

        cur_test_loss = val_loss

        loss_train.append(cur_train_loss)
        loss_test.append(cur_test_loss)
        
        # print progress
        print('Epoch: ', epoch,
              '| train loss: %.8f' % cur_train_loss,
              '| test loss: %.8f' % cur_test_loss)

    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)
    print('f1_score_test', f1_score_test)
    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time

"""
This is a first-order variance reduced version of SGCN++
"""

def algorithm_sgcn_first(feat_data, labels, lap_matrix, 
                         train_nodes, valid_nodes, test_nodes,  
                         args, device, calculate_grad_vars=False):
    memory_allocated, max_memory_allocated = [], []
    
    # use multiprocess sample data
    process_ids = np.arange(args.batch_num)
    lap_matrix_sq = lap_matrix.multiply(lap_matrix)

    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,
                 layers=args.n_layers, dropout=args.dropout).to(device)
    susage.to(device)

    print(susage)

    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(
        train_nodes, len(feat_data), lap_matrix, args.n_layers)
    adjs_full = package_mxl(adjs_full, device)

    optimizer = optim.Adam(susage.parameters()) if args.use_SGD=='False' else optim.SGD(susage.parameters(), lr=0.7)

    best_model = copy.deepcopy(susage)
    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)
    cnt = 0
    
    wall_clock_time = [0]
    loss_train = [best_val_loss]
    loss_test = [best_val_loss]
    grad_variance_all = []
    loss_train_all = [best_val_loss]

    for epoch in np.arange(args.epoch_num):
        # create large batch
        large_batch_sample_time_start = time.perf_counter()
        large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes =  sample_large_batch(args, train_nodes, 
                                                                                                     large_samp_num_list, len(feat_data),
                                                                                                     lap_matrix, lap_matrix_sq, args.n_layers)
        large_batch_sample_time = time.perf_counter() - large_batch_sample_time_start
        
        large_batch_transfer_time_start = time.perf_counter()
        large_adjs = package_mxl(large_adjs, device)
        large_batch_transfer_time = time.perf_counter() - large_batch_transfer_time_start
        
        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]
        
        # prepare next epoch train data
        mini_batch_sample_time_start = time.perf_counter()
        start_time = time.time()
        pool = mp.Pool(args.pool_num)
        jobs = prepare_data(pool, sampler, process_ids, large_output_nodes, samp_num_list, len(feat_data),
                            lap_matrix, lap_matrix_sq, args.n_layers)
        # fetch train data
        train_data = [job.get() for job in jobs]
        pool.close()
        pool.join()
        mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start

        inner_loop_num = args.batch_num
        calculate_grad_vars = calculate_grad_vars and epoch<20
        cur_train_loss, cur_train_loss_all, grad_variance, time_counter = sgcn_first(susage, optimizer, feat_data, labels,
                                                         train_nodes, valid_nodes,
                                                         large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes,
                                                         train_data, inner_loop_num, device, dist_bound=dist_bound,
                                                         calculate_grad_vars=calculate_grad_vars)
        compute_time = time_counter['compute_time']
        transfer_time = time_counter['transfer_time']
        
        epoch_time_counter = {
            'large_batch_sample_time': large_batch_sample_time,
            'large_batch_transfer_time': large_batch_transfer_time,
            'mini_batch_sample_time': mini_batch_sample_time,
            'compute_time': time_counter['compute_time'],
            'mini_batch_transfer_time': time_counter['transfer_time'],
        }
        wall_clock_time.append(epoch_time_counter)        
        loss_train_all.extend(cur_train_loss_all)
        grad_variance_all.extend(grad_variance)
        # calculate test loss
        susage.eval()

        susage.zero_grad()
        val_loss, _ = susage.calculate_loss_grad(
            feat_data, adjs_full, labels, valid_nodes)

        if val_loss < best_val_loss:
            best_model = copy.deepcopy(susage)
            best_val_loss = val_loss
            cnt = 0
        else:
            cnt += 1
            
        if cnt == args.n_stops//args.batch_num:
            break

        cur_test_loss = val_loss

        loss_train.append(cur_train_loss)
        loss_test.append(cur_test_loss)
        
        # print progress
        print('Epoch: ', epoch,
              '| train loss: %.8f' % cur_train_loss,
              '| test loss: %.8f' % cur_test_loss)

    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)
    print('f1_score_test', f1_score_test)
    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time

In [None]:
fn = './results/%s_%s_results.pkl'%(args.save_prefix, args.dataset)
if not os.path.exists(fn):
    results = dict()
else:
    with open(fn, 'rb') as f:
        results = pkl.load(f)

In [None]:
# ###########################################################################################
# ########################################### Full ##########################################
# ###########################################################################################
if 'full' in sample_method_list:
    st = time.time()
    print('>>> Full')
    susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(
                feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=True)
    results['fullgcn'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
    print('fullgcn', time.time() - st)

    with open(fn, 'wb') as f:
        pkl.dump(results, f)

In [None]:
# ###########################################################################################
# ########################################### LADIES ########################################
# ###########################################################################################
if 'ladies' in sample_method_list:
    sampler = ladies_sampler
    samp_num = 512
    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])
    large_samp_num = samp_num*10
    large_samp_num_list = np.array([large_samp_num for _ in range(args.n_layers)])

    if doubly_order:
        st = time.time()
        print('>>> ladies_doubly')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_doubly(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['ladies_doubly'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('ladies_doubly', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

    if vanilla:
        st = time.time()
        print('>>> ladies')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=False)
        results['ladies'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('ladies', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

    if zeroth_order:
        st = time.time()
        print('>>> ladies_zeroth')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_zeroth(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['ladies_zeroth'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('ladies_zeroth', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

In [None]:
###########################################################################################
########################################### FastGCN #######################################
###########################################################################################
if 'fastgcn' in sample_method_list:
    sampler = fastgcn_sampler
    samp_num = 4096
    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])
    large_samp_num = samp_num*10
    large_samp_num_list = np.array([large_samp_num for _ in range(args.n_layers)])
    
    if doubly_order:
        st = time.time()
        print('>>> fastgcn_doubly')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_doubly(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['fastgcn_doubly'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('fastgcn_doubly', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

    if vanilla:
        st = time.time()
        print('>>> fastgcn')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=False)
        results['fastgcn'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('fastgcn', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

    if zeroth_order:
        st = time.time()
        print('>>> fastgcn_zeroth')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_zeroth(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['fastgcn_zeroth'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('fastgcn_zeroth', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

In [None]:
###########################################################################################
########################################### GraphSaint ####################################
###########################################################################################
if 'graphsaint' in sample_method_list:
    sampler = graphsaint_sampler
    samp_num = 2048
    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])
    large_samp_num = samp_num*10
    large_samp_num_list = np.array([large_samp_num for _ in range(args.n_layers)])

    if doubly_order:
        st = time.time()
        print('>>> graphsaint_doubly')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_doubly(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['graphsaint_doubly'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('graphsaint_doubly', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

    if vanilla:      
        st = time.time()
        print('>>> graphsaint')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=False)
        results['graphsaint'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('graphsaint', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

    if zeroth_order:
        st = time.time()
        print('>>> graphsaint_zeroth')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_zeroth(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['graphsaint_zeroth'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('graphsaint_zeroth', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

In [None]:
###########################################################################################
########################################### GraphSage #####################################
###########################################################################################
if 'graphsage' in sample_method_list:
    sampler = graphsage_sampler
    samp_num = 5
    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])
    large_samp_num = samp_num*10
    large_samp_num_list = np.array([large_samp_num for _ in range(args.n_layers)])

    if doubly_order:
        st = time.time()
        print('>>> graphsage_doubly')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_doubly(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['graphsage_doubly'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('graphsage_doubly', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

    if vanilla:
        st = time.time()
        print('>>> graphsage')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=False)
        results['graphsage'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('graphsage', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

    if zeroth_order:
        st = time.time()
        print('>>> graphsage_zeroth')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_zeroth(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['graphsage_zeroth'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('graphsage_zeroth', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

In [None]:
###########################################################################################
########################################### Exact #########################################
###########################################################################################
if 'exact' in sample_method_list:
    sampler = exact_sampler
    samp_num = 0
    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])
    large_samp_num = samp_num*10
    large_samp_num_list = np.array([large_samp_num for _ in range(args.n_layers)])
    
    if vanilla:
        st = time.time()
        print('>>> exact')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=False)
        results['exact'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('exact', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

    if doubly_order:
        st = time.time()
        print('>>> exact_first')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_first(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['exact_first'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('exact_first', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

In [None]:
###########################################################################################
########################################### VRGCN #########################################
###########################################################################################
if 'vrgcn' in sample_method_list:
    samp_num = 2
    sampler = vrgcn_sampler
    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])

    if doubly_order:
        st = time.time()
        print('>>> vrgcn_doubly')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_vrgcn_doubly(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['vrgcn_doubly'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('vrgcn_doubly', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)

    if vanilla:
        st = time.time()
        print('>>> vrgcn_zeroth')
        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_vrgcn(
            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)
        results['vrgcn_zeroth'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]
        print('vrgcn_zeroth', time.time() - st)

        with open(fn, 'wb') as f:
            pkl.dump(results, f)