In [None]:
import os
import pickle
import pathlib
import datetime

from typing import Optional, Callable, Union

import torch
from torch import Tensor, Size
import torch.nn.functional as F
from torch.nn import Module, Parameter
from torch.utils.data import DataLoader

import numpy as np
from sklearn.metrics import accuracy_score

from google.colab import drive
drive.mount('/content/drive')

BASE_DIR = "/content/drive/MyDrive/torch-esn-stress"
WESAD_USERS = {
    'train': {
        25: [0, 3, 5],
        50: [0, 3, 5, 6, 9],
        75: [0, 3, 5, 6, 9, 10, 12],
        100: [0, 3, 5, 6, 9, 10, 12, 13, 14]
    }, 
    'valid': [1, 8, 11],
    'test': [2, 4, 7]
}

Mounted at /content/drive


In [None]:
#!wget https://uni-siegen.sciebo.de/s/HGdUkoNlW1Ub0Gx/download -O WESAD.zip
#!unzip "/content/WESAD.zip" -d "/content/drive/MyDrive/torch-esn-stress/data/raw"

# **Reservoir**

In [None]:
def uniform(size: Size, rho: Optional[float] = None, sigma: Optional[float] = None,
            scale: Optional[float] = None) -> Tensor:
    """
    Uniform random tensor
    Can either be rescaled according to spectral radius `rho`, spectral norm `sigma`, or `scale`.
    :param size: Size of tensor
    :param rho: Spectral radius
    :param sigma: Spectral norm
    :param scale: Simple rescaling of the standard random matrix
    :return: A random tensor
    """
    W = torch.empty(size).uniform_(-1, 1)
    rescale_(W, rho, sigma, scale)
    return W.data


def normal(size: Size, rho: Optional[float] = None, sigma: Optional[float] = None,
           scale: Optional[float] = None) -> Tensor:
    """
    Normal random tensor
    Can either be rescaled according to spectral radius `rho`, spectral norm `sigma`, or `scale`.
    :param size: Size of tensor
    :param rho: Spectral radius
    :param sigma: Spectral norm
    :param scale: Simple rescaling of the standard random matrix
    :return: A random tensor
    """
    W = torch.empty(size).normal_(mean=0, std=1)
    rescale_(W, rho, sigma, scale)
    return W.data


def ring(size: Size, rho: Optional[float] = None, sigma: Optional[float] = None,
         scale: Optional[float] = None) -> Tensor:
    """
    Ring matrix
    See:
    C. Gallicchio & A. Micheli (2020). Ring Reservoir Neural Networks for Graphs.
    In 2020 International Joint Conference on Neural Networks (IJCNN), IEEE.
    https://doi.org/10.1109/IJCNN48605.2020.9206723
    :param size: Size of tensor (must be square)
    :param rho: Spectral radius (equivalent to others)
    :param sigma: Spectral norm (equivalent to others)
    :param scale: Simple rescaling of the matrix (equivalent to others)
    :return: A re-scaled ring matrix
    """
    assert (len(size) == 2) and (size[0] == size[1])
    assert any(arg is not None for arg in [rho, sigma, scale])
    if scale is None:
        scale = rho if sigma is None else sigma
    W = torch.eye(size[0]).roll(1, 0) * scale
    return W.data


def orthogonal(size: Size, rho: Optional[float] = None, sigma: Optional[float] = None,
               scale: Optional[float] = None) -> Tensor:
    """
    Orthogonal matrix
    See:
    F. Mezzadri (2007). How to Generate Random Matrices from the Classical Compact Groups.
    Notices of the American Mathematical Society, 54(5), pp. 592-604.
    https://www.ams.org/notices/200705/fea-mezzadri-web.pdf
    :param size: Size of tensor (if not square, generates a semi-orthogonal matrix)
    :param rho: Spectral radius (equivalent to others)
    :param sigma: Spectral norm (equivalent to others)
    :param scale: Simple rescaling of the matrix (equivalent to others)
    :return: A re-scaled orthogonal matrix
    """
    assert any(arg is not None for arg in [rho, sigma, scale])
    if scale is None:
        scale = rho if sigma is None else sigma
    W = torch.empty(size)
    torch.nn.init.orthogonal_(W, scale)
    return W.data


def ones(size: Size, rho: Optional[float] = None, sigma: Optional[float] = None,
         scale: Optional[float] = None) -> Tensor:
    """
    Ones tensor
    Can either be rescaled according to spectral radius `rho`, spectral norm `sigma`, or `scale`.
    :param size: Size of tensor
    :param rho: Spectral radius
    :param sigma: Spectral norm
    :param scale: Simple rescaling of the standard random matrix
    :return: A random tensor
    """
    W = torch.ones(size)
    rescale_(W, rho, sigma, scale)
    return W.data


def zeros(size: Size, rho: Optional[float] = None, sigma: Optional[float] = None,
          scale: Optional[float] = None) -> Tensor:
    """
    Zeros tensor
    Rescaling is meaningless in this case.
    :param size: Size of tensor
    :param rho: Spectral radius
    :param sigma: Spectral norm
    :param scale: Simple rescaling of the standard random matrix
    :return: A random tensor
    """
    W = torch.zeros(size)
    return W.data


def rescale_(W: Tensor, rho: Optional[float] = None, sigma: Optional[float] = None,
             scale: Optional[float] = None) -> Tensor:
    """
    Rescale a matrix in-place
    Can either be rescaled according to spectral radius `rho`, spectral norm `sigma`, or `scale`.
    :param W: Matrix to rescale
    :param rho: Spectral radius
    :param sigma: Spectral norm
    :param scale: Simple rescaling of the standard random matrix
    :return: Rescaled matrix
    """
    if rho is not None:
        return W.div_(torch.linalg.eigvals(W).abs().max()).mul_(rho).float()
    elif sigma is not None:
        return W.div_(torch.linalg.matrix_norm(W, ord=2)).mul_(sigma).float()
    elif scale is not None:
        return W.mul_(scale).float()



class Reservoir(Module):
    """
    A Reservoir for Echo State Networks
    
    Args:
        input_size: the number of expected features in the input `x`
        hidden_size: the number of features in the hidden state `h`
        activation: name of the activation function from `torch` (e.g. `torch.tanh`)
        leakage: the value of the leaking parameter `alpha`
        input_scaling: the value for the desired scaling of the input (must be `<= 1`)
        rho: the desired spectral radius of the recurrent matrix (must be `< 1`)
        bias: if ``False``, the layer does not use bias weights `b`
        mode: execution mode of the reservoir (vanilla or intrinsic plasticity)
        kernel_initializer: the kind of initialization of the input transformation. Default: `'uniform'`
        recurrent_initializer: the kind of initialization of the recurrent matrix. Default: `'normal'`
        net_gain_and_bias: if ``True``, the network uses additional ``g`` (gain) and ``b`` (bias) parameters. Default: ``False`` 
    """

    def __init__(self,
                 input_size: int, 
                 hidden_size: int,
                 activation: str = 'tanh',
                 leakage: float = 1.,
                 input_scaling: float = 0.9,
                 rho: float = 0.99,
                 bias: bool = False,
                 kernel_initializer: Union[str, Callable[[Size], Tensor]] = 'uniform',
                 recurrent_initializer: Union[str, Callable[[Size], Tensor]] = 'normal',
                 net_gain_and_bias: bool = False) -> None:
        
        super().__init__()
        assert rho < 1 and input_scaling <= 1

        self.input_scaling = Parameter(torch.tensor(input_scaling), requires_grad=False)
        self.rho = Parameter(torch.tensor(rho), requires_grad=False)

        self.W_in = Parameter(
            init_params(kernel_initializer, scale=input_scaling)([hidden_size, input_size]), 
            requires_grad=False
        ) 
        self.W_hat = Parameter(
            init_params(recurrent_initializer, rho=rho)([hidden_size, hidden_size]),
            requires_grad=False
        )
        self.b = Parameter(
            init_params('uniform', scale=input_scaling)(hidden_size), 
            requires_grad=False
        ) if bias else None
        self.f = getattr(torch, activation)

        self.alpha = Parameter(torch.tensor(leakage), requires_grad=False)

        self.net_gain_and_bias = net_gain_and_bias
        if net_gain_and_bias:
            self.net_a = Parameter(
                init_params('ones')(hidden_size),
                requires_grad=True
            )
            self.net_b = Parameter(
                init_params('zeros')(hidden_size),
                requires_grad=True
            )
    
    @torch.no_grad()
    def forward(self, input: Tensor, initial_state: Optional[Tensor] = None, mask: Optional[Tensor] = None) -> Tensor:
        if initial_state is None:
            initial_state = torch.zeros(self.hidden_size).to(self.W_hat)
        
        embeddings = torch.stack([state for state in self._state_comp(input.to(self.W_hat), initial_state, mask)], dim=0)
        
        return embeddings

    def _state_comp(self, input: Tensor, initial_state: Tensor, mask: Optional[Tensor] = None):
        timesteps = input.shape[0]
        state = initial_state
        for t in range(timesteps):
            in_signal_t = F.linear(input[t].to(self.W_in), self.W_in, self.b) + F.linear(state, self.W_hat)
            if self.net_gain_and_bias:
                in_signal_t = in_signal_t * self.net_a + self.net_b
            h_t = torch.tanh(in_signal_t)
            state = (1 - self.alpha) * state + self.alpha * h_t
            yield state if mask is None else mask * state

    @property
    def input_size(self) -> int:
        """Input dimension"""
        return self.W_in.shape[1]

    @property
    def hidden_size(self) -> int:
        """Reservoir state dimension"""
        return self.W_hat.shape[1]


def init_params(name: str, **options) -> Callable[[Size], Tensor]:
    """
    Gets a random weight initializer
    :param name: Name of the random matrix generator in `esn.initializers`
    :param options: Random matrix generator options
    :return: A random weight generator function
    """
    #init = getattr(initializers, name)
    #return lambda size: init(size, **options)
    if name == 'uniform':
      return lambda size: uniform(size, **options)
    elif name == 'normal':
      return lambda size: normal(size, **options)
    elif name == 'ring':
      return lambda size: ring(size, **options)
    elif name == 'orthogonal':
      return lambda size: orthogonal(size, **options)
    elif name == 'ones':
      return lambda size: ones(size, **options)
    elif name == 'zeros':
      return lambda size: zeros(size, **options)


# **WESAD_Dataset**

In [None]:
RAW_WESAD_PATH = '/content/drive/MyDrive/torch-esn-stress/data/raw/WESAD'
WESAD_PATH = '/content/drive/MyDrive/torch-esn-stress/data/processed/WESAD'
USERS = ["2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "13", "14", "15", "16", "17"]


class WESAD_Dataset(torch.utils.data.Dataset):

    def __init__(self, idx: int) -> None:
        super().__init__()
        self.user = USERS[idx]

        u_path = os.path.join(WESAD_PATH, f'{self.user}.pkl')
        if not os.path.exists(u_path):
            self.user_data = self.preprocess()
        else:
            self.user_data = pickle.load(open(u_path, 'rb'))

        self._seq_length = None
        self.X, self.Y = None, None

    @property
    def seq_length(self):
        return self._seq_length

    @seq_length.setter
    def seq_length(self, new_length: int):
        if self._seq_length is None or new_length != self._seq_length:
            print(f"Setting the length of the chunks in WESAD user {self.user} from {self._seq_length} to {new_length}")
            self.X, self.Y = self._to_sequence_chunks(new_length)
            self._seq_length = new_length

    def __len__(self):
        return self.X.shape[1]

    def __getitem__(self, i: int):
        return self.X[:, i], self.Y[:, i]
    
    def _to_sequence_chunks(self, length: int):
        X, Y = torch.split(self.user_data['X'], length, dim=0), torch.split(self.user_data['Y'], length, dim=0)
        if X[-1].shape[0] != length:
            X, Y = X[:-1], Y[:-1]

        return torch.stack(X, dim=1), torch.stack(Y, dim=1)

    def preprocess(self):
        print(f"Preprocessing user {self.user}...")
        with open(os.path.join(RAW_WESAD_PATH, f'S{self.user}', f'S{self.user}.pkl'), 'rb') as f:
            data = pickle.load(f, encoding='latin1')
        X = np.concatenate([
            data['signal']['chest']['ACC'],
            data['signal']['chest']['Resp'],
            data['signal']['chest']['EDA'],
            data['signal']['chest']['ECG'],
            data['signal']['chest']['EMG'],
            data['signal']['chest']['Temp'],
        ], axis=1)
        Y = data['label']
        X = X[(Y>0) & (Y<5)]
        X = torch.tensor((X - np.mean(X, axis=0)) / np.std(X, axis=0))
        Y = Y[(Y>0) & (Y<5)]
        Y = F.one_hot(torch.tensor(Y-1, dtype=torch.int64), num_classes=4)
        u_dict = {'X': X, 'Y': Y}
        print(f"Preprocessing of user {self.user} complete!")

        print(f"Saving preprocessed data of user {self.user}...")
        pickle.dump(u_dict, open(os.path.join(WESAD_PATH, f'{self.user}.pkl'), 'wb+'))
        print(f"Preprocessed data of user {self.user} saved!")

        return u_dict


def seq_collate_fn(batch):
  x, y = [], []
  for x_i, y_i in batch:
      #x.append(x_i[:,2].view(-1,1))
      x.append(x_i)
      y.append(y_i)
  return torch.cat(x).float(), torch.cat(y).float()

# **Federated Learning**

In [None]:
def train(Nu, Nr, Ny, seq_len, perc, lamb, same_res):
  SERVER = {
    'SEQ_LENGTH': seq_len, 'PERCENTAGE': perc,
    'Nr': Nr, 'LAMBDA': lamb, 'SAME_RESERVOIR': same_res,

    'reservoir': Reservoir(input_size=Nu, hidden_size=Nr),
    'S_train': None, 'Y_train': None,
    'S_valid': None, 'Y_valid': None, 'x_valid': None,
    'S_test': None, 'Y_test': None, 'x_test': None,

    'clients': []
  }

  CLIENTS = []
  for u in WESAD_USERS['train'][perc]:
    data = WESAD_Dataset(u)
    data.seq_length = seq_len
    data_loader = DataLoader(data, batch_size=1, collate_fn=seq_collate_fn)
    data_iter = iter(data_loader)
    if same_res:
      reservoir = SERVER['reservoir']
    else:
      reservoir = Reservoir(input_size=Nu, hidden_size=Nr)

    CLIENTS.append({
        'id': u, 'data': data_iter, 'reservoir': reservoir,
        'x': None, 'A': torch.zeros(Ny,Nr), 'B': torch.zeros(Nr,Nr),
        'W': None, 'W_best': None, 'W_rand': None,
        'S': None, 'Y': None, 'finished': False
    })

  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  print(device)
  i = 0
  while not all([c['finished'] for c in CLIENTS]):
    for client in CLIENTS:
      if not client['finished']:
        id = client['id']
        data = client['data']
        reservoir = client['reservoir']

        batch = next(data, None)
        if batch is not None:
          u, t = batch
          embeddings = reservoir(u, initial_state=client['x']).float()
          client['x'] = embeddings[-1]
          x = embeddings[-1].view(-1,1) #(Nr, 1)
          y = t[-1].view(-1,1)          #(Ny, 1)
          client['S'] = x if client['S'] is None else torch.hstack((client['S'],x))
          client['Y'] = y if client['Y'] is None else torch.hstack((client['Y'],y))

          A_old, B_old = client['A'], client['B']
          x, y = x.to(device), y.to(device)
          A_new, B_new = y @ x.T, x @ x.T
          A = A_old.to(A_new) + A_new
          B = B_old.to(B_new) + B_new
          client['A'] = A.cpu()
          client['B'] = B.cpu()

        else:
          client['finished'] = True
          print(f"Client {id} finished!")
    
    i += 1
    print("Round",i)

  for client in CLIENTS: client['data']=None
  SERVER['clients'] = CLIENTS
  return SERVER

################################################################################

Nu, Ny = 8, 4
for perc in [100, 75, 50, 25]:
  for seq_len in [350, 700]:
    for Nr in [100, 250, 500]:
      for lamb in [1e-6]:
        print(f'----seq_len={seq_len}, perc={perc}, Nr={Nr}, lamb={lamb}, same_res=True')
        file_path = f'{BASE_DIR}/exp/server_{seq_len}_{perc}_{Nr}_{lamb}_same.pkl'
        if not os.path.exists(file_path):
          s = train(Nu, Nr, Ny, seq_len, perc, lamb, same_res=True)
          pickle.dump(s, open(file_path, "wb"))
        else:
          print(f'File {file_path} exists!')

# **Results**

In [None]:
def exec_experiments(alpha):
  results = []
  for file in os.listdir(f'{BASE_DIR}/exp'):
    if file.endswith(".pkl"):
      result = {}
      server = pickle.load(open(f'{BASE_DIR}/exp/{file}',"rb"))
      clients = server['clients']
      params = file.split('_')
      seq_len, perc, Nr, lamb, same_res = int(params[1]), int(params[2]), int(params[3]), float(params[4]), params[5]=="same.pkl"
      result['params'] = {
          'alpha': alpha,
          'seq_len': seq_len, 'perc': perc, 'Nr': Nr, 
          'lamb': lamb, 'same_res': same_res, 'n_clients': len(clients)
      }

      if server['S_valid'] is None:
        valid_data = [WESAD_Dataset(u) for u in WESAD_USERS['valid']]
        for d in valid_data: d.seq_length=seq_len
        valid_loaders = [DataLoader(d, batch_size=1, collate_fn=seq_collate_fn) for d in valid_data]
        for loader in valid_loaders:
          for u, t in loader:
            embeddings = server['reservoir'](u, initial_state=server['x_valid']).float()
            server['x_valid'] = embeddings[-1]
            x, y = embeddings[-1].view(-1,1), t[-1].view(-1,1)
            server['S_valid'] = x if server['S_valid'] is None else torch.hstack((server['S_valid'],x))
            server['Y_valid'] = y if server['Y_valid'] is None else torch.hstack((server['Y_valid'],y))
        pickle.dump(server, open(f'{BASE_DIR}/exp/{file}', "wb"))
      
      if server['S_test'] is None:
        test_data = [WESAD_Dataset(u) for u in WESAD_USERS['test']]
        for d in test_data: d.seq_length=seq_len
        test_loaders = [DataLoader(d, batch_size=1, collate_fn=seq_collate_fn) for d in test_data]
        for loader in test_loaders:
          for u, t in loader:
            embeddings = server['reservoir'](u, initial_state=server['x_test']).float()
            server['x_test'] = embeddings[-1]
            x, y = embeddings[-1].view(-1,1), t[-1].view(-1,1)
            server['S_test'] = x if server['S_test'] is None else torch.hstack((server['S_test'],x))
            server['Y_test'] = y if server['Y_test'] is None else torch.hstack((server['Y_test'],y))
        pickle.dump(server, open(f'{BASE_DIR}/exp/{file}', "wb"))

      S_train = torch.cat([c['S'] for c in clients], dim=1)
      Y_train = torch.cat([c['Y'] for c in clients], dim=1)
      S_valid, Y_valid = server['S_valid'], server['Y_valid']
      S_test, Y_test = server['S_test'], server['Y_test']

      # IncFed
      A = torch.stack([c['A'] for c in clients], dim=0).sum(dim=0)
      B = torch.stack([c['B'] for c in clients], dim=0).sum(dim=0)
      B += lamb*torch.eye(B.size(0))
      W_incfed = A @ B.pinverse()

      # FedAvg
      W_fedavg = torch.zeros(W_incfed.shape)
      for client in clients:
        W_fedavg += (client['A'] @ client['B'].pinverse())/client['S'].size(1)

      #FedImp
      p = (150/len(clients))/100
      n = int(p*Nr)
      k, k_rand = int(alpha*n), int((1-alpha)*n)
      A_best, B_best = torch.zeros(A.shape), torch.zeros(B.shape)
      for i, client in enumerate(clients):
        A_client, B_client = client['A'], client['B']
        A_client_best, B_client_best = torch.zeros(A_client.shape), torch.zeros(B_client.shape)

        imp = torch.sum(B_client**2, axis=1)
        idxs = list(range(imp.size(0)))
        _, topk_idxs = torch.topk(imp, k)
        topk_idxs = topk_idxs.tolist()

        rand_idxs = list(set(idxs)-set(topk_idxs))
        np.random.shuffle(rand_idxs)
        rand_idxs = rand_idxs[:k_rand]
        tot_idxs = topk_idxs + rand_idxs

        A_client_best[:,tot_idxs] = A_client[:,tot_idxs]
        B_client_best[tot_idxs,tot_idxs] = B_client[tot_idxs,tot_idxs]
        A_best += A_client_best
        B_best += B_client_best

      B_best += lamb*torch.eye(B_best.size(0))
      W_fedimp = A_best @ B_best.pinverse()

      # Random
      A_rand, B_rand = torch.zeros(A.shape), torch.zeros(B.shape)
      for client in clients:
        idxs = list(range(B_rand.size(0)))
        np.random.shuffle(idxs)
        idxs = idxs[:n]

        A_client, B_client = client['A'], client['B']
        A_client_rand, B_client_rand = torch.zeros(A_client.shape), torch.zeros(B_client.shape)
        

        A_client_rand[:,idxs] = A_client[:,idxs]
        B_client_rand[idxs,idxs] = B_client[idxs,idxs]
        A_rand += A_client_rand
        B_rand += B_client_rand

      B_rand += lamb*torch.eye(B_rand.size(0))
      W_rand = A_rand @ B_rand.pinverse()

      Y_pred_incfed = W_incfed @ S_train
      err_incfed = (torch.argmax(Y_train, axis=0) - torch.argmax(Y_pred_incfed, axis=0)).float()
      std_incfed = torch.std(err_incfed)
      score_incfed = accuracy_score(torch.argmax(Y_train, axis=0),torch.argmax(Y_pred_incfed, axis=0))
      Y_pred_fedavg = W_fedavg @ S_train
      err_fedavg = (torch.argmax(Y_train, axis=0) - torch.argmax(Y_pred_fedavg, axis=0)).float()
      std_fedavg = torch.std(err_fedavg)
      score_fedavg = accuracy_score(torch.argmax(Y_train, axis=0),torch.argmax(Y_pred_fedavg, axis=0))
      Y_pred_fedimp = W_fedimp @ S_train
      err_fedimp = (torch.argmax(Y_train, axis=0) - torch.argmax(Y_pred_fedimp, axis=0)).float()
      std_fedimp = torch.std(err_fedimp)
      score_fedimp = accuracy_score(torch.argmax(Y_train, axis=0),torch.argmax(Y_pred_fedimp, axis=0))
      Y_pred_rand = W_rand @ S_train
      err_rand = (torch.argmax(Y_train, axis=0) - torch.argmax(Y_pred_rand, axis=0)).float()
      std_rand = torch.std(err_rand)
      score_rand = accuracy_score(torch.argmax(Y_train, axis=0),torch.argmax(Y_pred_rand, axis=0))
      result['train'] = {
          'incfed': score_incfed,
          'fedavg': score_fedavg,
          'fedimp': score_fedimp,
          'random': score_rand,

          'std_incfed': std_incfed,
          'std_fedavg': std_fedavg,
          'std_fedimp': std_fedimp,
          'std_rand':   std_rand
      }
      
      Y_pred_incfed = W_incfed @ S_valid
      err_incfed = (torch.argmax(Y_valid, axis=0) - torch.argmax(Y_pred_incfed, axis=0)).float()
      std_incfed = torch.std(err_incfed)
      score_incfed = accuracy_score(torch.argmax(Y_valid, axis=0),torch.argmax(Y_pred_incfed, axis=0))
      Y_pred_fedavg = W_fedavg @ S_valid
      err_fedavg = (torch.argmax(Y_valid, axis=0) - torch.argmax(Y_pred_fedavg, axis=0)).float()
      std_fedavg = torch.std(err_fedavg)
      score_fedavg = accuracy_score(torch.argmax(Y_valid, axis=0),torch.argmax(Y_pred_fedavg, axis=0))
      Y_pred_fedimp = W_fedimp @ S_valid
      err_fedimp = (torch.argmax(Y_valid, axis=0) - torch.argmax(Y_pred_fedimp, axis=0)).float()
      std_fedimp = torch.std(err_fedimp)
      score_fedimp = accuracy_score(torch.argmax(Y_valid, axis=0),torch.argmax(Y_pred_fedimp, axis=0))
      Y_pred_rand = W_rand @ S_valid
      err_rand = (torch.argmax(Y_valid, axis=0) - torch.argmax(Y_pred_rand, axis=0)).float()
      std_rand = torch.std(err_rand)
      score_rand = accuracy_score(torch.argmax(Y_valid, axis=0),torch.argmax(Y_pred_rand, axis=0))
      result['valid'] = {
          'incfed': score_incfed,
          'fedavg': score_fedavg,
          'fedimp': score_fedimp,
          'random': score_rand,

          'std_incfed': std_incfed,
          'std_fedavg': std_fedavg,
          'std_fedimp': std_fedimp,
          'std_rand':   std_rand
      }
      
      Y_pred_incfed = W_incfed @ S_test
      err_incfed = (torch.argmax(Y_test, axis=0) - torch.argmax(Y_pred_incfed, axis=0)).float()
      std_incfed = torch.std(err_incfed)
      score_incfed = accuracy_score(torch.argmax(Y_test, axis=0),torch.argmax(Y_pred_incfed, axis=0))
      Y_pred_fedavg = W_fedavg @ S_test
      err_fedavg = (torch.argmax(Y_test, axis=0) - torch.argmax(Y_pred_fedavg, axis=0)).float()
      std_fedavg = torch.std(err_fedavg)
      score_fedavg = accuracy_score(torch.argmax(Y_test, axis=0),torch.argmax(Y_pred_fedavg, axis=0))
      Y_pred_fedimp = W_fedimp @ S_test
      err_fedimp = (torch.argmax(Y_test, axis=0) - torch.argmax(Y_pred_fedimp, axis=0)).float()
      std_fedimp = torch.std(err_fedimp)
      score_fedimp = accuracy_score(torch.argmax(Y_test, axis=0),torch.argmax(Y_pred_fedimp, axis=0))
      Y_pred_rand = W_rand @ S_test
      err_rand = (torch.argmax(Y_test, axis=0) - torch.argmax(Y_pred_rand, axis=0)).float()
      std_rand = torch.std(err_rand)
      score_rand = accuracy_score(torch.argmax(Y_test, axis=0),torch.argmax(Y_pred_rand, axis=0))
      result['test'] = {
          'incfed': score_incfed,
          'fedavg': score_fedavg,
          'fedimp': score_fedimp,
          'random': score_rand,

          'std_incfed': std_incfed,
          'std_fedavg': std_fedavg,
          'std_fedimp': std_fedimp,
          'std_rand':   std_rand
      }
      results.append(result)
    
  return results


r = []
for a in [0.3, 0.6, 0.8, 1]:
    res = exec_experiments(a)
    r = r+res


results_100 = [e for e in r if e['params']['perc']==100]
best_incfed_idx = np.argmax([e['valid']['incfed'] for e in results_100])
best_incfed_100 = results_100[best_incfed_idx]
best_incfed_100_train       = best_incfed_100['train']['incfed']
best_incfed_100_train_std   = best_incfed_100['train']['std_incfed']
best_incfed_100_valid       = best_incfed_100['valid']['incfed']
best_incfed_100_valid_std   = best_incfed_100['valid']['std_incfed']
best_incfed_100_test        = best_incfed_100['test']['incfed']
best_incfed_100_test_std    = best_incfed_100['test']['std_incfed']

best_incfed_75 = [e for e in r if e['params']['perc']==75 and e['params']['alpha']==best_incfed_100['params']['alpha'] and e['params']['seq_len']==best_incfed_100['params']['seq_len'] and e['params']['Nr']==best_incfed_100['params']['Nr']][0]
best_incfed_75_train      = best_incfed_75['train']['incfed']
best_incfed_75_train_std  = best_incfed_75['train']['std_incfed']
best_incfed_75_valid      = best_incfed_75['valid']['incfed']
best_incfed_75_valid_std  = best_incfed_75['valid']['std_incfed']
best_incfed_75_test       = best_incfed_75['test']['incfed']
best_incfed_75_test_std   = best_incfed_75['test']['std_incfed']

best_incfed_50 = [e for e in r if e['params']['perc']==50 and e['params']['alpha']==best_incfed_100['params']['alpha'] and e['params']['seq_len']==best_incfed_100['params']['seq_len'] and e['params']['Nr']==best_incfed_100['params']['Nr']][0]
best_incfed_50_train      = best_incfed_50['train']['incfed']
best_incfed_50_train_std  = best_incfed_50['train']['std_incfed']
best_incfed_50_valid      = best_incfed_50['valid']['incfed']
best_incfed_50_valid_std  = best_incfed_50['valid']['std_incfed']
best_incfed_50_test       = best_incfed_50['test']['incfed']
best_incfed_50_test_std   = best_incfed_50['test']['std_incfed']

best_incfed_25 = [e for e in r if e['params']['perc']==25 and e['params']['alpha']==best_incfed_100['params']['alpha'] and e['params']['seq_len']==best_incfed_100['params']['seq_len'] and e['params']['Nr']==best_incfed_100['params']['Nr']][0]
best_incfed_25_train      = best_incfed_25['train']['incfed']
best_incfed_25_train_std  = best_incfed_25['train']['std_incfed']
best_incfed_25_valid      = best_incfed_25['valid']['incfed']
best_incfed_25_valid_std  = best_incfed_25['valid']['std_incfed']
best_incfed_25_test       = best_incfed_25['test']['incfed']
best_incfed_25_test_std   = best_incfed_25['test']['std_incfed']

best_fedimp_idx = np.argmax([e['valid']['fedimp'] for e in results_100])
best_fedimp_100 = results_100[best_fedimp_idx]
best_fedimp_100_train       = best_fedimp_100['train']['fedimp']
best_fedimp_100_train_std   = best_fedimp_100['train']['std_fedimp']
best_fedimp_100_valid       = best_fedimp_100['valid']['fedimp']
best_fedimp_100_valid_std   = best_fedimp_100['valid']['std_fedimp']
best_fedimp_100_test        = best_fedimp_100['test']['fedimp']
best_fedimp_100_test_std    = best_fedimp_100['test']['std_fedimp']

best_fedimp_75 = [e for e in r if e['params']['perc']==75 and e['params']['alpha']==best_fedimp_100['params']['alpha'] and e['params']['seq_len']==best_fedimp_100['params']['seq_len'] and e['params']['Nr']==best_fedimp_100['params']['Nr']][0]
best_fedimp_75_train       = best_fedimp_75['train']['fedimp']
best_fedimp_75_train_std   = best_fedimp_75['train']['std_fedimp']
best_fedimp_75_valid       = best_fedimp_75['valid']['fedimp']
best_fedimp_75_valid_std   = best_fedimp_75['valid']['std_fedimp']
best_fedimp_75_test        = best_fedimp_75['test']['fedimp']
best_fedimp_75_test_std    = best_fedimp_75['test']['std_fedimp']

best_fedimp_50 = [e for e in r if e['params']['perc']==50 and e['params']['alpha']==best_fedimp_100['params']['alpha'] and e['params']['seq_len']==best_fedimp_100['params']['seq_len'] and e['params']['Nr']==best_fedimp_100['params']['Nr']][0]
best_fedimp_50_train       = best_fedimp_50['train']['fedimp']
best_fedimp_50_train_std   = best_fedimp_50['train']['std_fedimp']
best_fedimp_50_valid       = best_fedimp_50['valid']['fedimp']
best_fedimp_50_valid_std   = best_fedimp_50['valid']['std_fedimp']
best_fedimp_50_test        = best_fedimp_50['test']['fedimp']
best_fedimp_50_test_std    = best_fedimp_50['test']['std_fedimp']

best_fedimp_25 = [e for e in r if e['params']['perc']==25 and e['params']['alpha']==best_fedimp_100['params']['alpha'] and e['params']['seq_len']==best_fedimp_100['params']['seq_len'] and e['params']['Nr']==best_fedimp_100['params']['Nr']][0]
best_fedimp_25_train       = best_fedimp_25['train']['fedimp']
best_fedimp_25_train_std   = best_fedimp_25['train']['std_fedimp']
best_fedimp_25_valid       = best_fedimp_25['valid']['fedimp']
best_fedimp_25_valid_std   = best_fedimp_25['valid']['std_fedimp']
best_fedimp_25_test        = best_fedimp_25['test']['fedimp']
best_fedimp_25_test_std    = best_fedimp_25['test']['std_fedimp']

best_rand_idx = np.argmax([e['valid']['random'] for e in results_100])
best_rand_100 = results_100[best_rand_idx]
best_rand_100_train       = best_rand_100['train']['random']
best_rand_100_train_std   = best_rand_100['train']['std_rand']
best_rand_100_valid       = best_rand_100['valid']['random']
best_rand_100_valid_std   = best_rand_100['valid']['std_rand']
best_rand_100_test        = best_rand_100['test']['random']
best_rand_100_test_std    = best_rand_100['test']['std_rand']

best_rand_75 = [e for e in r if e['params']['perc']==75 and e['params']['alpha']==best_rand_100['params']['alpha'] and e['params']['seq_len']==best_rand_100['params']['seq_len'] and e['params']['Nr']==best_rand_100['params']['Nr']][0]
best_rand_75_train       = best_rand_75['train']['random']
best_rand_75_train_std   = best_rand_75['train']['std_rand']
best_rand_75_valid       = best_rand_75['valid']['random']
best_rand_75_valid_std   = best_rand_75['valid']['std_rand']
best_rand_75_test        = best_rand_75['test']['random']
best_rand_75_test_std    = best_rand_75['test']['std_rand']

best_rand_50 = [e for e in r if e['params']['perc']==50 and e['params']['alpha']==best_rand_100['params']['alpha'] and e['params']['seq_len']==best_rand_100['params']['seq_len'] and e['params']['Nr']==best_rand_100['params']['Nr']][0]
best_rand_50_train       = best_rand_50['train']['random']
best_rand_50_train_std   = best_rand_50['train']['std_rand']
best_rand_50_valid       = best_rand_50['valid']['random']
best_rand_50_valid_std   = best_rand_50['valid']['std_rand']
best_rand_50_test        = best_rand_50['test']['random']
best_rand_50_test_std    = best_rand_50['test']['std_rand']

best_rand_25 = [e for e in r if e['params']['perc']==25 and e['params']['alpha']==best_rand_100['params']['alpha'] and e['params']['seq_len']==best_rand_100['params']['seq_len'] and e['params']['Nr']==best_rand_100['params']['Nr']][0]
best_rand_25_train       = best_rand_25['train']['random']
best_rand_25_train_std   = best_rand_25['train']['std_rand']
best_rand_25_valid       = best_rand_25['valid']['random']
best_rand_25_valid_std   = best_rand_25['valid']['std_rand']
best_rand_25_test        = best_rand_25['test']['random']
best_rand_25_test_std    = best_rand_25['test']['std_rand']

best_fedavg_idx = np.argmax([e['valid']['fedavg'] for e in results_100])
best_fedavg_100 = results_100[best_fedavg_idx]
best_fedavg_100_train       = best_fedavg_100['train']['fedavg']
best_fedavg_100_train_std   = best_fedavg_100['train']['std_fedavg']
best_fedavg_100_valid       = best_fedavg_100['valid']['fedavg']
best_fedavg_100_valid_std   = best_fedavg_100['valid']['std_fedavg']
best_fedavg_100_test        = best_fedavg_100['test']['fedavg']
best_fedavg_100_test_std    = best_fedavg_100['test']['std_fedavg']

best_fedavg_75 = [e for e in r if e['params']['perc']==75 and e['params']['alpha']==best_fedavg_100['params']['alpha'] and e['params']['seq_len']==best_fedavg_100['params']['seq_len'] and e['params']['Nr']==best_fedavg_100['params']['Nr']][0]
best_fedavg_75_train       = best_fedavg_75['train']['fedavg']
best_fedavg_75_train_std   = best_fedavg_75['train']['std_fedavg']
best_fedavg_75_valid       = best_fedavg_75['valid']['fedavg']
best_fedavg_75_valid_std   = best_fedavg_75['valid']['std_fedavg']
best_fedavg_75_test        = best_fedavg_75['test']['fedavg']
best_fedavg_75_test_std    = best_fedavg_75['test']['std_fedavg']

best_fedavg_50 = [e for e in r if e['params']['perc']==50 and e['params']['alpha']==best_fedavg_100['params']['alpha'] and e['params']['seq_len']==best_fedavg_100['params']['seq_len'] and e['params']['Nr']==best_fedavg_100['params']['Nr']][0]
best_fedavg_50_train       = best_fedavg_50['train']['fedavg']
best_fedavg_50_train_std   = best_fedavg_50['train']['std_fedavg']
best_fedavg_50_valid       = best_fedavg_50['valid']['fedavg']
best_fedavg_50_valid_std   = best_fedavg_50['valid']['std_fedavg']
best_fedavg_50_test        = best_fedavg_50['test']['fedavg']
best_fedavg_50_test_std    = best_fedavg_50['test']['std_fedavg']

best_fedavg_25 = [e for e in r if e['params']['perc']==25 and e['params']['alpha']==best_fedavg_100['params']['alpha'] and e['params']['seq_len']==best_fedavg_100['params']['seq_len'] and e['params']['Nr']==best_fedavg_100['params']['Nr']][0]
best_fedavg_25_train       = best_fedavg_25['train']['fedavg']
best_fedavg_25_train_std   = best_fedavg_25['train']['std_fedavg']
best_fedavg_25_valid       = best_fedavg_25['valid']['fedavg']
best_fedavg_25_valid_std   = best_fedavg_25['valid']['std_fedavg']
best_fedavg_25_test        = best_fedavg_25['test']['fedavg']
best_fedavg_25_test_std    = best_fedavg_25['test']['std_fedavg']

print(f"Best IncFed params: {best_incfed_100['params']}")
print(f"Best FedImp params: {best_fedimp_100['params']}")
print(f"Best Random params: {best_rand_100['params']}")
print(f"Best FedAvg params: {best_fedavg_100['params']}")

print(f"""
     \t IncFed \t\t\t\t FedImp \t\t\t\t Random \t\t\t\t FedAvg
     \t TR         | VL        | TS \t\t TR         | VL        | TS \t\t TR         | VL        | TS \t\t TR         | VL        | TS
100% \t {best_incfed_100_train:.4f}\u00B1{best_incfed_100_train_std:.2f}|{best_incfed_100_valid:.4f}\u00B1{best_incfed_100_valid_std:.2f}|{best_incfed_100_test:.4f}\u00B1{best_incfed_100_test_std:.2f} \t {best_fedimp_100_train:.4f}\u00B1{best_fedimp_100_train_std:.2f}|{best_fedimp_100_valid:.4f}\u00B1{best_fedimp_100_valid_std:.2f}|{best_fedimp_100_test:.4f}\u00B1{best_fedimp_100_test_std:.2f} \t {best_rand_100_train:.4f}\u00B1{best_rand_100_train_std:.2f}|{best_rand_100_valid:.4f}\u00B1{best_rand_100_valid_std:.2f}|{best_rand_100_test:.4f}\u00B1{best_rand_100_test_std:.2f} \t {best_fedavg_100_train:.4f}\u00B1{best_fedavg_100_train_std:.2f}|{best_fedavg_100_valid:.4f}\u00B1{best_fedavg_100_valid_std:.2f}|{best_fedavg_100_test:.4f}\u00B1{best_fedavg_100_test_std:.2f}
75%  \t {best_incfed_75_train:.4f}\u00B1{best_incfed_75_train_std:.2f}|{best_incfed_75_valid:.4f}\u00B1{best_incfed_75_valid_std:.2f}|{best_incfed_75_test:.4f}\u00B1{best_incfed_75_test_std:.2f} \t {best_fedimp_75_train:.4f}\u00B1{best_fedimp_75_train_std:.2f}|{best_fedimp_75_valid:.4f}\u00B1{best_fedimp_75_valid_std:.2f}|{best_fedimp_75_test:.4f}\u00B1{best_fedimp_75_test_std:.2f} \t {best_rand_75_train:.4f}\u00B1{best_rand_75_train_std:.2f}|{best_rand_75_valid:.4f}\u00B1{best_rand_75_valid_std:.2f}|{best_rand_75_test:.4f}\u00B1{best_rand_75_test_std:.2f} \t {best_fedavg_75_train:.4f}\u00B1{best_fedavg_75_train_std:.2f}|{best_fedavg_75_valid:.4f}\u00B1{best_fedavg_75_valid_std:.2f}|{best_fedavg_75_test:.4f}\u00B1{best_fedavg_75_test_std:.2f}
50%  \t {best_incfed_50_train:.4f}\u00B1{best_incfed_50_train_std:.2f}|{best_incfed_50_valid:.4f}\u00B1{best_incfed_50_valid_std:.2f}|{best_incfed_50_test:.4f}\u00B1{best_incfed_50_test_std:.2f} \t {best_fedimp_50_train:.4f}\u00B1{best_fedimp_50_train_std:.2f}|{best_fedimp_50_valid:.4f}\u00B1{best_fedimp_50_valid_std:.2f}|{best_fedimp_50_test:.4f}\u00B1{best_fedimp_50_test_std:.2f} \t {best_rand_50_train:.4f}\u00B1{best_rand_50_train_std:.2f}|{best_rand_50_valid:.4f}\u00B1{best_rand_50_valid_std:.2f}|{best_rand_50_test:.4f}\u00B1{best_rand_50_test_std:.2f} \t {best_fedavg_50_train:.4f}\u00B1{best_fedavg_50_train_std:.2f}|{best_fedavg_50_valid:.4f}\u00B1{best_fedavg_50_valid_std:.2f}|{best_fedavg_50_test:.4f}\u00B1{best_fedavg_50_test_std:.2f}
25%  \t {best_incfed_25_train:.4f}\u00B1{best_incfed_25_train_std:.2f}|{best_incfed_25_valid:.4f}\u00B1{best_incfed_25_valid_std:.2f}|{best_incfed_25_test:.4f}\u00B1{best_incfed_25_test_std:.2f} \t {best_fedimp_25_train:.4f}\u00B1{best_fedimp_25_train_std:.2f}|{best_fedimp_25_valid:.4f}\u00B1{best_fedimp_25_valid_std:.2f}|{best_fedimp_25_test:.4f}\u00B1{best_fedimp_25_test_std:.2f} \t {best_rand_25_train:.4f}\u00B1{best_rand_25_train_std:.2f}|{best_rand_25_valid:.4f}\u00B1{best_rand_25_valid_std:.2f}|{best_rand_25_test:.4f}\u00B1{best_rand_25_test_std:.2f} \t {best_fedavg_25_train:.4f}\u00B1{best_fedavg_25_train_std:.2f}|{best_fedavg_25_valid:.4f}\u00B1{best_fedavg_25_valid_std:.2f}|{best_fedavg_25_test:.4f}\u00B1{best_fedavg_25_test_std:.2f}
""")

Best IncFed params: {'alpha': 0.3, 'seq_len': 350, 'perc': 100, 'Nr': 500, 'lamb': 1e-06, 'same_res': True, 'n_clients': 9}
Best FedImp params: {'alpha': 0.3, 'seq_len': 350, 'perc': 100, 'Nr': 500, 'lamb': 1e-06, 'same_res': True, 'n_clients': 9}
Best Random params: {'alpha': 1, 'seq_len': 350, 'perc': 100, 'Nr': 500, 'lamb': 1e-06, 'same_res': True, 'n_clients': 9}
Best FedAvg params: {'alpha': 0.3, 'seq_len': 350, 'perc': 100, 'Nr': 500, 'lamb': 1e-06, 'same_res': True, 'n_clients': 9}

     	 IncFed 				 FedImp 				 Random 				 FedAvg
     	 TR         | VL        | TS 		 TR         | VL        | TS 		 TR         | VL        | TS 		 TR         | VL        | TS
100% 	 0.8988±0.57|0.7546±0.82|0.7801±0.77 	 0.7370±1.01|0.7361±0.97|0.8128±0.70 	 0.7280±1.04|0.7516±0.96|0.8084±0.74 	 0.7388±0.95|0.7532±0.99|0.8170±0.66
75%  	 0.9170±0.54|0.7469±0.90|0.7504±0.81 	 0.7177±1.07|0.7008±1.07|0.8108±0.76 	 0.6962±1.10|0.6647±1.19|0.7887±0.86 	 0.7024±1.00|0.7627±0.90|0.8052±0.72
50%  	 0.9369