In [1]:
%%capture
pip install mne

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import mne
import pywt
from scipy.stats import skew, kurtosis
from scipy.signal import welch
from scipy.signal import find_peaks



# folder = 'E:/Data WareHouse/1sec/Epileptic Seizure/equalized'
folder = '/content/drive/MyDrive/EEG Signal /Epileptic seizure/data/equalized epoch'
epochs_path = [os.path.join(folder,i) for i in os.listdir(folder) if i[-3:]=='fif']


data= [mne.read_epochs(i).pick_types(eeg=True) for i in epochs_path]
labels = [mne.read_epochs(i).events[:,2] for i in epochs_path]
group = [[i]*len(j) for i,j in enumerate(data)]
X=np.vstack(data)
Y=np.hstack(labels)
group= np.hstack(group)




def zero_crossing_rate(signal):
    zero_crossings = np.where(np.diff(np.sign(signal)))[0]
    return len(zero_crossings) / len(signal)

def hjorth_parameters(signal):
    diff_input = np.diff(signal)
    diff_diff_input = np.diff(diff_input)

    activity = np.var(signal)
    mobility = np.sqrt(np.var(diff_input)/activity)
    complexity = np.sqrt(np.var(diff_diff_input)/np.var(diff_input)) / mobility

    return activity, mobility, complexity



def extract_time_domain_features(epochs):
    features = []

    for epoch in epochs:
        epoch_features = []

        for channel_data in epoch:
            # Flatten the channel data
            flattened_data = channel_data.flatten()

            # Basic Time-Domain Features
            mean_val = np.mean(flattened_data)
            median_val = np.median(flattened_data)
            var_val = np.var(flattened_data)
            std_dev = np.std(flattened_data)
            skewness = skew(flattened_data)
            kurt = kurtosis(flattened_data)
            zcr = zero_crossing_rate(flattened_data)
            peak_amp = np.ptp(flattened_data)

            # Hjorth Parameters
            activity, mobility, complexity = hjorth_parameters(flattened_data)

            # Additional Features
            num_waves = len(find_peaks(flattened_data)[0])
            wave_duration = len(flattened_data) / num_waves if num_waves > 0 else 0

            channel_features = [
                mean_val, median_val, var_val, std_dev, skewness, kurt, zcr, num_waves,
                wave_duration, peak_amp, activity, mobility, complexity
            ]
            epoch_features.append(channel_features)

        features.append(epoch_features)

    return np.array(features)







def get_wavelet_coeffs(channel_data, wavelet='db4', level=5):
    coeffs = pywt.wavedec(channel_data, wavelet, level=level)
    return coeffs


def extract_frequency_domain_features(epochs, sfreq,wavelet='db4', bands={'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 12), 'beta': (12, 30), 'gamma': (30, 100), 'sigma': (11, 16)}):
    features = []

    for epoch in epochs:
        epoch_features = []

        for channel_data in epoch:
            # Compute the Power Spectral Density (PSD)
            freqs, psd = welch(channel_data, sfreq, nperseg=256)

            # Frequency domain features
            mean_val = np.mean(psd)
            median_val = np.median(psd)
            var_val = np.var(psd)
            std_dev = np.std(psd)
            skewness = skew(psd)
            kurt = kurtosis(psd)

            # Compute wavelet coefficients
            wave_coeffs = get_wavelet_coeffs(channel_data, wavelet, level=5)
            wave_coeffs_mean = np.mean(wave_coeffs[0])

            # Band Power Features
            band_powers = {}
            for band, freq_range in bands.items():
                freq_mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1])
                band_power = np.sum(psd[freq_mask])
                band_powers[band] = band_power

            # Band Power Ratios
            theta_alpha_ratio = band_powers['theta'] / band_powers['alpha']
            beta_alpha_ratio = band_powers['beta'] / band_powers['alpha']
            theta_alpha_beta_ratio = (band_powers['theta'] + band_powers['alpha']) / band_powers['beta']
            # Additional Band Power Ratios
            theta_beta_ratio = band_powers['theta'] / band_powers['beta']
            theta_alpha_beta_alpha_ratio = (band_powers['theta'] + band_powers['alpha']) / (band_powers['alpha'] + band_powers['beta'])
            gamma_delta_ratio = band_powers['gamma'] / band_powers['delta']
            gamma_beta_delta_alpha_ratio = (band_powers['gamma'] + band_powers['beta']) / (band_powers['delta'] + band_powers['alpha'])


            channel_features = [
                mean_val, median_val, var_val, std_dev, skewness, kurt,
                band_powers['delta'], band_powers['theta'], band_powers['alpha'],
                band_powers['beta'], band_powers['gamma'], band_powers['sigma'],
                theta_alpha_ratio, beta_alpha_ratio, theta_alpha_beta_ratio,theta_beta_ratio,
                theta_alpha_beta_alpha_ratio, gamma_delta_ratio, gamma_beta_delta_alpha_ratio,
                wave_coeffs_mean
            ]
            epoch_features.append(channel_features)

        features.append(epoch_features)

    return np.array(features)



X_time = extract_time_domain_features(X)
sfreq = 256  # Replace with the sampling frequency of your data
# epochs_data = [epoch.get_data() for epoch in epochs]  # Assuming epochs is a list of MNE Epochs objects
X_frequency = extract_frequency_domain_features(X, sfreq)
X_merged_features = np.concatenate((X_time ,X_frequency), axis=2)

Reading /content/drive/MyDrive/EEG Signal /Epileptic seizure/data/equalized epoch/PN0-epo.fif ...
    Found the data of interest:
        t =    -199.22 ...     500.00 ms
        0 CTF compensation matrices available
Not setting metadata
190 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Reading /content/drive/MyDrive/EEG Signal /Epileptic seizure/data/equalized epoch/PN1-epo.fif ...
    Found the data of interest:
        t =    -199.22 ...     500.00 ms
        0 CTF compensation matrices available
Not setting metadata
146 matching events found
No baseline correction applied
0 projection items activated
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Reading /content/drive/MyDrive/EEG Signal /Epileptic seizure/data/equalized epoch/PN2-epo.fif ...
    Found the data of interest:
        t =    -199.22 ...     500.00 ms
        0 CTF compensation 



In [3]:
 pip install flwr flwr-datasets

Collecting flwr
  Downloading flwr-1.6.0-py3-none-any.whl (219 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m219.2/219.2 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting flwr-datasets
  Downloading flwr_datasets-0.0.2-py3-none-any.whl (22 kB)
Collecting cryptography<42.0.0,>=41.0.2 (from flwr)
  Downloading cryptography-41.0.7-cp37-abi3-manylinux_2_28_x86_64.whl (4.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.4/4.4 MB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
Collecting iterators<0.0.3,>=0.0.2 (from flwr)
  Downloading iterators-0.0.2-py3-none-any.whl (3.9 kB)
Collecting pycryptodome<4.0.0,>=3.18.0 (from flwr)
  Downloading pycryptodome-3.20.0-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m70.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets<3.0.0,>=2.14.3 (from flwr-datasets)
  Downloading datasets-2.16.1-py3-no

In [4]:
pip install xgboost



In [5]:
import argparse
from typing import Union
from logging import INFO
from datasets import Dataset, DatasetDict
import xgboost as xgb

import flwr as fl
from flwr_datasets import FederatedDataset
from flwr.common.logger import log
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Parameters,
    Status,
)
from flwr_datasets.partitioner import IidPartitioner

In [8]:
# Load (HIGGS) dataset and conduct partitioning
# We use a small subset (num_partitions=30) of the dataset for demonstration to speed up the data loading process.
partitioner = IidPartitioner(num_partitions=3)
fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner})

# Load the partition for this `node_id`
partition = fds.load_partition(node_id=args.node_id, split="train")
partition.set_format("numpy")

NameError: name 'args' is not defined

In [None]:
# We first define arguments parser for user to specify the client/node ID.
parser = argparse.ArgumentParser()
parser.add_argument(
    "--node-id",
    default=0,
    type=int,
    help="Node ID used for the current client.",
)
args = parser.parse_args()

# Load the partition for this `node_id`.
partition = fds.load_partition(idx=args.node_id, split="train")
partition.set_format("numpy")

In [None]:
# Train/test splitting
train_data, valid_data, num_train, num_val = train_test_split(
    partition, test_fraction=0.2, seed=42
)

# Reformat data to DMatrix for xgboost
train_dmatrix = transform_dataset_to_dmatrix(train_data)
valid_dmatrix = transform_dataset_to_dmatrix(valid_data)

In [None]:
# Define data partitioning related functions
def train_test_split(partition: Dataset, test_fraction: float, seed: int):
    """Split the data into train and validation set given split rate."""
    train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
    partition_train = train_test["train"]
    partition_test = train_test["test"]

    num_train = len(partition_train)
    num_test = len(partition_test)

    return partition_train, partition_test, num_train, num_test


def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix:
    """Transform dataset to DMatrix format for xgboost."""
    x = data["inputs"]
    y = data["label"]
    new_data = xgb.DMatrix(x, label=y)
    return new_data

In [None]:
num_local_round = 1
params = {
    "objective": "binary:logistic",
    "eta": 0.1,  # lr
    "max_depth": 8,
    "eval_metric": "auc",
    "nthread": 16,
    "num_parallel_tree": 1,
    "subsample": 1,
    "tree_method": "hist",
}

In [None]:
class XgbClient(fl.client.Client):
    def __init__(self):
        self.bst = None
        self.config = None

In [None]:
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
    _ = (self, ins)
    return GetParametersRes(
        status=Status(
            code=Code.OK,
            message="OK",
        ),
        parameters=Parameters(tensor_type="", tensors=[]),
    )

In [None]:
def fit(self, ins: FitIns) -> FitRes:
    if not self.bst:
        # First round local training
        log(INFO, "Start training at round 1")
        bst = xgb.train(
            params,
            train_dmatrix,
            num_boost_round=num_local_round,
            evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")],
        )
        self.config = bst.save_config()
        self.bst = bst
    else:
        for item in ins.parameters.tensors:
            global_model = bytearray(item)

        # Load global model into booster
        self.bst.load_model(global_model)
        self.bst.load_config(self.config)

        bst = self._local_boost()

    local_model = bst.save_raw("json")
    local_model_bytes = bytes(local_model)

    return FitRes(
        status=Status(
            code=Code.OK,
            message="OK",
        ),
        parameters=Parameters(tensor_type="", tensors=[local_model_bytes]),
        num_examples=num_train,
        metrics={},
    )

In [None]:
def _local_boost(self):
    # Update trees based on local training data.
    for i in range(num_local_round):
        self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())

    # Extract the last N=num_local_round trees for sever aggregation
    bst = self.bst[
        self.bst.num_boosted_rounds()
        - num_local_round : self.bst.num_boosted_rounds()
    ]

In [None]:
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
    eval_results = self.bst.eval_set(
        evals=[(valid_dmatrix, "valid")],
        iteration=self.bst.num_boosted_rounds() - 1,
    )
    auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)

    return EvaluateRes(
        status=Status(
            code=Code.OK,
            message="OK",
        ),
        loss=0.0,
        num_examples=num_val,
        metrics={"AUC": auc},
    )

In [None]:
fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient())