In [None]:
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

In [None]:
import sys
sys.path.insert(1, os.path.join(sys.path[0], '../'))

In [None]:
import torch
import torchvision

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import random
import numpy as np

# Set a seed value
# seed = 42 
# os.environ['PYTHONHASHSEED']=str(seed)
# torch.manual_seed(0)
# random.seed(seed)
# np.random.seed(seed)

In [None]:
import copy
import matplotlib.pyplot as plt
import model

import utils
from training import get_config, get_config_fedir, ImgDataset, evaluate_model
import federated_learning as fl

In [None]:
# Define dataset and settings.

dataset_name = 'cifar_10'
# dataset_name = 'cifar_100'
# dataset_name = 'tiny_imagenet_200'
# dataset_name = 'svhn'
# dataset_name = 'fashion_mnist'

# Define distribution settings.
num_clients = 100  # [10, 50, 100]
if num_clients == 10:
    num_participation = 10
elif num_clients == 50:
    num_participation = 10
elif num_clients == 100:
    num_participation = 20
client_idxes = list(range(num_clients))

beta = 0.5 # [0.1, 0.5, 5]

client_data_dir = os.path.join('./client_data/', dataset_name + '_c_{}_beta_{}'.format(num_clients, beta))

In [None]:
data_config, train_config = get_config(dataset_name)

# Data config.
img_size = data_config['img_size']
channels = data_config['channels']
batch_size = data_config['batch_size']
train_transform = data_config['train_transform']
test_transform = data_config['test_transform']

# Training config.
fedir_args = get_config_fedir()
num_rounds = train_config['rounds']
num_local_epochs = train_config['local_epochs']
save_interval = train_config['save_interval']

fedir_args['mu'] = 10   # [0.001, 0.01, 0.1, 0.5, 1, 3, 5, 10] 

fedir_args['loss'] = 'contrastive'      # ['contrastive', 'mse', 'neg_cosine']
fedir_args['scaling'] = 'cosine'        # ['avg', 'cosine', 'cka']

optim = train_config['optim']      # ['sgd', 'adam']
optim_args = None
if optim == 'sgd':
    optim_args = {
        'lr': train_config['lr'], 
        'weight_decay': train_config['weight_decay'], 
        'momentum' : train_config['momentum'],
    }
elif optim == 'adam':
    optim_args = {'lr': train_config['lr']} 

# b: beta, le: num_local_epochs, mu: balancing parameter
save_dir = os.path.join('./output/fedir', dataset_name + '_c_{}_b_{}_le_{}_mu_{}_l_{}_sim_{}_opt_{}_lr_{}'.format(
    num_clients,
    beta, 
    num_local_epochs, 
    fedir_args['mu'],
    fedir_args['loss'], 
    fedir_args['scaling'],
    optim,
    optim_args['lr']
))
os.makedirs(save_dir, exist_ok=True)

print(data_config)
print(train_config)
print(fedir_args)

In [None]:
# def get_linear_mu():
#     num_schedule_rounds = num_rounds - flat_rounds * 2
#     mu_schedule = [ini_mu] * flat_rounds
#     mu_schedule.extend(np.linspace(ini_mu, max_mu, num_schedule_rounds))
#     mu_schedule.extend([max_mu] * flat_rounds)
#     return mu_schedule

# def get_linear_mu():
#     num_schedule_rounds = num_rounds - (start_linear + end_linear)
#     mu_schedule = [ini_mu] * start_linear
#     mu_schedule.extend(np.linspace(ini_mu, max_mu, num_schedule_rounds))
#     mu_schedule.extend([max_mu] * end_linear)
#     return mu_schedule

# mu_schedule = get_linear_mu()
# plot_config = utils.ACC_PLOT_CONFIG.copy()
# plot_config['figsize'] = (12, 6)
# plot_config['save_dir'] = os.path.join(save_dir, 'mu.png')
# plot_config['show_img'] = True
# plot_config['xlabel'] = 'rounds'
# plot_config['ylabel'] = 'mu'
# plot_config['labels'] = ['mu']
# data_list = [
#     mu_schedule
# ]
# utils.save_plot(data_list, plot_config)

In [None]:
# Data.
data_dir = os.path.join('../datasets/', dataset_name)

# Centralized testset for global model evaluation.
test_dir = os.path.join(data_dir, 'test')
test_data = ImgDataset(test_dir, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4)
# test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True, pin_memory=True)

num_classes = len(test_data.classes)

In [None]:
# Define client data.
client_loaders = []
for client_idx in range(num_clients):
    data_dir = os.path.join(client_data_dir, str(client_idx))
    dataset = ImgDataset(data_dir, transform=train_transform)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    # data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True, pin_memory=True)
    client_loaders.append(data_loader)

In [None]:
client_loader = client_loaders[0]

# Print a few images.
dataiter = iter(client_loader)
images, labels = dataiter.next()
utils.show_img_tensor(torchvision.utils.make_grid(images[:32]))

In [None]:
# Global model.
if dataset_name == 'cifar_10' or dataset_name == 'svhn' or dataset_name == 'fashion_mnist':
    glob_model = model.cnn(num_classes=num_classes)
    fedir_args['inter_layers'] = ['conv1', 'conv2', 'conv3', 'fc1', 'fc2']                                # Which intermediate representions to use.
    # fedir_args['inter_layers'] = ['conv1', 'conv2', 'fc']
    # fedir_args['inter_layers'] = ['conv1', 'conv2', 'conv3', 'conv4', 'fc1', 'fc2']
elif dataset_name == 'cifar_100' or dataset_name == 'tiny_imagenet_200':
    glob_model = model.resnet20(num_classes=num_classes, image_size=img_size)
    fedir_args['inter_layers'] = ['conv1', 'block1', 'block2', 'block3']

glob_model.to(device)
glob_w = glob_model.state_dict()

# Initialize prev_w for each client.
prev_w_dict = dict()
for client_idx in range(num_clients):
    prev_w_dict[client_idx] = copy.deepcopy(glob_w)

# For logging model performance.
performance_dict, performance_log = dict(), dict()
metric_keys = ['g_train_loss', 'g_train_acc', 'g_test_loss', 'g_test_acc']
_, performance_log = utils.get_performance_loggers(metric_keys)

In [None]:
# Automatic resuming from checkpoint.
log_path = os.path.join(save_dir, 'performance_log.pickle')
if os.path.isfile(log_path):
    performance_log = utils.load_pickle(log_path)
start_round = len(performance_log[metric_keys[0]])

# Reload global and previous local models.
if start_round > 0:
    glob_model.load_state_dict(torch.load(os.path.join(save_dir, 'g_r_{}.pth'.format(start_round))))
    glob_w = glob_model.state_dict()
    prev_w_dict = torch.load(os.path.join(save_dir, 'prev_w.pth'))
    for client_idx in range(num_clients):
        prev_w_dict[client_idx] = copy.deepcopy(glob_w)

In [None]:
# mu_schedule = get_linear_mu()

# Training.
for round_no in range(start_round, num_rounds):
    utils.print_separator(text='Round: {} / {}'.format(round_no + 1, num_rounds))
    
    # Evaluate the global model.
    test_loss, test_acc = evaluate_model(glob_model, test_loader)
    performance_log['g_test_loss'].append(test_loss)
    performance_log['g_test_acc'].append(test_acc)
    
    # fedir_args['mu'], _ = get_mu(performance_log['g_train_acc'])
    # fedir_args['mu'] = mu_schedule[round_no]
    
    participating_clients = sorted(np.random.choice(client_idxes, size=num_participation, replace=False))
    print('participating_clients:', participating_clients)
    
    # Local training.
    client_updates = dict()
    for client_idx in participating_clients:
        print('client:', client_idx)
        client_loader = client_loaders[client_idx]
        client_update = fl.local_update_fedir(glob_model, prev_w_dict[client_idx], client_loader, num_local_epochs, optim, optim_args, fedir_args)
        for key in ['local_w', 'num_samples', 'train_loss', 'train_acc']:
            client_updates.setdefault(key, list()).append(client_update[key])
        prev_w_dict[client_idx] = copy.deepcopy(client_update['local_w'])
    
    # Model aggregation.
    glob_w = fl.weighted_averaging(client_updates['local_w'], client_updates['num_samples'])
    glob_model.load_state_dict(glob_w)
    
    # Average local performance.
    performance_log['g_train_loss'].append(sum(client_updates['train_loss'])/len(client_updates['train_loss']))
    performance_log['g_train_acc'].append(sum(client_updates['train_acc'])/len(client_updates['train_acc']))
    
    # Save global model.
    if (round_no + 1) % save_interval == 0:
        torch.save(glob_model.state_dict(), os.path.join(save_dir, 'g_r_{}.pth'.format(round_no + 1)))
        torch.save(prev_w_dict, os.path.join(save_dir, 'prev_w.pth'))
        utils.save_pickle(log_path, performance_log)
    
    for key in sorted(metric_keys):
        print(key, ': ',  performance_log[key][-1])
    print('mu', fedir_args['mu'])

In [None]:
# Evaluate the final global model on clients.
c_loss_list, c_acc_list = [], []
for client_idx in range(num_clients):
    client_loader = client_loaders[client_idx]
    train_loss, train_acc = fl.evaluate_model(glob_model, client_loader, tqdm_desc='client {}/{}'.format(client_idx, num_clients))
    c_loss_list.append(train_loss)
    c_acc_list.append(train_acc)

performance_log['final_g_train_loss'] = sum(c_loss_list) / len(c_loss_list)
performance_log['final_g_train_acc'] = sum(c_acc_list) / len(c_acc_list)
    
# Evaluate the global model.
test_loss, test_acc = evaluate_model(glob_model, test_loader)
performance_log['final_g_test_loss'] = test_loss
performance_log['final_g_test_acc'] = test_acc

utils.save_pickle(log_path, performance_log)

In [None]:
# Plot training history.
performance_log = utils.load_pickle(log_path)

loss_plot_config = utils.LOSS_PLOT_CONFIG.copy()
loss_plot_config['figsize'] = (12, 6)
loss_plot_config['save_dir'] = os.path.join(save_dir, 'loss.png')
loss_plot_config['show_img'] = True
loss_plot_config['xlabel'] = 'rounds'
loss_plot_config['labels'] = ['g_train_loss', 'g_test_loss']
data_list = [
    performance_log['g_train_loss'] + [performance_log['final_g_train_loss']],
    performance_log['g_test_loss'] + [performance_log['final_g_test_loss']]
]
utils.save_plot(data_list, loss_plot_config)

acc_plot_config = utils.ACC_PLOT_CONFIG.copy()
acc_plot_config['figsize'] = (12, 6)
acc_plot_config['save_dir'] = os.path.join(save_dir, 'accuracy.png')
acc_plot_config['show_img'] = True
acc_plot_config['xlabel'] = 'rounds'
acc_plot_config['labels'] = ['g_train_acc', 'g_test_acc']
data_list = [
    performance_log['g_train_acc'] + [performance_log['final_g_train_acc']],
    performance_log['g_test_acc'] + [performance_log['final_g_test_acc']]
]
utils.save_plot(data_list, acc_plot_config)

In [None]:
# Reload saved global model.
if dataset_name == 'cifar_10' or dataset_name == 'svhn' or dataset_name == 'fashion_mnist':
    glob_model = model.cnn(num_classes=num_classes)
elif dataset_name == 'cifar_100' or dataset_name == 'tiny_imagenet_200':
    glob_model = model.resnet20(num_classes=num_classes, image_size=img_size)
    
glob_model.to(device)
glob_model.load_state_dict(torch.load(os.path.join(save_dir, 'g_r_{}.pth'.format(num_rounds))))

# Evaluate the global model.
test_loss, test_acc = evaluate_model(glob_model, test_loader)
print(test_loss, test_acc)

In [None]:
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")