In [1]:
import pickle
import copy
import torch
from client import Client
from models import SLC, MLP
from dataset import FedDataset, get_data
import numpy as np
import os
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from math import ceil
import pandas as pd
torch.manual_seed(0)
np.random.seed(0)


In [2]:
rounds = 100
test_freq = 1
local_epochs = 5

# Determine hardware availability
if torch.cuda.is_available():
    device = "cuda"  # NVIDIA GPU
elif torch.backends.mps.is_available():
    device = "mps"  # Apple GPU
else:
    device = "cpu"  # Defaults to CPU if NVIDIA GPU/Apple GPU aren't available

print(f"device: {device}")

# Test parameters
temporal_len = 10

transient_dim = 4
output_dim = 13
hidden_dims = [32]

batch_size = 32

n = 4
learning_rate = 1e-2
momentum = 0
optimizer = "SGD"

alpha = 1
alpha_per_modality = False

device: cuda


In [3]:
test_series_name = "test_z"
main_folder = "test_results"


In [4]:
nClients = 4

In [5]:
# all_mod = {
#     "accChest": [1, 2, 3],
#     "EKG": [4, 5],
#     "accLa": [6, 7, 8],
#     "gyroLa": [9, 10, 11],
#     "magLa": [12, 13, 14],
#     "accRa": [15, 16, 17],
#     "gyroRa": [18, 19, 20],
#     "magRa": [21, 22, 23]
#  }


In [6]:
all_mod = {
    "all": [
        1,
        2,
        3,
        6,
        7,
        8,
        15,
        16,
        17,
        9,
        10,
        11,
        18,
        19,
        20,
        12,
        13,
        14,
        21,
        22,
        23,
    ],
    "acc": [1, 2, 3, 6, 7, 8, 15, 16, 17],
    "gyro": [9, 10, 11, 18, 19, 20],
    "mag": [12, 13, 14, 21, 22, 23],
}

In [7]:
modalities = [
    {"gyro": [9, 10, 11, 18, 19, 20],
     "acc": [1, 2, 3, 6, 7, 8, 15, 16, 17]},
    {"gyro": [9, 10, 11, 18, 19, 20],
     "mag": [12, 13, 14, 21, 22, 23],},
    {"gyro": [9, 10, 11, 18, 19, 20]},
    {"gyro": [9, 10, 11, 18, 19, 20]},
]

In [8]:
# Mode
federatedGlob = True
federatedLoc = False
lg_frac = 0

# result lists
train_performance = None
test_performance = {i: None for i in range(n)}

In [9]:
# clients = []
# if federatedLoc:
#     uni_loc = SLC(all_mod, hidden_dims, transient_dim, False)
# uni_glob = MLP(transient_dim, output_dim)
# # Generate clients
# for i in range(n):
#     local_mod = SLC(modalities[i], hidden_dims, transient_dim, False)
#     glob_mod = MLP(transient_dim, output_dim)

#     # if federatedLoc:
#     #     s_dict = {}
#     #     local_dict = uni_loc.state_dict()
#     #     for k in local_mod.state_dict():
#     #         s_dict[k] = copy.deepcopy(local_dict[k])
#     #     local_mod.load_state_dict(s_dict)
#     # if federatedGlob:
#     #     s_dict = {}
#     #     global_dict = uni_glob.state_dict()
#     #     for k in glob_mod.state_dict():
#     #         s_dict[k] = copy.deepcopy(global_dict[k])
#     #     glob_mod.load_state_dict(s_dict)

#     clients.append(
#         Client(
#             glob_mod,
#             local_mod,
#             local_epochs,
#             learning_rate=learning_rate,
#             optimizer=optimizer,
#             device=device,
#             momentum=momentum,
#         )
#     )

In [10]:


data_train, data_test = get_data("data/data_all.csv", 4, False)


In [11]:
clients = []

# Populate clients
for i in range(nClients):
    clients.append(
        Client(
            MLP(transient_dim, output_dim),
            SLC(modalities[i], hidden_dims, transient_dim, False),
            DataLoader(FedDataset(data_train[i], device),
                       batch_size=32),
            DataLoader(FedDataset(data_test[i], device),
                       batch_size=32),
            local_epochs,
            learning_rate,
            "Adam",
            device=device
        )
    )


In [12]:
last_entry = 0

for round in range(rounds):
    # Global params for FL
    w_glob_tmp = None

    # Local params for FL
    w_loc_tmp = None

    # Count of encounters of each param
    w_loc_tmp_count = None
    if round > (1 - lg_frac) * rounds:
        federatedLoc = False
    print(f"Round {round}")
    
    for client in range(n):
        w_glob_ret, w_local_ret, _ = clients[client].train()
        performance = clients[client].test()
        if federatedGlob:
            if w_glob_tmp is None:
                w_glob_tmp = copy.deepcopy(w_glob_ret)
            else:
                for k in w_glob_ret:
                    w_glob_tmp[k] += w_glob_ret[k]

        if federatedLoc:
            if alpha_per_modality:
                factor = (
                    1 if len(w_local_ret) / 8 == 1 else len(w_local_ret) / 8 * alpha
                )
            else:
                factor = 1 if len(w_local_ret) / 8 == 1 else alpha

            if w_loc_tmp is None:
                w_loc_tmp = {}
                w_loc_tmp_count = {}
            for k in w_local_ret.keys():
                if k not in w_loc_tmp:
                    w_loc_tmp[k] = factor * w_local_ret[k]
                    w_loc_tmp_count[k] = factor
                else:
                    w_loc_tmp[k] += factor * w_local_ret[k]
                    w_loc_tmp_count[k] += factor
        print(f"Dev {client}-loss:{performance}")
    print("-------------------")
    

    # get weighted average for global weights
    if federatedGlob:
        for k in w_glob_tmp.keys():
            w_glob_tmp[k] = torch.div(w_glob_tmp[k], n)
    if federatedLoc:
        for k in w_loc_tmp.keys():
            w_loc_tmp[k] = torch.div(w_loc_tmp[k], w_loc_tmp_count[k])

    # copy weights to each client based on mode
    if federatedGlob or federatedLoc:
        for client in range(n):
            clients[client].load_params(w_glob_tmp, w_loc_tmp)

    # plt.clf()
    # plt.plot(train_performance, label=[1, 2, 3, 4])
    # plt.legend()

    # print(modalities)

Round 0
Dev 0-loss:0.4756944444444444
Dev 1-loss:0.2326388888888889
Dev 2-loss:0.09375
Dev 3-loss:0.22569444444444445
-------------------
Round 1
Dev 0-loss:0.3958333333333333
Dev 1-loss:0.2326388888888889
Dev 2-loss:0.010416666666666666
Dev 3-loss:0.13194444444444445
-------------------
Round 2
Dev 0-loss:0.4826388888888889
Dev 1-loss:0.2326388888888889
Dev 2-loss:0.07291666666666667
Dev 3-loss:0.23958333333333334
-------------------
Round 3
Dev 0-loss:0.4722222222222222
Dev 1-loss:0.2326388888888889
Dev 2-loss:0.0
Dev 3-loss:0.3055555555555556
-------------------
Round 4
Dev 0-loss:0.4895833333333333
Dev 1-loss:0.2326388888888889
Dev 2-loss:0.006944444444444444
Dev 3-loss:0.3368055555555556
-------------------
Round 5
Dev 0-loss:0.5868055555555556
Dev 1-loss:0.2326388888888889
Dev 2-loss:0.027777777777777776
Dev 3-loss:0.3888888888888889
-------------------
Round 6
Dev 0-loss:0.5763888888888888
Dev 1-loss:0.2361111111111111
Dev 2-loss:0.2673611111111111
Dev 3-loss:0.3923611111111111


In [13]:
# save_path = main_folder + "/" + test_series_name + "/test_015"

# if os.path.isdir(save_path) is False:
#     os.mkdir(
#         save_path,
#     )
#     torch.save(test_performance, f"{save_path}/test_data")
#     info_dict = {
#         "Num of clients": n,
#         "Learning rate": learning_rate,
#         "Federated Global": federatedGlob,
#         "Federated Local": federatedLoc,
#         "Batch Size": batch_size,
#         "Global rounds": rounds,
#         "Local epochs": local_epochs,
#         "Clients": clients,
#         "Modalities": modalities,
#         "Optimizer": optimizer,
#         "Temporal length": temporal_len,
#         "Transient dimension": transient_dim,
#         "Hidden dimensions": hidden_dims,
#         "alpha": alpha,
#         "alpha_per_modality": alpha_per_modality,
#     }
#     with open(f"{save_path}/test_info.txt", "w") as f:
#         f.write(info_dict.__repr__())

#     for i in range(n):
#         torch.save(clients[i].model.state_dict(), f"{save_path}/dev{i}_model")
# else:
#     print("test exists")
