In [None]:
! pip install -r requirements.txt

Collecting torchvision (from -r requirements.txt (line 5))
  Downloading torchvision-0.20.1-cp312-cp312-win_amd64.whl.metadata (6.2 kB)
Downloading torchvision-0.20.1-cp312-cp312-win_amd64.whl (1.6 MB)
   ---------------------------------------- 0.0/1.6 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.6 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.6 MB ? eta -:--:--
   - -------------------------------------- 0.0/1.6 MB 326.8 kB/s eta 0:00:05
   ----------- ---------------------------- 0.4/1.6 MB 3.1 MB/s eta 0:00:01
   -------------------------------- ------- 1.3/1.6 MB 6.7 MB/s eta 0:00:01
   ---------------------------------------- 1.6/1.6 MB 6.6 MB/s eta 0:00:00
Installing collected packages: torchvision
Successfully installed torchvision-0.20.1



[notice] A new release of pip is available: 24.0 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
# Manipulation de données
import pandas as pd
import numpy as np

# Traitement du signal
from scipy import signal
import mne

# Machine Learning et Deep Learning
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import torch


# Visualisation
import matplotlib.pyplot as plt
#import seaborn as sns
#import plotly.express as px

# Gestion de Notebooks
#import papermill as pm
import ipywidgets as widgets

# Utilitaires
import joblib
import yaml
import pickle
import os
import sys
import importlib

# Importation code local
sys.path.append('preprocessing')
import preprocess
import dataset
import torchcam
import models.GGN.ggn_model as GGN
import models.GGN.train as train
importlib.reload(preprocess)
importlib.reload(GGN)
importlib.reload(train)
importlib.reload(dataset)


## Chargement de la configuration du projet

In [None]:
# Chargement de la configuration YAML
with open("config.yml", "r") as file:
    config = yaml.safe_load(file)
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Preprocessing

In [None]:
preprocess.preprocess_data(config)

Extracting parameters from data/raw/2023_eegpainmarkers_laval/sub-003\eeg\sub-003_task-audioactive_eeg.vhdr...
Setting channel info structure...
Reading 0 ... 490000  =      0.000 ...   490.000 secs...
Extracting parameters from data/raw/2023_eegpainmarkers_laval/sub-003\eeg\sub-003_task-audiopassive_eeg.vhdr...
Setting channel info structure...
Reading 0 ... 490000  =      0.000 ...   490.000 secs...
Extracting parameters from data/raw/2023_eegpainmarkers_laval/sub-003\eeg\sub-003_task-thermalactive_eeg.vhdr...
Setting channel info structure...
Reading 0 ... 500000  =      0.000 ...   500.000 secs...
Extracting parameters from data/raw/2023_eegpainmarkers_laval/sub-003\eeg\sub-003_task-thermalpassive_eeg.vhdr...
Setting channel info structure...
Reading 0 ... 500000  =      0.000 ...   500.000 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 4 contiguous segments
Setting up band-stop filter

FI

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 6601 samples (6.601 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 6601 samples (6.601 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 6601 samples (6.601 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Used Annotations descriptions: [np.str_('audio'), np.str_('pain')]
Not setting metadata
494 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 494 events and 4001 original time points ...
0 bad epochs dropped


## Model Training

In [None]:
running_model = config['project_config']['running_model']

subjects_id = config['data']['subjects']

bids_root = config['data']['path']


test_losses = []
accuracies = []
recalls = []
precisions = []
f1_scores = []
auc_rocs = []

if not subjects_id:
        subjects_id = [
            d for d in os.listdir(bids_root)
            if os.path.isdir(os.path.join(bids_root, d)) and d.startswith("sub-")
        ]
        print(f"Aucun ID de sujet spécifié. Tous les sujets détectés : {subjects_id}")

for subject in subjects_id:

    if running_model == "GGN":
        train_loader, val_loader, test_loader = dataset.create_dataloader([subject], config)

        model = GGN.GGN(**config['models']['GGN']['parameters'], device=device)
        
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        num_epochs = 50

        # Train and validate the model
        train.train(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)

        # Test the model
        avg_test_loss, accuracy, recall, precision, f1, auc_roc = train.test(model, test_loader, criterion, device)
        test_losses.append(avg_test_loss)
        accuracies.append(accuracy)
        recalls.append(recall)
        precisions.append(precision)
        f1_scores.append(f1)
        auc_rocs.append(auc_roc)
        
        model.explain_temporal_cnn(test_loader, device)
        

    elif running_model == "SVM":
        train_loader, val_loader, test_loader = dataset.create_dataloader([subject], config)

        # Convert data loaders to numpy arrays
        X_train, y_train = preprocess.dataloader_without_topology_to_numpy(train_loader)
        X_val, y_val = preprocess.dataloader_without_topology_to_numpy(val_loader)
        X_test, y_test = preprocess.dataloader_without_topology_to_numpy(test_loader)
        
        print(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
        print(f"X_val: {X_val.shape}, y_val: {y_val.shape}")
        print(f"X_test: {X_test.shape}, y_test: {y_test.shape}")

        # Reshape data
        X_train = X_train.reshape(X_train.shape[0], -1)
        X_val = X_val.reshape(X_val.shape[0], -1)
        X_test = X_test.reshape(X_test.shape[0], -1)

        # Train the SVM
        svm_params = config['models']['SVM']['parameters']
        svm = SVC(**svm_params)
        svm.fit(X_train, y_train)

        # Evaluate on validation set
        y_val_pred = svm.predict(X_val)
        val_acc = accuracy_score(y_val, y_val_pred)

        # Test the model
        y_test_pred = svm.predict(X_test)
        test_acc = accuracy_score(y_test, y_test_pred)
        accuracies.append(test_acc)

# Calculate mean test accuracy across all subjects
mean_test_loss = np.mean(test_losses)
mean_accuracy = np.mean(accuracies)
mean_recall = np.mean(recalls)
mean_precision = np.mean(precisions)
mean_f1 = np.mean(f1_scores)
mean_auc_roc = np.mean(auc_rocs)

print(f"Mean Test Loss across all subjects: {mean_test_loss:.4f}")
print(f"Mean Test Accuracy across all subjects: {mean_accuracy:.2f}%")
print(f"Mean Recall (Sensitivity) across all subjects: {mean_recall:.2f}")
print(f"Mean Precision across all subjects: {mean_precision:.2f}")
print(f"Mean F1 Score across all subjects: {mean_f1:.2f}")
print(f"Mean AUC-ROC across all subjects: {mean_auc_roc:.2f}")
