In [10]:
# Libraries

In [14]:
import os
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

sns.set(style="whitegrid")

device = "cuda" if torch.cuda.is_available() else "cpu"

In [16]:
# Data Loading & Preprocessing

In [None]:
dataset_path = "../APPLESEED_Dataset"  # adjust path if needed

def load_sub_session(sub_path, ses):
    eeg_path = os.path.join(sub_path, f"ses-{ses}", "eeg")
    vhdr_files = [f for f in os.listdir(eeg_path) if f.endswith(".vhdr")]
    raws = []
    for vhdr in vhdr_files:
        raw = mne.io.read_raw_brainvision(os.path.join(eeg_path, vhdr), preload=True)
        raw.filter(1., 40.)  # bandpass filter
        raw.set_eeg_reference('average')
        raws.append(raw)
    return raws

def load_all_subjects(dataset_path):
    subjects = [d for d in os.listdir(dataset_path) if d.startswith("sub")]
    all_data = []
    labels = []
    for sub in subjects:
        sub_path = os.path.join(dataset_path, sub)
        for ses in ["1","2","3","4"]:
            try:
                raws = load_sub_session(sub_path, ses)
                all_data.extend(raws)
                # Example: session number as label (4,8,12,16 weeks)
                labels.extend([int(ses)]*len(raws))
            except:
                continue
    return all_data, np.array(labels)

print("Loading EEG data...")
all_raws, labels = load_all_subjects(dataset_path)
print(f"Total EEG recordings: {len(all_raws)}")

In [None]:
# Feature Extraction

In [None]:
def extract_features(raw):
    psd, freqs = psd_welch(raw, fmin=1, fmax=40, n_fft=256)
    psd = np.log(psd + 1e-6)
    return psd

print("Extracting features...")
X = np.array([extract_features(raw) for raw in all_raws])
print("Feature array shape:", X.shape)

In [None]:
# Supervised Dataset

In [None]:
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(labels, dtype=torch.long)

dataset = TensorDataset(X_tensor, y_tensor)
train_size = int(0.8*len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8)


In [None]:
# CNN+LSTM Model

In [None]:
class CNN_LSTM(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(n_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )
        self.lstm = nn.LSTM(128, 128, batch_first=True)
        self.fc = nn.Linear(128, n_classes)
    
    def forward(self, x):
        x = self.cnn(x)
        x = x.permute(0,2,1)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        return self.fc(x)


In [None]:
# Training Loop

In [None]:
# Initialize model
n_channels = X_tensor.shape[1] if len(X_tensor.shape) > 2 else 1
n_classes = len(torch.unique(y_tensor))
model = CNN_LSTM(n_channels, n_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 30
train_losses, val_losses, val_accs = [], [], []

print(" Training started...\n")

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation phase
    model.eval()
    val_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for X_val, y_val in val_loader:
            X_val, y_val = X_val.to(device), y_val.to(device)
            outputs = model(X_val)
            loss = criterion(outputs, y_val)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += y_val.size(0)
            correct += (predicted == y_val).sum().item()

    avg_val_loss = val_loss / len(val_loader)
    val_acc = correct / total
    val_losses.append(avg_val_loss)
    val_accs.append(val_acc)

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | "
          f"Val Acc: {val_acc*100:.2f}%")

print("\n Training complete!")

# Save model
torch.save(model.state_dict(), "models/neurogrow_cnn_lstm.pth")
print("Model saved to models/neurogrow_cnn_lstm.pth")

In [None]:
# Visualization - Supervised

In [None]:
plt.figure(figsize=(10,5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.legend()
plt.title("Loss Curves")
plt.show()

plt.figure(figsize=(10,5))
plt.plot(val_accs, label="Validation Accuracy")
plt.legend()
plt.title("Validation Accuracy")
plt.show()

raw = all_raws[0]
psd, freqs = psd_welch(raw)
mne.viz.plot_topomap(np.mean(psd, axis=1), pos=raw.info, show=True, names=raw.ch_names)

In [None]:
# Unsupervised Autoencoder

In [None]:
# Prepare flattened data for autoencoder
X_flat = X.reshape(len(X), -1)                      # shape: (n_samples, n_channels * n_freq_bins)
X_tensor_flat = torch.tensor(X_flat, dtype=torch.float32)

# Dataset & DataLoader
batch_size_ae = 8
dataset_flat = TensorDataset(X_tensor_flat)
loader_flat = DataLoader(dataset_flat, batch_size=batch_size_ae, shuffle=True)

# Autoencoder model
class EEG_Autoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim=64):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, latent_dim)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return recon, z

input_dim = X_flat.shape[1]
latent_dim = 64
ae_model = EEG_Autoencoder(input_dim=input_dim, latent_dim=latent_dim).to(device)

# Training setup
optimizer_ae = torch.optim.Adam(ae_model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion_ae = nn.MSELoss()

ae_epochs = 30
ae_losses = []

os.makedirs("../models", exist_ok=True)
os.makedirs("../results", exist_ok=True)

print("🔁 Starting autoencoder training...")
for epoch in range(ae_epochs):
    ae_model.train()
    epoch_loss = 0.0
    nbatches = 0
    for batch in loader_flat:
        xb = batch[0].to(device)          # shape: (B, input_dim)
        optimizer_ae.zero_grad()
        recon, _ = ae_model(xb)
        loss = criterion_ae(recon, xb)
        loss.backward()
        optimizer_ae.step()
        epoch_loss += loss.item()
        nbatches += 1
    avg_loss = epoch_loss / nbatches
    ae_losses.append(avg_loss)
    print(f"AE Epoch [{epoch+1}/{ae_epochs}] - Loss: {avg_loss:.6f}")

# Save autoencoder
ae_path = "../models/autoencoder_eeg.pth"
torch.save(ae_model.state_dict(), ae_path)
print(f" Autoencoder saved to {ae_path}")

In [None]:
# Unsupervised Visualizations (full)

# 1) Plot autoencoder training loss curve
plt.figure(figsize=(8,5))
plt.plot(range(1, len(ae_losses)+1), ae_losses, marker='o')
plt.title("Autoencoder Training Loss")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.grid(True)
plt.tight_layout()
plt.savefig("../results/autoencoder_loss.png", dpi=150)
plt.show()

# 2) Obtain latent vectors for all samples (batch inference)
ae_model.eval()
with torch.no_grad():
    X_tensor_flat_device = X_tensor_flat.to(device)
    _, latent_vectors = ae_model(X_tensor_flat_device)   # shape: (n_samples, latent_dim)
latent = latent_vectors.cpu().numpy()

# 3) t-SNE visualization of latent space
tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
latent_2d = tsne.fit_transform(latent)

plt.figure(figsize=(8,6))
scatter = plt.scatter(latent_2d[:,0], latent_2d[:,1], c=labels, cmap="tab10", s=25, alpha=0.85)
plt.title("t-SNE of Autoencoder Latent Space")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")
plt.colorbar(scatter, label="Session label (example)")
plt.tight_layout()
plt.savefig("../results/ae_tsne.png", dpi=150)
plt.show()

# 4) PCA visualization (complementary)
pca = PCA(n_components=2)
pca_2d = pca.fit_transform(latent)
plt.figure(figsize=(8,6))
plt.scatter(pca_2d[:,0], pca_2d[:,1], c=labels, cmap="tab10", s=25, alpha=0.85)
plt.title("PCA of Autoencoder Latent Space")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.tight_layout()
plt.savefig("../results/ae_pca.png", dpi=150)
plt.show()

# 5) Reconstruction quality: show original vs reconstructed PSD for a few samples
num_examples = min(4, len(X_flat))
example_idx = [0, 1, 2, 3][:num_examples]

with torch.no_grad():
    recon_all, _ = ae_model(X_tensor_flat.to(device))
recon_all = recon_all.cpu().numpy()     # shape: (n_samples, input_dim)

# For plotting, reshape flattened vectors back to (n_channels, n_freq_bins)
n_channels = X.shape[1]
n_freq_bins = X.shape[2]

plt.figure(figsize=(12, 3 * num_examples))
for i, idx in enumerate(example_idx):
    orig = X_flat[idx].reshape(n_channels, n_freq_bins)
    recon = recon_all[idx].reshape(n_channels, n_freq_bins)

    # plot channel-averaged PSD (mean across channels) to reduce clutter
    orig_mean = orig.mean(axis=0)
    recon_mean = recon.mean(axis=0)

    ax = plt.subplot(num_examples, 1, i+1)
    ax.plot(orig_mean, label="Original (mean across channels)", linewidth=1)
    ax.plot(recon_mean, label="Reconstructed (mean across channels)", linewidth=1, linestyle="--")
    ax.set_title(f"Sample {idx} — Original vs Reconstructed (mean PSD)")
    ax.set_xlabel("Frequency bin index")
    ax.set_ylabel("Log PSD (a.u.)")
    ax.legend()
    ax.grid(True)

plt.tight_layout()
plt.savefig("../results/reconstruction_examples.png", dpi=150)
plt.show()

# 6) Reconstruction error per sample (to find outliers)
recon_error = np.mean((X_flat - recon_all)**2, axis=1)   # MSE per sample
plt.figure(figsize=(8,4))
plt.hist(recon_error, bins=40)
plt.title("Histogram of Reconstruction Error (MSE) per Sample")
plt.xlabel("MSE")
plt.ylabel("Count")
plt.tight_layout()
plt.savefig("../results/reconstruction_error_hist.png", dpi=150)
plt.show()

# 7) Save latent embeddings and reconstruction errors for downstream analysis
np.save("../results/latent_vectors.npy", latent)            # shape: (n_samples, latent_dim)
np.save("../results/reconstruction_error.npy", recon_error) # shape: (n_samples,)
print("✅ Saved latent vectors and reconstruction errors to ../results/")