In [None]:
import os
import pandas as pd
import torch
from torch import Tensor
from typing import Tuple, List, Callable
import random
import warnings
warnings.filterwarnings("ignore")

import pdb
import time

torch.cuda.is_available() 

In [None]:
gpu_ids = [0]

In [None]:
from run import *
from utils import *

In [None]:
### TIME_LIMITS and LRs are for simulatiom arguments 
TIME_LIMITS = {'mnist': 500, 'fashion mnist': 10000  , 'cifar 10': 500, 'celeba': 1500}
LRs         = {'mnist': 0.01, 'fashion mnist': 0.001, 'cifar 10': 0.01, 'celeba': 0.001}

In [None]:
dataset_name = 'fashion mnist'
lr = LRs[dataset_name]
time_limit = TIME_LIMITS[dataset_name]

In [None]:
### Here you should define different setups that you want to run and compare.
### Each config should be added to the "setups" dictionary. Config arguments are difined 
### by a dictionary from argument_name to value. 
 
setups = {}
 
group_count = 5
client_count = 20
max_local_steps = 5

server_interaction_time = 10
server_waiting_time = 0


# Baseline config, one client interacting with the server as if the model is getting trained on the server
method = "identity"
quantizer = {"method": method}
setups[f"Baseline"] = {'algorithm': "Fed_Avg", 
                                     'client count': 1,
                                     'local step': 1,
                                     'group count': 1,
                                     'quantizer': quantizer, 
                                     'time_limit': time_limit,
                                     'lr': lr,
                                     'sit': 0,
                                     'gpu_ids': gpu_ids }


# # # Fed-Avg config example
method = "identity"
quantizer = {"method": method}

local_steps = max_local_steps
setups[f"Fed-Avg ({client_count},{group_count},{local_steps}) sit: {server_interaction_time}"] = {'algorithm': "Fed_Avg", 
                                                                                         'client count': client_count,
                                                                                         'local step': local_steps,
                                                                                         'group count': group_count,
                                                                                         'quantizer': quantizer, 
                                                                                         'time_limit': time_limit,
                                                                                         'lr': lr,
                                                                                         'sit': server_interaction_time,
                                                                                         'gpu_ids': gpu_ids}

# # # ## QuAFL config example
method = "lattice"
quant_q = 14
server_interaction_time *= quant_q / 32
server_interaction_time = float("{0:.3f}".format(server_interaction_time))
quantizer = {"method": method, 'quant_q': quant_q, 'quant_s': 0.0001}#, 'quant_q': quant_q, 'quant_s': 0.001
# method = "qsgd"
# q_levels = 16
# server_interaction_time *= 4 / 32
# server_interaction_time = float("{0:.3f}".format(server_interaction_time))
# quantizer = {"method": method, 'k': q_levels}
setups[f"QuAFL   ({client_count},{group_count},{max_local_steps},{method}) swt: {server_waiting_time} sit: {server_interaction_time}"] = {'algorithm': "quantized_fl", 
                                                                                          'client count': client_count,
                                                                                          'local step': max_local_steps,
                                                                                          'group count': group_count,
                                                                                          'quantizer': quantizer, 
                                                                                          'time_limit': time_limit,
                                                                                          'lr': lr,
                                                                                          'swt': server_waiting_time,
                                                                                          'sit': server_interaction_time,
                                                                                          'gpu_ids': gpu_ids}




# # Fed-Buff config example
method = "identity"
quantizer = {"method": method}

local_steps = max_local_steps
setups[f"FedBuff ({client_count},{group_count},{local_steps}) sit: {server_interaction_time}"] = {'algorithm': "FedBuff", 
                                                                                         'client count': client_count,
                                                                                         'local step': local_steps,
                                                                                         'group count': group_count,
                                                                                         'quantizer': quantizer, 
                                                                                         'time_limit': time_limit,
                                                                                         'lr': lr,
                                                                                         'sit': server_interaction_time,
                                                                                         'gpu_ids': gpu_ids}




In [None]:
start = time.time()

log_period = 200 ## Simulation time difference between loggings.
logs, trainers = run(setups, dataset_name, log_period, count=client_count, decreasing=False, slow_client_ratio = 0.50)

end = time.time()
print(f"Finished in {end - start}")

In [None]:
x_axis, y_axis = "Local steps", "Loss"
plot_trends(logs, x_axis, y_axis, 0)

x_axis, y_axis = "Server steps", "Loss"
plot_trends(logs, x_axis, y_axis, 0)

x_axis, y_axis = "Time", "Loss"
plot_trends(logs, x_axis, y_axis, 0) 

x_axis, y_axis = "Aggregated local steps", "Loss"
plot_trends(logs, x_axis, y_axis, 0) 


In [None]:
x_axis, y_axis = "Local steps", "Accuracy"
plot_trends(logs, x_axis, y_axis)

x_axis, y_axis = "Server steps", "Accuracy"
plot_trends(logs, x_axis, y_axis)

x_axis, y_axis = "Time", "Accuracy"
plot_trends(logs, x_axis, y_axis)