In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [3]:

# Load embeddings and labels
data = np.load("../data/processed/subject_features.npz")
X = data["X"]   # features
y = data["y"]   # labels (0 = control, 1 = depressed)

print("Feature matrix:", X.shape)
print("Labels:", y.shape, "Positive rate:", np.mean(y))



Feature matrix: (486, 1152)
Labels: (486,) Positive rate: 0.17078189300411523


In [4]:
# Train/validation split
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

print("Train size:", X_train.shape[0], "Val size:", X_val.shape[0])

Train size: 388 Val size: 98


In [5]:
#  Standardization
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)

# Convert to tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_val_tensor   = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor   = torch.tensor(y_val, dtype=torch.long)

train_ds = TensorDataset(X_train_tensor, y_train_tensor)
val_ds   = TensorDataset(X_val_tensor, y_val_tensor)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True) #shuffle then load 32 patients batches, 
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False)

In [None]:
# Define simple MLP classifier 
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=128): #might try 64 or 256 hidden dim
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 2)   # binary classification
        )

    def forward(self, x):
        return self.net(x)

model = MLPClassifier(input_dim=X.shape[1])
print(model)



MLPClassifier(
  (net): Sequential(
    (0): Linear(in_features=1152, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=128, out_features=2, bias=True)
  )
)


In [7]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop
n_epochs = 10
for epoch in range(1, n_epochs+1):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)

        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)

    avg_loss = total_loss / len(train_loader.dataset)

    # Validation 
    model.eval()
    all_preds, all_true = [], []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_true.extend(yb.cpu().numpy())

    val_acc = accuracy_score(all_true, all_preds)
    print(f"Epoch {epoch}: train loss={avg_loss:.4f}, val acc={val_acc:.4f}")



Epoch 1: train loss=0.4542, val acc=0.9184
Epoch 2: train loss=0.1536, val acc=0.8980
Epoch 3: train loss=0.0784, val acc=0.9082
Epoch 4: train loss=0.0401, val acc=0.9082
Epoch 5: train loss=0.0255, val acc=0.9082
Epoch 6: train loss=0.0131, val acc=0.9184
Epoch 7: train loss=0.0104, val acc=0.9184
Epoch 8: train loss=0.0074, val acc=0.9184
Epoch 9: train loss=0.0068, val acc=0.9184
Epoch 10: train loss=0.0052, val acc=0.9184


In [8]:
# Final evaluation 
print("\nValidation classification report:")
print(classification_report(all_true, all_preds))


Validation classification report:
              precision    recall  f1-score   support

           0       0.95      0.95      0.95        81
           1       0.76      0.76      0.76        17

    accuracy                           0.92        98
   macro avg       0.86      0.86      0.86        98
weighted avg       0.92      0.92      0.92        98

