In [None]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report

### Data Preprocessing

In [1]:
# TO DO

### EEG Data Loading

In [None]:
class EEGDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.files = os.listdir(data_dir)
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        file_name = self.files[idx]
        file_path = os.path.join(self.data_dir, file_name)

        df = pd.read_csv(file_path)
        eeg_data = torch.tensor(df.iloc[:, :-1].values, dtype=torch.float32)
        eeg_data = eeg_data.unsqueeze(0)

        # for this, file names should be in the format "movement_direction_0.csv"
        movement, direction, _ = file_name.split("_")
        movement_label = ["neutre", "avant", "arrière"].index(movement)
        direction_label = ["neutre", "gauche", "droite"].index(direction)

        movement_tensor = torch.tensor(movement_label, dtype=torch.long)
        direction_tensor = torch.tensor(direction_label, dtype=torch.long)

        return eeg_data, (movement_tensor, direction_tensor)

In [None]:
data_dir = os.path.join(os.getcwd(), "data")  # data directory
dataset = EEGDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

### Define EEGNet model and Pytorch

In [None]:
class EEGNet(nn.Module):
    def __init__(self, num_classes_movement=3, num_classes_direction=3):
        super(EEGNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, (1, 64), padding=(0, 32))
        self.batchnorm1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, (2, 1))
        self.conv3 = nn.Conv2d(32, 32, (1, 16), groups=32, padding=(0, 8))
        self.conv4 = nn.Conv2d(32, 32, (1, 1))
        self.batchnorm2 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 2 * 50, 64)
        self.dropout = nn.Dropout(0.5)
        self.fc_movement = nn.Linear(64, num_classes_movement)
        self.fc_direction = nn.Linear(64, num_classes_direction)

    def forward(self, x):
        x = F.elu(self.batchnorm1(self.conv1(x)))
        x = F.elu(self.conv2(x))
        x = F.elu(self.batchnorm2(self.conv3(x)))
        x = F.elu(self.conv4(x))
        x = x.view(x.size(0), -1)
        x = F.elu(self.fc1(x))
        x = self.dropout(x)
        movement = self.fc_movement(x)
        direction = self.fc_direction(x)
        return movement, direction

### Model Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EEGNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

num_epochs = 30
for epoch in range(num_epochs):
    total_loss = 0
    for eeg, (movement, direction) in dataloader:
        eeg, movement, direction = eeg.to(device), movement.to(device), direction.to(device)
        optimizer.zero_grad()
        movement_pred, direction_pred = model(eeg)
        loss_movement = criterion(movement_pred, movement)
        loss_direction = criterion(direction_pred, direction)
        loss = loss_movement + loss_direction
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

### Model Save

In [None]:
torch.save(model.state_dict(), "eegnet_model.pth")
print("✅ Modèle Entraîné et Sauvegardé !")

### Test and analysis

In [None]:
def evaluate_model(model, dataloader):
    model.eval()
    all_movement_true, all_movement_pred = [], []
    all_direction_true, all_direction_pred = [], []
    
    with torch.no_grad():
        for eeg, (movement, direction) in dataloader:
            eeg, movement, direction = eeg.to(device), movement.to(device), direction.to(device)
            movement_pred, direction_pred = model(eeg)
            movement_pred = movement_pred.argmax(dim=1)
            direction_pred = direction_pred.argmax(dim=1)
            
            all_movement_true.extend(movement.cpu().numpy())
            all_movement_pred.extend(movement_pred.cpu().numpy())
            all_direction_true.extend(direction.cpu().numpy())
            all_direction_pred.extend(direction_pred.cpu().numpy())
    
    return all_movement_true, all_movement_pred, all_direction_true, all_direction_pred

def plot_confusion_matrix(true_labels, pred_labels, classes, title):
    cm = confusion_matrix(true_labels, pred_labels)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.title(f"Matrice de Confusion - {title}")
    plt.show()

def show_classification_report(true_labels, pred_labels, classes, title):
    print(f"\n📊 Rapport de Classification - {title}")
    print(classification_report(true_labels, pred_labels, target_names=classes))

def analyze_errors(true_labels, pred_labels, classes, title):
    errors = [(t, p) for t, p in zip(true_labels, pred_labels) if t != p]
    unique_errors, counts = np.unique(errors, axis=0, return_counts=True)
    print(f"\n🔍 Pires Erreurs - {title}")
    for (true, pred), count in zip(unique_errors, counts):
        print(f"→ Vrai: {classes[true]} | Prédit: {classes[pred]} | Fois: {count}")


In [None]:
movement_classes = ["neutre", "avant", "arrière"]
direction_classes = ["neutre", "gauche", "droite"]

all_movement_true, all_movement_pred, all_direction_true, all_direction_pred = evaluate_model(model, dataloader)
plot_confusion_matrix(all_movement_true, all_movement_pred, movement_classes, "movement")
plot_confusion_matrix(all_direction_true, all_direction_pred, direction_classes, "Direction")
show_classification_report(all_movement_true, all_movement_pred, movement_classes, "movement")
show_classification_report(all_direction_true, all_direction_pred, direction_classes, "Direction")
analyze_errors(all_movement_true, all_movement_pred, movement_classes, "movement")
analyze_errors(all_direction_true, all_direction_pred, direction_classes, "Direction")