### General Import

In [None]:
import torch
from torchvision import datasets as torch_datasets
from torchvision import transforms
import numpy as np

### File Import

In [None]:
from fl_strategy import run_FL
from configurations import EnvSettings, TaskSettings, init_FL_clients, init_global_model
from utils import normalize
from fl_dataset import FLDataLoader, get_FL_datasets
from fl_clients import generate_clients_perf, generate_clients_crash_prob, generate_crash_trace

In [None]:
crash_prob = 0.3 # probability of a client will crash
lag_tol = 5 # lag tolerance
pick_C = 1.0 # ratio of training clients per round
task2run = 'mnist'  # options={boston, mnist, cifar10}
task_mode = 'Semi-Async'

### Set Configuration

In [None]:
bw_set = (0.175, 1250) # (client throughput, bandwidth_server) in MB/s

if task2run == 'boston':
    ''' Boston housing regression settings (ms per epoch)'''
    env_cfg = EnvSettings(n_clients=5, n_rounds=100, n_epochs=3, batch_size=5, train_frac=0.7, shuffle=False, pick_frac=pick_C, benign_ratio=1.0, data_dist=('N', 0.3), perf_dist=('X', None), crash_dist=('E', crash_prob),
                            keep_best=True, dev='cpu', showplot=False, bw_set=bw_set, max_T=830)
    task_cfg = TaskSettings(task_type='Reg', dataset='Boston', path='data/boston_housing.csv', in_dim=12, out_dim=1, optimizer='SGD', loss='mse', lr=1e-4, lr_decay=1.0)
elif task2run == 'mnist':
    ''' MNIST digits classification task settings (3s per epoch on GPU)'''
    env_cfg = EnvSettings(n_clients=50, n_rounds=10, n_epochs=5, batch_size=40, train_frac=6.0/7.0, shuffle=False, pick_frac=pick_C, benign_ratio=0.6, data_dist=('N', 0.3), perf_dist=('X', None), crash_dist=('E', crash_prob),
                            keep_best=True, device='gpu', showplot=False, bw_set=bw_set, max_T=5600)
    task_cfg = TaskSettings(task_type='CNN', dataset='mnist', path='data/MNIST/', in_dim=None, out_dim=None, optimizer='SGD', loss='nllLoss', lr=1e-3, lr_decay=1.0)
elif task2run == 'cifar10':
    env_cfg = EnvSettings(n_clients=50, n_rounds=10, n_epochs=5, batch_size=20, train_frac=6.0/7.0, shuffle=False, pick_frac=pick_C, benign_ratio=0.6, data_dist=('E', None), perf_dist=('X', None), crash_dist=('E', crash_prob),
                            keep_best=True, device='gpu', showplot=False, bw_set=bw_set, max_T=5600)
    task_cfg = TaskSettings(task_type='ResNet', dataset='cifar10', path='data/cifar10/', in_dim=None, out_dim=None, optimizer='SGD', loss='nllLoss', lr=1e-2, lr_decay=5e-4)
else:
    print('[Err] Invalid task name provided. Options are {boston, mnist, cifar10}')
    exit(0)

env_cfg.mode = task_mode

### Load Dataset

In [None]:
if task_cfg.dataset == 'Boston':
    data = np.loadtxt(task_cfg.path, delimiter=',', skiprows=1)
    data = normalize(data)
    data_merged = True
elif task_cfg.dataset == 'mnist':
    mnist_train = torch_datasets.MNIST('data/mnist/', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
    mnist_test = torch_datasets.MNIST('data/mnist/', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
    train_x = mnist_train.data.view(-1, 1, 28, 28).float()
    train_y = mnist_train.targets.long()
    test_x = mnist_test.data.view(-1, 1, 28, 28).float()
    test_y = mnist_test.targets.long()

    train_data_size = len(train_x)
    test_data_size = len(test_x)
    data_size = train_data_size + test_data_size
    data_merged = False
elif task_cfg.dataset == 'cifar10':
    cifar10_train = torch_datasets.CIFAR10('data/cifar10/', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
    cifar10_test = torch_datasets.CIFAR10('data/cifar10/', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
    train_x = torch.tensor(cifar10_train.data).permute(0,3,1,2).view(-1,3,32,32).float()
    train_y = torch.tensor(cifar10_train.targets).long()
    test_x = torch.tensor(cifar10_test.data).permute(0,3,1,2).view(-1, 3, 32, 32).float()
    test_y = torch.tensor(cifar10_test.targets).long()

    train_data_size = len(train_x)
    test_data_size = len(test_x)
    data_size = train_data_size + test_data_size
    data_merged = False
else:
    print('E> Invalid dataset specified. Options are {Boston, mnist, cifar10}')
    exit(-1)

# Partition Data
if data_merged:
    data_size = len(data)
    train_data_size = int(data_size * env_cfg.train_frac)
    test_data_size = data_size - train_data_size
    data = torch.tensor(data).float()
    train_x = data[0:train_data_size, 0:task_cfg.in_dim]  # training data, x
    train_y = data[0:train_data_size, task_cfg.out_dim * -1:].reshape(-1, task_cfg.out_dim)  # training data, y
    test_x = data[train_data_size:, 0:task_cfg.in_dim]  # test data following, x
    test_y = data[train_data_size:, task_cfg.out_dim * -1:].reshape(-1, task_cfg.out_dim)  # test data, x

### Create Clients and Server

In [None]:
clients, cindexmap = init_FL_clients(env_cfg.n_clients)
fed_data_train, fed_data_test, client_shard_sizes = get_FL_datasets(train_x, train_y, test_x, test_y, env_cfg, task_cfg, clients)
fed_loader_train = FLDataLoader(fed_data_train, cindexmap, env_cfg.batch_size, env_cfg.shuffle)
fed_loader_test = FLDataLoader(fed_data_test, cindexmap, env_cfg.batch_size, env_cfg.shuffle)

clients_perf_vec = generate_clients_perf(env_cfg, from_file=True)

# Maximum waiting time for client response in a round setting
clients_est_round_T_train = np.array(client_shard_sizes) / env_cfg.batch_size * env_cfg.n_epochs / np.array(clients_perf_vec)
response_time_limit = env_cfg.max_T if env_cfg.max_T else max(clients_est_round_T_train) + 2 * task_cfg.model_size / bw_set[0]

clients_crash_prob_vec = generate_clients_crash_prob(env_cfg)
crash_trace, progress_trace = generate_crash_trace(env_cfg, clients_crash_prob_vec)

glob_model = init_global_model(env_cfg, task_cfg)

## Run Program

In [None]:
best_model, best_round, final_loss = run_FL(env_cfg, task_cfg, glob_model, cindexmap, data_size, fed_loader_train, fed_loader_test, client_shard_sizes, clients_perf_vec, 
                                                crash_trace, progress_trace,clients_est_round_T_train, response_time_limit, lag_tol)