In [1]:
!pip install braindecode moabb ray
!pip install -U "flwr[simulation]"



In [2]:
from braindecode.datasets import MOABBDataset
from sklearn.model_selection import train_test_split


## Loanding dataset

In [3]:
import numpy as np
from braindecode.preprocessing import (
    exponential_moving_standardize,
    preprocess,
    Preprocessor,
)

# Carregar o dataset
subject_id = 3
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])

# Dividir o dataset em 5 partes manualmente
num_parts = 5
part_size = len(dataset.datasets) // num_parts
dataset_parts = []

# Separar manualmente os dados e recriar subconjuntos como novos datasets
for i in range(num_parts):
    start = i * part_size
    end = (i + 1) * part_size if i < num_parts - 1 else len(dataset.datasets)
    subset = dataset.datasets[start:end]
    dataset_parts.append(subset)

# Definir diferentes conjuntos de parâmetros para cada parte
param_sets = [
    {"low_cut_hz": 4.0, "high_cut_hz": 30.0, "factor_new": 1e-3, "init_block_size": 1000},
    {"low_cut_hz": 5.0, "high_cut_hz": 35.0, "factor_new": 5e-4, "init_block_size": 1200},
    {"low_cut_hz": 3.5, "high_cut_hz": 28.0, "factor_new": 2e-3, "init_block_size": 900},
    {"low_cut_hz": 6.0, "high_cut_hz": 40.0, "factor_new": 1e-4, "init_block_size": 1500},
    {"low_cut_hz": 4.5, "high_cut_hz": 32.0, "factor_new": 1e-3, "init_block_size": 1100},
]

# Lista para armazenar os datasets processados
processed_datasets = []

# Aplicar os preprocessadores a cada subconjunto com parâmetros específicos
for i, (subset, params) in enumerate(zip(dataset_parts, param_sets)):
    preprocessors = [
        Preprocessor("pick_types", eeg=True, meg=False, stim=False),
        Preprocessor(lambda data, factor: np.multiply(data, factor), factor=1e6),
        Preprocessor("filter", l_freq=params["low_cut_hz"], h_freq=params["high_cut_hz"]),
        Preprocessor(exponential_moving_standardize, factor_new=params["factor_new"], init_block_size=params["init_block_size"]),
    ]

    # Criar um novo dataset a partir do subconjunto
    new_dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])
    new_dataset.datasets = subset  # Atribuir subconjunto ao novo dataset

    # Preprocessar o novo dataset de forma independente
    preprocess(new_dataset, preprocessors, n_jobs=-1)
    print(f"Preprocessamento concluído para o subconjunto {i+1} com parâmetros: {params}")

    # Armazenar o dataset processado na lista
    processed_datasets.append(new_dataset)

# Agora `processed_datasets` contém os cinco datasets preprocessados


  warn('Preprocessing choices with lambda functions cannot be saved.')


Preprocessamento concluído para o subconjunto 1 com parâmetros: {'low_cut_hz': 4.0, 'high_cut_hz': 30.0, 'factor_new': 0.001, 'init_block_size': 1000}


  warn('Preprocessing choices with lambda functions cannot be saved.')


Preprocessamento concluído para o subconjunto 2 com parâmetros: {'low_cut_hz': 5.0, 'high_cut_hz': 35.0, 'factor_new': 0.0005, 'init_block_size': 1200}


  warn('Preprocessing choices with lambda functions cannot be saved.')


Preprocessamento concluído para o subconjunto 3 com parâmetros: {'low_cut_hz': 3.5, 'high_cut_hz': 28.0, 'factor_new': 0.002, 'init_block_size': 900}


  warn('Preprocessing choices with lambda functions cannot be saved.')


Preprocessamento concluído para o subconjunto 4 com parâmetros: {'low_cut_hz': 6.0, 'high_cut_hz': 40.0, 'factor_new': 0.0001, 'init_block_size': 1500}


  warn('Preprocessing choices with lambda functions cannot be saved.')


Preprocessamento concluído para o subconjunto 5 com parâmetros: {'low_cut_hz': 4.5, 'high_cut_hz': 32.0, 'factor_new': 0.001, 'init_block_size': 1100}


In [4]:
from braindecode.preprocessing import create_windows_from_events

def extractionWindow(dataset):
  trial_start_offset_seconds = -0.5
  # Extract sampling frequency, check that they are same in all datasets
  sfreq = dataset.datasets[0].raw.info["sfreq"]
  assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
  # Calculate the window start offset in samples.
  trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

  # Create windows using braindecode function for this. It needs parameters to
  # define how windows should be used.
  windows_dataset = create_windows_from_events(
      dataset,
      trial_start_offset_samples=trial_start_offset_samples,
      trial_stop_offset_samples=0,
      preload=True,
  )


  return windows_dataset

In [5]:
def split_windows_dataset(windows_dataset):
  splitted = windows_dataset.split("session")
  print(splitted)
  if('0train' in splitted):
    return splitted['0train']  # Session train
  else:
    return splitted['1test']
  #test_ set = splitted['1test']  # Session evaluation
  # return train_set

## Criação de modelo

In [6]:
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet

def create_model(shape):
  print('Entrou na função model')
  seed = 20200220

  n_classes = 4
  classes = list(range(n_classes))
  # Extract number of chans and time steps from dataset
  n_channels = shape[0]
  input_window_samples = shape[1]

  model = ShallowFBCSPNet(
      n_channels,
      n_classes,
      input_window_samples=input_window_samples,
      final_conv_length="auto",
  )

  return model


In [7]:
import flwr as fl

In [8]:
def get_client_dataset(node_id):
    """Função que mapeia o node_id do cliente ao seu dataset correspondente"""
    client_datasets = {
        0: processed_datasets[0],
        1: processed_datasets[1],
        2: processed_datasets[2],
        3: processed_datasets[3],
        4: processed_datasets[4]
    }
    return client_datasets.get(node_id)

  and should_run_async(code)


In [9]:
def numpyclient_fn(context, message=None):
    # Usando a função de mapeamento para pegar os dados baseados no node_id do context
    client_id = context.node_config["partition-id"]
    print('Esse é o client id:', client_id)
    dataset_to_use = get_client_dataset(client_id)
    if dataset_to_use is None:
        raise ValueError(f"Dataset não encontrado para o cliente {context.node_id}")

    windows_dataset = extractionWindow(dataset_to_use)
    print('Saiu da função window')
    model = create_model(windows_dataset[0][0].shape)
    print('Saiu da função model')
    # Continua a configuração do cliente
    client = FlowerNumPyClient(model, windows_dataset)
    return client.to_client()

In [10]:
from collections import OrderedDict
import torch

def get_parameters(model):
    return [val.cpu().numpy() for _, val in model.state_dict().items()]


def set_parameters(model, parameters):
    params_dict = zip(model.state_dict().keys(), parameters)
    # Ordenar o dict de acordo com os pesos
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)

In [14]:
from skorch.callbacks import LRScheduler
from braindecode import EEGClassifier

class FlowerNumPyClient(fl.client.NumPyClient):
    def __init__(self,model, windows_dataset):
        lr = 0.0625 * 0.01
        weight_decay = 0
        batch_size = 64
        n_epochs = 2

        self.model = model
        self.window = windows_dataset
        self.train_set = split_windows_dataset(self.window)
        print('Vai iniciar o classificador')
        self.clf = EEGClassifier(
            self.model,
            criterion=torch.nn.NLLLoss,
            optimizer=torch.optim.AdamW,
            train_split=None,
            optimizer__lr=lr,
            optimizer__weight_decay=weight_decay,
            batch_size=batch_size,
            callbacks=[
                "accuracy",
                ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
            ],
            device='cpu',
            # 4 classes no dataset
            classes=list(range(4)),
            max_epochs=n_epochs,
        )
        print('Saiu do classificador')

    def get_parameters(self, config):
        print('Get parameters')
        return get_parameters(self.model)

    def set_weights(self, net, parameters):
      params_dict = zip(net.state_dict().keys(), parameters)
      state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
      net.load_state_dict(state_dict, strict=True)

    # treinamento
    def fit(self, parameters, config):
        print('Inicia o fit')
        self.set_weights(self.model, parameters)
        lr = 0.0625 * 0.01
        weight_decay = 0
        batch_size = 64
        n_epochs = 2

        print('Iniciando a classificação')

        self.clf = EEGClassifier(
            self.model,
            criterion=torch.nn.NLLLoss,
            optimizer=torch.optim.AdamW,
            train_split=None,
            optimizer__lr=lr,
            optimizer__weight_decay=weight_decay,
            batch_size=batch_size,
            callbacks=[
                "accuracy",
                ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
            ],
            device='cpu',
            # 4 classes no dataset
            classes=list(range(4)),
            max_epochs=n_epochs,
        )
        # Model training for a specified number of epochs. `y` is None as it is already supplied
        # in the dataset.

        print('iniciando o fit!!!!!!!!!!!')
        # self.clf.fit(self.train_set, y=None)


        return self.get_parameters(self.model), len(self.window ), {}

    # Avaliação
    def evaluate(self, parameters, config):
        # print(f"[Client evaluate, config: {config}")
        # set_parameters(self.model, parameters)
        # y_test = self.train_set.get_metadata().target
        # test_acc = self.clf.score(self.train_set, y=y_test)
        # print(f"Test acc: {(test_acc * 100):.2f}%")
        # depois calcular o loss

        # evaluated t.he model after training
        self.clf.fit(self.train_set, y=None)
        y_test = self.train_set.get_metadata().target
        test_acc = self.clf.score(self.train_set, y=y_test)
        print(f"Test acc: {(test_acc * 100):.2f}%")

        return float(test_acc), len(self.train_set), {"accuracy": float(test_acc)}

In [12]:
def server_fn(context):
    # Configure the server for 3 rounds of training
    config = fl.server.ServerConfig(num_rounds=3)
    return fl.server.ServerAppComponents(config=config)


# Create ServerApp
server = fl.server.ServerApp(server_fn=server_fn)

In [15]:
NUM_PARTITIONS = 5

# Run simulation
fl.simulation.run_simulation(
    server_app=server,
    client_app=numpyclient_fn,
    num_supernodes=NUM_PARTITIONS,
)


[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=54750)[0m 2024-11-16 17:56:08.558167: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=54750)[0m 2024-11-16 17:56:08.582942: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=54750)[0m 2024-11-16 17:56:08.589768: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(ClientAppActor pid=54750)[0m see the appropriate new directories, set the environmen

[36m(ClientAppActor pid=54750)[0m Esse é o client id: 3
[36m(ClientAppActor pid=54750)[0m Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[36m(ClientAppActor pid=54750)[0m Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[36m(ClientAppActor pid=54750)[0m Saiu da função window
[36m(ClientAppActor pid=54750)[0m Entrou na função model
[36m(ClientAppActor pid=54750)[0m Saiu da função model
[36m(ClientAppActor pid=54750)[0m {'1test': <braindecode.datasets.base.BaseConcatDataset object at 0x788295511de0>}
[36m(ClientAppActor pid=54750)[0m Vai iniciar o classificador
[36m(ClientAppActor pid=54750)[0m Saiu do classificador


  cuda_attrs = torch.load(f, **load_kwargs)


KeyboardInterrupt: 