In [None]:
import pickle
import copy
import torch
from client import Client
from models import SLC, MLP
from dataset import FedDataset, get_data
import numpy as np
import os
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from math import ceil
import pandas as pd
torch.manual_seed(0)
np.random.seed(0)


In [None]:
rounds = 100
test_freq = 1
local_epochs = 5

# Determine hardware availability
if torch.cuda.is_available():
    device = "cuda"  # NVIDIA GPU
elif torch.backends.mps.is_available():
    device = "mps"  # Apple GPU
else:
    device = "cpu"  # Defaults to CPU if NVIDIA GPU/Apple GPU aren't available

print(f"device: {device}")

# Test parameters
temporal_len = 10

transient_dim = 4
output_dim = 13
hidden_dims = [32]

batch_size = 32

n = 4
learning_rate = 1e-2
momentum = 0
optimizer = "SGD"

alpha = 1
alpha_per_modality = False

In [None]:
test_series_name = "test_z"
main_folder = "test_results"


In [None]:
# if not os.path.isdir(main_folder+'/'+test_series_name):
#     train_data, test_data, train_dict, test_dict = get_data(
#         "data/data_all.csv", n)
#     num_batch = ceil(max([len(i) for i in test_dict.values()])/batch_size)
#     os.mkdir(main_folder+'/'+test_series_name)
#     with open(main_folder+'/'+test_series_name+'/train_dict.pkl', 'wb') as f:
#         pickle.dump(train_dict, f)
#     with open(main_folder+'/'+test_series_name+'/test_dict.pkl', 'wb') as f:
#         pickle.dump(test_dict, f)
#     torch.save(train_data, main_folder+'/'+test_series_name+'/train_data')

#     torch.save(test_data, main_folder+'/'+test_series_name+'/test_data')
# else:
#     train_dict = {}
#     test_dict = {}
#     with open(main_folder+'/'+test_series_name+'/train_dict.pkl', "rb") as input_file:
#         train_dict = pickle.load(input_file)
#     with open(main_folder+'/'+test_series_name+'/test_dict.pkl', "rb") as input_file:
#         test_dict = pickle.load(input_file)

#     train_data = torch.load(main_folder+'/'+test_series_name+'/train_data')

#     test_data = torch.load(main_folder+'/'+test_series_name+'/test_data')

In [None]:
nClients = 4

In [None]:
# all_mod = {
#     "accChest": [1, 2, 3],
#     "EKG": [4, 5],
#     "accLa": [6, 7, 8],
#     "gyroLa": [9, 10, 11],
#     "magLa": [12, 13, 14],
#     "accRa": [15, 16, 17],
#     "gyroRa": [18, 19, 20],
#     "magRa": [21, 22, 23]
#  }


In [None]:
all_mod = {
    "all": [
        1,
        2,
        3,
        6,
        7,
        8,
        15,
        16,
        17,
        9,
        10,
        11,
        18,
        19,
        20,
        12,
        13,
        14,
        21,
        22,
        23,
    ],
    "acc": [1, 2, 3, 6, 7, 8, 15, 16, 17],
    "gyro": [9, 10, 11, 18, 19, 20],
    "mag": [12, 13, 14, 21, 22, 23],
}

In [None]:
modalities = [
    {"gyro": [9, 10, 11, 18, 19, 20]},
    {"gyro": [9, 10, 11, 18, 19, 20]},
    {"gyro": [9, 10, 11, 18, 19, 20]},
    {"gyro": [9, 10, 11, 18, 19, 20]},
]

In [None]:
# Mode
federatedGlob = True
federatedLoc = True
lg_frac = 0

# result lists
train_performance = None
test_performance = {i: None for i in range(n)}

In [None]:
# clients = []
# if federatedLoc:
#     uni_loc = SLC(all_mod, hidden_dims, transient_dim, False)
# uni_glob = MLP(transient_dim, output_dim)
# # Generate clients
# for i in range(n):
#     local_mod = SLC(modalities[i], hidden_dims, transient_dim, False)
#     glob_mod = MLP(transient_dim, output_dim)

#     # if federatedLoc:
#     #     s_dict = {}
#     #     local_dict = uni_loc.state_dict()
#     #     for k in local_mod.state_dict():
#     #         s_dict[k] = copy.deepcopy(local_dict[k])
#     #     local_mod.load_state_dict(s_dict)
#     # if federatedGlob:
#     #     s_dict = {}
#     #     global_dict = uni_glob.state_dict()
#     #     for k in glob_mod.state_dict():
#     #         s_dict[k] = copy.deepcopy(global_dict[k])
#     #     glob_mod.load_state_dict(s_dict)

#     clients.append(
#         Client(
#             glob_mod,
#             local_mod,
#             local_epochs,
#             learning_rate=learning_rate,
#             optimizer=optimizer,
#             device=device,
#             momentum=momentum,
#         )
#     )

In [None]:


data_train, data_test = get_data("data/data_all.csv", 4, False)


In [None]:
clients = []

# Populate clients
for i in range(nClients):
    clients.append(
        Client(
            MLP(transient_dim, output_dim),
            SLC(modalities[i], hidden_dims, transient_dim, False),
            DataLoader(FedDataset(data_train[i], device),
                       batch_size=32),
            local_epochs,
            learning_rate,
            "Adam",
            device=device
        )
    )


In [None]:
last_entry = 0

for round in range(rounds):
    # Global params for FL
    w_glob_tmp = None

    # Local params for FL
    w_loc_tmp = None

    # Count of encounters of each param
    w_loc_tmp_count = None
    if round > (1 - lg_frac) * rounds:
        federatedLoc = False
    print_loss = []
    for client in range(n):
        w_glob_ret, w_local_ret, performance = clients[client].train()

        print_loss.append(np.average(performance))

        if federatedGlob:
            if w_glob_tmp is None:
                w_glob_tmp = copy.deepcopy(w_glob_ret)
            else:
                for k in w_glob_ret:
                    w_glob_tmp[k] += w_glob_ret[k]

        if federatedLoc:
            if alpha_per_modality:
                factor = (
                    1 if len(w_local_ret) / 8 == 1 else len(w_local_ret) / 8 * alpha
                )
            else:
                factor = 1 if len(w_local_ret) / 8 == 1 else alpha

            if w_loc_tmp is None:
                w_loc_tmp = {}
                w_loc_tmp_count = {}
            for k in w_local_ret.keys():
                if k not in w_loc_tmp:
                    w_loc_tmp[k] = factor * w_local_ret[k]
                    w_loc_tmp_count[k] = factor
                else:
                    w_loc_tmp[k] += factor * w_local_ret[k]
                    w_loc_tmp_count[k] += factor

        # performance = clients[client].test()
        # if test_performance[client] is None:
        #     test_performance[client] = copy.deepcopy(performance)
        # else:
        #     test_performance[client] = np.hstack(
        #         (test_performance[client], performance)
        #     )

    if train_performance is None:
        train_performance = np.array(print_loss).reshape(1, -1)
    else:
        train_performance = np.vstack((train_performance, np.array(print_loss)))

    # get weighted average for global weights
    if federatedGlob:
        for k in w_glob_tmp.keys():
            w_glob_tmp[k] = torch.div(w_glob_tmp[k], n)
    if federatedLoc:
        for k in w_loc_tmp.keys():
            w_loc_tmp[k] = torch.div(w_loc_tmp[k], w_loc_tmp_count[k])

    # copy weights to each client based on mode
    if federatedGlob or federatedLoc:
        for client in range(n):
            clients[client].load_params(w_glob_tmp, w_loc_tmp)


    plt.clf()
    plt.plot(train_performance, label=[1, 2, 3, 4])
    plt.legend()

    print(modalities)

In [None]:
# save_path = main_folder + "/" + test_series_name + "/test_015"

# if os.path.isdir(save_path) is False:
#     os.mkdir(
#         save_path,
#     )
#     torch.save(test_performance, f"{save_path}/test_data")
#     info_dict = {
#         "Num of clients": n,
#         "Learning rate": learning_rate,
#         "Federated Global": federatedGlob,
#         "Federated Local": federatedLoc,
#         "Batch Size": batch_size,
#         "Global rounds": rounds,
#         "Local epochs": local_epochs,
#         "Clients": clients,
#         "Modalities": modalities,
#         "Optimizer": optimizer,
#         "Temporal length": temporal_len,
#         "Transient dimension": transient_dim,
#         "Hidden dimensions": hidden_dims,
#         "alpha": alpha,
#         "alpha_per_modality": alpha_per_modality,
#     }
#     with open(f"{save_path}/test_info.txt", "w") as f:
#         f.write(info_dict.__repr__())

#     for i in range(n):
#         torch.save(clients[i].model.state_dict(), f"{save_path}/dev{i}_model")
# else:
#     print("test exists")
