In [1]:
! pip install -r requirements.txt

Defaulting to user installation because normal site-packages is not writeable
Looking in links: /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2020/avx2, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo/avx2, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2020/generic, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo/generic, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic


In [1]:
# Manipulation de données
import pandas as pd
import numpy as np

# Traitement du signal
from scipy import signal
import mne

# Machine Learning et Deep Learning
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import torch


# Visualisation
import matplotlib.pyplot as plt
#import seaborn as sns
#import plotly.express as px

# Gestion de Notebooks
#import papermill as pm
import ipywidgets as widgets

# Utilitaires
import joblib
import yaml
import pickle
import os
import sys
import importlib

# Importation code local
sys.path.append('preprocessing')
import preprocess
import dataset
import torchcam
import models.GGN.ggn_model as GGN
import models.GGN.train as train
importlib.reload(preprocess)
importlib.reload(GGN)
importlib.reload(train)
importlib.reload(dataset)


<module 'dataset' from '/project/166600089/marc-debug-grad-cam-cnn/preprocessing/dataset.py'>

## Chargement de la configuration du projet

In [2]:
# Chargement de la configuration YAML
with open("config.yml", "r") as file:
    config = yaml.safe_load(file)
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Preprocessing

In [4]:
preprocess.preprocess_data(config)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 6601 samples (6.601 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Saved epoch 1 for sub-003 task audioactive to data/processed/sub-003/audioactive/1-epo.fif
Saved epoch 2 for sub-003 task audioactive to data/processed/sub-003/audioactive/2-epo.fif
Saved epoch 3 for sub-003 task audioactive to data/processed/sub-003/audioactive/3-epo.fif
Saved epoch 4 for sub-003 task audioactive to data/processed/sub-003/audioactive/4-epo.fif
Saved epoch 5 for sub-003 task audioactive to data/processed/sub-003/audioactive/5-epo.fif
Saved epoch 6 for sub-003 task audioactive to data/processed/sub-003/audioactive/6-epo.fif
Saved epoch 7 for sub-003 task audioactive to data/processed/sub-003/audioactive/7-epo.fif
Saved epoch 8 for sub-003 task audioactive to data/processed/sub-003/audioactive/8-epo.fif
Saved epoch 9 for sub-003 task audioactive to data/processed/sub-003/audioactive/9-epo.fif
Saved epoch 10 for sub-003 task audioactive to data/processed/sub-003/audioactive/10-epo.fif
Saved epoch 11 for sub-003 task audioactive to data/processed/sub-003/audioactive/11-epo

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Saved epoch 1 for sub-003 task audiopassive to data/processed/sub-003/audiopassive/1-epo.fif
Saved epoch 2 for sub-003 task audiopassive to data/processed/sub-003/audiopassive/2-epo.fif
Saved epoch 3 for sub-003 task audiopassive to data/processed/sub-003/audiopassive/3-epo.fif
Saved epoch 4 for sub-003 task audiopassive to data/processed/sub-003/audiopassive/4-epo.fif
Saved epoch 5 for sub-003 task audiopassive to data/processed/sub-003/audiopassive/5-epo.fif
Saved epoch 6 for sub-003 task audiopassive to data/processed/sub-003/audiopassive/6-epo.fif
Saved epoch 7 for sub-003 task audiopassive to data/processed/sub-003/audiopassive/7-epo.fif
Saved epoch 8 for sub-003 task audiopassive to data/processed/sub-003/audiopassive/8-epo.fif
Saved epoch 9 for sub-003 task audiopassive to data/processed/sub-003/audiopassive/9-epo.fif
Saved epoch 10 for sub-003 task audiopassive to data/processed/sub-003/audiopassive/10-epo.fif
Saved epoch 11 for sub-003 task audiopassive to data/processed/sub-0

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Saved epoch 1 for sub-003 task thermalactive to data/processed/sub-003/thermalactive/1-epo.fif
Saved epoch 2 for sub-003 task thermalactive to data/processed/sub-003/thermalactive/2-epo.fif
Saved epoch 3 for sub-003 task thermalactive to data/processed/sub-003/thermalactive/3-epo.fif
Saved epoch 4 for sub-003 task thermalactive to data/processed/sub-003/thermalactive/4-epo.fif
Saved epoch 5 for sub-003 task thermalactive to data/processed/sub-003/thermalactive/5-epo.fif
Saved epoch 6 for sub-003 task thermalactive to data/processed/sub-003/thermalactive/6-epo.fif
Saved epoch 7 for sub-003 task thermalactive to data/processed/sub-003/thermalactive/7-epo.fif
Saved epoch 8 for sub-003 task thermalactive to data/processed/sub-003/thermalactive/8-epo.fif
Saved epoch 9 for sub-003 task thermalactive to data/processed/sub-003/thermalactive/9-epo.fif
Saved epoch 10 for sub-003 task thermalactive to data/processed/sub-003/thermalactive/10-epo.fif
Saved epoch 11 for sub-003 task thermalactive to

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s


Saved epoch 1 for sub-003 task thermalpassive to data/processed/sub-003/thermalpassive/1-epo.fif
Saved epoch 2 for sub-003 task thermalpassive to data/processed/sub-003/thermalpassive/2-epo.fif
Saved epoch 3 for sub-003 task thermalpassive to data/processed/sub-003/thermalpassive/3-epo.fif
Saved epoch 4 for sub-003 task thermalpassive to data/processed/sub-003/thermalpassive/4-epo.fif
Saved epoch 5 for sub-003 task thermalpassive to data/processed/sub-003/thermalpassive/5-epo.fif
Saved epoch 6 for sub-003 task thermalpassive to data/processed/sub-003/thermalpassive/6-epo.fif
Saved epoch 7 for sub-003 task thermalpassive to data/processed/sub-003/thermalpassive/7-epo.fif
Saved epoch 8 for sub-003 task thermalpassive to data/processed/sub-003/thermalpassive/8-epo.fif
Saved epoch 9 for sub-003 task thermalpassive to data/processed/sub-003/thermalpassive/9-epo.fif
Saved epoch 10 for sub-003 task thermalpassive to data/processed/sub-003/thermalpassive/10-epo.fif
Saved epoch 11 for sub-003 t

## Model Training

In [26]:
sys.path.append('preprocessing')
import preprocess
import dataset
import torchcam
import models.GGN.ggn_model as GGN
import models.GGN.train as train
importlib.reload(preprocess)
importlib.reload(GGN)
importlib.reload(train)
importlib.reload(dataset)

running_model = config['project_config']['running_model']

subjects_id = config['data']['subjects']

bids_root = config['data']['path']


test_losses = []
accuracies = []
recalls = []
precisions = []
f1_scores = []
auc_rocs = []

if not subjects_id:
        subjects_id = [
            d for d in os.listdir(bids_root)
            if os.path.isdir(os.path.join(bids_root, d)) and d.startswith("sub-")
        ]
        print(f"Aucun ID de sujet spécifié. Tous les sujets détectés : {subjects_id}")

for subject in subjects_id:

    if running_model == "GGN":
        train_loader, val_loader, test_loader = dataset.create_dataloader([subject], config)

        model = GGN.GGN(**config['models']['GGN']['parameters'], device=device)
        
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        num_epochs = 10

        # Train and validate the model
        train.train(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)

        # Test the model
        avg_test_loss, accuracy, recall, precision, f1, auc_roc = train.test(model, test_loader, criterion, device)
        test_losses.append(avg_test_loss)
        accuracies.append(accuracy)
        recalls.append(recall)
        precisions.append(precision)
        f1_scores.append(f1)
        auc_rocs.append(auc_roc)
        
        model.explain_temporal_cnn(test_loader, device)
        

    elif running_model == "SVM":
        train_loader, val_loader, test_loader = dataset.create_dataloader([subject], config)

        # Convert data loaders to numpy arrays
        X_train, y_train = preprocess.dataloader_without_topology_to_numpy(train_loader)
        X_val, y_val = preprocess.dataloader_without_topology_to_numpy(val_loader)
        X_test, y_test = preprocess.dataloader_without_topology_to_numpy(test_loader)
        
        print(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
        print(f"X_val: {X_val.shape}, y_val: {y_val.shape}")
        print(f"X_test: {X_test.shape}, y_test: {y_test.shape}")

        # Reshape data
        X_train = X_train.reshape(X_train.shape[0], -1)
        X_val = X_val.reshape(X_val.shape[0], -1)
        X_test = X_test.reshape(X_test.shape[0], -1)

        # Train the SVM
        svm_params = config['models']['SVM']['parameters']
        svm = SVC(**svm_params)
        svm.fit(X_train, y_train)

        # Evaluate on validation set
        y_val_pred = svm.predict(X_val)
        val_acc = accuracy_score(y_val, y_val_pred)

        # Test the model
        y_test_pred = svm.predict(X_test)
        test_acc = accuracy_score(y_test, y_test_pred)
        accuracies.append(test_acc)

# Calculate mean test accuracy across all subjects
mean_test_loss = np.mean(test_losses)
mean_accuracy = np.mean(accuracies)
mean_recall = np.mean(recalls)
mean_precision = np.mean(precisions)
mean_f1 = np.mean(f1_scores)
mean_auc_roc = np.mean(auc_rocs)

print(f"Mean Test Loss across all subjects: {mean_test_loss:.4f}")
print(f"Mean Test Accuracy across all subjects: {mean_accuracy:.2f}%")
print(f"Mean Recall (Sensitivity) across all subjects: {mean_recall:.2f}")
print(f"Mean Precision across all subjects: {mean_precision:.2f}")
print(f"Mean F1 Score across all subjects: {mean_f1:.2f}")
print(f"Mean AUC-ROC across all subjects: {mean_auc_roc:.2f}")


                                                           

Epoch [1/10], Loss: 0.6936
Validation Loss: 0.6932


                                                           

KeyboardInterrupt: 