In [2]:
import os
from pathlib import Path
import random
from typing import Tuple, List, Dict, Any

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
from sklearn.model_selection import ParameterGrid

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [3]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


In [4]:
PREPARED_ROOT = Path("prepared/")          # dossier racine de sortie du préprocessing
SPLIT = "all"                           # on a tout mis dans un seul split logique
CSV_NAME = "train_sequences_index_balanced.csv" # nom du CSV équilibré

csv_path = PREPARED_ROOT / SPLIT / CSV_NAME
print(csv_path)
print("CSV utilisé :", csv_path)

df = pd.read_csv(csv_path)
print(df.head())
print("Taille CSV :", len(df))
print(df["label"].value_counts())

prepared\all\train_sequences_index_balanced.csv
CSV utilisé : prepared\all\train_sequences_index_balanced.csv
                                           path  label   patient    session  \
0  all/sequences/aaaaarjc_s007_t002_seq0083.npy      0  aaaaarjc  s007_2015   
1  all/sequences/aaaaaqnr_s001_t000_seq0080.npy      0  aaaaaqnr  s001_2014   
2  all/sequences/aaaaaqtw_s002_t011_seq0037.npy      1  aaaaaqtw  s002_2014   
3  all/sequences/aaaaanme_s010_t010_seq0105.npy      0  aaaaanme  s010_2014   
4  all/sequences/aaaaaqvx_s002_t001_seq0199.npy      0  aaaaaqvx  s002_2015   

            recording  last_win_center_s  
0  aaaaarjc_s007_t002              186.0  
1  aaaaaqnr_s001_t000              180.0  
2  aaaaaqtw_s002_t011              156.0  
3  aaaaanme_s010_t010              572.0  
4  aaaaaqvx_s002_t001             2658.0  
Taille CSV : 9546
label
0    4773
1    4773
Name: count, dtype: int64


In [5]:
PATIENT_COL = "patient"   # nom de la colonne patient dans ton CSV

if PATIENT_COL not in df.columns:
    raise ValueError(f"La colonne '{PATIENT_COL}' est absente du CSV. Vérifie le nom des colonnes.")

patients = df[PATIENT_COL].unique()
print("Nombre de patients uniques :", len(patients))

Nombre de patients uniques : 91


In [7]:
rng = np.random.default_rng(SEED)
rng.shuffle(patients)

n_patients = len(patients)
n_train = int(0.6 * n_patients)
n_val   = int(0.2 * n_patients)

In [8]:
train_patients = patients[:n_train]
val_patients   = patients[n_train:n_train + n_val]
test_patients  = patients[n_train + n_val:]

def filter_by_patients(df, allowed_patients):
    return df[df[PATIENT_COL].isin(allowed_patients)].reset_index(drop=True)

df_train = filter_by_patients(df, train_patients)
df_val   = filter_by_patients(df, val_patients)
df_test  = filter_by_patients(df, test_patients)

print("Taille train :", len(df_train))
print("Taille val   :", len(df_val))
print("Taille test  :", len(df_test))

print("\nDistribution des labels (train)")
print(df_train["label"].value_counts())
print("\nDistribution des labels (val)")
print(df_val["label"].value_counts())
print("\nDistribution des labels (test)")
print(df_test["label"].value_counts())

Taille train : 6117
Taille val   : 1797
Taille test  : 1632

Distribution des labels (train)
label
0    3244
1    2873
Name: count, dtype: int64

Distribution des labels (val)
label
1    957
0    840
Name: count, dtype: int64

Distribution des labels (test)
label
1    943
0    689
Name: count, dtype: int64


In [9]:
FIXED_N_CHANNELS = 27

class SeizureSequenceDataset(Dataset):
    """
    Dataset pour les séquences de spectrogrammes.
    Chaque fichier npy a la forme [seq_len, C_i, F, T] (C_i variable).
    On impose une forme finale [FIXED_N_CHANNELS, seq_len, F, T]
      - si C_i > FIXED_N_CHANNELS -> on tronque
      - si C_i < FIXED_N_CHANNELS -> on pad avec des zéros
    """
    def __init__(self, df: pd.DataFrame, prepared_root: Path, fixed_n_channels: int = FIXED_N_CHANNELS):
        self.df = df.reset_index(drop=True)
        self.prepared_root = prepared_root
        self.fixed_n_channels = fixed_n_channels

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        row = self.df.iloc[idx]
        rel_path = row["path"]
        label = int(row["label"])

        npy_path = self.prepared_root / rel_path
        arr = np.load(npy_path)               # [seq_len, C_i, F, T]
        arr = arr.astype(np.float32)

        # Réorganiser en [C_i, seq_len, F, T]
        arr = np.transpose(arr, (1, 0, 2, 3))
        C_i, N, Freq, Time = arr.shape

        # Ajuster le nombre de canaux à FIXED_N_CHANNELS
        if C_i > self.fixed_n_channels:
            # On garde les premiers canaux
            arr = arr[:self.fixed_n_channels, :, :, :]
        elif C_i < self.fixed_n_channels:
            # On pad avec des zéros sur les canaux manquants
            pad_c = self.fixed_n_channels - C_i
            pad_width = ((0, pad_c), (0, 0), (0, 0), (0, 0))
            arr = np.pad(arr, pad_width, mode="constant", constant_values=0.0)

        assert arr.shape[0] == self.fixed_n_channels, f"Shape canaux incorrecte : {arr.shape}"

        x = torch.from_numpy(arr)                  # [C_fix, N, F, T]
        y = torch.tensor(label, dtype=torch.long)
        return x, y


# -------------------------------------------------------------------
# Création des datasets / dataloaders
# -------------------------------------------------------------------
train_dataset = SeizureSequenceDataset(df_train, PREPARED_ROOT, fixed_n_channels=FIXED_N_CHANNELS)
val_dataset   = SeizureSequenceDataset(df_val,   PREPARED_ROOT, fixed_n_channels=FIXED_N_CHANNELS)
test_dataset  = SeizureSequenceDataset(df_test,  PREPARED_ROOT, fixed_n_channels=FIXED_N_CHANNELS)

BATCH_SIZE = 8

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=0, pin_memory=True)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=0, pin_memory=True)

print("Batches train :", len(train_loader))
print("Batches val   :", len(val_loader))
print("Batches test  :", len(test_loader))

# Vérification d'un batch
batch_x, batch_y = next(iter(train_loader))
print("Shape batch_x :", batch_x.shape)  # (B, C_fix, N, F, T)
print("Shape batch_y :", batch_y.shape)

Batches train : 765
Batches val   : 225
Batches test  : 204
Shape batch_x : torch.Size([8, 27, 10, 40, 7])
Shape batch_y : torch.Size([8])




In [10]:
class Seizure3DCNN(nn.Module):
    def __init__(self, in_channels: int, base_filters: int = 16, dropout: float = 0.5):
        super().__init__()

        # Bloc 1 : Conv3D + BN + ReLU + MaxPool3D (sur F et T uniquement)
        self.conv1 = nn.Conv3d(in_channels, base_filters,
                               kernel_size=(3, 3, 3),
                               padding=(1, 1, 1))
        self.bn1 = nn.BatchNorm3d(base_filters)
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2),
                                  stride=(1, 2, 2))   # D (N) inchangé, F/T divisés par ~2

        # Bloc 2 : Conv3D + BN + ReLU + MaxPool3D (réduction sur D, F, T)
        self.conv2 = nn.Conv3d(base_filters, base_filters * 2,
                               kernel_size=(3, 3, 3),
                               padding=(1, 1, 1))
        self.bn2 = nn.BatchNorm3d(base_filters * 2)
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2),
                                  stride=(2, 2, 2))   # D, F, T divisés par ~2

        # Bloc 3 : Conv3D + BN + ReLU (pas de pooling pour éviter T -> 0)
        self.conv3 = nn.Conv3d(base_filters * 2, base_filters * 4,
                               kernel_size=(3, 3, 3),
                               padding=(1, 1, 1))
        self.bn3 = nn.BatchNorm3d(base_filters * 4)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Classification finale
        self.fc = nn.Linear(base_filters * 4, 2)   # 2 classes (preictal / interictal)

    def forward(self, x):
        # x : [B, C_fix, N, F, T]
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))  # -> [B, F1, N, F1', T1']
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))  # -> [B, F2, N2, F2', T2']
        x = F.relu(self.bn3(self.conv3(x)))              # -> [B, F3, N3, F3', T3']

        # Global average pooling sur les dimensions (N3, F3', T3')
        x = x.mean(dim=[2, 3, 4])    # -> [B, F3]

        x = self.dropout(x)
        x = self.fc(x)               # -> [B, 2]
        return x


# Test rapide avec un batch
C_fix, N, Freq, Time = batch_x.shape[1:]
print("C_fix, N, Freq, Time :", C_fix, N, Freq, Time)

model_test = Seizure3DCNN(in_channels=C_fix, base_filters=16, dropout=0.5).to(device)
out = model_test(batch_x.to(device))
print("Sortie modèle (shape):", out.shape)

C_fix, N, Freq, Time : 27 10 40 7
Sortie modèle (shape): torch.Size([8, 2])


In [11]:
def train_one_epoch(model, dataloader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    running_correct = 0
    total = 0

    for x, y in dataloader:
        x = x.to(device)   # [B, C_fix, N, F, T]
        y = y.to(device)   # [B]

        optimizer.zero_grad()
        outputs = model(x)              # [B, 2]
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)

        preds = outputs.argmax(dim=1)
        running_correct += (preds == y).sum().item()
        total += x.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_correct / total
    return epoch_loss, epoch_acc


def eval_one_epoch(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    running_correct = 0
    total = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            outputs = model(x)
            loss = criterion(outputs, y)

            running_loss += loss.item() * x.size(0)
            preds = outputs.argmax(dim=1)
            running_correct += (preds == y).sum().item()
            total += x.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_correct / total
    return epoch_loss, epoch_acc

In [27]:
param_grid = {
    "base_filters": [16, 32],
    "dropout": [0.3, 0.5],
    "lr": [1e-3, 3e-4]
}

grid = list(ParameterGrid(param_grid))
print("Nombre de configs à tester :", len(grid))

N_EPOCHS_SEARCH = 5   # tu peux augmenter plus tard

best_val_acc = 0.0
best_config = None
best_state_dict = None

criterion = nn.CrossEntropyLoss()

C_fix = FIXED_N_CHANNELS

for i, params in enumerate(grid):
    print(f"\n===== Config {i+1}/{len(grid)} =====")
    print(params)

    # nouveau modèle pour chaque config
    model = Seizure3DCNN(
        in_channels=C_fix,
        base_filters=params["base_filters"],
        dropout=params["dropout"],
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=params["lr"],
        weight_decay=1e-4,    # L2 léger
    )

    for epoch in range(1, N_EPOCHS_SEARCH + 1):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
        val_loss, val_acc = eval_one_epoch(model, val_loader, criterion)

        print(f"  Epoch {epoch}/{N_EPOCHS_SEARCH} - "
              f"Train loss: {train_loss:.4f}, acc: {train_acc:.3f} | "
              f"Val loss: {val_loss:.4f}, acc: {val_acc:.3f}")

    # on garde la meilleure config selon la dernière val_acc
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_config = params
        best_state_dict = model.state_dict()

print("\n===== Meilleure config trouvée =====")
print(best_config)
print("Best val acc :", best_val_acc)

Nombre de configs à tester : 8

===== Config 1/8 =====
{'base_filters': 16, 'dropout': 0.3, 'lr': 0.001}
  Epoch 1/5 - Train loss: 0.5712, acc: 0.698 | Val loss: 1.2807, acc: 0.459
  Epoch 2/5 - Train loss: 0.3994, acc: 0.824 | Val loss: 1.0786, acc: 0.499
  Epoch 3/5 - Train loss: 0.2926, acc: 0.881 | Val loss: 1.4746, acc: 0.501
  Epoch 4/5 - Train loss: 0.2268, acc: 0.910 | Val loss: 1.6145, acc: 0.457
  Epoch 5/5 - Train loss: 0.1706, acc: 0.934 | Val loss: 1.7552, acc: 0.508

===== Config 2/8 =====
{'base_filters': 16, 'dropout': 0.3, 'lr': 0.0003}
  Epoch 1/5 - Train loss: 0.5614, acc: 0.692 | Val loss: 0.9653, acc: 0.527
  Epoch 2/5 - Train loss: 0.3425, acc: 0.860 | Val loss: 0.9542, acc: 0.525
  Epoch 3/5 - Train loss: 0.2096, acc: 0.925 | Val loss: 1.6489, acc: 0.492
  Epoch 4/5 - Train loss: 0.1616, acc: 0.940 | Val loss: 1.7659, acc: 0.445
  Epoch 5/5 - Train loss: 0.1221, acc: 0.957 | Val loss: 2.0133, acc: 0.522

===== Config 3/8 =====
{'base_filters': 16, 'dropout': 0.5,

In [None]:
N_EPOCHS_FINAL = 20
PATIENCE = 5
MODEL_SAVE_PATH = "models/best_cnn3d_tuh.pth"

if best_config is None:
    raise RuntimeError("Aucune config n'a été trouvée pendant le grid search.")


model = Seizure3DCNN(
    in_channels=FIXED_N_CHANNELS,                 # ✅ PAS C
    base_filters=best_config["base_filters"],
    dropout=best_config["dropout"],
).to(device)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=best_config["lr"],
    weight_decay=1e-4,
)

criterion = nn.CrossEntropyLoss()

best_val_loss = float("inf")
epochs_no_improve = 0

for epoch in range(1, N_EPOCHS_FINAL + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = eval_one_epoch(model, val_loader, criterion)

    print(f"Epoch {epoch}/{N_EPOCHS_FINAL} - "
          f"Train loss: {train_loss:.4f}, acc: {train_acc:.3f} | "
          f"Val loss: {val_loss:.4f}, acc: {val_acc:.3f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print("  -> Nouveau meilleur modèle sauvegardé.")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print("  -> Early stopping déclenché.")
            break

print("\nEntraînement terminé.")
print("Meilleur modèle enregistré dans :", MODEL_SAVE_PATH)



Epoch 1/20 - Train loss: 0.5473, acc: 0.700 | Val loss: 1.1719, acc: 0.504
  -> Nouveau meilleur modèle sauvegardé.
Epoch 2/20 - Train loss: 0.3312, acc: 0.862 | Val loss: 1.3441, acc: 0.503
Epoch 3/20 - Train loss: 0.2117, acc: 0.915 | Val loss: 1.4210, acc: 0.490
Epoch 4/20 - Train loss: 0.1446, acc: 0.949 | Val loss: 1.5470, acc: 0.504
Epoch 5/20 - Train loss: 0.1212, acc: 0.957 | Val loss: 2.8695, acc: 0.494
Epoch 6/20 - Train loss: 0.1040, acc: 0.962 | Val loss: 1.6452, acc: 0.502
  -> Early stopping déclenché.

Entraînement terminé.
Meilleur modèle enregistré dans : models/best_cnn3d_tuh.pth


In [12]:
best_model = Seizure3DCNN(
    in_channels=FIXED_N_CHANNELS,
    base_filters=best_config["base_filters"],
    dropout=best_config["dropout"],
).to(device)

best_model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device))
best_model.eval()

all_preds = []
all_targets = []

with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        y = y.to(device)

        outputs = best_model(x)
        preds = outputs.argmax(dim=1)

        all_preds.append(preds.cpu().numpy())
        all_targets.append(y.cpu().numpy())

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)

test_acc = accuracy_score(all_targets, all_preds)
cm = confusion_matrix(all_targets, all_preds)
report = classification_report(all_targets, all_preds, digits=3)

print("Test accuracy :", test_acc)
print("\nConfusion matrix :\n", cm)
print("\nClassification report :\n", report)