In [1]:
import numpy as np
import pandas as pd
import scipy.io as sio
from tensorflow.keras.utils import to_categorical
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
import warnings
from sklearn.model_selection import train_test_split
import os
import sys
import shutil
import time
import pywt
import gc
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.callbacks import Callback,ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import cohen_kappa_score
from tensorflow.keras.utils import plot_model
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from tensorflow.keras.constraints import max_norm
from tensorflow.keras import backend as K
from tensorflow.keras.layers import SpatialDropout1D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay
from sklearn.metrics import cohen_kappa_score
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Layer, LayerNormalization
from keras.layers import GlobalAveragePooling1D
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, AveragePooling2D, MaxPooling2D
from tensorflow.keras.layers import Conv1D, Conv2D, SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization, LayerNormalization, Flatten 
from tensorflow.keras.layers import Add, Concatenate, Lambda, Input, Permute
from tensorflow.keras.regularizers import L2
import math
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense
from tensorflow.keras.layers import multiply, Permute, Concatenate, Conv2D, Add, Activation, Lambda
from tensorflow.keras.layers import Dropout, MultiHeadAttention, LayerNormalization, Reshape
from tensorflow.keras import backend as K
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization
from tensorflow.keras import layers, regularizers
from tensorflow.keras.constraints import MaxNorm  # Import MaxNorm from the correct module
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import Layer
from tensorflow.keras.initializers import HeNormal, HeUniform

# Suppress DeprecationWarnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Set logging level to suppress INFO and WARNING messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  
tf.get_logger().setLevel('ERROR')  

# Optional: Disable XLA if not needed
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices=false'

import random
# Set random seeds for reproducibility
seed = 42
np.random.seed(seed)
tf.random.set_seed(seed)
random.seed(seed)
# os.environ['TF_DETERMINISTIC_OPS'] = '1'  # Ensure deterministic operations

# Preprocess & Data Loader

In [2]:
def load_data_LOSO (data_path, subject, dataset): 
    """ Loading and Dividing of the data set based on the 
    'Leave One Subject Out' (LOSO) evaluation approach. 
    LOSO is used for  Subject-independent evaluation.
    In LOSO, the model is trained and evaluated by several folds, equal to the 
    number of subjects, and for each fold, one subject is used for evaluation
    and the others for training. The LOSO evaluation technique ensures that 
    separate subjects (not visible in the training data) are usedto evaluate 
    the model.
    
        Parameters
        ----------
        data_path: string
            dataset path
            # Dataset BCI Competition IV-2a is available at 
            # http://bnci-horizon-2020.eu/database/data-sets
        subject: int
            number of subject in [1, .. ,9/14]
            Here, the subject data is used  test the model and other subjects data
            for training
    """
    
    X_train, y_train = None, None  # Initialize as None for concatenation later
    X_test, y_test = None, None    # Initialize for the test data
    
    for sub in range(0, 9):  # Iterate through subjects (assuming there are 9 subjects)
        if (dataset == 'BCI2a'):
            # Directly load the training and testing data without subdirectories
            X1, y1 = load_BCI2a_data(data_path, sub+1, True)  # Training data
            X2, y2 = load_BCI2a_data(data_path, sub+1, False) # Test data
        elif (dataset == 'CS2R'):
            X1, y1, _, _, _  = load_CS2R_data_v2(data_path, sub, True)
            X2, y2, _, _, _  = load_CS2R_data_v2(data_path, sub, False)

        # Concatenate training and testing data for each subject
        X = np.concatenate((X1, X2), axis=0)
        y = np.concatenate((y1, y2), axis=0)
                   
        if sub == subject:
            # Set aside test data for the current subject
            X_test = X
            y_test = y
        else:
            # If X_train is None, assign the first subject's data
            if X_train is None:
                X_train = X
                y_train = y
            else:
                # Concatenate the current subject's data with previous subjects' data
                X_train = np.concatenate((X_train, X), axis=0)
                y_train = np.concatenate((y_train, y), axis=0)

    return X_train, y_train, X_test, y_test

def load_BCI2a_data(data_path, subject, training, all_trials = True):    
    # Define MI-trials parameters
    n_channels = 22
    n_tests = 6*48     
    window_Length = 7*250 
    
    # Define MI trial window 
    fs = 250          # sampling rate
    t1 = int(2*fs)  # start time_point
    t2 = int(6*fs)    # end time_point

    class_return = np.zeros(n_tests)
    data_return = np.zeros((n_tests, n_channels, window_Length))

    NO_valid_trial = 0
    if training:
        a = sio.loadmat(data_path+'A0'+str(subject+1)+'T.mat')
    else:
        a = sio.loadmat(data_path+'A0'+str(subject+1)+'E.mat')
    a_data = a['data']
    for ii in range(0,a_data.size):
        a_data1 = a_data[0,ii]
        a_data2= [a_data1[0,0]]
        a_data3= a_data2[0]
        a_X         = a_data3[0]
        a_trial     = a_data3[1]
        a_y         = a_data3[2]
        a_artifacts = a_data3[5]

        for trial in range(0,a_trial.size):
             if(a_artifacts[trial] != 0 and not all_trials):
                 continue
             data_return[NO_valid_trial,:,:] = np.transpose(a_X[int(a_trial[trial]):(int(a_trial[trial])+window_Length),:22])
             class_return[NO_valid_trial] = int(a_y[trial])
             NO_valid_trial +=1        
    

    data_return = data_return[0:NO_valid_trial, :, t1:t2]
    class_return = class_return[0:NO_valid_trial]
    class_return = (class_return-1).astype(int)

    return data_return, class_return

def standardize_data(X_train, X_test, channels): 
    # X_train & X_test :[Trials, MI-tasks, Channels, Time points]
    for j in range(channels):
          scaler = StandardScaler()
          scaler.fit(X_train[:, 0, j, :])
          X_train[:, 0, j, :] = scaler.transform(X_train[:, 0, j, :])
          X_test[:, 0, j, :] = scaler.transform(X_test[:, 0, j, :])

    return X_train, X_test

def batch_cwt(batch_signals, frequencies, sampling_frequency, normalization='zscore'):
    """
    Compute the Continuous Wavelet Transform (CWT) for a batch of signals using TensorFlow.
    
    Args:
        batch_signals (tf.Tensor): Input batch of signals with shape [batch_size, channels, signal_length].
        frequencies (np.array): Array of frequencies to use for the CWT.
        sampling_frequency (int): Sampling frequency of the signals.
        normalization (str): Type of normalization ('zscore' or 'minmax').
    
    Returns:
        tf.Tensor: Tensor containing the normalized CWT coefficients for the input batch of signals.
        Shape: (batch_size, frequencies, signal_length, channels)
    """
    # Extract batch size, channels, and signal length from the tensor
    batch_size, channels, signal_length = batch_signals.shape
    
    # Convert the batch to NumPy at once (avoiding per-signal conversion)
    batch_signals_np = batch_signals.numpy()
    
    # Initialize a list to store the CWT results
    cwt_batch = []
    
    # Loop through the batch to compute CWT for each channel in each sample
    for i in range(batch_size):
        cwt_channels = []
        for ch in range(channels):
            # Extract the signal for the current channel
            signal = batch_signals_np[i, ch, :]
            
            # Compute CWT for the signal
            coefficients, _ = pywt.cwt(signal, frequencies, 'cmor1.5-1.0', sampling_period=1/sampling_frequency)
            coefficients = np.abs(coefficients)  # Take the absolute value of the coefficients
            
            # Normalize CWT coefficients based on the chosen method
            if normalization == 'zscore':
                # Z-Score normalization: (x - mean) / std
                mean = np.mean(coefficients)
                std = np.std(coefficients)
                coefficients = (coefficients - mean) / std if std != 0 else coefficients
            elif normalization == 'minmax':
                # Min-Max normalization: (x - min) / (max - min)
                min_val = np.min(coefficients)
                max_val = np.max(coefficients)
                coefficients = (coefficients - min_val) / (max_val - min_val) if max_val != min_val else coefficients
            
            # Append normalized CWT coefficients for the current channel
            cwt_channels.append(coefficients)
        
        # Stack CWT results for all channels of the current sample
        cwt_channels_stacked = np.stack(cwt_channels, axis=0)  # Shape: (channels, frequencies, signal_length)
        cwt_batch.append(cwt_channels_stacked)
    
    # Convert the list of CWT results into a single tensor
    cwt_batch_np = np.array(cwt_batch)  # Shape: (batch_size, channels, frequencies, signal_length)
    
    # Transpose dimensions to match image format
    cwt_batch_np = np.transpose(cwt_batch_np, (0, 2, 3, 1))  # Shape: (batch_size, frequencies, signal_length, channels)
    
    # Convert the normalized CWT batch back to a TensorFlow tensor
    cwt_batch_tensor = tf.convert_to_tensor(cwt_batch_np, dtype=tf.float32)

    return cwt_batch_tensor

def get_data(path, subject, frequencies, sampling_frequency, dataset='BCI2a', classes_labels='all', LOSO=False, 
             isStandard=True, isShuffle=True, include_cwt=True):
    
    # Load and split the dataset into training and testing
    if LOSO:
        X_train, y_train, X_test, y_test = load_data_LOSO(path, subject, dataset)
    else:
        if dataset == 'BCI2a':
            X_train, y_train = load_BCI2a_data(path, subject, True)
            X_test, y_test = load_BCI2a_data(path, subject, False)

    # Shuffle the data if specified
    if isShuffle:
        X_train, y_train = shuffle(X_train, y_train, random_state=seed)
        X_test, y_test = shuffle(X_test, y_test, random_state=seed)

    # Reshape the training data
    N_tr, N_ch, T = X_train.shape
    X_train = X_train.reshape(N_tr, 1, N_ch, T)
    y_train_onehot = to_categorical(y_train)

    # Reshape the testing data
    N_te, N_ch, T = X_test.shape
    X_test = X_test.reshape(N_te, 1, N_ch, T)
    y_test_onehot = to_categorical(y_test)

    # Standardize the data if specified
    if isStandard:
        X_train, X_test = standardize_data(X_train, X_test, N_ch)

    # Ensure shape consistency
    print(f"X_train shape: {X_train.shape}, X_test shape: {X_test.shape}")
    print(f"y_train_onehot shape: {y_train_onehot.shape}, y_test_onehot shape: {y_test_onehot.shape}")

    # Transpose the data
    X_train_transposed = tf.transpose(X_train, perm=[0, 2, 1, 3])  # (288, 22, 1, 1125)
    X_test_transposed = tf.transpose(X_test, perm=[0, 2, 1, 3])    # (288, 22, 1, 1125)

    # Squeeze only the singleton dimension (the last dimension, which is size 1)
    X_train_squeezed = tf.squeeze(X_train_transposed, axis=2)  # Shape: (288, 22, 1125)
    X_test_squeezed = tf.squeeze(X_test_transposed, axis=2)    # Shape: (288, 22, 1125)

    # Apply the batch_cwt function for both training and test data
    X_train_cwt = batch_cwt(X_train_squeezed, frequencies, sampling_frequency)
    X_test_cwt = batch_cwt(X_test_squeezed, frequencies, sampling_frequency)

    # Ensure shape consistency after CWT
    print(f"X_train_cwt shape: {X_train_cwt.shape}, X_test_cwt shape: {X_test_cwt.shape}")

    # Return CWT data and labels
    return X_train, y_train, y_train_onehot, X_test, y_test, y_test_onehot, X_train_cwt, X_test_cwt

# Augmentation

In [3]:
def add_noise(data, noise_factor=0.001, seed=42):
    """Add Gaussian noise to the data."""
    return data + noise_factor * tf.random.normal(shape=tf.shape(data), dtype=tf.float32, seed=seed)

def scale_data(data, min_scale=0.9, max_scale=1.1, seed=42):
    """Scale the data by a random factor."""
    scale = tf.random.uniform([], minval=min_scale, maxval=max_scale, dtype=tf.float32, seed=seed)
    return data * scale

def apply_time_mask(data, max_mask_size=10, seed=42):
    """Apply a random time mask to the data."""
    time_steps = tf.shape(data)[1]
    time_mask_start = tf.random.uniform([], minval=0, maxval=time_steps, dtype=tf.int32, seed=seed)
    mask_size = tf.random.uniform([], minval=0, maxval=tf.minimum(max_mask_size, time_steps - time_mask_start), dtype=tf.int32, seed=seed)

    mask = tf.ones(shape=[tf.shape(data)[0], mask_size, tf.shape(data)[2]])
    mask = tf.pad(mask, [[0, 0], [time_mask_start, time_steps - time_mask_start - mask_size], [0, 0]], "CONSTANT")
    mask = tf.expand_dims(mask, axis=-1)

    return data * (1.0 - mask)

def mixup_augmentation(cwt_data, labels, alpha=0.2, seed=42):
    """Apply mixup augmentation to CWT data."""
    num_samples = tf.shape(cwt_data)[0]
    tf.random.set_seed(seed)

    # Choose random indices for mixup
    random_indices = tf.random.shuffle(tf.range(num_samples), seed=seed)
    mixup_lambda = tf.random.uniform([num_samples, 1, 1, 1], minval=0, maxval=alpha, dtype=tf.float32, seed=seed)

    # Ensure labels are broadcast-compatible and of type float32
    if len(labels.shape) == 2:  # Assuming [num_samples, num_classes]
        mixup_lambda_labels = tf.reshape(mixup_lambda, [num_samples, 1])

    # Convert labels to float32 if needed
    labels = tf.cast(labels, tf.float32)
    mixup_lambda_labels = tf.cast(mixup_lambda_labels, tf.float32)

    # Mix data and labels with chosen lambda
    mixed_data = mixup_lambda * cwt_data + (1 - mixup_lambda) * tf.gather(cwt_data, random_indices)
    mixed_labels = mixup_lambda_labels * labels + (1 - mixup_lambda_labels) * tf.gather(labels, random_indices)
    
    return mixed_data, mixed_labels

def augment_time_domain(timg, label, seed=42):
    """Apply multiple combinations of augmentations on time-domain data."""
    tf.random.set_seed(seed)

    # Apply all augmentations in batch
    noise_data = add_noise(timg, seed=seed)
    scale_data_only = scale_data(timg, seed=seed)
    noise_scale_data = scale_data(noise_data, seed=seed)
    noise_scale_mask_data = apply_time_mask(noise_scale_data, seed=seed)

    # Concatenate once for efficiency
    X_time_aug_combined = tf.concat([timg, noise_data, scale_data_only, noise_scale_data, noise_scale_mask_data], axis=0)
    y_time_aug_combined = tf.tile(label, [5, 1])

    return X_time_aug_combined, y_time_aug_combined

def shift_data(cwt_data, shift_range=(-3, 4), seed=42):
    """Shift the CWT data along the time dimension."""
    batch_size, freqs, time_steps, channels = tf.shape(cwt_data)
    shift = tf.random.uniform([], minval=shift_range[0], maxval=shift_range[1], dtype=tf.int32, seed=seed)

    if shift > 0:
        shifted = tf.concat([tf.zeros([batch_size, freqs, shift, channels]), cwt_data[:, :, :-shift, :]], axis=2)
    elif shift < 0:
        shifted = tf.concat([cwt_data[:, :, -shift:, :], tf.zeros([batch_size, freqs, -shift, channels])], axis=2)
    else:
        shifted = cwt_data

    return shifted

def augment_cwt(cwt_data, label, seed=42):
    """Apply multiple combinations of augmentations on CWT data."""
    tf.random.set_seed(seed)

    noise_cwt = add_noise(cwt_data, seed=seed)
    shift_cwt = shift_data(cwt_data, seed=seed)
    noise_shift_cwt = shift_data(noise_cwt, seed=seed)
    mixup_noise_shift_cwt, mixup_labels = mixup_augmentation(noise_shift_cwt, label, seed=seed)

    X_cwt_aug_combined = tf.concat([cwt_data, noise_cwt, shift_cwt, noise_shift_cwt, mixup_noise_shift_cwt], axis=0)
    y_cwt_aug_combined = tf.concat([label] * 4 + [mixup_labels], axis=0)

    return X_cwt_aug_combined, y_cwt_aug_combined

def time_and_cwt_augment(timg, cwt_data, label, seed=42):
    """Apply both time-domain and CWT augmentations."""
    X_time_aug_combined, y_time_aug_combined = augment_time_domain(timg, label, seed=seed)
    X_cwt_aug_combined, y_cwt_aug_combined = augment_cwt(cwt_data, label, seed=seed)

    return X_time_aug_combined, X_cwt_aug_combined, y_time_aug_combined

# Attention Blocks

In [4]:
!pip install einops

  pid, fd = os.forkpty()


Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [5]:
class TS_AttentionModule(layers.Layer):
    def __init__(self, emb_size, num_heads, dropout):
        super(TS_AttentionModule, self).__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.dropout = dropout
        
        # Linear layers for queries, keys, and values
        self.query_dense = layers.Dense(emb_size)
        self.key_dense = layers.Dense(emb_size)
        self.value_dense = layers.Dense(emb_size)
        self.dropout_layer = layers.Dropout(dropout)
        
        # Linear projection after attention output
        self.projection = layers.Dense(emb_size)

    def call(self, query, key, value, mask=None):
        # Dense layers for queries, keys, and values
        queries = self.query_dense(query)
        keys = self.key_dense(key)
        values = self.value_dense(value)
        
        # Reshape for multi-head attention
        batch_size = tf.shape(queries)[0]
        seq_len = tf.shape(queries)[1]
        head_dim = self.emb_size // self.num_heads
        
        # Reshaping queries, keys, and values for multiple heads
        queries = tf.reshape(queries, (batch_size, seq_len, self.num_heads, head_dim))
        keys = tf.reshape(keys, (batch_size, seq_len, self.num_heads, head_dim))
        values = tf.reshape(values, (batch_size, seq_len, self.num_heads, head_dim))

        # Transpose to prepare for scaled dot-product attention
        queries = tf.transpose(queries, perm=[0, 2, 1, 3])
        keys = tf.transpose(keys, perm=[0, 2, 1, 3])
        values = tf.transpose(values, perm=[0, 2, 1, 3])
        
        # Scaled dot-product attention
        attention_scores = tf.matmul(queries, keys, transpose_b=True) # energy
        attention_scores = attention_scores / tf.math.sqrt(tf.cast(head_dim, tf.float32))
        
        if mask is not None:
            attention_scores += (mask * -1e9)
        
        attention_weights = tf.nn.softmax(attention_scores, axis=-1)
        attention_weights = self.dropout_layer(attention_weights)
        
        # Calculate attention output
        attention_output = tf.matmul(attention_weights, values)
        
        # Transpose and reshape back to original form
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
        attention_output = tf.reshape(attention_output, (batch_size, seq_len, self.emb_size))
        
        # Final projection to combine heads
        output = self.projection(attention_output)
        
        return output

def TS_ResidualAdd(x, fn):
    """
    Residual connection followed by function application.
    
    Parameters:
        x (tf.Tensor): Input tensor.
        fn (function): A function to apply to the input tensor.
        
    Returns:
        tf.Tensor: Output tensor after residual connection and function application.
    """
    res = x
    x = fn(x)
    return res + x

def TS_FeedForwardBlock(x, emb_size, expansion, drop_p):
    """
    Feed-forward block with linear layers.
    
    Args:
        x (tf.Tensor): Input tensor.
        emb_size (int): Embedding size.
        expansion (int): Expansion factor for the feed-forward layer.
        drop_p (float): Dropout probability.
    
    Returns:
        tf.Tensor: Output tensor after feed-forward processing.
    """
    # First dense layer with expansion
    x = layers.Dense(expansion * emb_size)(x)
    
    # GELU activation wrapped in a Keras layer
    x = layers.Activation('gelu')(x)
    
    # Dropout wrapped in a Keras layer
    x = layers.Dropout(drop_p)(x)
    
    # Second dense layer to map back to embedding size
    x = layers.Dense(emb_size)(x)
    
    return x

# Main Attention Block

In [6]:
from einops import rearrange

def attention_block(in_layer, freq_features, emb_size=32, num_heads=4, dropout=0.3, ratio=8, 
                    residual=True, apply_to_input=True, forward_expansion=4, forward_drop_p=0.5, 
                    **kwargs):
    
    # Positional encoding initialization (can be trainable or fixed)
    seq_len = in_layer.shape[1]
    pos_embedding_in = layers.Embedding(input_dim=seq_len, output_dim=emb_size)
    pos_embedding_freq = layers.Embedding(input_dim=freq_features.shape[1], output_dim=emb_size)
    
    # Step 1: Layer normalization for stability
    layer_norm = layers.LayerNormalization(epsilon=1e-6)
    in_layer_norm = layer_norm(in_layer)
    freq_features_norm = layer_norm(freq_features)

    # Step 2: Add positional encoding to both time-series and frequency features
    pos_indices_in = tf.range(start=0, limit=seq_len, delta=1)
    pos_indices_freq = tf.range(start=0, limit=freq_features.shape[1], delta=1)

    in_layer_pos_encoded = in_layer_norm + pos_embedding_in(pos_indices_in)
    freq_features_pos_encoded = freq_features_norm + pos_embedding_freq(pos_indices_freq)

    # Step 3: Project and Reshape Frequency Features before Cross-Attention
    freq_features_projected = layers.Dense(in_layer.shape[1])(freq_features_pos_encoded)
    freq_features_norm = layers.LayerNormalization(epsilon=1e-6)(freq_features_projected)
    freq_features_norm = layers.Reshape((in_layer.shape[1], -1))(freq_features_norm)

    # Step 4: Use TS_AttentionModule for Cross-Attention with freq_features_norm as key and value
    cross_attention_output = TS_AttentionModule(emb_size, num_heads, dropout)(
        query=in_layer_pos_encoded, key=freq_features_norm, value=in_layer_pos_encoded)

    # Step 5: Residual Connection after Cross-Attention using TS_ResidualAdd
    if residual:
        cross_attention_output = TS_ResidualAdd(
            in_layer_pos_encoded, lambda x: TS_AttentionModule(emb_size, num_heads, dropout)(
                x, key=freq_features_norm, value=in_layer_pos_encoded))

    # Step 6: Apply the Feedforward Block using TS_FeedForwardBlock
    feedforward_output = TS_FeedForwardBlock(cross_attention_output, emb_size, forward_expansion, forward_drop_p)

    # Step 7: Final residual connection (optional)
    if residual:
        final_output = TS_ResidualAdd(cross_attention_output, lambda x: TS_FeedForwardBlock(x, emb_size, forward_expansion, forward_drop_p))
    else:
        final_output = feedforward_output

    return final_output

# Time Frequency Conv

In [7]:
# Custom layer for squeezing
class SqueezeLayer(Layer):
    def __init__(self, axis):
        super(SqueezeLayer, self).__init__()
        self.axis = axis

    def call(self, inputs):
        return tf.squeeze(inputs, axis=self.axis)
    
class TransposeLayer(layers.Layer):
    def call(self, x):
        return tf.transpose(x, perm=[0, 1, 3, 2])  # Transpose to (batch_size, frequencies, channels, time)

def tf_conv_module(x, output_features=32, target_seq_len=None):
    """
    Modified Conv Module to produce features suitable for Key/Value in MultiHeadAttention.
    
    Parameters:
    - x: Input tensor (batch_size, frequencies, samples, channels)
    - output_features: Number of output features to match the query in MHA.
    - target_seq_len: The desired sequence length to match the query for MHA. If None, uses the frequency dimension size.
    
    Returns:
    - block1: Processed tensor ready for use as Key/Value in MultiHeadAttention.
    """
    weightDecay = 0.009
    maxNormValue = 0.6  # MaxNorm constraint value

    # Step 1: Transpose Layer to align channels (from input shape to (batch_size, frequencies, channels, time))
    x = TransposeLayer()(x)  # Output shape: (batch_size, frequencies, channels, time)
#     print("After TransposeLayer:", x.shape)

    # Step 2: First Conv2D Layer (Capture frequency-time patterns)
    x = layers.SeparableConv2D(16, (1, 10), padding='same', 
                           depthwise_regularizer=regularizers.l2(weightDecay),
                           pointwise_regularizer=regularizers.l2(weightDecay),
                           activation=None)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ELU()(x)
#     print("After first Conv2D:", x.shape)

    # Step 3: Depthwise Conv2D Layer (Better capture of frequency patterns across channels)
    x = layers.DepthwiseConv2D(kernel_size=(1, 22), padding='same', 
                               depthwise_regularizer=regularizers.l2(weightDecay),
                               depthwise_constraint=MaxNorm(maxNormValue))(x)
    x = layers.BatchNormalization()(x)
    x = layers.ELU()(x)
#     print("After DepthwiseConv2D:", x.shape)
    
    # Step 4: Dropout for regularization
    x = layers.Dropout(0.5)(x)

    # Step 5: Pointwise Conv2D Layer (Increase feature complexity)
    x = layers.Conv2D(32, kernel_size=(1, 1), padding='same', 
                      kernel_regularizer=regularizers.l2(weightDecay),
                      kernel_constraint=MaxNorm(maxNormValue), activation=None)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ELU()(x)
#     print("After Pointwise Conv2D:", x.shape)

    # Step 6: Second Conv2D Layer (Enhanced features across frequencies)
    x = layers.Conv2D(32, (4, 1), padding='same', 
                      kernel_regularizer=regularizers.l2(weightDecay),
                      kernel_constraint=MaxNorm(maxNormValue), activation=None)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ELU()(x)
#     print("After second Conv2D:", x.shape)

    x = layers.AveragePooling2D(pool_size=(1, 18))(x)
    # print("After first AveragePooling2D:", x.shape)

    # Step 8: Dropout Layer for regularization
    x = layers.Dropout(0.6)(x)

    # Step 9: Final Conv Layer to reduce features to `output_features` (32)
    x = layers.Conv2D(output_features, (1, 1), padding='same', 
                      kernel_regularizer=regularizers.l2(weightDecay),
                      kernel_constraint=MaxNorm(maxNormValue), activation=None)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ELU()(x)
#     print("After final Conv2D (output features):", x.shape)

    x = SqueezeLayer(axis=2)(x)
#     print("After Squeze:", x.shape)

    return x

# Time Series Conv

In [8]:
def TS_Conv_block(input_layer, F1=4, kernLength=64, poolSize=7, D=2, in_chans=22, 
                  weightDecay=0.01, maxNorm=0.6, dropout=0.3):
    """ Conv_block with moderate kernel size variations and feature averaging using `layers`. """

    F2 = F1 * D
    
    # Block 1a: First Convolution with kernLength = 32
    block1a = layers.Conv2D(F1, (kernLength, 1), padding='same', data_format='channels_last', 
                            kernel_regularizer=regularizers.L2(weightDecay),
                            kernel_constraint=max_norm(maxNorm, axis=[0, 1, 2]), use_bias=False)(input_layer)
    block1a = layers.BatchNormalization(axis=-1)(block1a)

    # Block 1b: Convolution with kernLength + 16 = 48
    block1b = layers.Conv2D(F1, (kernLength + 16, 1), padding='same', data_format='channels_last', 
                            kernel_regularizer=regularizers.L2(weightDecay),
                            kernel_constraint=max_norm(maxNorm, axis=[0, 1, 2]), use_bias=False)(input_layer)
    block1b = layers.BatchNormalization(axis=-1)(block1b)

    # Block 1c: Convolution with kernLength - 16 = 16
    block1c = layers.Conv2D(F1, (kernLength - 16, 1), padding='same', data_format='channels_last', 
                            kernel_regularizer=regularizers.L2(weightDecay),
                            kernel_constraint=max_norm(maxNorm, axis=[0, 1, 2]), use_bias=False)(input_layer)
    block1c = layers.BatchNormalization(axis=-1)(block1c)

    # Averaging the outputs from different kernel sizes
    block1 = layers.Average()([block1a, block1b, block1c])

    # Block 2: Depthwise Convolution
    block2 = layers.DepthwiseConv2D((1, in_chans), depth_multiplier=D, data_format='channels_last',
                                    depthwise_regularizer=regularizers.L2(weightDecay),
                                    depthwise_constraint=max_norm(maxNorm, axis=[0, 1, 2]), use_bias=False)(block1)
    block2 = layers.BatchNormalization(axis=-1)(block2)
    block2 = layers.Activation('elu')(block2)
    
    # Adjusted Pooling to retain more spatial information
    block2 = layers.AveragePooling2D((6, 1), data_format='channels_last')(block2)  # Reduce pooling size to (6,1)
    block2 = layers.Dropout(dropout)(block2)
    
    # Block 3: Final Convolution
    block3 = layers.Conv2D(F2, (16, 1), padding='same', data_format='channels_last',
                           kernel_regularizer=regularizers.L2(weightDecay),
                           kernel_constraint=max_norm(maxNorm, axis=[0, 1, 2]), use_bias=False)(block2)
    block3 = layers.BatchNormalization(axis=-1)(block3)
    block3 = layers.Activation('elu')(block3)
    
    # Final Pooling
    block3 = layers.AveragePooling2D((7, 1), data_format='channels_last')(block3)
    block3 = layers.Dropout(dropout)(block3)
    
    return block3

# Temporal Convolution Block

In [9]:
class GatedLinearUnit(Layer):
    def __init__(self, **kwargs):
        super(GatedLinearUnit, self).__init__(**kwargs)
    
    def call(self, x):
        x1, x2 = tf.split(x, num_or_size_splits=2, axis=-1)  # Split into two equal parts
        return x1 * tf.sigmoid(x2)  # Apply gating

def TCN_block(input_layer, input_dimension, depth, kernel_size, filters, dropout, 
               weightDecay=0.009, maxNorm=0.6):
    """ TCN block with GLU and optimized dropout """
    
    # Initial Conv1D block with GLU
    block = Conv1D(filters, kernel_size=kernel_size, dilation_rate=1, activation='linear',
                   kernel_regularizer=L2(weightDecay),
                   kernel_constraint=max_norm(maxNorm, axis=[0, 1]),
                   padding='causal', kernel_initializer='he_uniform')(input_layer)
    block = BatchNormalization()(block)
    block = SpatialDropout1D(dropout)(block)
    
    block = GatedLinearUnit()(block)  # Apply GLU as a layer

    # Second Conv1D block with GLU
    block = Conv1D(filters, kernel_size=kernel_size, dilation_rate=1, activation='linear',
                   kernel_regularizer=L2(weightDecay),
                   kernel_constraint=max_norm(maxNorm, axis=[0, 1]),
                   padding='causal', kernel_initializer='he_uniform')(block)
    block = BatchNormalization()(block)
    block = SpatialDropout1D(dropout)(block)

    # Residual connection
    if input_dimension != filters:
        conv = Conv1D(filters, kernel_size=1, activation='linear',
                      kernel_regularizer=L2(weightDecay),
                      kernel_constraint=max_norm(maxNorm, axis=[0, 1]),
                      padding='same')(input_layer)
        conv = BatchNormalization()(conv)
        conv = GatedLinearUnit()(conv)  # Apply GLU to residual connection
        added = Add()([block, conv])
    else:
        added = Add()([block, input_layer])
    
    out = Activation('linear')(added)  # Maintain the linearity after residual addition

    # Repeat for additional depth
    for i in range(depth - 1):
        block = Conv1D(filters, kernel_size=kernel_size, dilation_rate=2**(i + 1), activation='linear',
                       kernel_regularizer=L2(weightDecay),
                       kernel_constraint=max_norm(maxNorm, axis=[0, 1]),
                       padding='causal', kernel_initializer='he_uniform')(out)
        block = BatchNormalization()(block)
        block = SpatialDropout1D(dropout)(block)
        
        block = GatedLinearUnit()(block)  # Apply GLU

        block = Conv1D(filters, kernel_size=kernel_size, dilation_rate=2**(i + 1), activation='linear',
                       kernel_regularizer=L2(weightDecay),
                       kernel_constraint=max_norm(maxNorm, axis=[0, 1]),
                       padding='causal', kernel_initializer='he_uniform')(block)
        block = BatchNormalization()(block)
        block = SpatialDropout1D(dropout)(block)

        # Add residual connection
        added = Add()([block, out])
        out = Activation('linear')(added)  # Keep it linear for residual
        
    return out

# Our Model

In [10]:
#%% The proposed model, 
def TSxTF(n_classes, frequencies, in_chans=22, in_samples=1125, eegn_F1=16, 
            eegn_D=2, eegn_kernelSize=64, eegn_poolSize=7, eegn_dropout=0.3,
            tcn_depth=1, tcn_kernelSize=4, tcn_filters=32, tcn_dropout=0.3,
            tcn_activation='elu', fuse='average'):

    # Time-series Input
    input_eeg = Input(shape=(1, in_chans, in_samples))  # (batch_size, 1, channels, samples)
    
    # Frequency-series Input
    input_freq = Input(shape=(len(frequencies), in_samples, in_chans))  # (batch_size, frequencies, samples, channels)

    dense_weightDecay = 0.5
    conv_weightDecay = 0.009
    conv_maxNorm = 0.6
    from_logits = False

    numFilters = eegn_F1
    F2 = numFilters * eegn_D

    # EEG Convolution Block for Time-Series Input
    input_eeg_permuted = Permute((3, 2, 1))(input_eeg)  # (batch_size, samples, channels, 1)
    block1 = TS_Conv_block(input_layer=input_eeg_permuted, F1=eegn_F1, D=eegn_D,
                         kernLength=eegn_kernelSize, poolSize=eegn_poolSize,
                         weightDecay=conv_weightDecay, maxNorm=conv_maxNorm,
                         in_chans=in_chans, dropout=eegn_dropout)

    block1 = Lambda(lambda x: x[:, :, -1, :])(block1)  # Squeeze sequence dimension if needed

    # Frequency Features for Attention Mask using TF_ConvModule
    freq_features = tf_conv_module(input_freq)  # Extract features from frequency input

    # Apply a 1D Convolution to the entire time series
    conv_layer = Conv1D(filters=tcn_filters, kernel_size=3, padding='same')(block1)  # Temporal convolution(SW complement)

    attention_layer = attention_block(conv_layer, freq_features=freq_features)

    # Temporal Convolutional Network (TCN)
    tcn_output = TCN_block(input_layer=attention_layer, input_dimension=F2, depth=tcn_depth,
                            kernel_size=tcn_kernelSize, filters=tcn_filters,
                            weightDecay=conv_weightDecay, maxNorm=conv_maxNorm,
                            dropout=tcn_dropout)

    # Global Average Pooling to capture information over the entire sequence
    global_avg_pool = GlobalAveragePooling1D()(tcn_output)

    # Final Dense Layer for classification
    final_dense = Dense(n_classes, kernel_regularizer=L2(dense_weightDecay))(global_avg_pool)

    # Final output layer (softmax or linear)
    if from_logits:
        out = Activation('linear', name='linear')(final_dense)
    else:
        out = Activation('softmax', name='softmax')(final_dense)

    # Create the model with EEG and CWT inputs
    return Model(inputs=[input_eeg, input_freq], outputs=out)

# Train_Test

In [11]:
pip install tqdm

Note: you may need to restart the kernel to use updated packages.


In [12]:
class CustomTQDMProgressBar(Callback):
    def on_train_begin(self, logs=None):
        # Initialize the tqdm progress bar for epochs
        self.epochs_bar = tqdm(total=self.params['epochs'], position=0, desc='Epochs', unit='epoch')
        # Initialize variables to track best validation accuracy and loss, and their corresponding epochs
        self.best_val_acc = 0.0
        self.best_epoch_acc = 0
        self.best_val_loss = float('inf')  # Start with infinity for minimum comparison
        self.best_epoch_loss = 0

    def on_epoch_end(self, epoch, logs=None):
        # Retrieve relevant metrics from logs
        train_acc = logs.get('accuracy', 0.0)
        val_acc = logs.get('val_accuracy', 0.0)
        val_loss = logs.get('val_loss', float('inf'))  # Default to infinity if not available

        # Update the best validation accuracy and epoch if the current accuracy is higher
        if val_acc > self.best_val_acc:
            self.best_val_acc = val_acc
            self.best_epoch_acc = epoch + 1  # Store the best epoch (1-based index)

        # Update the best validation loss and epoch if the current loss is lower
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            self.best_epoch_loss = epoch + 1  # Store the best epoch (1-based index)

        # Update the progress bar description with metrics, including best validation accuracy and loss
        self.epochs_bar.set_description(
            f"Epoch {epoch + 1}, Train Acc: {train_acc:.4f}, Valid Acc: {val_acc:.4f}, Valid Loss: {val_loss:.4f}, "
            f"Best Valid Acc: {self.best_val_acc:.4f} (Epoch {self.best_epoch_acc}), "
            f"Best Valid Loss: {self.best_val_loss:.4f} (Epoch {self.best_epoch_loss})"
        )

        # Move the progress bar by 1 epoch
        self.epochs_bar.update(1)

    def on_train_end(self, logs=None):
        # Close the tqdm progress bar at the end of training
        self.epochs_bar.close()

#%%
def draw_learning_curves(history, sub):
    # Plot training and validation accuracy
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model accuracy - subject: ' + str(sub))
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Val'], loc='center right')
    plt.show()
    # Plot training and validation loss
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss - subject: ' + str(sub))
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Val'], loc='center right')
    plt.show()
    # Plot validation accuracy and loss together on the same plot with dual y-axes
    fig, ax1 = plt.subplots()
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Validation Loss', color='tab:blue')
    ax1.plot(history.history['val_loss'], color='tab:blue', label='Validation Loss')
    ax1.tick_params(axis='y', labelcolor='tab:blue')
    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
    ax2.set_ylabel('Validation Accuracy', color='tab:red')
    ax2.plot(history.history['val_accuracy'], color='tab:red', label='Validation Accuracy')
    ax2.tick_params(axis='y', labelcolor='tab:red')
    fig.suptitle('Validation Accuracy and Loss - subject: ' + str(sub))
    fig.tight_layout()  # adjust layout to make room for both y-axes
    plt.show()
    # Close the plot to avoid memory issues
    plt.close()

def draw_confusion_matrix(cf_matrix, sub, results_path, classes_labels):
    # Generate confusion matrix plot
    display_labels = classes_labels
    disp = ConfusionMatrixDisplay(confusion_matrix=cf_matrix, 
                                display_labels=display_labels)
    disp.plot()
    disp.ax_.set_xticklabels(display_labels, rotation=12)
    plt.title('Confusion Matrix of Subject: ' + sub )
    plt.savefig(results_path + '/subject_' + sub + '.png')
    plt.show()

def draw_performance_barChart(num_sub, metric, label, mean_best_value):
    fig, ax = plt.subplots()
    x = list(range(1, num_sub + 1))
    # Draw the bar chart for each subject's metric
    bars = ax.bar(x, metric, 0.5, label=label) 
    # Draw a dotted line for the overall mean of the best scores
    ax.axhline(y=mean_best_value, color='r', linestyle='--', label=f'Avg {label} ({mean_best_value:.4f})')
    # Add labels and titles
    ax.set_ylabel(label)
    ax.set_xlabel("Subject")
    ax.set_xticks(x)
    ax.set_title(f'Model {label} per Subject')
    ax.set_ylim([0, 1])
    ax.legend(loc='upper right')
    # Display the accuracy score above each bar
    for bar, score in zip(bars, metric):
        ax.text(
            bar.get_x() + bar.get_width() / 2,  # X-coordinate
            bar.get_height(),                  # Y-coordinate (top of the bar)
            f'{score:.4f}',                    # Text (formatted score)
            ha='center', va='bottom'           # Center text horizontally, below text vertically
        )  
    # Show the plot
    plt.show()
    
def plot_tsne(features, labels, title="t-SNE Plot", class_labels=['Left hand', 'Right hand', 'Foot', 'Tongue'],
              perplexity=40, learning_rate=100, n_pca_components=30, save_path='./tsne_plot.png'):
    # Reduce dimensions with PCA before applying t-SNE
#     pca = PCA(n_components=n_pca_components, random_state=42)
#     pca_results = pca.fit_transform(features)
    # Apply t-SNE
    tsne = TSNE(n_components=2, perplexity=perplexity, learning_rate=learning_rate, init='pca', random_state=42)
#     tsne = TSNE(n_components=2, perplexity=perplexity,init='pca', random_state=166)
#     tsne_results = tsne.fit_transform(pca_results)
    tsne_results = tsne.fit_transform(features)
    # Normalize t-SNE results for consistent plotting
    x_min, x_max = tsne_results.min(0), tsne_results.max(0)
    tsne_normalized = (tsne_results - x_min) / (x_max - x_min)
    # Define specific colors for each class
    colors = ['red', 'blue', 'green', 'brown']
    # Create the scatter plot
    plt.figure(figsize=(8, 6))
    for i, label in enumerate(np.unique(labels)):
        # Select data points belonging to the current class
        class_points = tsne_normalized[labels == label]
        plt.scatter(class_points[:, 0], class_points[:, 1], 
                    color=colors[i], label=class_labels[label], alpha=0.7)
    # Add plot details
    plt.title(title)
    plt.xlabel("t-SNE Component 1")
    plt.ylabel("t-SNE Component 2")
    plt.xticks([])  # Remove x-axis ticks for cleaner visualization
    plt.yticks([])  # Remove y-axis ticks for cleaner visualization
    plt.legend()
    plt.show()

def train_and_test(dataset_conf, train_conf, results_path):
    # Remove the 'results' folder before training
    if os.path.exists(results_path):
        shutil.rmtree(results_path)
    os.makedirs(results_path)

    in_exp = time.time()  # Start time for the overall experiment
    best_models = open(results_path + "/best_models.txt", "w")  # Log best models
    log_write = open(results_path + "/log.txt", "w")  # Log file

    # Dataset and training parameters
    dataset = dataset_conf.get('name')
    n_classes = dataset_conf.get('n_classes')
    n_sub = dataset_conf.get('n_sub')
    data_path = dataset_conf.get('data_path')
    isStandard = dataset_conf.get('isStandard')
    LOSO = dataset_conf.get('LOSO')
    include_cwt = dataset_conf.get('include_cwt')
    batch_size = train_conf.get('batch_size')
    model_name = train_conf.get('model')
    lr = train_conf.get('lr')
    epochs = train_conf.get('epochs')
    n_train = train_conf.get('n_train')
    from_logits = train_conf.get('from_logits')
    frequencies = dataset_conf.get('cwt_frequencies')
    sampling_frequency = dataset_conf.get('sampling_frequency')
    LearnCurves = train_conf.get('LearnCurves')
    classes_label = dataset_conf.get('cl_labels')
    patience = train_conf.get('patience')

    # Initialize arrays for storing training accuracy, kappa, test accuracy, kappa, and confusion matrices
    test_acc = np.zeros((n_sub, n_train))
    test_kappa = np.zeros((n_sub, n_train))
    cf_matrix = np.zeros([n_sub, n_train, n_classes, n_classes])
    test_precision = np.zeros((n_sub, n_train))
    test_recall = np.zeros((n_sub, n_train))
    test_f1 = np.zeros((n_sub, n_train))
    inference_time = 0
    
    # Ensure a consistent seed is set
    global_seed = 42
    tf.random.set_seed(global_seed)
    np.random.seed(global_seed)
    random.seed(global_seed)
    # os.environ['TF_DETERMINISTIC_OPS'] = '1'  # Ensure deterministic operations
  
    for sub in tqdm(range(0,1), desc="Training and testing subjects", unit="subject"):
        print(f'\nTraining on subject {sub + 1}')
        log_write.write(f'\nTraining on subject {sub + 1}\n')
        BestSubjAcc = 0
        bestTrainingHistory = []

        # Get training and test data
        X_train, _, y_train_onehot, X_test, _, y_test_onehot, X_train_cwt, X_test_cwt = get_data(
            data_path, sub, frequencies, sampling_frequency, dataset=dataset, LOSO=LOSO, isStandard=isStandard, include_cwt=True
        )

        # Convert TensorFlow tensors to NumPy arrays if necessary
        X_train = X_train.numpy() if hasattr(X_train, 'numpy') else X_train
        y_train_onehot = y_train_onehot.numpy() if hasattr(y_train_onehot, 'numpy') else y_train_onehot
        X_train_cwt = X_train_cwt.numpy() if hasattr(X_train_cwt, 'numpy') else X_train_cwt

        # Augment the training data using your augmentation function
        X_train_aug, X_train_cwt_aug, y_train_aug = time_and_cwt_augment(X_train, X_train_cwt, y_train_onehot)

        # Convert to NumPy arrays if necessary
        X_train_aug = X_train_aug.numpy() if hasattr(X_train_aug, 'numpy') else X_train_aug
        y_train_aug = y_train_aug.numpy() if hasattr(y_train_aug, 'numpy') else y_train_aug
        X_train_cwt_aug = X_train_cwt_aug.numpy() if hasattr(X_train_cwt_aug, 'numpy') else X_train_cwt_aug
        # Print shapes after augmentation
        print(f"Augmented shapes:")
        print(f"X_train_aug shape: {X_train_aug.shape}")
        print(f"y_train_aug shape: {y_train_aug.shape}")
        print(f"X_train_cwt_aug shape: {X_train_cwt_aug.shape}")
    
        # Training loop with modifications
        for train in tqdm(range(n_train), desc=f"Training runs for subject {sub + 1}", unit="run"):
            # Control the seed for each training run
            #run_seed = train + global_seed
            run_seed = global_seed
            tf.random.set_seed(run_seed)
            np.random.seed(run_seed)
            random.seed(run_seed)
            
            in_run = time.time()
            filepath = os.path.join(results_path, 'saved_models', f'run-{train + 1}', f'subject-{sub + 1}.weights.h5')
            os.makedirs(os.path.dirname(filepath), exist_ok=True)

            # Create the model and set initial weights
            model = getModel(model_name, dataset_conf, from_logits)
            initial_weights = model.get_weights()  # Save initial weights
            
            # Print model input configuration (only for first subject's first run)
            if sub == 0 and train == 0:
                print("Model input configuration:", model.inputs)
                model.summary()
                plot_model(model, to_file=os.path.join(results_path, 'model_summary.png'), show_shapes=True, show_layer_names=True)
            
            # Compile the model with gradient clipping
            model.compile(
                loss=tf.keras.losses.CategoricalCrossentropy(from_logits=from_logits),
                optimizer=tf.keras.optimizers.Adam(learning_rate=lr),  # Make sure lr is defined properly
                metrics=['accuracy']
            )

            callbacks = [
                tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_accuracy', verbose=0, save_best_only=True, save_weights_only=True, mode='max'),
                tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.90, patience=20, verbose=0, min_lr=0.0001),
                tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', verbose=1, mode='max', patience=patience),
                CustomTQDMProgressBar()  # Custom progress bar
            ]
            
            # Load the initial weights before training each subject
            model.set_weights(initial_weights)
            
            # Train the model
            history = model.fit([X_train_aug, X_train_cwt_aug], y_train_aug,
                                validation_data=([X_test, X_test_cwt], y_test_onehot),
                                epochs=epochs, batch_size=batch_size, callbacks=callbacks, verbose=0)

            # Evaluate the best model on validation data
            model.load_weights(filepath)

            # Evaluate the model on test data
            y_pred_test = model.predict([X_test, X_test_cwt])
            if from_logits:
                y_pred_test = tf.nn.softmax(y_pred_test).numpy().argmax(axis=-1)
            else:
                y_pred_test = y_pred_test.argmax(axis=-1)

            labels_test = y_test_onehot.argmax(axis=-1)
            test_acc[sub, train] = accuracy_score(labels_test, y_pred_test)
            test_kappa[sub, train] = cohen_kappa_score(labels_test, y_pred_test)
            test_precision[sub, train] = precision_score(labels_test, y_pred_test, average='weighted')
            test_recall[sub, train] = recall_score(labels_test, y_pred_test, average='weighted')
            test_f1[sub, train] = f1_score(labels_test, y_pred_test, average='weighted')
            cf_matrix[sub, train, :, :] = confusion_matrix(labels_test, y_pred_test, normalize='true')

            out_run = time.time()

            # Log the performance after each run
            info = f'Subject: {sub + 1}   seed {train + 1}   time: {(out_run - in_run) / 60:.1f} m   '
            info += f'test_acc: {test_acc[sub, train]:.4f}  test_kappa: {test_kappa[sub, train]:.4f}  '
            info += f'test_precision: {test_precision[sub, train]:.4f}  test_recall: {test_recall[sub, train]:.4f}  test_f1: {test_f1[sub, train]:.4f}'
            print(info)
            log_write.write(info + '\n')
            
            # Save best run for the subject
            if BestSubjAcc < test_acc[sub, train]:
                BestSubjAcc = test_acc[sub, train]
                bestTrainingHistory = history

            # Clear GPU and RAM after each run
            del history  # Remove large objects to free memory
            K.clear_session()  # Clear TensorFlow session
            gc.collect()  # Collect garbage to release memory
            
            # If GPU memory is utilized, reset GPU memory
            if tf.config.experimental.get_visible_devices('GPU'):
                tf.keras.backend.clear_session()  # Clear GPU memory
                try:
                    tf.config.experimental.set_memory_growth(tf.config.experimental.get_visible_devices('GPU')[0], True)
                except:
                    pass

        # Test the best model after all runs
        runs = os.listdir(results_path + "/saved_models")
        X_test, X_test_cwt = np.array(X_test), np.array(X_test_cwt)
        n_samples = X_test.shape[0]

        for seed in range(len(runs)):
            model.load_weights(f'{results_path}/saved_models/run-{seed + 1}/subject-{sub + 1}.weights.h5')
            y_pred_test = []
            start_time = time.time()

            for i in range(0, n_samples, batch_size):
                batch_end = min(i + batch_size, n_samples)
                y_pred_batch = model.predict([X_test[i:batch_end], X_test_cwt[i:batch_end]]).argmax(axis=-1)
                y_pred_test.extend(y_pred_batch)

            inference_time += (time.time() - start_time) / n_samples

            y_pred_test = np.array(y_pred_test)
            labels_test = y_test_onehot.argmax(axis=-1)

            # Compute metrics
            precision = precision_score(labels_test, y_pred_test, average='weighted')
            recall = recall_score(labels_test, y_pred_test, average='weighted')
            f1 = f1_score(labels_test, y_pred_test, average='weighted')

            test_precision[sub, seed] = precision
            test_recall[sub, seed] = recall
            test_f1[sub, seed] = f1
            test_acc[sub, seed] = accuracy_score(labels_test, y_pred_test)
            test_kappa[sub, seed] = cohen_kappa_score(labels_test, y_pred_test)
            cf_matrix[sub, seed, :, :] = confusion_matrix(labels_test, y_pred_test, normalize='true')
            
            # Clean up after each seed-based evaluation
            gc.collect()
            K.clear_session()

        # Logging and plotting results
        best_run = np.argmax(test_acc[sub, :])  # Best run based on accuracy
        best_models.write(f'subject-{sub + 1}.weights.h5 for run-{best_run + 1}\n')
        
        # Load the best model weights for the best run
        best_model_path = f'{results_path}/saved_models/run-{best_run + 1}/subject-{sub + 1}.weights.h5'
        model.load_weights(best_model_path)

        # Plot learning curves if required
        if LearnCurves:
            print('Plotting Learning Curves .......')
            draw_learning_curves(bestTrainingHistory, sub + 1)
            
        draw_confusion_matrix(cf_matrix[sub, best_run, :, :], f'subject_{sub + 1}', results_path, classes_label)
        print(f'Confusion matrix plotted for subject {sub + 1}.')
        
        # Define the feature extraction model
        feature_model = Model(inputs=[model.input[0], model.input[1]], outputs=model.get_layer('global_average_pooling1d').output)

        # Get learned features for t-SNE
        learned_features = feature_model.predict([X_test, X_test_cwt])
        labels_test = y_test_onehot.argmax(axis=-1)  # True class labels

        # Plot t-SNE using learned features
        print(f"Plotting t-SNE for learned features...")
        plot_tsne(learned_features, labels_test)
        
#        # Assuming 'model' is your trained model and you have a layer named 'ts__attention_module_1'
#         attention_layer_name = 'ts__attention_module_1'
#         attention_extractor = Model(inputs=model.input, 
#                                     outputs=model.get_layer(attention_layer_name).output)

#         # Generate Attention Maps for a Batch of Test Data
#         # Replace `attention_inputs` with the actual test inputs
#         attention_maps = attention_extractor.predict([X_test, X_test_cwt])

#         # Visualize Attention Maps for the First Sample in the Batch
#         # Check the shape of attention_maps to understand its dimensions
#         sample_attention_map = attention_maps[0]  # Select the first sample in the batch

#         # Plotting the attention heatmap
#         plt.figure(figsize=(10, 8))
#         sns.heatmap(sample_attention_map, cmap='viridis')
#         plt.title(f'Attention Map for Layer {attention_layer_name}')
#         plt.xlabel('Key (input time steps/channels)')
#         plt.ylabel('Query (output time steps/channels)')
#         plt.show()

        # After each subject
        del X_train, y_train_onehot, X_train_cwt
        del y_pred_test, labels_test, X_test, X_test_cwt, y_test_onehot
        del X_train_aug, X_train_cwt_aug, y_train_aug
        del model, bestTrainingHistory
        gc.collect()
        K.clear_session()
        
    # Timing the end of the experiment
    out_exp = time.time()
    # Prepare header for testing performance logging
    head1_test = head2_test = '                '
    for sub in range(n_sub): 
        head1_test += f'sub_{sub + 1}   '
        head2_test += '-----   '
    head1_test += ' average'
    head2_test += ' -------'
    
    # Prepare test performance logging with additional metrics
    test_info = f'\n---------------------------------\nTest performance (acc %, kappa, precision, recall, f1):\n---------------------------------\n{head1_test}\n{head2_test}'
    
    # Print test performance for each seed and subject-wise
    for run in range(n_train):  # Use n_train to ensure it matches the training runs
        test_info += f'\nSeed {run + 1}:'
        test_info_acc = '(acc %)  '
        test_info_k = '   (k-sco)      '
        test_info_prec = '   (prec)       '
        test_info_recall = '  (recall)      '
        test_info_f1 = '    (f1)        '  
        for sub in range(n_sub): 
            test_info_acc += f'{test_acc[sub, run] * 100:.2f}    '
            test_info_k += f'{test_kappa[sub, run]:.3f}   '
            test_info_prec += f'{test_precision[sub, run]:.3f}   '
            test_info_recall += f'{test_recall[sub, run]:.3f}   '
            test_info_f1 += f'{test_f1[sub, run]:.3f}   '
        test_info_acc += f' {np.average(test_acc[:, run]) * 100:.2f}   '
        test_info_k += f'  {np.average(test_kappa[:, run]):.3f}   '
        test_info_prec += f'  {np.average(test_precision[:, run]):.3f}   '
        test_info_recall += f'  {np.average(test_recall[:, run]):.3f}   '
        test_info_f1 += f'  {np.average(test_f1[:, run]):.3f}   '   
        test_info += test_info_acc + '\n' + test_info_k + '\n' + test_info_prec + '\n' + test_info_recall + '\n' + test_info_f1

    # Subject-wise averages across all seeds
    test_info += f'\n\nSubject-wise averages across all seeds:\n'
    test_info += ' (acc %)        '
    test_info_kappa = '  (k-sco)       '
    test_info_prec_avg = '  (prec)        '
    test_info_recall_avg = '  (recall)      '
    test_info_f1_avg = '    (f1)        '

    subject_best_acc_list = []
    subject_best_kappa_list = []
    subject_best_prec_list = []
    subject_best_recall_list = []
    subject_best_f1_list = []

    for sub in range(n_sub): 
        # Calculate averages
        subject_avg_acc = np.average(test_acc[sub, :])
        subject_avg_kappa = np.average(test_kappa[sub, :])
        subject_avg_prec = np.average(test_precision[sub, :])
        subject_avg_recall = np.average(test_recall[sub, :])
        subject_avg_f1 = np.average(test_f1[sub, :])

        test_info += f'{subject_avg_acc * 100:.2f}    '
        test_info_kappa += f'{subject_avg_kappa:.3f}   '
        test_info_prec_avg += f'{subject_avg_prec:.3f}   '
        test_info_recall_avg += f'{subject_avg_recall:.3f}   '
        test_info_f1_avg += f'{subject_avg_f1:.3f}   '

        # Calculate best values for each subject
        subject_best_acc_list.append(np.max(test_acc[sub, :]))
        subject_best_kappa_list.append(np.max(test_kappa[sub, :]))
        subject_best_prec_list.append(np.max(test_precision[sub, :]))
        subject_best_recall_list.append(np.max(test_recall[sub, :]))
        subject_best_f1_list.append(np.max(test_f1[sub, :]))

    # Overall averages
    test_info += f' {np.average(test_acc) * 100:.2f}   '
    test_info_kappa += f'  {np.average(test_kappa):.3f}   '
    test_info_prec_avg += f'  {np.average(test_precision):.3f}   '
    test_info_recall_avg += f'  {np.average(test_recall):.3f}   '
    test_info_f1_avg += f'  {np.average(test_f1):.3f}   '
    test_info += '\n' + test_info_kappa + '\n' + test_info_prec_avg + '\n' + test_info_recall_avg + '\n' + test_info_f1_avg

    # Display subject-wise best results
    test_info += f'\n\nSubject-wise best results across all runs:\n'
    test_info += '  (best acc %)  '
    test_info_kappa_best = '  (best k-sco)  '
    test_info_prec_best = '  (best prec)   '
    test_info_recall_best = '  (best recall) '
    test_info_f1_best = '    (best f1)   '

    for sub in range(n_sub):
        test_info += f'{subject_best_acc_list[sub] * 100:.2f}    '
        test_info_kappa_best += f'{subject_best_kappa_list[sub]:.3f}   '
        test_info_prec_best += f'{subject_best_prec_list[sub]:.3f}   '
        test_info_recall_best += f'{subject_best_recall_list[sub]:.3f}   '
        test_info_f1_best += f'{subject_best_f1_list[sub]:.3f}   '

    # Overall means of best scores
    mean_best_acc = np.mean(subject_best_acc_list)
    mean_best_kappa = np.mean(subject_best_kappa_list)
    mean_best_prec = np.mean(subject_best_prec_list)
    mean_best_recall = np.mean(subject_best_recall_list)
    mean_best_f1 = np.mean(subject_best_f1_list)

    test_info += f'  {mean_best_acc * 100:.2f}   '
    test_info_kappa_best += f'  {mean_best_kappa:.3f}   '
    test_info_prec_best += f'  {mean_best_prec:.3f}   '
    test_info_recall_best += f'  {mean_best_recall:.3f}   '
    test_info_f1_best += f'  {mean_best_f1:.3f}   '
    test_info += '\n' + test_info_kappa_best + '\n' + test_info_prec_best + '\n' + test_info_recall_best + '\n' + test_info_f1_best

    # Overall averages for testing performance
    test_info += f'\n----------------------------------\nAverage - all seeds (acc %): {np.average(test_acc) * 100:.2f}\n'
    test_info += f'                    (k-sco): {np.average(test_kappa):.3f}\n'
    test_info += f'                    (prec):  {np.average(test_precision):.3f}\n'
    test_info += f'                    (recall):{np.average(test_recall):.3f}\n'
    test_info += f'                    (f1):    {np.average(test_f1):.3f}\n'
    test_info += f'\nSubject-wise best average (acc %): {mean_best_acc*100:.2f}\n'
    test_info += f'                          (k-sco): {mean_best_kappa:.3f}\n'
    test_info += f'                          (prec):  {mean_best_prec:.3f}\n'
    test_info += f'                          (recall):{mean_best_recall:.3f}\n'
    test_info += f'                           (f1):   {mean_best_f1:.3f}\n'
    test_info += f'\nInference time: {inference_time / len(runs):.2f} ms per trial\n'
    test_info += '----------------------------------\n'

    # Final output
    info = test_info

    # Print final results and write to log file
    print(info)
    log_write.write(info + '\n')

    # Save confusion matrices and inference time
    np.save(os.path.join(results_path, 'confusion_matrix.npy'), cf_matrix)
    np.save(os.path.join(results_path, 'inference_time.npy'), inference_time)
    np.save(os.path.join(results_path, 'precision.npy'), test_precision)
    np.save(os.path.join(results_path, 'recall.npy'), test_recall)
    np.save(os.path.join(results_path, 'f1_score.npy'), test_f1)
    
    draw_performance_barChart(n_sub, subject_best_acc_list, 'Testing Accuracy', mean_best_acc)
    draw_performance_barChart(n_sub, subject_best_kappa_list, 'Testing Kappa Score', mean_best_kappa)
    draw_performance_barChart(n_sub, subject_best_prec_list, 'Testing Precision', mean_best_prec)
    draw_performance_barChart(n_sub, subject_best_recall_list, 'Testing Recall', mean_best_recall)
    draw_performance_barChart(n_sub, subject_best_f1_list, 'Testing F1-Score', mean_best_f1)
    draw_confusion_matrix(cf_matrix.mean((0, 1)), 'All', results_path, classes_labels)
    
    # Close log files    
    best_models.close()   
    log_write.close()

    # Total experiment time
    print(f'\nTotal Experiment Time: {(time.time() - in_exp) / 60:.1f} minutes.')
    
#%%
def getModel(model_name, dataset_conf, from_logits = False):
    
    n_classes = dataset_conf.get('n_classes')
    n_channels = dataset_conf.get('n_channels')
    in_samples = dataset_conf.get('in_samples')
    frequencies = dataset_conf.get('cwt_frequencies') 

    # Select the model
    if(model_name == 'TSxTF'):
        # Train using the proposed ATCNet model: https://ieeexplore.ieee.org/document/9852687
        model = TSxTF( 
            # Dataset parameters
            n_classes = n_classes, 
            in_chans = n_channels, 
            in_samples = in_samples,
            frequencies= frequencies,
            # Convolutional (CV) block parameters
            eegn_F1 = 16,
            eegn_D = 2, 
            eegn_kernelSize = 64,
            eegn_poolSize = 7,
            eegn_dropout = 0.3,
            # Temporal convolutional (TC) block parameters
            tcn_depth = 2, 
            tcn_kernelSize = 4,
            tcn_filters = 32,
            tcn_dropout = 0.3, 
            tcn_activation='elu',
            )     
    else:
        raise Exception("'{}' model is not supported yet!".format(model_name))

    return model

# Frequency Generator

In [13]:
def generate_frequencies(total_frequencies):
    # Define the distribution percentages
    band_percentages = {
        '0.5-8 Hz': 0.05,    # 5% between 0 and 8 Hz
        '8-30 Hz': 0.40,   # 60% between 8 and 40 Hz
        '30-50 Hz': 0.30,  # 10% between 40 and 60 Hz
        '50-85 Hz': 0.05,  # 5% between 60 and 80 Hz
        '85-100 Hz': 0.15  # 20% between 80 and 100 Hz
    }

    # Calculate the number of frequencies for each band
    count_0_8 = int(total_frequencies * band_percentages['0.5-8 Hz'])
    count_8_40 = int(total_frequencies * band_percentages['8-30 Hz'])
    count_40_60 = int(total_frequencies * band_percentages['30-50 Hz'])
    count_60_80 = int(total_frequencies * band_percentages['50-85 Hz'])
    count_80_100 = total_frequencies - (count_0_8 + count_8_40 + count_40_60 + count_60_80)  # Remaining for 80-100 Hz

    # Generate frequencies for each band
    band_0_8 = np.linspace(0.5, 8, count_0_8, endpoint=False)         # Frequencies between 0 and 8 Hz
    band_8_40 = np.linspace(8, 30, count_8_40, endpoint=False)       # Frequencies between 8 and 40 Hz
    band_40_60 = np.linspace(30, 50, count_40_60, endpoint=False)    # Frequencies between 40 and 60 Hz
    band_60_80 = np.linspace(50, 85, count_60_80, endpoint=False)    # Frequencies between 60 and 80 Hz
    band_80_100 = np.linspace(85, 100.5, count_80_100, endpoint=False) # Frequencies between 80 and 100 Hz

    # Combine all frequencies
    frequencies = np.concatenate((band_0_8, band_8_40, band_40_60, band_60_80, band_80_100))

    # Sort frequencies (just in case, though it should already be sorted)
    frequencies = np.sort(frequencies)

    return frequencies

# Example usage:
total_frequencies = 32  # You can change this to any number of total frequencies
frequencies = generate_frequencies(total_frequencies)

# print(frequencies)
# print(f"Total frequencies: {len(frequencies)}")

# Run Train_Test

In [None]:
in_samples = 1000
n_channels = 22
n_sub = 9
n_classes = 4
classes_labels = ['Left hand', 'Right hand','Foot','Tongue']
data_path ='/kaggle/input/bcic-iv-2amatlab-version/'
dataset = 'BCI2a'
lr = 0.0009

# Create a folder to store the results of the experiment
results_path = os.getcwd() + "/results"
if not  os.path.exists(results_path):
      os.makedirs(results_path)   # Create a new directory if it does not exist 
    
    # Set dataset paramters 
# Set dataset parameters
dataset_conf = {
    'name': dataset,
    'n_classes': n_classes,
    'cl_labels': classes_labels,
    'n_sub': n_sub,
    'n_channels': n_channels,
    'in_samples': in_samples,
    'data_path': data_path,
    'cwt_frequencies': frequencies,  # CWT frequency range from 0.5 to 100 Hz 
    'isStandard': True,
    'LOSO': False,
    'include_cwt': True,
    'sampling_frequency': 250  # Raw EEG signal sampling frequency
}

train_conf = { 'batch_size': 64, 'epochs': 1000, 'patience': 300, 'lr': lr,'n_train': 10,
                  'LearnCurves': True, 'from_logits': False, 'model':'TSxTF'}
    
results_path= '/kaggle/working/results'
# Call the training function
train_and_test(dataset_conf,train_conf, results_path)

Training and testing subjects:   0%|          | 0/1 [00:00<?, ?subject/s]


Training on subject 1
X_train shape: (288, 1, 22, 1000), X_test shape: (288, 1, 22, 1000)
y_train_onehot shape: (288, 4), y_test_onehot shape: (288, 4)
X_train_cwt shape: (288, 32, 1000, 22), X_test_cwt shape: (288, 32, 1000, 22)
Augmented shapes:
X_train_aug shape: (1440, 1, 22, 1000)
y_train_aug shape: (1440, 4)
X_train_cwt_aug shape: (1440, 32, 1000, 22)



Training runs for subject 1:   0%|          | 0/10 [00:00<?, ?run/s][A

Model input configuration: [<KerasTensor shape=(None, 1, 22, 1000), dtype=float32, sparse=None, name=keras_tensor>, <KerasTensor shape=(None, 32, 1000, 22), dtype=float32, sparse=None, name=keras_tensor_1>]


I0000 00:00:1731614726.307984     103 service.cc:145] XLA service 0x7d1a20004890 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731614726.308041     103 service.cc:153]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
I0000 00:00:1731614755.557927     103 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Epoch 17, Train Acc: 0.9812, Valid Acc: 0.3715, Valid Loss: 4.9629, Best Valid Acc: 0.3715 (Epoch 16), Best Valid Loss: 4.9629 (Epoch 17):   2%|▏         | 17/1000 [02:10<57:26,  3.51s/epoch]  