In [10]:
# Libraries

In [14]:
import os
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

sns.set(style="whitegrid")

device = "cuda" if torch.cuda.is_available() else "cpu"

In [16]:
# Data Loading & Preprocessing

In [None]:
dataset_path = "../APPLESEED_Dataset"  # adjust path if needed

def load_sub_session(sub_path, ses):
    eeg_path = os.path.join(sub_path, f"ses-{ses}", "eeg")
    vhdr_files = [f for f in os.listdir(eeg_path) if f.endswith(".vhdr")]
    raws = []
    for vhdr in vhdr_files:
        raw = mne.io.read_raw_brainvision(os.path.join(eeg_path, vhdr), preload=True)
        raw.filter(1., 40.)  # bandpass filter
        raw.set_eeg_reference('average')
        raws.append(raw)
    return raws

def load_all_subjects(dataset_path):
    subjects = [d for d in os.listdir(dataset_path) if d.startswith("sub")]
    all_data = []
    labels = []
    for sub in subjects:
        sub_path = os.path.join(dataset_path, sub)
        for ses in ["1","2","3","4"]:
            try:
                raws = load_sub_session(sub_path, ses)
                all_data.extend(raws)
                # Example: session number as label (4,8,12,16 weeks)
                labels.extend([int(ses)]*len(raws))
            except:
                continue
    return all_data, np.array(labels)

print("Loading EEG data...")
all_raws, labels = load_all_subjects(dataset_path)
print(f"Total EEG recordings: {len(all_raws)}")

Loading EEG data...
Extracting parameters from ../APPLESEED_Dataset\sub-01\ses-1\eeg\sub-01_ses-1_task-appleseedexample_eeg.vhdr...
Setting channel info structure...
Reading 0 ... 3431099  =      0.000 ...   686.220 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 16501 samples (3.300 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting parameters from ../APPLESEED_Dataset\sub-01\ses-2\eeg\sub-01_ses-2_task-appleseedexample_ee

In [None]:
# Feature Extraction

In [None]:
def extract_features(raw):
    psd, freqs = psd_welch(raw, fmin=1, fmax=40, n_fft=256)
    psd = np.log(psd + 1e-6)
    return psd

print("Extracting features...")
X = np.array([extract_features(raw) for raw in all_raws])
print("Feature array shape:", X.shape)

In [None]:
# Supervised Dataset

In [None]:
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(labels, dtype=torch.long)

dataset = TensorDataset(X_tensor, y_tensor)
train_size = int(0.8*len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8)


In [None]:
# CNN+LSTM Model

In [None]:
class CNN_LSTM(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(n_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )
        self.lstm = nn.LSTM(128, 128, batch_first=True)
        self.fc = nn.Linear(128, n_classes)
    
    def forward(self, x):
        x = self.cnn(x)
        x = x.permute(0,2,1)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        return self.fc(x)


In [None]:
# Training Loop

In [None]:
# Initialize model
n_channels = X_tensor.shape[1] if len(X_tensor.shape) > 2 else 1
n_classes = len(torch.unique(y_tensor))
model = CNN_LSTM(n_channels, n_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 30
train_losses, val_losses, val_accs = [], [], []

print(" Training started...\n")

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation phase
    model.eval()
    val_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for X_val, y_val in val_loader:
            X_val, y_val = X_val.to(device), y_val.to(device)
            outputs = model(X_val)
            loss = criterion(outputs, y_val)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += y_val.size(0)
            correct += (predicted == y_val).sum().item()

    avg_val_loss = val_loss / len(val_loader)
    val_acc = correct / total
    val_losses.append(avg_val_loss)
    val_accs.append(val_acc)

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | "
          f"Val Acc: {val_acc*100:.2f}%")

print("\n Training complete!")

# Save model
torch.save(model.state_dict(), "models/neurogrow_cnn_lstm.pth")
print("Model saved to models/neurogrow_cnn_lstm.pth")