# Imports

In [None]:
import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
import torch
from tensorboardX import SummaryWriter
from dataclasses import dataclass
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from options import args_parser
from update import test_inference, LocalUpdate, LocalUpdate_PFL
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
from utils import get_dataset, average_weights, exp_details

# Setup & Functions

In [None]:
@dataclass
class Args:
    """
      This class implements C-like structures to hold the arguments for the simulations
      instead of passing them as bash commands.
    """
    pfl: int = 1 # pfl/tfl (0=TFL, 1=PFL)
    comm_rounds: int = 10 # number of rounds of training
    num_users: int = 100 # number of users: K
    frac: float = 0.1 # the fraction of clients: C
    local_ep: int = 10 # the number of local epochs: E
    local_bs: int = 10 # local batch size: B
    lr: float = 0.01 # learning rate
    momentum: float = 0.5 # SGD momentum (default: 0.5)
    # model arguments
    model: str = 'mlp' # model name
    kernel_num: int = 9 # number of each kind of kernel
    kernel_sizes: str = '3,4,5' # comma-separated kernel size to use for convolution
    num_channels: int = 1 # number of channels of imgs
    norm: str = 'batch_norm' # batch_norm, layer_norm, or None
    num_filters: int = 32 # number of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot.
    max_pool: str = 'True' # Whether use max pooling rather than strided convolutions
    # other arguments
    dataset: str = 'mnist' #name of dataset
    num_classes: int = 10 #
    gpu: int = None
    optimizer: str = 'sgd'
    iid: int = 1
    unequal: int = 0
    stopping_rounds: int = 10
    verbose: int = 1
    seed: int = 1

In [None]:
def sum_weights(model1, model2):
    res = copy.deepcopy(model1)
    for layer1, layer2 in zip(model1, model2):
        l1 = np.array(model1[layer1])
        l2 = np.array(model2[layer2])
        res[layer1] = np.add(l1, l2)
    return res

def multiply_weights(model, w):
    res = copy.deepcopy(model)
    for layer in model:
        res[layer] = torch.tensor(w*np.array(model[layer]))
    return res

def plot_distrbutions(dataset, n_clients, max_cols=5):
    bins = [x for x in range(10)] #for categorical data with 10
    cols, rows = min(n_clients,max_cols), max(1, round(n_clients/max_cols))
    figure = plt.figure(figsize=(15*rows, 2*rows))
    print("rows= {}, cols={} ".format(rows, cols))
    for i in range(n_clients):
        idxs = user_groups[i]
        ys = [dataset[int(idx)][1] for idx in idxs]
        figure.add_subplot(rows, cols, i+1)
        plt.title("client {}".format(i))
        #plt.axis("off")
        plt.hist(ys, 10,range=(0,10));
    plt.show()
    
def plot_data(dataset):
    figure = plt.figure(figsize=(10, 10))
    cols, rows = 9, 1
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(train_dataset), size=(1,)).item()
        img, label = dataset[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.title(label)
        plt.axis("off")
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()
#plot_data()

def plot_history(history_dict, comm_rounds, metric="loss"):
    figure = plt.figure(figsize=(6, 2))
    for idx in range(len(history_dict.keys())):
        plt.plot(range(1, comm_rounds+1), history_dict[idx][metric], label="user ".format(idx))
    plt.title("{} per communication rounds".format(metric))
    plt.xlabel("communication round")
    plt.ylabel(metric)
    plt.show()

In [None]:
def build_model(args, train_dataset=None):
    if args.model=='cnn': # Convolutional neural network
        if args.dataset == 'mnist':
            model = CNNMnist(args=args)
        elif args.dataset =='cifar':
            model = CNNCifar(args=args)
        elif args.dataset == 'fmnist':
            model = CNNFashion_Mnist(args=args)
        return model
    elif args.model == 'mlp':
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
            return model.print()
        return
    else:
        exit('Error: unrecognized model')

In [None]:
def run(args, global_model, train_dataset, test_dataset, user_groups, ratio=0.3, print_every=5):
    """
      This functions implements the full FL scenario. It fetches the dataset, creats the model,
      assign data points to users, train the local devices, get the global model and compute the 
      overall training and testing loss and accuracies. 
      @ Args:
        - args: the simulation parameters
      @ Returns:
        - x:
     """
    start_time = time.time()
    # .. define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')
    
    if args.gpu:
        torch.cuda.set_device(args.gpu)
    device = 'cuda' if args.gpu else 'cpu'
    
    # build model
    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    
    # copy weights
    global_weights = global_model.state_dict()

    # Histories for plotting..
    users_histories = {idx:{"loss":[], "accuracy":[]} for idx in range(args.num_users)}
    
    # news
    USER_POINTS = [len(user_groups[idx]) for idx in range(args.num_users)]
    TOTAL_POINTS = sum(USER_POINTS)
    
    if args.pfl == 1:
        local_users = {idx:LocalUpdate_PFL(args=args, id=idx, dataset=train_dataset, idxs=user_groups[idx], ratio=ratio,\
                                     logger=logger) for idx in range(args.num_users)}
    else:
        local_users = {idx:LocalUpdate(args=args, id=idx, dataset=train_dataset, idxs=user_groups[idx], logger=logger)\
                       for idx in range(args.num_users)}
        
    # Training
    train_loss, train_accuracy = [], []
    for comm_round in tqdm(range(args.comm_rounds)):
        # init local weights and loss
        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {comm_round+1} |\n')
        
        global_model.train()
        # sample a fraction of users (with args frac)
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        
        for idx in range(args.num_users): # to be returned to idxs_users
            user = local_users[idx]
            w, loss = user.update_weights(model=copy.deepcopy(global_model), global_round=comm_round)
            w = multiply_weights(w, USER_POINTS[idx]/TOTAL_POINTS)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))
            users_histories[idx]["loss"].append(loss) #for plotting
        
        # update global weights
        global_weights = average_weights(local_weights)
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)
        
        # Calculate avg training accuracy over all users at every comm_round
        list_acc = []
        global_model.eval()
        for idx in range(args.num_users):
            user = local_users[idx]
            acc, loss = user.inference(model=global_model)
            list_acc.append(acc)
            users_histories[idx]["accuracy"].append(acc) #for plotting
        print("train accs: {}".format(list_acc))
        train_accuracy.append(sum(list_acc)/len(list_acc))
        
        # print global training loss after every 'i' rounds
        if (comm_round+1) % print_every == 0:
            print(f' \nAvg Training Stats after {comm_round+1} global rounds:')
            if args.pfl:
                print('Last Average Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
            else:
                print(f'Training Loss : {np.mean(np.array(train_loss))}')
                print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
        
    # Test inference after completion of training
    test_accs = []
    for idx in range(args.num_users):
        user = local_users[idx]
        acc, loss = user.inference(model=global_model, type="test")
        test_accs.append("{}%".format(round(100*acc,2)))
       
    print(f' \n Results after {args.comm_rounds} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracies: {}".format(test_accs))
    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
    
    return users_histories

# Experiments

## Ex 1. MNIST | IID | 5 users | 10 comm rounds

In [None]:
# Choose the arguments
args = Args(pfl=0, model='cnn', dataset='mnist', gpu=0, iid=1, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)

In [None]:
# Intialize the global model
org_model = build_model(args)

### Data

In [None]:
# Get the dataset
train_dataset, test_dataset, user_groups = get_dataset(args)
# Plot the data
plot_distrbutions(dataset=train_dataset, n_clients=args.num_users)

### TFL

In [None]:
# Run the FL operation
users_histories_tfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="accuracy")

### PFL

In [None]:
# modify the arguments, only (pfl)
args = Args(pfl=1, model='cnn', dataset='mnist', gpu=0, iid=1, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)

In [None]:
# Run the FL operation
users_histories_pfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="accuracy")

## Ex 2. MNIST | nonIID | 5 users | 10 comm rounds

In [None]:
# Choose the arguments
args = Args(pfl=0, model='cnn', dataset='mnist', gpu=0, iid=0, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)

### Data

In [None]:
# Get the dataset
train_dataset, test_dataset, user_groups = get_dataset(args)
# Plot the data
plot_distrbutions(dataset=train_dataset, n_clients=args.num_users)

### TFL

In [None]:
# Run the FL operation
users_histories_tfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="accuracy")

### PFL

In [None]:
# modify the arguments, only (pfl)
args = Args(pfl=1, model='cnn', dataset='mnist', gpu=0, iid=0, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)
# Run the FL operation
users_histories_pfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="accuracy")

## Ex 3. MNIST | p-nonIID | 5 users | 10 comm rounds

In [None]:
# Choose the arguments
args = Args(pfl=0, model='cnn', dataset='mnist', gpu=0, iid=0, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)

### Data

In [None]:
# Get the dataset
train_dataset, test_dataset, user_groups = get_dataset(args, nonequal=True)
# Plot the data
plot_distrbutions(dataset=train_dataset, n_clients=args.num_users)

### TFL

In [None]:
# Run the FL operation
users_histories_tfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="accuracy")

### PFL

In [None]:
# modify the arguments, only (pfl)
args = Args(pfl=1, model='cnn', dataset='mnist', gpu=0, iid=0, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)
# Run the FL operation
users_histories_pfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="accuracy")

## Ex 4. CIFAR10 | nonIID | 5 users | 10 comm rounds

In [None]:
# Choose the arguments
args = Args(pfl=0, model='cnn', dataset='cifar', gpu=0, iid=0, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)
# Intialize the global model
org_model = build_model(args)

### Data

In [None]:
# Get the dataset
train_dataset, test_dataset, user_groups = get_dataset(args)
# Plot the data
plot_distrbutions(dataset=train_dataset, n_clients=args.num_users)

### TFL

In [None]:
# Run the FL operation
users_histories_tfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="accuracy")

### PFL

In [None]:
# modify the arguments, only (pfl)
args = Args(pfl=1, model='cnn', dataset='cifar', gpu=0, iid=0, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)
# Run the FL operation
users_histories_pfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="accuracy")

## Ex 5. CIFAR10 | p-nonIID | 5 users | 10 comm rounds

In [None]:
# Choose the arguments
args = Args(pfl=0, model='cnn', dataset='cifar', gpu=0, iid=0, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)

### Data

In [None]:
# Get the dataset
train_dataset, test_dataset, user_groups = get_dataset(args, nonequal=True)
# Plot the data
plot_distrbutions(dataset=train_dataset, n_clients=args.num_users)

### TFL

In [None]:
# Run the FL operation
users_histories_tfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="accuracy")

### PFL

In [None]:
# modify the arguments, only (pfl)
args = Args(pfl=1, model='cnn', dataset='cifar', gpu=0, iid=0, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)
# Run the FL operation
users_histories_pfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="accuracy")

## Ex 6. CIFAR10 | p-nonIID | 10 users | 15 comm rounds

In [None]:
# Choose the arguments
args = Args(pfl=0, model='cnn', dataset='cifar', gpu=0, iid=0, comm_rounds=15, num_users=10, frac=1, local_ep=3, verbose=0)

### Data

In [None]:
# Get the dataset
train_dataset, test_dataset, user_groups = get_dataset(args, nonequal=True)
# Plot the data
plot_distrbutions(dataset=train_dataset, n_clients=args.num_users)

### TFL

In [None]:
# Run the FL operation
users_histories_tfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_tfl, args.comm_rounds, metric="accuracy")

### PFL

In [None]:
# modify the arguments, only (pfl)
args = Args(pfl=1, model='cnn', dataset='cifar', gpu=0, iid=0, comm_rounds=10, num_users=5, frac=1, local_ep=3, verbose=0)
# Run the FL operation
users_histories_pfl = run(args, copy.deepcopy(org_model), train_dataset, test_dataset, user_groups, print_every=5)

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="loss")

In [None]:
plot_history(users_histories_pfl, args.comm_rounds, metric="accuracy")