In [None]:
import tensorflow as tf
!pip install tensorflow-addons
!pip install mne
import tensorflow_addons as tfa
from tqdm.notebook import tqdm

from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras import regularizers
from tensorflow.keras.layers import Input, Dense, Activation, Dropout, SpatialDropout1D, SpatialDropout2D, BatchNormalization
from tensorflow.keras.layers import Flatten, InputSpec, Layer, Concatenate, AveragePooling2D, MaxPooling2D, Reshape, Permute
from tensorflow.keras.layers import Conv2D, SeparableConv2D, DepthwiseConv2D, LayerNormalization
from tensorflow.keras.layers import TimeDistributed, Lambda, AveragePooling1D, Add, Conv1D, Multiply
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, BatchNormalization, Activation, Concatenate, Input, AveragePooling2D, Dropout, Flatten, Dense, LSTM
from tensorflow.keras.models import Model
from tensorflow.keras.constraints import max_norm, unit_norm 
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow_addons.layers import WeightNormalization
from tensorflow.keras.utils import plot_model

import random
import mne
import matplotlib.pyplot as plt
import sklearn
from sklearn.metrics import silhouette_score, confusion_matrix

import pandas as pd
import numpy as np
from glob import glob

# Model

In [None]:
Fs = 256             # Sampling frequency
n_channels = 4       # Number of channels
Wn = 1               # Sampling window duration
n_samples = Wn*Fs    # sampling window length per channel

n_ff = [2,4,8,16]    # Number of frequency filters for each inception module of EEG-ITNet
n_sf = [1,1,1,1]     # Number of spatial filters in each frequency sub-band of EEG-ITNet
batch_size = 32 
epochs = 500

In [None]:
def Hybrid_CNN_LSTM(Chans, Samples, out_class=3, drop_rate=0.2):
    # Définition de la forme de l'entrée
    Input_block = Input(shape=(Chans, Samples, 1))
    
    # Partie CNN
    block1 = Conv2D(32, (1, 16), activation='relu', padding='same', name='Conv1')(Input_block)
    block1 = BatchNormalization()(block1)
    block1 = DepthwiseConv2D((Chans, 1), activation='relu', depth_multiplier=1, padding='valid', name='DepthConv1')(block1)
    block1 = BatchNormalization()(block1)
    block1 = Activation('elu')(block1)

    block2 = Conv2D(64, (1, 32), activation='relu', padding='same', name='Conv2')(Input_block)
    block2 = BatchNormalization()(block2)
    block2 = DepthwiseConv2D((Chans, 1), activation='relu', depth_multiplier=1, padding='valid', name='DepthConv2')(block2)
    block2 = BatchNormalization()(block2)
    block2 = Activation('elu')(block2)

    block3 = Conv2D(128, (1, 64), activation='relu', padding='same', name='Conv3')(Input_block)
    block3 = BatchNormalization()(block3)
    block3 = DepthwiseConv2D((Chans, 1), activation='relu', depth_multiplier=1, padding='valid', name='DepthConv3')(block3)
    block3 = BatchNormalization()(block3)
    block3 = Activation('elu')(block3)

    block4 = Conv2D(256, (1, 128), activation='relu', padding='same', name='Conv4')(Input_block)
    block4 = BatchNormalization()(block4)
    block4 = DepthwiseConv2D((Chans, 1), activation='relu', depth_multiplier=1, padding='valid', name='DepthConv4')(block4)
    block4 = BatchNormalization()(block4)
    block4 = Activation('elu')(block4)

    # Concaténation des blocs CNN
    block = Concatenate(axis=-1)([block1, block2, block3, block4])

    # Réduction de dimension et préparation pour LSTM
    lstm_input = AveragePooling2D((1, 4))(block)
    lstm_input = Dropout(drop_rate)(lstm_input)
    lstm_input = Flatten()(lstm_input)
    lstm_input = Dense(128, activation='relu')(lstm_input)
    lstm_input = Dropout(drop_rate)(lstm_input)
    
    # Ajout de LSTM
    lstm_output = LSTM(64, return_sequences=False)(tf.expand_dims(lstm_input, axis=1))

    # Partie classification
    out = Dense(out_class, activation='softmax')(lstm_output)

    # Création du modèle
    model = Model(inputs=Input_block, outputs=out)
    return model

# Data Preparation