# Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Lib Imports

In [3]:
import os
import time
import torch

# Path for FedHeal Repo

In [4]:
# project_path = '/content/drive/My Drive/FedHEAL_SAS' # Zaeem's Folder path
project_path = '/content/drive/My Drive/Colab Notebooks/Adv ML/Project/FedHEAL_SAS' # Fayzan's Folder path
os.chdir(project_path)

# Timer Function

In [5]:
def calculate_time(start_time, end_time):
    elapsed_time = end_time - start_time
    hours = int(elapsed_time // 3600)
    minutes = int((elapsed_time % 3600) // 60)
    seconds = elapsed_time % 60
    print(f"Total Execution Time: {hours}h {minutes}m {seconds:.2f}s")

# Configuration Variables

In [6]:
# Number of communication rounds in federated learning (global aggregation cycles)
communication_epoch_var = 5

# Number of local training epochs for each participant before global aggregation
local_epoch_var = 10

# Total number of participants in the federated learning setup
no_of_total_clients = 6

# Number of clients using MNIST dataset
no_of_mnist_clients = 2

# Number of clients using USPS dataset
no_of_usps_clients = 2

# Number of clients using SVHN dataset
no_of_svhn_clients = 2

# Number of clients using SYN dataset
no_of_syn_clients = 0

# Random seed for reproducibility of experiments
random_seed = 42

# Option for applying learning rate decay (0 = No decay, 1 = Apply decay)
learning_decay = 0

# Threshold value for HEAL (Hyperparameter for filtering updates) Default is 0.3
threshold = 0.3

# Momentum update factor (beta) used in HEAL Default is 0.4
beta = 0.4

# Averaging strategy for federated learning ("weight" or "equal")
averaging = 'weight'

# **FedHeal**

In [7]:
model_name = 'fedavgheal'

# Start timer
start_time = time.time()
print("Start Time:", start_time)

# Execute main training script with arguments passed as variables
!python main.py --device_id 0 \
--communication_epoch {communication_epoch_var} \
--local_epoch {local_epoch_var} \
--syn {no_of_syn_clients} \
--parti_num {no_of_total_clients} \
--mnist {no_of_mnist_clients} \
--usps {no_of_usps_clients} \
--svhn {no_of_svhn_clients} \
--seed {random_seed} \
--learning_decay {learning_decay} \
--threshold {threshold} \
--beta {beta} \
--averaging {averaging} \
--model {model_name}

# End timer
end_time = time.time()
print("End Time:", end_time)
calculate_time(start_time, end_time)

Start Time: 1734898624.6289375
digits
fl_digits
officecaltech
fl_officecaltech
fedavgheal_6_fl_digits_5_10
Counter({'mnist': 2, 'svhn': 2, 'usps': 2})
['mnist' 'mnist' 'svhn' 'usps' 'svhn' 'usps']
Downloading datasets for domains:  ['mnist' 'mnist' 'svhn' 'usps' 'svhn' 'usps']
Using downloaded and verified file: ./data0/train_32x32.mat
Using downloaded and verified file: ./data0/train_32x32.mat
DOMAINS_LIST:  ['mnist', 'usps', 'svhn', 'syn']
Using downloaded and verified file: ./data0/test_32x32.mat
Using downloaded and verified file: ./data0/test_32x32.mat
0
online clients:  [3, 0, 1, 2, 5, 4]
Local Pariticipant 3 loss = 1.828: 100% 10/10 [00:02<00:00,  4.88it/s]
Local Pariticipant 0 loss = 1.302: 100% 10/10 [00:05<00:00,  1.82it/s]
Local Pariticipant 1 loss = 1.618: 100% 10/10 [00:04<00:00,  2.01it/s]
Local Pariticipant 2 loss = 2.222: 100% 10/10 [00:06<00:00,  1.56it/s]
Local Pariticipant 5 loss = 1.801: 100% 10/10 [00:00<00:00, 16.52it/s]
Local Pariticipant 4 loss = 2.139: 100% 10/

# **FedAvg**

In [8]:
model_name = 'fedavg'

# Start timer
start_time = time.time()
print("Start Time:", start_time)

# Execute main training script with arguments passed as variables
!python main.py --device_id 0 \
--communication_epoch {communication_epoch_var} \
--local_epoch {local_epoch_var} \
--syn {no_of_syn_clients} \
--parti_num {no_of_total_clients} \
--mnist {no_of_mnist_clients} \
--usps {no_of_usps_clients} \
--svhn {no_of_svhn_clients} \
--seed {random_seed} \
--learning_decay {learning_decay} \
--threshold {threshold} \
--beta {beta} \
--averaging {averaging} \
--model {model_name}

# End timer
end_time = time.time()
print("End Time:", end_time)
calculate_time(start_time, end_time)

Start Time: 1734898961.1919672
digits
fl_digits
officecaltech
fl_officecaltech
fedavg_6_fl_digits_5_10
Counter({'mnist': 2, 'svhn': 2, 'usps': 2})
['mnist' 'mnist' 'svhn' 'usps' 'svhn' 'usps']
Downloading datasets for domains:  ['mnist' 'mnist' 'svhn' 'usps' 'svhn' 'usps']
Using downloaded and verified file: ./data0/train_32x32.mat
Using downloaded and verified file: ./data0/train_32x32.mat
DOMAINS_LIST:  ['mnist', 'usps', 'svhn', 'syn']
Using downloaded and verified file: ./data0/test_32x32.mat
Using downloaded and verified file: ./data0/test_32x32.mat
0
[3, 0, 1, 2, 5, 4]
Local Pariticipant 3 loss = 1.841: 100% 10/10 [00:01<00:00,  9.36it/s]
Local Pariticipant 0 loss = 1.316: 100% 10/10 [00:05<00:00,  1.80it/s]
Local Pariticipant 1 loss = 1.621: 100% 10/10 [00:04<00:00,  2.02it/s]
Local Pariticipant 2 loss = 2.221: 100% 10/10 [00:06<00:00,  1.60it/s]
Local Pariticipant 5 loss = 1.813: 100% 10/10 [00:00<00:00, 12.84it/s]
Local Pariticipant 4 loss = 2.125: 100% 10/10 [00:06<00:00,  1.6