In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import mne
from glob import glob
import pandas as pd

In [16]:
# Paramètres
Fs = 256  # Fréquence d'échantillonnage
n_channels = 4  # Nombre de canaux
Wn = 1  # Durée de la fenêtre d'échantillonnage
n_samples = Wn * Fs  # Longueur de la fenêtre d'échantillonnage
batch_size = 32
epochs = 500
n_ff = [2, 4, 8, 16]  # Nombre de filtres de fréquence pour chaque module d'inception
n_sf = [1, 1, 1, 1]  # Nombre de filtres spatiaux dans chaque sous-bande de fréquence


In [17]:
def convertDF2MNE(sub):
    info = mne.create_info(list(sub.columns), ch_types=['eeg'] * len(sub.columns), sfreq=256)
    info.set_montage('standard_1020')
    data = mne.io.RawArray(sub.T, info)
    data.set_eeg_reference()
    epochs = mne.make_fixed_length_epochs(data, duration=Wn, overlap=0.2 * Wn)
    return epochs.get_data()

In [18]:

# Charger et traiter les données
def load_data():
    files1 = glob("dataset//Train//*gauche*")
    files2 = glob("dataset//Train//*droite*")
    files3 = glob("dataset//Train//*neutre*")

    dfs1, dfs2, dfs3 = [], [], []
    for f in files1:
        df = pd.read_csv(f)
        cols_remove = ['timestamps', 'Right AUX']
        df = df.loc[:, ~df.columns.isin(cols_remove)]
        df.columns = df.columns.str.replace('RAW_', '', 1)
        df = df.fillna(df.mean())
        dfs1.append(df)  
    data1 = pd.concat(dfs1, ignore_index=True)

    for f in files2:
        df = pd.read_csv(f)
        cols_remove = ['timestamps', 'Right AUX']
        df = df.loc[:, ~df.columns.isin(cols_remove)]
        df.columns = df.columns.str.replace('RAW_', '', 1)
        df = df.fillna(df.mean())
        dfs2.append(df)  
    data2 = pd.concat(dfs2, ignore_index=True)

    for f in files3:
        df = pd.read_csv(f)
        cols_remove = ['timestamps', 'Right AUX']
        df = df.loc[:, ~df.columns.isin(cols_remove)]
        df.columns = df.columns.str.replace('RAW_', '', 1)
        df = df.fillna(df.mean())
        dfs3.append(df)  
    data3 = pd.concat(dfs3, ignore_index=True)

    x_left = np.empty((0, n_channels, n_samples))
    y_left = np.empty(0)
    x_right = np.empty((0, n_channels, n_samples))
    y_right = np.empty(0)
    x_neutral = np.empty((0, n_channels, n_samples))
    y_neutral = np.empty(0)

    # Convertir les données
    data = convertDF2MNE(data1)
    for i in range(len(data)):
        label = 0
        y_left = np.append(y_left, label)
    x_left = np.append(x_left, data, axis=0)

    data = convertDF2MNE(data2)
    for i in range(len(data)):
        label = 1
        y_right = np.append(y_right, label)
    x_right = np.append(x_right, data, axis=0)

    data = convertDF2MNE(data3)
    for i in range(len(data)):
        label = 2
        y_neutral = np.append(y_neutral, label)
    x_neutral = np.append(x_neutral, data, axis=0)

    x_left = x_left[:, :, :, np.newaxis]
    x_right = x_right[:, :, :, np.newaxis]
    x_neutral = x_neutral[:, :, :, np.newaxis]

    split = 64
    x_train = np.concatenate((x_left[split:, :, :, :], x_right[split:, :, :, :], x_neutral[split:, :, :, :]), axis=0)
    x_val = np.concatenate((x_left[:split, :, :, :], x_right[:split, :, :, :], x_neutral[split:, :, :, :]), axis=0)

    y_train = np.concatenate((y_left[split:], y_right[split:], y_neutral[split:]))
    y_val = np.concatenate((y_left[:split], y_right[:split], y_neutral[split:]))

    return x_train, y_train, x_val, y_val

In [19]:
class EEGNet(nn.Module):
    def __init__(self, num_classes = 3):
        super(EEGNet, self).__init__()
        
        # Première couche de convolution
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=2, kernel_size=(1, 4), stride=1, padding=(0, 2))
        self.batch_norm1 = nn.BatchNorm2d(2)
        self.activation1 = nn.ReLU()

        # Ajouter la couche depthwise
        self.depthwise_conv1 = nn.Conv2d(in_channels=2, out_channels=4, kernel_size=(2, 1), stride=1, padding=0)
        self.batch_norm2 = nn.BatchNorm2d(4)
        self.activation2 = nn.ReLU()

        # Définir une couche conv2 (si nécessaire)
        self.conv2 = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(1, 3), stride=1, padding=(0, 1))  # Par exemple
        self.batch_norm3 = nn.BatchNorm2d(8)
        self.activation3 = nn.ReLU()

        # Couches fully connected
        self.fc = nn.Linear(8 * 255 * 2, 64)  # 8 * 255 * 2 = 4080
        self.out = nn.Linear(64, num_classes)  # Par exemple, 28 classes
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # Appliquer les convolutions et activations
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = self.activation1(x)

        x = self.depthwise_conv1(x)
        x = self.batch_norm2(x)
        x = self.activation2(x)

        x = self.conv2(x)
        x = self.batch_norm3(x)
        x = self.activation3(x)

        # Vérifier les dimensions de x avant de passer à la couche fully connected
        # print(x.shape)  # Ajouter cette ligne pour vérifier les dimensions

        # Appliquer la couche fully connected
        x = x.view(x.size(0), -1)  # Aplatir la sortie
        x = self.fc(x)  # Passer à la couche fully connected
        x = self.dropout(x)
        x = self.out(x) 

        return x



In [20]:
class EEGDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


In [21]:
x_train, y_train, x_val, y_val = load_data()

# Créer les DataLoaders
train_dataset = EEGDataset(x_train, y_train)
val_dataset = EEGDataset(x_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Creating RawArray with float64 data, n_channels=4, n_times=172023
    Range : 0 ... 172022 =      0.000 ...   671.961 secs
Ready.
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Not setting metadata
839 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 839 events and 256 original time points ...
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=4, n_times=119145
    Range : 0 ... 119144 =      0.000 ...   465.406 secs
Ready.
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Not setting metadata
581 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 581 events and 256 original time points ...
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=4, n_times=104930
    Range : 0 ... 104929 =      0.000 ..

In [22]:
model = EEGNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
# Entraînement du modèle
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")

Epoch 1/500, Loss: 1.2540, Accuracy: 44.77%
Epoch 2/500, Loss: 0.9609, Accuracy: 55.69%
Epoch 3/500, Loss: 0.8020, Accuracy: 64.43%
Epoch 4/500, Loss: 0.7250, Accuracy: 68.85%
Epoch 5/500, Loss: 0.6315, Accuracy: 73.79%
Epoch 6/500, Loss: 0.5556, Accuracy: 77.41%
Epoch 7/500, Loss: 0.5045, Accuracy: 79.54%
Epoch 8/500, Loss: 0.4302, Accuracy: 85.57%
Epoch 9/500, Loss: 0.3851, Accuracy: 85.46%
Epoch 10/500, Loss: 0.4032, Accuracy: 85.29%
Epoch 11/500, Loss: 0.3845, Accuracy: 84.94%
Epoch 12/500, Loss: 0.3091, Accuracy: 89.48%
Epoch 13/500, Loss: 0.2548, Accuracy: 90.63%
Epoch 14/500, Loss: 0.1833, Accuracy: 93.45%
Epoch 15/500, Loss: 0.2069, Accuracy: 93.74%
Epoch 16/500, Loss: 0.1821, Accuracy: 93.51%
Epoch 17/500, Loss: 0.1784, Accuracy: 94.31%
Epoch 18/500, Loss: 0.1681, Accuracy: 93.91%
Epoch 19/500, Loss: 0.1759, Accuracy: 94.08%
Epoch 20/500, Loss: 0.1495, Accuracy: 95.34%
Epoch 21/500, Loss: 0.0918, Accuracy: 96.84%
Epoch 22/500, Loss: 0.1047, Accuracy: 96.72%
Epoch 23/500, Loss:

In [99]:
torch.save(model, 'model.pth')

In [None]:
import torch
model = torch.load('model.pth',  weights_only=False)
# model.eval()