In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from utils.consensus_node import ConsensusNode
from utils.master_node import MasterNode

import wide_resnet_submodule.config as cf
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim

from wide_resnet_submodule.networks import *
from utils.graphs_config import *
from utils.config import *
from utils.functions import *

In [None]:
num_classes = 10

def get_model(model_name):
    if model_name == 'lenet':
        model = LeNet
        model_args = [num_classes]
    elif model_name == 'vggnet':
        model = VGG
        model_args = [11, num_classes] # VGGnet depth should be either 11, 13, 16, 19
    elif model_name == 'resnet':
        model = ResNet
        model_args = [18, num_classes] # Resnet depth should be either 18, 34, 50, 101, 152
    elif model_name == 'wide-resnet':
        model = Wide_ResNet
        model_args = [28, 10, 0.3, num_classes] # depth, widen_factor, dropout (Wide-resnet depth should be 6n+4)
    else:
        print('Error: Network should be either [LeNet / VGGNet / ResNet / Wide_ResNet', file=sys.stderr)
        exit(0)
    return model, model_args

criterion = nn.CrossEntropyLoss

optimizer = optim.SGD
optimizer_kwargs = {'momentum': 0.9, 'weight_decay': 5e-4}

In [None]:
batch_size = cf.batch_size

dataset_name = 'cifar10'

In [None]:
def get_train_test(topology, seed):
    np.random.seed(seed)
    torch.manual_seed(seed)

    transform_train = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize(cf.mean[dataset_name], cf.std[dataset_name]),
    ]) # meanstd transformation

    transform_test = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(cf.mean[dataset_name], cf.std[dataset_name]),
    ])

    trainset = torchvision.datasets.CIFAR10(root='../data/cifar10',
                                          train=True, download=True,
                                          transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='../data/cifar10',
                                          train=False, download=False,
                                          transform=transform_test)

    n_agents = len(topology)
    indices = [i for i in range(len(trainset))]
    np.random.shuffle(indices)
    indices = indices[:n_agents*(len(trainset) // n_agents)]
    indices = np.array_split(indices, n_agents)
    subsets = [torch.utils.data.Subset(trainset, indices=ind) for ind in indices]
    train_loaders = {node_name: torch.utils.data.DataLoader(subset,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      num_workers=2)
              for node_name, subset in zip(topology, subsets)
              }

    test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

    return train_loaders, test_loader

In [None]:
use_cuda = torch.cuda.is_available()
if use_cuda:
    print("Yes, we use CUDA!")
else:
    print("CUDA is not available. Check your drivers.")

In [None]:
def train_model(stat_funcs, session_name, seed, w_schedule_func, topology, lr_schedule_func=lr_schedule_default, lr=0.02, equal_w=False, is_consensus=True, num_epochs=120, model_name='resnet'):
    train_loaders, test_loader = get_train_test(topology, seed)
    
    n_agents = len(topology)
    epoch_len = min(len(tl) for tl in train_loaders.values())
    
    statistics = {func_name: {node_name: {'values': [], 'epoch': [], 'tmp': 0.0} 
                          for node_name in topology}
              for func_name in stat_funcs}
    
    master = MasterNode(node_names=topology.keys(),
                        weights=topology, # param_a = a*param_a + b*param_b + c*param_c, where a + b + c = 1.0
                        train_loaders=train_loaders, # dict of train loaders, train_loaders[node_name] = train loader for node_name
                        test_loader=test_loader, # general test loader
                        fit_step=fit_batch_cifar, # function(node: ConsensusNode, epoch: Int)
                        update_params=update_params_cifar, # function(node: ConsensusNode)
                        lr=lr,
                        w_schedule=w_schedule_func,
                        lr_schedule=lr_schedule_func,
                        epoch=num_epochs, # number of epochs
                        epoch_len=epoch_len, # length each epoch
                        update_params_epoch_start=0 if is_consensus else num_epochs+1, # the first epoch from which consensus begins
                        update_params_period=1, # consensus iteration period
                        use_cuda=use_cuda,
                        resume_path=None, #f'./checkpoint/{dataset_name}/{session_name}',
                        session=f'{dataset_name}/{session_name}',
                        verbose=0 # verbose mode
                       )
    model, model_args = get_model(model_name)
    master.set_model(model, *model_args)
    master.set_optimizer(optimizer, optimizer_kwargs)
    master.set_error(criterion)
    master.set_stats(stat_funcs=stat_funcs, statistics=statistics)
    master.initialize_nodes()
    if equal_w:
        master.equalize_all_model_params()
    master.start_consensus()
    return master

In [None]:
def get_stat_funcs(topology_name):
    if topology_name == 'LONELY':
        return {'test_accuracy': calc_accuracy_cifar,
              'cumulative_train_loss': get_cumulative_train_loss}
    else:
        return {'test_accuracy': calc_accuracy_cifar,
              'cumulative_train_loss': get_cumulative_train_loss,
              'param_dev': get_flat_params_cifar,
              'self_weight': get_self_weight}

In [None]:
num_epochs = 120
model_name = 'resnet'

#num_epochs = 3
#model_name = 'lenet'

In [None]:
topology =      [ABC_3,   LONELY]
topology_name = ['ABC_3', 'LONELY']

In [None]:
seeds = [0, 13, 42, 1337]
#seeds = [1137]

In [None]:
finished_sessions = []

In [None]:
for top_iter in range(len(topology)):
    top = topology[top_iter]
    top_name = topology_name[top_iter]
    for seed in seeds:
        stat_funcs = get_stat_funcs(top_name)

        if top_name == 'LONELY':
            lr = 0.02
            lr_schedule = lr_schedule_default
        else:
            lr = 0.02 * len(top)
            lr_schedule = lr_schedule_div3

        session_name = f'lr_div3_testing/seed_{seed}/{top_name}/lr_{str(lr)}/weights_log_dec/{model_name}'
        
        if session_name in finished_sessions:
            print(f'Session {session_name} already finished.')
            continue

        print(f'\nSession {session_name} started.')
        train_model(stat_funcs, session_name, seed, weights_schedule_log_decrease, top,
                  lr_schedule_func=lr_schedule,
                  lr=lr,
                  is_consensus=False if top_name == 'LONELY' else True,
                  num_epochs=num_epochs,
                  model_name=model_name
                )
        print(f'\nSession {session_name} ended.')
        finished_sessions.append(session_name)

In [None]:
import pickle5 as pickle

def load_stats(session_name):
    path = f'./checkpoint/{dataset_name}/{session_name}/'

    with open(path + 'statistics.pickle', 'rb') as f:
        data = pickle.load(f)
    return data

def get_cmap(n, name='hsv'):
    '''
    Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.
    '''
    return plt.cm.get_cmap(name, n)

In [None]:
fig, ax = plt.subplots(figsize=(24, 12), ncols=1)
ax.set_xlabel('Epoch', fontsize=16)
ax.set_ylabel('Test accuracy', fontsize=16)
ax.set_title('Lr div3 testing.', fontsize=20, pad=10.0)

cmap = get_cmap(len(finished_sessions) + 1)

for i, session in enumerate(finished_sessions):
    data = load_stats(session)
    arr = session.split('/')
    if arr[2] == 'LONELY':
        values = data['test_accuracy']['Model']['values']
        epochs =  data['test_accuracy']['Model']['epoch']
    else:
        values = data['test_accuracy']['Alice']['values']
        epochs =  data['test_accuracy']['Alice']['epoch']
    ma = max(values)
    ax.plot(epochs, values, label=f'{session}: {ma:.2f}', linestyle='--' if arr[2] != 'LONELY' else '-')

plt.grid(True)
plt.legend(fontsize=16)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(24, 12), ncols=1)
ax.set_xlabel('Epoch', fontsize=16)
ax.set_ylabel('Train_loss', fontsize=16)
ax.set_title('Lr div3 testing.', fontsize=20, pad=10.0)

cmap = get_cmap(len(finished_sessions) + 1)

for i, session in enumerate(finished_sessions):
    data = load_stats(session)
    arr = session.split('/')
    if arr[2] == 'LONELY':
        values = data['cumulative_train_loss']['Model']['values']
        epochs =  data['cumulative_train_loss']['Model']['epoch']
    else:
        values = data['cumulative_train_loss']['Alice']['values']
        epochs =  data['cumulative_train_loss']['Alice']['epoch']

    ax.plot(epochs, values, label=f'{session}', linestyle='--' if arr[2] != 'LONELY' else '-')

plt.yscale('log')
plt.grid(True)
plt.legend(fontsize=16)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(24, 12), ncols=1)
ax.set_xlabel('Epoch', fontsize=16)
ax.set_ylabel('Param_dev', fontsize=16)
ax.set_title('Lr div3 testing.', fontsize=20, pad=10.0)

cmap = get_cmap(len(finished_sessions) + 1)

for i, session in enumerate(finished_sessions):
    data = load_stats(session)
    arr = session.split('/')
    if arr[2] != 'LONELY':
        values = data['param_dev']['Alice']['values']
        epochs = data['param_dev']['Alice']['epoch']
        ax.plot(epochs, values, label=f'{session}', linestyle='--' if arr[2] != 'LONELY' else '-')

plt.yscale('log')
plt.grid(True)
plt.legend(fontsize=16)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(24, 12), ncols=1)
ax.set_xlabel('Epoch', fontsize=16)
ax.set_ylabel('Self-weight', fontsize=16)
ax.set_title('Lr div3 testing.', fontsize=20, pad=10.0)

cmap = get_cmap(len(finished_sessions) + 1)

for i, session in enumerate(finished_sessions):
    data = load_stats(session)
    arr = session.split('/')
    if arr[2] != 'LONELY':
        values = data['self_weight']['Alice']['values']
        epochs = data['self_weight']['Alice']['epoch']
        ax.plot(epochs, values, label=f'{session}', linestyle='--' if arr[2] != 'LONELY' else '-')

plt.yscale('log')
plt.grid(True)
plt.legend(fontsize=16)
plt.show()