In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os
import numpy as np
from glob import glob
import json
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

## **Extracting Features using DenseNet-121**

In [None]:
import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import models, transforms


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load DenseNet-121 model
densenet = models.densenet121(pretrained=True).to(device)
densenet.eval()

feature_extractor = torch.nn.Sequential(*list(densenet.features.children()), torch.nn.AdaptiveAvgPool2d((1, 1)))

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet means
                         std=[0.229, 0.224, 0.225])   # ImageNet stds
])

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 78.1MB/s]


In [None]:
# Extract feature vector
def extract_features(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = preprocess(image).unsqueeze(0).to(device)  # Shape: (1, 3, 224, 224)

    with torch.no_grad():
        features = feature_extractor(input_tensor)  # Shape: (1, 1024, 1, 1)
        features = features.view(-1)  # Flatten to (1024,)
    return features.cpu()

In [None]:
splits = ["train", "val", "test"]

In [None]:
source_base = "drive/MyDrive/thesis2025/split_dataset_new"
target_base = "drive/MyDrive/thesis2025/split_dataset_densenet"

In [None]:
for split in splits:
    split_path = os.path.join(source_base, split)
    subject_ids = os.listdir(split_path)

    for subject_id in tqdm(subject_ids, desc=f"Processing {split}"):
        subject_path = os.path.join(split_path, subject_id)
        eeg_dir = os.path.join(subject_path, "eeg_stft_spectrogram2")
        audio_dir = os.path.join(subject_path, "audio_spectrogram")

        eeg_files = sorted([os.path.join(eeg_dir, f) for f in os.listdir(eeg_dir) if f.endswith(".png")])
        audio_files = sorted([os.path.join(audio_dir, f) for f in os.listdir(audio_dir) if f.endswith(".png")])

        eeg_embeddings = []
        for p in eeg_files:
            try:
                eeg_embeddings.append(extract_features(p))
            except Exception as e:
                print(f"Skipping EEG image {p}: {e}")

        audio_embeddings = []
        for p in audio_files:
            try:
                audio_embeddings.append(extract_features(p))
            except Exception as e:
                print(f"Skipping audio image {p}: {e}")

        if not eeg_embeddings or not audio_embeddings:
            print(f"Skipping subject {subject_id} due to insufficient embeddings.")
            continue

        eeg_stack = torch.stack(eeg_embeddings)  # (N, 1024)
        audio_stack = torch.stack(audio_embeddings)  # (M, 1024)

        eeg_len = eeg_stack.size(0)
        audio_len = audio_stack.size(0)

        if audio_len < eeg_len:
            padding = torch.zeros((eeg_len - audio_len, audio_stack.size(1)))
            audio_stack = torch.cat([audio_stack, padding], dim=0)
        elif eeg_len < audio_len:
            padding = torch.zeros((audio_len - eeg_len, eeg_stack.size(1)))
            eeg_stack = torch.cat([eeg_stack, padding], dim=0)

        combined = torch.cat([eeg_stack, audio_stack], dim=1)  # (max_len, 2048)
        combined_mean = combined.mean(dim=0)  # (2048,)

        save_dir = os.path.join(target_base, split, subject_id)
        os.makedirs(save_dir, exist_ok=True)

        np.save(os.path.join(save_dir, "eeg_embedding.npy"), eeg_stack.numpy())
        np.save(os.path.join(save_dir, "audio_embedding.npy"), audio_stack.numpy())

Processing train: 100%|██████████| 26/26 [10:06<00:00, 23.33s/it]
Processing val: 100%|██████████| 6/6 [02:27<00:00, 24.65s/it]
Processing test: 100%|██████████| 6/6 [02:12<00:00, 22.02s/it]


In [None]:
import os
import numpy as np

base_dir = "drive/MyDrive/thesis2025/split_dataset_densenet/train"

for subject_id in os.listdir(base_dir):
    subject_path = os.path.join(base_dir, subject_id)
    eeg_path = os.path.join(subject_path, "eeg_embedding.npy")

    if os.path.isfile(eeg_path):
        eeg_embedding = np.load(eeg_path)
        print(f"Subject {subject_id}: eeg_embedding shape = {eeg_embedding.shape}")
    else:
        print(f"Subject {subject_id}: eeg_embedding.npy not found")


Subject 02010010: eeg_embedding shape = (29, 1024)
Subject 02020008: eeg_embedding shape = (29, 1024)
Subject 02010024: eeg_embedding shape = (29, 1024)
Subject 02030017: eeg_embedding shape = (29, 1024)
Subject 02020015: eeg_embedding shape = (29, 1024)
Subject 02010023: eeg_embedding shape = (29, 1024)
Subject 02030002: eeg_embedding shape = (29, 1024)
Subject 02020023: eeg_embedding shape = (29, 1024)
Subject 02030006: eeg_embedding shape = (29, 1024)
Subject 02010025: eeg_embedding shape = (29, 1024)
Subject 02020010: eeg_embedding shape = (29, 1024)
Subject 02030007: eeg_embedding shape = (29, 1024)
Subject 02030009: eeg_embedding shape = (29, 1024)
Subject 02020026: eeg_embedding shape = (29, 1024)
Subject 02020022: eeg_embedding shape = (29, 1024)
Subject 02030005: eeg_embedding shape = (29, 1024)
Subject 02020014: eeg_embedding shape = (29, 1024)
Subject 02010036: eeg_embedding shape = (29, 1024)
Subject 02020018: eeg_embedding shape = (29, 1024)
Subject 02010005: eeg_embedding

## **AdaptiveAvg Pooling**

In [None]:
import torch
import torch.nn as nn

class ConvPoolReLUClassifier(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=256, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(256, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.pool(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        return self.fc2(x)


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from glob import glob

In [None]:
class EEGAudioDataset(Dataset):
    def __init__(self, subject_dirs):
        self.subject_dirs = subject_dirs

    def __len__(self):
        return len(self.subject_dirs)

    def __getitem__(self, idx):
        subject_path = self.subject_dirs[idx]
        subject_id = os.path.basename(subject_path)

        eeg_embedding = np.load(os.path.join(subject_path, "eeg_embedding.npy"))
        audio_embedding = np.load(os.path.join(subject_path, "audio_embedding.npy"))

        combined_embedding = np.concatenate([eeg_embedding, audio_embedding], axis=1)
        label = 1 if subject_id.startswith('0201') else 0

        combined_embedding = torch.tensor(combined_embedding, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        return combined_embedding, label


In [None]:
# ---- Training ----
def train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=30):
    best_model = None
    best_val_acc = 0

    for epoch in range(epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in tqdm(val_loader):
                x, y = x.to(device), y.to(device)
                output = model(x)
                preds = output.argmax(dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)

        val_acc = correct / total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model.state_dict()

    model.load_state_dict(best_model)
    return model


In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence

def collate_fn_padd(batch):
    """
    batch: list of tuples (x, y)
    x: tensor with shape (seq_len, feat_dim)
    y: int label
    """
    sequences = [item[0] for item in batch]  # list of (seq_len, feat_dim) tensors
    labels = torch.tensor([item[1] for item in batch], dtype=torch.long)

    padded_seqs = pad_sequence(sequences, batch_first=True, padding_value=0)

    return padded_seqs, labels


In [None]:
base_dir = "drive/MyDrive/thesis2025/split_dataset_densenet"

# Step 1: Index all subject folders from train, val, test
subject_path_map = {}
for split in ['train', 'val', 'test']:
    split_path = os.path.join(base_dir, split)
    for subj in os.listdir(split_path):
        subject_path_map[subj] = os.path.join(split_path, subj)

In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=16, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=16, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=16, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


100%|██████████| 1/1 [00:00<00:00, 51.81it/s]
100%|██████████| 1/1 [00:00<00:00, 46.51it/s]
100%|██████████| 1/1 [00:00<00:00, 56.50it/s]
100%|██████████| 1/1 [00:00<00:00, 50.58it/s]
100%|██████████| 1/1 [00:00<00:00, 57.79it/s]
100%|██████████| 1/1 [00:00<00:00, 59.36it/s]
100%|██████████| 1/1 [00:00<00:00, 56.01it/s]
100%|██████████| 1/1 [00:00<00:00, 47.94it/s]
100%|██████████| 1/1 [00:00<00:00, 45.90it/s]
100%|██████████| 1/1 [00:00<00:00, 57.90it/s]
100%|██████████| 1/1 [00:00<00:00, 57.83it/s]
100%|██████████| 1/1 [00:00<00:00, 52.21it/s]
100%|██████████| 1/1 [00:00<00:00, 58.82it/s]
100%|██████████| 1/1 [00:00<00:00, 57.45it/s]
100%|██████████| 1/1 [00:00<00:00, 32.85it/s]
100%|██████████| 1/1 [00:00<00:00, 55.45it/s]
100%|██████████| 1/1 [00:00<00:00, 54.66it/s]
100%|██████████| 1/1 [00:00<00:00, 53.56it/s]
100%|██████████| 1/1 [00:00<00:00, 56.24it/s]
100%|██████████| 1/1 [00:00<00:00, 54.73it/s]
100%|██████████| 1/1 [00:00<00:00, 55.15it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.0000, Precision: 0.0000, Recall: 0.0000, F1: 0.0000


100%|██████████| 1/1 [00:00<00:00, 54.70it/s]
100%|██████████| 1/1 [00:00<00:00, 51.93it/s]
100%|██████████| 1/1 [00:00<00:00, 46.44it/s]
100%|██████████| 1/1 [00:00<00:00, 55.69it/s]
100%|██████████| 1/1 [00:00<00:00, 50.22it/s]
100%|██████████| 1/1 [00:00<00:00, 52.83it/s]
100%|██████████| 1/1 [00:00<00:00, 56.19it/s]
100%|██████████| 1/1 [00:00<00:00, 53.31it/s]
100%|██████████| 1/1 [00:00<00:00, 52.48it/s]
100%|██████████| 1/1 [00:00<00:00, 54.57it/s]
100%|██████████| 1/1 [00:00<00:00, 51.40it/s]
100%|██████████| 1/1 [00:00<00:00, 54.24it/s]
100%|██████████| 1/1 [00:00<00:00, 31.17it/s]
100%|██████████| 1/1 [00:00<00:00, 50.54it/s]
100%|██████████| 1/1 [00:00<00:00, 54.23it/s]
100%|██████████| 1/1 [00:00<00:00, 50.67it/s]
100%|██████████| 1/1 [00:00<00:00, 53.66it/s]
100%|██████████| 1/1 [00:00<00:00, 50.97it/s]
100%|██████████| 1/1 [00:00<00:00, 53.81it/s]
100%|██████████| 1/1 [00:00<00:00, 30.34it/s]
100%|██████████| 1/1 [00:00<00:00, 49.89it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.7500, Precision: 0.8333, Recall: 0.7500, F1: 0.7333


100%|██████████| 1/1 [00:00<00:00, 50.48it/s]
100%|██████████| 1/1 [00:00<00:00, 50.00it/s]
100%|██████████| 1/1 [00:00<00:00, 48.44it/s]
100%|██████████| 1/1 [00:00<00:00, 51.33it/s]
100%|██████████| 1/1 [00:00<00:00, 71.04it/s]
100%|██████████| 1/1 [00:00<00:00, 50.30it/s]
100%|██████████| 1/1 [00:00<00:00, 57.60it/s]
100%|██████████| 1/1 [00:00<00:00, 56.97it/s]
100%|██████████| 1/1 [00:00<00:00, 53.03it/s]
100%|██████████| 1/1 [00:00<00:00, 56.05it/s]
100%|██████████| 1/1 [00:00<00:00, 30.82it/s]
100%|██████████| 1/1 [00:00<00:00, 54.84it/s]
100%|██████████| 1/1 [00:00<00:00, 60.71it/s]
100%|██████████| 1/1 [00:00<00:00, 53.43it/s]
100%|██████████| 1/1 [00:00<00:00, 56.42it/s]
100%|██████████| 1/1 [00:00<00:00, 57.89it/s]
100%|██████████| 1/1 [00:00<00:00, 55.12it/s]
100%|██████████| 1/1 [00:00<00:00, 31.44it/s]
100%|██████████| 1/1 [00:00<00:00, 53.35it/s]
100%|██████████| 1/1 [00:00<00:00, 56.53it/s]
100%|██████████| 1/1 [00:00<00:00, 53.64it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.6250, Precision: 0.3125, Recall: 0.5000, F1: 0.3846


100%|██████████| 1/1 [00:00<00:00, 34.35it/s]
100%|██████████| 1/1 [00:00<00:00, 43.97it/s]
100%|██████████| 1/1 [00:00<00:00, 25.64it/s]
100%|██████████| 1/1 [00:00<00:00, 41.37it/s]
100%|██████████| 1/1 [00:00<00:00, 39.66it/s]
100%|██████████| 1/1 [00:00<00:00, 24.19it/s]
100%|██████████| 1/1 [00:00<00:00, 30.02it/s]
100%|██████████| 1/1 [00:00<00:00, 43.81it/s]
100%|██████████| 1/1 [00:00<00:00, 43.03it/s]
100%|██████████| 1/1 [00:00<00:00, 41.69it/s]
100%|██████████| 1/1 [00:00<00:00, 23.81it/s]
100%|██████████| 1/1 [00:00<00:00, 36.42it/s]
100%|██████████| 1/1 [00:00<00:00, 42.31it/s]
100%|██████████| 1/1 [00:00<00:00, 43.01it/s]
100%|██████████| 1/1 [00:00<00:00, 42.64it/s]
100%|██████████| 1/1 [00:00<00:00, 44.80it/s]
100%|██████████| 1/1 [00:00<00:00, 38.85it/s]
100%|██████████| 1/1 [00:00<00:00, 43.15it/s]
100%|██████████| 1/1 [00:00<00:00, 26.80it/s]
100%|██████████| 1/1 [00:00<00:00, 44.02it/s]
100%|██████████| 1/1 [00:00<00:00, 40.61it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.5714, Precision: 0.2857, Recall: 0.5000, F1: 0.3636


100%|██████████| 1/1 [00:00<00:00, 42.85it/s]
100%|██████████| 1/1 [00:00<00:00, 43.76it/s]
100%|██████████| 1/1 [00:00<00:00, 41.94it/s]
100%|██████████| 1/1 [00:00<00:00, 43.31it/s]
100%|██████████| 1/1 [00:00<00:00, 43.96it/s]
100%|██████████| 1/1 [00:00<00:00, 43.58it/s]
100%|██████████| 1/1 [00:00<00:00, 12.09it/s]
100%|██████████| 1/1 [00:00<00:00, 12.31it/s]
100%|██████████| 1/1 [00:00<00:00, 12.26it/s]
100%|██████████| 1/1 [00:00<00:00, 24.50it/s]
100%|██████████| 1/1 [00:00<00:00, 20.05it/s]
100%|██████████| 1/1 [00:00<00:00, 32.39it/s]
100%|██████████| 1/1 [00:00<00:00, 19.91it/s]
100%|██████████| 1/1 [00:00<00:00, 20.07it/s]
100%|██████████| 1/1 [00:00<00:00, 15.73it/s]
100%|██████████| 1/1 [00:00<00:00, 22.07it/s]
100%|██████████| 1/1 [00:00<00:00, 37.83it/s]
100%|██████████| 1/1 [00:00<00:00, 30.38it/s]
100%|██████████| 1/1 [00:00<00:00, 31.07it/s]
100%|██████████| 1/1 [00:00<00:00, 41.42it/s]
100%|██████████| 1/1 [00:00<00:00, 30.10it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.5714, Precision: 0.5500, Recall: 0.5417, F1: 0.5333

5-Fold CV Results:
Mean Accuracy  = 0.5036 ± 0.2601
Mean Precision = 0.3963 ± 0.2796
Mean Recall    = 0.4583 ± 0.2472
Mean F1-Score  = 0.4030 ± 0.2410





## **Max Pooling**

In [None]:
import torch
import torch.nn as nn

class ConvPoolReLUClassifier(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=256, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveMaxPool1d(1)  # Changed from Avg to Max
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(256, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.pool(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        return self.fc2(x)


In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=16, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=16, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=16, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


100%|██████████| 1/1 [00:00<00:00, 59.75it/s]
100%|██████████| 1/1 [00:00<00:00, 53.05it/s]
100%|██████████| 1/1 [00:00<00:00, 52.24it/s]
100%|██████████| 1/1 [00:00<00:00, 51.92it/s]
100%|██████████| 1/1 [00:00<00:00, 30.43it/s]
100%|██████████| 1/1 [00:00<00:00, 58.38it/s]
100%|██████████| 1/1 [00:00<00:00, 53.86it/s]
100%|██████████| 1/1 [00:00<00:00, 53.14it/s]
100%|██████████| 1/1 [00:00<00:00, 55.83it/s]
100%|██████████| 1/1 [00:00<00:00, 57.73it/s]
100%|██████████| 1/1 [00:00<00:00, 48.03it/s]
100%|██████████| 1/1 [00:00<00:00, 49.34it/s]
100%|██████████| 1/1 [00:00<00:00, 45.43it/s]
100%|██████████| 1/1 [00:00<00:00, 55.79it/s]
100%|██████████| 1/1 [00:00<00:00, 59.13it/s]
100%|██████████| 1/1 [00:00<00:00, 47.92it/s]
100%|██████████| 1/1 [00:00<00:00, 53.90it/s]
100%|██████████| 1/1 [00:00<00:00, 43.57it/s]
100%|██████████| 1/1 [00:00<00:00, 49.84it/s]
100%|██████████| 1/1 [00:00<00:00, 45.21it/s]
100%|██████████| 1/1 [00:00<00:00, 58.26it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.1250, Precision: 0.1000, Recall: 0.1250, F1: 0.1111


100%|██████████| 1/1 [00:00<00:00, 52.34it/s]
100%|██████████| 1/1 [00:00<00:00, 55.47it/s]
100%|██████████| 1/1 [00:00<00:00, 31.77it/s]
100%|██████████| 1/1 [00:00<00:00, 55.15it/s]
100%|██████████| 1/1 [00:00<00:00, 50.97it/s]
100%|██████████| 1/1 [00:00<00:00, 57.41it/s]
100%|██████████| 1/1 [00:00<00:00, 56.53it/s]
100%|██████████| 1/1 [00:00<00:00, 56.95it/s]
100%|██████████| 1/1 [00:00<00:00, 44.26it/s]
100%|██████████| 1/1 [00:00<00:00, 54.64it/s]
100%|██████████| 1/1 [00:00<00:00, 54.39it/s]
100%|██████████| 1/1 [00:00<00:00, 49.80it/s]
100%|██████████| 1/1 [00:00<00:00, 56.80it/s]
100%|██████████| 1/1 [00:00<00:00, 59.78it/s]
100%|██████████| 1/1 [00:00<00:00, 55.58it/s]
100%|██████████| 1/1 [00:00<00:00, 55.86it/s]
100%|██████████| 1/1 [00:00<00:00, 53.96it/s]
100%|██████████| 1/1 [00:00<00:00, 56.63it/s]
100%|██████████| 1/1 [00:00<00:00, 37.42it/s]
100%|██████████| 1/1 [00:00<00:00, 41.58it/s]
100%|██████████| 1/1 [00:00<00:00, 38.39it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.5000, Precision: 0.5000, Recall: 0.5000, F1: 0.5000


100%|██████████| 1/1 [00:00<00:00, 31.18it/s]
100%|██████████| 1/1 [00:00<00:00, 55.83it/s]
100%|██████████| 1/1 [00:00<00:00, 33.20it/s]
100%|██████████| 1/1 [00:00<00:00, 42.61it/s]
100%|██████████| 1/1 [00:00<00:00, 57.59it/s]
100%|██████████| 1/1 [00:00<00:00, 59.84it/s]
100%|██████████| 1/1 [00:00<00:00, 58.14it/s]
100%|██████████| 1/1 [00:00<00:00, 27.77it/s]
100%|██████████| 1/1 [00:00<00:00, 57.19it/s]
100%|██████████| 1/1 [00:00<00:00, 54.26it/s]
100%|██████████| 1/1 [00:00<00:00, 54.32it/s]
100%|██████████| 1/1 [00:00<00:00, 14.90it/s]
100%|██████████| 1/1 [00:00<00:00, 25.57it/s]
100%|██████████| 1/1 [00:00<00:00, 39.14it/s]
100%|██████████| 1/1 [00:00<00:00, 17.17it/s]
100%|██████████| 1/1 [00:00<00:00, 35.49it/s]
100%|██████████| 1/1 [00:00<00:00, 47.50it/s]
100%|██████████| 1/1 [00:00<00:00, 56.39it/s]
100%|██████████| 1/1 [00:00<00:00, 52.02it/s]
100%|██████████| 1/1 [00:00<00:00, 55.63it/s]
100%|██████████| 1/1 [00:00<00:00, 45.74it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.8750, Precision: 0.9167, Recall: 0.8333, F1: 0.8545


100%|██████████| 1/1 [00:00<00:00, 26.94it/s]
100%|██████████| 1/1 [00:00<00:00, 39.41it/s]
100%|██████████| 1/1 [00:00<00:00, 35.04it/s]
100%|██████████| 1/1 [00:00<00:00, 19.14it/s]
100%|██████████| 1/1 [00:00<00:00, 39.50it/s]
100%|██████████| 1/1 [00:00<00:00, 44.41it/s]
100%|██████████| 1/1 [00:00<00:00, 40.44it/s]
100%|██████████| 1/1 [00:00<00:00, 25.38it/s]
100%|██████████| 1/1 [00:00<00:00, 41.76it/s]
100%|██████████| 1/1 [00:00<00:00, 41.91it/s]
100%|██████████| 1/1 [00:00<00:00, 37.42it/s]
100%|██████████| 1/1 [00:00<00:00, 37.70it/s]
100%|██████████| 1/1 [00:00<00:00, 43.96it/s]
100%|██████████| 1/1 [00:00<00:00, 20.73it/s]
100%|██████████| 1/1 [00:00<00:00, 13.65it/s]
100%|██████████| 1/1 [00:00<00:00, 25.72it/s]
100%|██████████| 1/1 [00:00<00:00, 43.20it/s]
100%|██████████| 1/1 [00:00<00:00, 29.21it/s]
100%|██████████| 1/1 [00:00<00:00, 26.51it/s]
100%|██████████| 1/1 [00:00<00:00, 36.45it/s]
100%|██████████| 1/1 [00:00<00:00, 37.35it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.5714, Precision: 0.5500, Recall: 0.5417, F1: 0.5333


100%|██████████| 1/1 [00:00<00:00, 40.83it/s]
100%|██████████| 1/1 [00:00<00:00, 41.20it/s]
100%|██████████| 1/1 [00:00<00:00, 29.37it/s]
100%|██████████| 1/1 [00:00<00:00, 40.33it/s]
100%|██████████| 1/1 [00:00<00:00, 26.50it/s]
100%|██████████| 1/1 [00:00<00:00, 27.38it/s]
100%|██████████| 1/1 [00:00<00:00, 44.73it/s]
100%|██████████| 1/1 [00:00<00:00, 43.31it/s]
100%|██████████| 1/1 [00:00<00:00, 40.93it/s]
100%|██████████| 1/1 [00:00<00:00, 42.81it/s]
100%|██████████| 1/1 [00:00<00:00, 17.36it/s]
100%|██████████| 1/1 [00:00<00:00, 16.99it/s]
100%|██████████| 1/1 [00:00<00:00,  8.58it/s]
100%|██████████| 1/1 [00:00<00:00, 29.11it/s]
100%|██████████| 1/1 [00:00<00:00, 41.37it/s]
100%|██████████| 1/1 [00:00<00:00, 35.55it/s]
100%|██████████| 1/1 [00:00<00:00, 18.76it/s]
100%|██████████| 1/1 [00:00<00:00, 34.22it/s]
100%|██████████| 1/1 [00:00<00:00, 24.73it/s]
100%|██████████| 1/1 [00:00<00:00, 30.58it/s]
100%|██████████| 1/1 [00:00<00:00, 25.15it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.4286, Precision: 0.4167, Recall: 0.4167, F1: 0.4167

5-Fold CV Results:
Mean Accuracy  = 0.5000 ± 0.2414
Mean Precision = 0.4967 ± 0.2619
Mean Recall    = 0.4833 ± 0.2276
Mean F1-Score  = 0.4831 ± 0.2381


## **GlobalAvg**

In [None]:
import torch
import torch.nn as nn

class ConvPoolReLUClassifier(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=256, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(256, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2)        # (batch, input_dim, seq_len)
        x = self.conv1(x)            # (batch, 256, seq_len)
        x = self.relu(x)
        x = torch.mean(x, dim=2)     # Global average pooling across time dimension
        x = self.relu(self.fc1(x))   # (batch, hidden_dim)
        return self.fc2(x)           # (batch, num_classes)


In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=16, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=16, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=16, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


100%|██████████| 1/1 [00:00<00:00, 28.41it/s]
100%|██████████| 1/1 [00:00<00:00, 52.84it/s]
100%|██████████| 1/1 [00:00<00:00, 40.53it/s]
100%|██████████| 1/1 [00:00<00:00, 49.76it/s]
100%|██████████| 1/1 [00:00<00:00, 43.02it/s]
100%|██████████| 1/1 [00:00<00:00, 50.09it/s]
100%|██████████| 1/1 [00:00<00:00, 50.13it/s]
100%|██████████| 1/1 [00:00<00:00, 27.45it/s]
100%|██████████| 1/1 [00:00<00:00, 49.25it/s]
100%|██████████| 1/1 [00:00<00:00, 55.54it/s]
100%|██████████| 1/1 [00:00<00:00, 55.56it/s]
100%|██████████| 1/1 [00:00<00:00, 53.65it/s]
100%|██████████| 1/1 [00:00<00:00, 52.42it/s]
100%|██████████| 1/1 [00:00<00:00, 58.43it/s]
100%|██████████| 1/1 [00:00<00:00, 54.74it/s]
100%|██████████| 1/1 [00:00<00:00, 54.09it/s]
100%|██████████| 1/1 [00:00<00:00, 53.66it/s]
100%|██████████| 1/1 [00:00<00:00, 51.48it/s]
100%|██████████| 1/1 [00:00<00:00, 44.93it/s]
100%|██████████| 1/1 [00:00<00:00, 55.95it/s]
100%|██████████| 1/1 [00:00<00:00, 53.90it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.1250, Precision: 0.1000, Recall: 0.1250, F1: 0.1111


100%|██████████| 1/1 [00:00<00:00, 55.97it/s]
100%|██████████| 1/1 [00:00<00:00, 55.80it/s]
100%|██████████| 1/1 [00:00<00:00, 52.59it/s]
100%|██████████| 1/1 [00:00<00:00, 51.20it/s]
100%|██████████| 1/1 [00:00<00:00, 47.34it/s]
100%|██████████| 1/1 [00:00<00:00, 51.40it/s]
100%|██████████| 1/1 [00:00<00:00, 27.75it/s]
100%|██████████| 1/1 [00:00<00:00, 53.20it/s]
100%|██████████| 1/1 [00:00<00:00, 44.78it/s]
100%|██████████| 1/1 [00:00<00:00, 42.70it/s]
100%|██████████| 1/1 [00:00<00:00, 33.07it/s]
100%|██████████| 1/1 [00:00<00:00, 34.70it/s]
100%|██████████| 1/1 [00:00<00:00, 39.31it/s]
100%|██████████| 1/1 [00:00<00:00, 42.26it/s]
100%|██████████| 1/1 [00:00<00:00, 36.06it/s]
100%|██████████| 1/1 [00:00<00:00, 41.15it/s]
100%|██████████| 1/1 [00:00<00:00, 38.58it/s]
100%|██████████| 1/1 [00:00<00:00, 18.39it/s]
100%|██████████| 1/1 [00:00<00:00, 46.73it/s]
100%|██████████| 1/1 [00:00<00:00, 33.95it/s]
100%|██████████| 1/1 [00:00<00:00, 24.78it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.6250, Precision: 0.7857, Recall: 0.6250, F1: 0.5636


100%|██████████| 1/1 [00:00<00:00, 60.27it/s]
100%|██████████| 1/1 [00:00<00:00, 53.63it/s]
100%|██████████| 1/1 [00:00<00:00, 50.43it/s]
100%|██████████| 1/1 [00:00<00:00, 60.70it/s]
100%|██████████| 1/1 [00:00<00:00, 30.03it/s]
100%|██████████| 1/1 [00:00<00:00, 32.57it/s]
100%|██████████| 1/1 [00:00<00:00, 59.58it/s]
100%|██████████| 1/1 [00:00<00:00, 56.99it/s]
100%|██████████| 1/1 [00:00<00:00, 55.05it/s]
100%|██████████| 1/1 [00:00<00:00, 50.47it/s]
100%|██████████| 1/1 [00:00<00:00, 56.94it/s]
100%|██████████| 1/1 [00:00<00:00, 57.02it/s]
100%|██████████| 1/1 [00:00<00:00, 30.16it/s]
100%|██████████| 1/1 [00:00<00:00, 52.01it/s]
100%|██████████| 1/1 [00:00<00:00, 49.83it/s]
100%|██████████| 1/1 [00:00<00:00, 18.71it/s]
100%|██████████| 1/1 [00:00<00:00, 43.23it/s]
100%|██████████| 1/1 [00:00<00:00, 31.04it/s]
100%|██████████| 1/1 [00:00<00:00, 31.02it/s]
100%|██████████| 1/1 [00:00<00:00, 25.38it/s]
100%|██████████| 1/1 [00:00<00:00, 46.37it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.8750, Precision: 0.9167, Recall: 0.8333, F1: 0.8545


100%|██████████| 1/1 [00:00<00:00, 40.57it/s]
100%|██████████| 1/1 [00:00<00:00, 37.06it/s]
100%|██████████| 1/1 [00:00<00:00, 32.61it/s]
100%|██████████| 1/1 [00:00<00:00, 28.36it/s]
100%|██████████| 1/1 [00:00<00:00, 22.55it/s]
100%|██████████| 1/1 [00:00<00:00, 41.32it/s]
100%|██████████| 1/1 [00:00<00:00, 40.38it/s]
100%|██████████| 1/1 [00:00<00:00, 35.86it/s]
100%|██████████| 1/1 [00:00<00:00, 20.11it/s]
100%|██████████| 1/1 [00:00<00:00, 21.88it/s]
100%|██████████| 1/1 [00:00<00:00, 30.18it/s]
100%|██████████| 1/1 [00:00<00:00, 29.07it/s]
100%|██████████| 1/1 [00:00<00:00,  5.37it/s]
100%|██████████| 1/1 [00:00<00:00, 21.21it/s]
100%|██████████| 1/1 [00:00<00:00, 16.04it/s]
100%|██████████| 1/1 [00:00<00:00, 14.60it/s]
100%|██████████| 1/1 [00:00<00:00, 10.95it/s]
100%|██████████| 1/1 [00:00<00:00, 30.92it/s]
100%|██████████| 1/1 [00:00<00:00, 31.88it/s]
100%|██████████| 1/1 [00:00<00:00, 32.45it/s]
100%|██████████| 1/1 [00:00<00:00, 33.72it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.7143, Precision: 0.7083, Recall: 0.7083, F1: 0.7083


100%|██████████| 1/1 [00:00<00:00, 40.79it/s]
100%|██████████| 1/1 [00:00<00:00, 20.40it/s]
100%|██████████| 1/1 [00:00<00:00, 41.93it/s]
100%|██████████| 1/1 [00:00<00:00, 42.63it/s]
100%|██████████| 1/1 [00:00<00:00, 41.30it/s]
100%|██████████| 1/1 [00:00<00:00, 42.99it/s]
100%|██████████| 1/1 [00:00<00:00, 42.50it/s]
100%|██████████| 1/1 [00:00<00:00, 31.29it/s]
100%|██████████| 1/1 [00:00<00:00, 40.90it/s]
100%|██████████| 1/1 [00:00<00:00, 43.76it/s]
100%|██████████| 1/1 [00:00<00:00, 25.63it/s]
100%|██████████| 1/1 [00:00<00:00, 41.16it/s]
100%|██████████| 1/1 [00:00<00:00, 16.93it/s]
100%|██████████| 1/1 [00:00<00:00, 40.25it/s]
100%|██████████| 1/1 [00:00<00:00, 35.30it/s]
100%|██████████| 1/1 [00:00<00:00, 42.66it/s]
100%|██████████| 1/1 [00:00<00:00, 30.31it/s]
100%|██████████| 1/1 [00:00<00:00, 34.69it/s]
100%|██████████| 1/1 [00:00<00:00, 21.06it/s]
100%|██████████| 1/1 [00:00<00:00, 40.74it/s]
100%|██████████| 1/1 [00:00<00:00, 40.14it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.4286, Precision: 0.4167, Recall: 0.4167, F1: 0.4167

5-Fold CV Results:
Mean Accuracy  = 0.5536 ± 0.2583
Mean Precision = 0.5855 ± 0.2929
Mean Recall    = 0.5417 ± 0.2486
Mean F1-Score  = 0.5309 ± 0.2556





## **Attention Pooling**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionPooling(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.attn_weights = nn.Linear(input_dim, 1)

    def forward(self, x):
        # x: (batch, seq_len, dim)
        scores = self.attn_weights(x)             # (batch, seq_len, 1)
        weights = F.softmax(scores, dim=1)        # (batch, seq_len, 1)
        pooled = torch.sum(weights * x, dim=1)    # (batch, dim)
        return pooled

class ConvPoolReLUClassifier(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=256, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.attn_pool = AttentionPooling(input_dim=256)
        self.fc1 = nn.Linear(256, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2)                    # (batch, input_dim, seq_len)
        x = self.conv1(x)                        # (batch, 256, seq_len)
        x = self.relu(x)
        x = x.transpose(1, 2)                    # (batch, seq_len, 256)
        x = self.attn_pool(x)                    # (batch, 256)
        x = self.relu(self.fc1(x))               # (batch, hidden_dim)
        return self.fc2(x)                       # (batch, num_classes)


In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=16, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=16, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=16, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


100%|██████████| 1/1 [00:00<00:00, 28.68it/s]
100%|██████████| 1/1 [00:00<00:00, 52.02it/s]
100%|██████████| 1/1 [00:00<00:00, 31.18it/s]
100%|██████████| 1/1 [00:00<00:00, 33.50it/s]
100%|██████████| 1/1 [00:00<00:00, 35.86it/s]
100%|██████████| 1/1 [00:00<00:00, 51.87it/s]
100%|██████████| 1/1 [00:00<00:00, 11.94it/s]
100%|██████████| 1/1 [00:00<00:00, 30.80it/s]
100%|██████████| 1/1 [00:00<00:00, 38.20it/s]
100%|██████████| 1/1 [00:00<00:00, 56.25it/s]
100%|██████████| 1/1 [00:00<00:00, 64.39it/s]
100%|██████████| 1/1 [00:00<00:00, 52.12it/s]
100%|██████████| 1/1 [00:00<00:00, 59.66it/s]
100%|██████████| 1/1 [00:00<00:00, 20.60it/s]
100%|██████████| 1/1 [00:00<00:00, 59.46it/s]
100%|██████████| 1/1 [00:00<00:00, 60.33it/s]
100%|██████████| 1/1 [00:00<00:00, 53.34it/s]
100%|██████████| 1/1 [00:00<00:00, 57.59it/s]
100%|██████████| 1/1 [00:00<00:00, 34.48it/s]
100%|██████████| 1/1 [00:00<00:00, 54.63it/s]
100%|██████████| 1/1 [00:00<00:00, 28.86it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.2500, Precision: 0.1667, Recall: 0.2500, F1: 0.2000


100%|██████████| 1/1 [00:00<00:00, 54.87it/s]
100%|██████████| 1/1 [00:00<00:00, 58.33it/s]
100%|██████████| 1/1 [00:00<00:00, 51.49it/s]
100%|██████████| 1/1 [00:00<00:00, 27.47it/s]
100%|██████████| 1/1 [00:00<00:00, 50.25it/s]
100%|██████████| 1/1 [00:00<00:00, 33.37it/s]
100%|██████████| 1/1 [00:00<00:00, 56.58it/s]
100%|██████████| 1/1 [00:00<00:00, 54.49it/s]
100%|██████████| 1/1 [00:00<00:00, 54.19it/s]
100%|██████████| 1/1 [00:00<00:00, 55.47it/s]
100%|██████████| 1/1 [00:00<00:00, 55.49it/s]
100%|██████████| 1/1 [00:00<00:00, 53.66it/s]
100%|██████████| 1/1 [00:00<00:00, 30.90it/s]
100%|██████████| 1/1 [00:00<00:00, 54.69it/s]
100%|██████████| 1/1 [00:00<00:00, 52.08it/s]
100%|██████████| 1/1 [00:00<00:00, 55.31it/s]
100%|██████████| 1/1 [00:00<00:00, 57.11it/s]
100%|██████████| 1/1 [00:00<00:00, 50.29it/s]
100%|██████████| 1/1 [00:00<00:00, 53.39it/s]
100%|██████████| 1/1 [00:00<00:00, 30.75it/s]
100%|██████████| 1/1 [00:00<00:00, 56.72it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.5000, Precision: 0.5000, Recall: 0.5000, F1: 0.5000


100%|██████████| 1/1 [00:00<00:00, 49.44it/s]
100%|██████████| 1/1 [00:00<00:00, 47.75it/s]
100%|██████████| 1/1 [00:00<00:00, 55.76it/s]
100%|██████████| 1/1 [00:00<00:00, 41.67it/s]
100%|██████████| 1/1 [00:00<00:00, 37.42it/s]
100%|██████████| 1/1 [00:00<00:00, 30.04it/s]
100%|██████████| 1/1 [00:00<00:00, 32.81it/s]
100%|██████████| 1/1 [00:00<00:00, 26.22it/s]
100%|██████████| 1/1 [00:00<00:00, 40.02it/s]
100%|██████████| 1/1 [00:00<00:00, 34.65it/s]
100%|██████████| 1/1 [00:00<00:00, 29.33it/s]
100%|██████████| 1/1 [00:00<00:00, 54.07it/s]
100%|██████████| 1/1 [00:00<00:00, 26.26it/s]
100%|██████████| 1/1 [00:00<00:00, 20.27it/s]
100%|██████████| 1/1 [00:00<00:00, 55.80it/s]
100%|██████████| 1/1 [00:00<00:00, 56.24it/s]
100%|██████████| 1/1 [00:00<00:00, 51.75it/s]
100%|██████████| 1/1 [00:00<00:00, 31.04it/s]
100%|██████████| 1/1 [00:00<00:00, 34.06it/s]
100%|██████████| 1/1 [00:00<00:00, 63.41it/s]
100%|██████████| 1/1 [00:00<00:00, 60.85it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.5000, Precision: 0.2857, Recall: 0.4000, F1: 0.3333


100%|██████████| 1/1 [00:00<00:00, 42.90it/s]
100%|██████████| 1/1 [00:00<00:00, 44.08it/s]
100%|██████████| 1/1 [00:00<00:00, 40.35it/s]
100%|██████████| 1/1 [00:00<00:00, 26.64it/s]
100%|██████████| 1/1 [00:00<00:00, 44.11it/s]
100%|██████████| 1/1 [00:00<00:00, 41.40it/s]
100%|██████████| 1/1 [00:00<00:00, 42.12it/s]
100%|██████████| 1/1 [00:00<00:00, 44.39it/s]
100%|██████████| 1/1 [00:00<00:00, 38.92it/s]
100%|██████████| 1/1 [00:00<00:00, 41.38it/s]
100%|██████████| 1/1 [00:00<00:00, 42.18it/s]
100%|██████████| 1/1 [00:00<00:00, 27.93it/s]
100%|██████████| 1/1 [00:00<00:00, 39.69it/s]
100%|██████████| 1/1 [00:00<00:00, 42.78it/s]
100%|██████████| 1/1 [00:00<00:00, 40.49it/s]
100%|██████████| 1/1 [00:00<00:00, 43.48it/s]
100%|██████████| 1/1 [00:00<00:00, 37.29it/s]
100%|██████████| 1/1 [00:00<00:00, 41.79it/s]
100%|██████████| 1/1 [00:00<00:00, 44.77it/s]
100%|██████████| 1/1 [00:00<00:00, 41.18it/s]
100%|██████████| 1/1 [00:00<00:00, 27.07it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.5714, Precision: 0.2857, Recall: 0.5000, F1: 0.3636


100%|██████████| 1/1 [00:00<00:00, 27.14it/s]
100%|██████████| 1/1 [00:00<00:00, 44.05it/s]
100%|██████████| 1/1 [00:00<00:00, 37.65it/s]
100%|██████████| 1/1 [00:00<00:00, 41.41it/s]
100%|██████████| 1/1 [00:00<00:00, 41.05it/s]
100%|██████████| 1/1 [00:00<00:00, 28.74it/s]
100%|██████████| 1/1 [00:00<00:00, 10.69it/s]
100%|██████████| 1/1 [00:00<00:00, 30.52it/s]
100%|██████████| 1/1 [00:00<00:00, 21.02it/s]
100%|██████████| 1/1 [00:00<00:00, 30.80it/s]
100%|██████████| 1/1 [00:00<00:00, 13.54it/s]
100%|██████████| 1/1 [00:00<00:00, 32.97it/s]
100%|██████████| 1/1 [00:00<00:00, 26.72it/s]
100%|██████████| 1/1 [00:00<00:00, 19.14it/s]
100%|██████████| 1/1 [00:00<00:00,  9.36it/s]
100%|██████████| 1/1 [00:00<00:00, 27.14it/s]
100%|██████████| 1/1 [00:00<00:00, 25.56it/s]
100%|██████████| 1/1 [00:00<00:00, 42.81it/s]
100%|██████████| 1/1 [00:00<00:00, 37.26it/s]
100%|██████████| 1/1 [00:00<00:00, 45.34it/s]
100%|██████████| 1/1 [00:00<00:00, 43.59it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.2857, Precision: 0.2000, Recall: 0.2500, F1: 0.2222

5-Fold CV Results:
Mean Accuracy  = 0.4214 ± 0.1286
Mean Precision = 0.2876 ± 0.1161
Mean Recall    = 0.3800 ± 0.1122
Mean F1-Score  = 0.3238 ± 0.1080


### **After Manual Hyperparameter Tuning**

In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=32, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=32, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048, hidden_dim = 1024).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs =40)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")

print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")

100%|██████████| 1/1 [00:00<00:00, 19.77it/s]
100%|██████████| 1/1 [00:00<00:00, 42.73it/s]
100%|██████████| 1/1 [00:00<00:00, 31.01it/s]
100%|██████████| 1/1 [00:00<00:00, 52.45it/s]
100%|██████████| 1/1 [00:00<00:00, 45.58it/s]
100%|██████████| 1/1 [00:00<00:00, 35.16it/s]
100%|██████████| 1/1 [00:00<00:00, 42.22it/s]
100%|██████████| 1/1 [00:00<00:00, 33.45it/s]
100%|██████████| 1/1 [00:00<00:00, 39.18it/s]
100%|██████████| 1/1 [00:00<00:00, 22.60it/s]
100%|██████████| 1/1 [00:00<00:00, 46.56it/s]
100%|██████████| 1/1 [00:00<00:00, 52.58it/s]
100%|██████████| 1/1 [00:00<00:00, 52.44it/s]
100%|██████████| 1/1 [00:00<00:00, 35.86it/s]
100%|██████████| 1/1 [00:00<00:00, 50.80it/s]
100%|██████████| 1/1 [00:00<00:00, 43.25it/s]
100%|██████████| 1/1 [00:00<00:00, 46.87it/s]
100%|██████████| 1/1 [00:00<00:00, 57.47it/s]
100%|██████████| 1/1 [00:00<00:00, 52.00it/s]
100%|██████████| 1/1 [00:00<00:00, 56.18it/s]
100%|██████████| 1/1 [00:00<00:00, 50.50it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.2500, Precision: 0.1667, Recall: 0.2500, F1: 0.2000


100%|██████████| 1/1 [00:00<00:00, 47.05it/s]
100%|██████████| 1/1 [00:00<00:00, 51.94it/s]
100%|██████████| 1/1 [00:00<00:00, 21.04it/s]
100%|██████████| 1/1 [00:00<00:00, 57.85it/s]
100%|██████████| 1/1 [00:00<00:00, 57.60it/s]
100%|██████████| 1/1 [00:00<00:00, 47.95it/s]
100%|██████████| 1/1 [00:00<00:00, 31.15it/s]
100%|██████████| 1/1 [00:00<00:00, 54.52it/s]
100%|██████████| 1/1 [00:00<00:00, 54.31it/s]
100%|██████████| 1/1 [00:00<00:00, 53.73it/s]
100%|██████████| 1/1 [00:00<00:00, 22.32it/s]
100%|██████████| 1/1 [00:00<00:00, 37.99it/s]
100%|██████████| 1/1 [00:00<00:00, 18.16it/s]
100%|██████████| 1/1 [00:00<00:00, 24.98it/s]
100%|██████████| 1/1 [00:00<00:00, 42.20it/s]
100%|██████████| 1/1 [00:00<00:00, 50.87it/s]
100%|██████████| 1/1 [00:00<00:00, 49.50it/s]
100%|██████████| 1/1 [00:00<00:00,  7.22it/s]
100%|██████████| 1/1 [00:00<00:00, 13.56it/s]
100%|██████████| 1/1 [00:00<00:00, 34.62it/s]
100%|██████████| 1/1 [00:00<00:00, 26.30it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.7500, Precision: 0.8333, Recall: 0.7500, F1: 0.7333


100%|██████████| 1/1 [00:00<00:00, 54.21it/s]
100%|██████████| 1/1 [00:00<00:00, 53.75it/s]
100%|██████████| 1/1 [00:00<00:00, 31.21it/s]
100%|██████████| 1/1 [00:00<00:00, 48.11it/s]
100%|██████████| 1/1 [00:00<00:00, 57.30it/s]
100%|██████████| 1/1 [00:00<00:00, 55.64it/s]
100%|██████████| 1/1 [00:00<00:00, 52.84it/s]
100%|██████████| 1/1 [00:00<00:00, 58.49it/s]
100%|██████████| 1/1 [00:00<00:00, 56.59it/s]
100%|██████████| 1/1 [00:00<00:00, 28.31it/s]
100%|██████████| 1/1 [00:00<00:00, 52.81it/s]
100%|██████████| 1/1 [00:00<00:00, 57.30it/s]
100%|██████████| 1/1 [00:00<00:00, 54.88it/s]
100%|██████████| 1/1 [00:00<00:00, 53.58it/s]
100%|██████████| 1/1 [00:00<00:00, 60.05it/s]
100%|██████████| 1/1 [00:00<00:00, 49.07it/s]
100%|██████████| 1/1 [00:00<00:00, 30.60it/s]
100%|██████████| 1/1 [00:00<00:00, 54.02it/s]
100%|██████████| 1/1 [00:00<00:00, 57.03it/s]
100%|██████████| 1/1 [00:00<00:00, 56.40it/s]
100%|██████████| 1/1 [00:00<00:00, 53.42it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.7500, Precision: 0.7333, Recall: 0.7333, F1: 0.7333


100%|██████████| 1/1 [00:00<00:00, 45.66it/s]
100%|██████████| 1/1 [00:00<00:00, 45.47it/s]
100%|██████████| 1/1 [00:00<00:00, 41.07it/s]
100%|██████████| 1/1 [00:00<00:00, 42.08it/s]
100%|██████████| 1/1 [00:00<00:00, 44.18it/s]
100%|██████████| 1/1 [00:00<00:00, 44.80it/s]
100%|██████████| 1/1 [00:00<00:00, 41.53it/s]
100%|██████████| 1/1 [00:00<00:00, 26.51it/s]
100%|██████████| 1/1 [00:00<00:00, 48.95it/s]
100%|██████████| 1/1 [00:00<00:00, 43.92it/s]
100%|██████████| 1/1 [00:00<00:00, 29.00it/s]
100%|██████████| 1/1 [00:00<00:00, 43.91it/s]
100%|██████████| 1/1 [00:00<00:00, 41.06it/s]
100%|██████████| 1/1 [00:00<00:00, 38.33it/s]
100%|██████████| 1/1 [00:00<00:00, 43.65it/s]
100%|██████████| 1/1 [00:00<00:00, 43.76it/s]
100%|██████████| 1/1 [00:00<00:00, 43.05it/s]
100%|██████████| 1/1 [00:00<00:00, 42.48it/s]
100%|██████████| 1/1 [00:00<00:00, 44.09it/s]
100%|██████████| 1/1 [00:00<00:00, 44.58it/s]
100%|██████████| 1/1 [00:00<00:00, 42.40it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.8571, Precision: 0.9000, Recall: 0.8333, F1: 0.8444


100%|██████████| 1/1 [00:00<00:00, 44.66it/s]
100%|██████████| 1/1 [00:00<00:00, 41.71it/s]
100%|██████████| 1/1 [00:00<00:00, 42.06it/s]
100%|██████████| 1/1 [00:00<00:00, 43.74it/s]
100%|██████████| 1/1 [00:00<00:00, 26.46it/s]
100%|██████████| 1/1 [00:00<00:00, 44.06it/s]
100%|██████████| 1/1 [00:00<00:00, 44.75it/s]
100%|██████████| 1/1 [00:00<00:00, 44.02it/s]
100%|██████████| 1/1 [00:00<00:00, 41.51it/s]
100%|██████████| 1/1 [00:00<00:00, 45.25it/s]
100%|██████████| 1/1 [00:00<00:00, 44.01it/s]
100%|██████████| 1/1 [00:00<00:00, 44.70it/s]
100%|██████████| 1/1 [00:00<00:00, 41.70it/s]
100%|██████████| 1/1 [00:00<00:00, 29.27it/s]
100%|██████████| 1/1 [00:00<00:00, 42.34it/s]
100%|██████████| 1/1 [00:00<00:00, 42.74it/s]
100%|██████████| 1/1 [00:00<00:00, 33.94it/s]
100%|██████████| 1/1 [00:00<00:00, 38.61it/s]
100%|██████████| 1/1 [00:00<00:00, 41.37it/s]
100%|██████████| 1/1 [00:00<00:00, 40.44it/s]
100%|██████████| 1/1 [00:00<00:00, 41.04it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.4286, Precision: 0.4167, Recall: 0.4167, F1: 0.4167

5-Fold CV Results:
Mean Accuracy  = 0.6071 ± 0.2292
Mean Precision = 0.6100 ± 0.2768
Mean Recall    = 0.5967 ± 0.2240
Mean F1-Score  = 0.5856 ± 0.2399





# **Incorporating Text Modality**

In [None]:
class EEGAudioTextDataset(Dataset):
    def __init__(self, subject_dirs):
        self.subject_dirs = subject_dirs

    def __len__(self):
        return len(self.subject_dirs)

    def __getitem__(self, idx):
        subject_path = self.subject_dirs[idx]
        subject_id = os.path.basename(subject_path)

        eeg_embedding = np.load(os.path.join(subject_path, "eeg_embedding.npy"))
        audio_embedding = np.load(os.path.join(subject_path, "audio_embedding.npy"))
        text_embedding = np.load(os.path.join(subject_path, "text_embedding.npy"))

        combined_embedding = np.concatenate([eeg_embedding, audio_embedding, text_embedding], axis=1)
        label = 1 if subject_id.startswith('0201') else 0

        combined_embedding = torch.tensor(combined_embedding, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        return combined_embedding, label

In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioTextDataset(train_subjs)
    val_set = EEGAudioTextDataset(val_subjs)
    test_set = EEGAudioTextDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=32, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=32, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2816, hidden_dim = 1024).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs =40)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")

print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")

100%|██████████| 1/1 [00:00<00:00, 37.63it/s]
100%|██████████| 1/1 [00:00<00:00, 16.32it/s]
100%|██████████| 1/1 [00:00<00:00, 38.43it/s]
100%|██████████| 1/1 [00:00<00:00, 36.15it/s]
100%|██████████| 1/1 [00:00<00:00, 22.01it/s]
100%|██████████| 1/1 [00:00<00:00, 35.29it/s]
100%|██████████| 1/1 [00:00<00:00, 37.09it/s]
100%|██████████| 1/1 [00:00<00:00, 35.13it/s]
100%|██████████| 1/1 [00:00<00:00, 38.75it/s]
100%|██████████| 1/1 [00:00<00:00, 36.53it/s]
100%|██████████| 1/1 [00:00<00:00, 38.74it/s]
100%|██████████| 1/1 [00:00<00:00, 22.38it/s]
100%|██████████| 1/1 [00:00<00:00, 38.70it/s]
100%|██████████| 1/1 [00:00<00:00, 34.16it/s]
100%|██████████| 1/1 [00:00<00:00, 27.24it/s]
100%|██████████| 1/1 [00:00<00:00, 40.22it/s]
100%|██████████| 1/1 [00:00<00:00, 27.71it/s]
100%|██████████| 1/1 [00:00<00:00, 40.91it/s]
100%|██████████| 1/1 [00:00<00:00, 39.86it/s]
100%|██████████| 1/1 [00:00<00:00, 38.19it/s]
100%|██████████| 1/1 [00:00<00:00, 28.17it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.2500, Precision: 0.1667, Recall: 0.2500, F1: 0.2000


100%|██████████| 1/1 [00:00<00:00, 38.62it/s]
100%|██████████| 1/1 [00:00<00:00, 40.82it/s]
100%|██████████| 1/1 [00:00<00:00, 41.77it/s]
100%|██████████| 1/1 [00:00<00:00, 36.88it/s]
100%|██████████| 1/1 [00:00<00:00, 36.84it/s]
100%|██████████| 1/1 [00:00<00:00, 34.56it/s]
100%|██████████| 1/1 [00:00<00:00, 31.78it/s]
100%|██████████| 1/1 [00:00<00:00, 24.27it/s]
100%|██████████| 1/1 [00:00<00:00, 40.34it/s]
100%|██████████| 1/1 [00:00<00:00, 40.04it/s]
100%|██████████| 1/1 [00:00<00:00, 37.47it/s]
100%|██████████| 1/1 [00:00<00:00, 39.15it/s]
100%|██████████| 1/1 [00:00<00:00, 40.72it/s]
100%|██████████| 1/1 [00:00<00:00, 39.53it/s]
100%|██████████| 1/1 [00:00<00:00, 27.47it/s]
100%|██████████| 1/1 [00:00<00:00, 37.13it/s]
100%|██████████| 1/1 [00:00<00:00, 40.33it/s]
100%|██████████| 1/1 [00:00<00:00, 38.86it/s]
100%|██████████| 1/1 [00:00<00:00, 35.87it/s]
100%|██████████| 1/1 [00:00<00:00, 40.67it/s]
100%|██████████| 1/1 [00:00<00:00, 39.69it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.7500, Precision: 0.8333, Recall: 0.7500, F1: 0.7333


100%|██████████| 1/1 [00:00<00:00, 25.17it/s]
100%|██████████| 1/1 [00:00<00:00, 12.33it/s]
100%|██████████| 1/1 [00:00<00:00, 10.67it/s]
100%|██████████| 1/1 [00:00<00:00, 40.14it/s]
100%|██████████| 1/1 [00:00<00:00, 37.09it/s]
100%|██████████| 1/1 [00:00<00:00, 40.78it/s]
100%|██████████| 1/1 [00:00<00:00, 38.82it/s]
100%|██████████| 1/1 [00:00<00:00, 25.84it/s]
100%|██████████| 1/1 [00:00<00:00, 38.37it/s]
100%|██████████| 1/1 [00:00<00:00, 38.80it/s]
100%|██████████| 1/1 [00:00<00:00, 38.13it/s]
100%|██████████| 1/1 [00:00<00:00, 36.17it/s]
100%|██████████| 1/1 [00:00<00:00, 40.94it/s]
100%|██████████| 1/1 [00:00<00:00, 37.67it/s]
100%|██████████| 1/1 [00:00<00:00, 26.78it/s]
100%|██████████| 1/1 [00:00<00:00, 41.17it/s]
100%|██████████| 1/1 [00:00<00:00, 34.73it/s]
100%|██████████| 1/1 [00:00<00:00, 38.56it/s]
100%|██████████| 1/1 [00:00<00:00, 34.43it/s]
100%|██████████| 1/1 [00:00<00:00, 35.64it/s]
100%|██████████| 1/1 [00:00<00:00, 36.41it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.7500, Precision: 0.7333, Recall: 0.7333, F1: 0.7333


100%|██████████| 1/1 [00:00<00:00, 30.71it/s]
100%|██████████| 1/1 [00:00<00:00, 26.75it/s]
100%|██████████| 1/1 [00:00<00:00, 37.53it/s]
100%|██████████| 1/1 [00:00<00:00, 30.13it/s]
100%|██████████| 1/1 [00:00<00:00, 20.94it/s]
100%|██████████| 1/1 [00:00<00:00, 33.53it/s]
100%|██████████| 1/1 [00:00<00:00, 25.68it/s]
100%|██████████| 1/1 [00:00<00:00, 20.99it/s]
100%|██████████| 1/1 [00:00<00:00, 33.88it/s]
100%|██████████| 1/1 [00:00<00:00, 26.51it/s]
100%|██████████| 1/1 [00:00<00:00, 21.92it/s]
100%|██████████| 1/1 [00:00<00:00, 28.72it/s]
100%|██████████| 1/1 [00:00<00:00, 28.14it/s]
100%|██████████| 1/1 [00:00<00:00, 20.29it/s]
100%|██████████| 1/1 [00:00<00:00, 30.29it/s]
100%|██████████| 1/1 [00:00<00:00, 30.20it/s]
100%|██████████| 1/1 [00:00<00:00, 21.36it/s]
100%|██████████| 1/1 [00:00<00:00, 30.00it/s]
100%|██████████| 1/1 [00:00<00:00, 31.56it/s]
100%|██████████| 1/1 [00:00<00:00, 21.47it/s]
100%|██████████| 1/1 [00:00<00:00, 30.50it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.8571, Precision: 0.9000, Recall: 0.8333, F1: 0.8444


100%|██████████| 1/1 [00:00<00:00, 29.76it/s]
100%|██████████| 1/1 [00:00<00:00, 28.58it/s]
100%|██████████| 1/1 [00:00<00:00, 29.71it/s]
100%|██████████| 1/1 [00:00<00:00, 30.86it/s]
100%|██████████| 1/1 [00:00<00:00, 29.59it/s]
100%|██████████| 1/1 [00:00<00:00, 30.33it/s]
100%|██████████| 1/1 [00:00<00:00, 30.95it/s]
100%|██████████| 1/1 [00:00<00:00, 30.85it/s]
100%|██████████| 1/1 [00:00<00:00, 30.20it/s]
100%|██████████| 1/1 [00:00<00:00, 29.32it/s]
100%|██████████| 1/1 [00:00<00:00, 28.70it/s]
100%|██████████| 1/1 [00:00<00:00, 26.71it/s]
100%|██████████| 1/1 [00:00<00:00, 29.37it/s]
100%|██████████| 1/1 [00:00<00:00, 33.50it/s]
100%|██████████| 1/1 [00:00<00:00, 31.00it/s]
100%|██████████| 1/1 [00:00<00:00, 33.65it/s]
100%|██████████| 1/1 [00:00<00:00, 31.83it/s]
100%|██████████| 1/1 [00:00<00:00, 30.08it/s]
100%|██████████| 1/1 [00:00<00:00, 19.54it/s]
100%|██████████| 1/1 [00:00<00:00, 28.20it/s]
100%|██████████| 1/1 [00:00<00:00, 27.91it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.5714, Precision: 0.5500, Recall: 0.5417, F1: 0.5333

5-Fold CV Results:
Mean Accuracy  = 0.6357 ± 0.2136
Mean Precision = 0.6367 ± 0.2630
Mean Recall    = 0.6217 ± 0.2089
Mean F1-Score  = 0.6089 ± 0.2278





# **[Ignored] Trial with Hyperparameter Tuning**

In [None]:
!pip install optuna

Collecting optuna
  Downloading optuna-4.3.0-py3-none-any.whl.metadata (17 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.16.1-py3-none-any.whl.metadata (7.3 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Downloading optuna-4.3.0-py3-none-any.whl (386 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m386.6/386.6 kB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading alembic-1.16.1-py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.9.0-py3-none-any.whl (11 kB)
Installing collected packages: colorlog, alembic, optuna
Successfully installed alembic-1.16.1 colorlog-6.9.0 optuna-4.3.0


In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
base_dir = "drive/MyDrive/thesis2025/split_dataset_densenet"
split_dirs = ["train", "val", "test"]

all_subject_dirs = []
for split in split_dirs:
    all_subject_dirs.extend(sorted(glob(os.path.join(base_dir, split, "*"))))

## **Hyperparameter Tuning: Audio + EEG Modalities**

In [None]:
import optuna

def objective(trial):
    # --- Sample hyperparameters ---
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    hidden_dim = trial.suggest_categorical("hidden_dim", [256, 512, 1024])
    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64, 128])

    # --- Initialize model ---
    model = ConvPoolReLUClassifier(input_dim=2048, hidden_dim=hidden_dim).to(device)
    criterion = nn.CrossEntropyLoss()

    # --- Choose optimizer ---
    optimizer = torch.optim.Adamax(model.parameters(), lr=lr, weight_decay=weight_decay)

    # --- Prepare data ---
    train_subjs, val_subjs = train_test_split(all_subject_dirs, test_size=0.2, random_state=42, stratify=labels)
    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=batch_size, collate_fn=collate_fn_padd)

    # --- Train ---
    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=20)

    # --- Evaluate ---
    trained_model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    val_acc = correct / total
    return val_acc

In [None]:
study = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(seed=42))
study.optimize(objective, n_trials=50)

print("Best hyperparameters:", study.best_params)


[I 2025-06-13 09:15:05,831] A new study created in memory with name: no-name-2587c040-26a1-47e1-9a79-385f5bc2efae
100%|██████████| 1/1 [00:00<00:00, 17.00it/s]
100%|██████████| 1/1 [00:00<00:00, 17.32it/s]
100%|██████████| 1/1 [00:00<00:00, 21.12it/s]
100%|██████████| 1/1 [00:00<00:00, 21.38it/s]
100%|██████████| 1/1 [00:00<00:00, 21.83it/s]
100%|██████████| 1/1 [00:00<00:00, 21.57it/s]
100%|██████████| 1/1 [00:00<00:00, 21.84it/s]
100%|██████████| 1/1 [00:00<00:00, 21.44it/s]
100%|██████████| 1/1 [00:00<00:00, 20.89it/s]
100%|██████████| 1/1 [00:00<00:00, 22.19it/s]
100%|██████████| 1/1 [00:00<00:00, 20.13it/s]
100%|██████████| 1/1 [00:00<00:00, 22.66it/s]
100%|██████████| 1/1 [00:00<00:00, 18.91it/s]
100%|██████████| 1/1 [00:00<00:00, 18.21it/s]
100%|██████████| 1/1 [00:00<00:00, 17.18it/s]
100%|██████████| 1/1 [00:00<00:00, 21.71it/s]
100%|██████████| 1/1 [00:00<00:00, 21.37it/s]
100%|██████████| 1/1 [00:00<00:00, 20.91it/s]
100%|██████████| 1/1 [00:00<00:00, 21.84it/s]
100%|███████

Best hyperparameters: {'lr': 0.002227610432767094, 'weight_decay': 0.00010888025537501785, 'hidden_dim': 256, 'batch_size': 8}


In [None]:
best_params = study.best_params

In [None]:
best_params

{'lr': 0.002227610432767094,
 'weight_decay': 0.00010888025537501785,
 'hidden_dim': 256,
 'batch_size': 8}

In [None]:
class ConvPoolReLUClassifier(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=256, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(256, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.pool(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        return self.fc2(x)

In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
base_dir = "drive/MyDrive/thesis2025/split_dataset_densenet"
split_dirs = ["train", "val", "test"]

all_subject_dirs = []
for split in split_dirs:
    all_subject_dirs.extend(sorted(glob(os.path.join(base_dir, split, "*"))))

labels = [1 if os.path.basename(d).startswith("0201") else 0 for d in all_subject_dirs]
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# ---- Hyperparameters ----

lr = best_params['lr']
weight_decay = best_params['weight_decay']
hidden_dim = best_params['hidden_dim']
batch_size = best_params['batch_size']

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx, (train_idx, test_idx) in enumerate(skf.split(all_subject_dirs, labels)):
    print(f"\nFold {fold_idx + 1}/5")
    train_subjs = [all_subject_dirs[i] for i in train_idx]
    test_subjs = [all_subject_dirs[i] for i in test_idx]
    train_subjs, val_subjs = train_test_split(train_subjs, test_size=0.1, random_state=42)

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=batch_size, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048, hidden_dim = hidden_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay= weight_decay)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")



Fold 1/5


100%|██████████| 1/1 [00:00<00:00, 49.39it/s]
100%|██████████| 1/1 [00:00<00:00, 51.43it/s]
100%|██████████| 1/1 [00:00<00:00, 49.80it/s]
100%|██████████| 1/1 [00:00<00:00, 36.07it/s]
100%|██████████| 1/1 [00:00<00:00, 26.23it/s]
100%|██████████| 1/1 [00:00<00:00, 45.92it/s]
100%|██████████| 1/1 [00:00<00:00, 52.82it/s]
100%|██████████| 1/1 [00:00<00:00, 49.86it/s]
100%|██████████| 1/1 [00:00<00:00, 36.50it/s]
100%|██████████| 1/1 [00:00<00:00, 24.00it/s]
100%|██████████| 1/1 [00:00<00:00, 38.57it/s]
100%|██████████| 1/1 [00:00<00:00, 46.50it/s]
100%|██████████| 1/1 [00:00<00:00, 50.21it/s]
100%|██████████| 1/1 [00:00<00:00, 51.49it/s]
100%|██████████| 1/1 [00:00<00:00, 52.89it/s]
100%|██████████| 1/1 [00:00<00:00, 52.00it/s]
100%|██████████| 1/1 [00:00<00:00, 53.36it/s]
100%|██████████| 1/1 [00:00<00:00, 53.87it/s]
100%|██████████| 1/1 [00:00<00:00, 54.78it/s]
100%|██████████| 1/1 [00:00<00:00, 52.21it/s]
100%|██████████| 1/1 [00:00<00:00, 56.95it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.6250, Precision: 0.6333, Recall: 0.6250, F1: 0.6190

Fold 2/5


100%|██████████| 1/1 [00:00<00:00, 54.42it/s]
100%|██████████| 1/1 [00:00<00:00, 55.94it/s]
100%|██████████| 1/1 [00:00<00:00, 54.69it/s]
100%|██████████| 1/1 [00:00<00:00, 54.61it/s]
100%|██████████| 1/1 [00:00<00:00, 45.57it/s]
100%|██████████| 1/1 [00:00<00:00, 52.57it/s]
100%|██████████| 1/1 [00:00<00:00, 46.24it/s]
100%|██████████| 1/1 [00:00<00:00, 54.95it/s]
100%|██████████| 1/1 [00:00<00:00, 54.16it/s]
100%|██████████| 1/1 [00:00<00:00, 30.78it/s]
100%|██████████| 1/1 [00:00<00:00, 51.05it/s]
100%|██████████| 1/1 [00:00<00:00, 52.49it/s]
100%|██████████| 1/1 [00:00<00:00, 50.28it/s]
100%|██████████| 1/1 [00:00<00:00, 50.87it/s]
100%|██████████| 1/1 [00:00<00:00, 51.85it/s]
100%|██████████| 1/1 [00:00<00:00, 33.22it/s]
100%|██████████| 1/1 [00:00<00:00, 37.36it/s]
100%|██████████| 1/1 [00:00<00:00, 53.68it/s]
100%|██████████| 1/1 [00:00<00:00, 48.70it/s]
100%|██████████| 1/1 [00:00<00:00, 44.89it/s]
100%|██████████| 1/1 [00:00<00:00, 56.97it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.5000, Precision: 0.5000, Recall: 0.5000, F1: 0.5000

Fold 3/5


100%|██████████| 1/1 [00:00<00:00, 33.21it/s]
100%|██████████| 1/1 [00:00<00:00, 57.41it/s]
100%|██████████| 1/1 [00:00<00:00, 57.68it/s]
100%|██████████| 1/1 [00:00<00:00, 57.71it/s]
100%|██████████| 1/1 [00:00<00:00, 48.20it/s]
100%|██████████| 1/1 [00:00<00:00, 54.26it/s]
100%|██████████| 1/1 [00:00<00:00, 56.54it/s]
100%|██████████| 1/1 [00:00<00:00, 35.39it/s]
100%|██████████| 1/1 [00:00<00:00, 49.89it/s]
100%|██████████| 1/1 [00:00<00:00, 57.21it/s]
100%|██████████| 1/1 [00:00<00:00, 56.82it/s]
100%|██████████| 1/1 [00:00<00:00, 57.07it/s]
100%|██████████| 1/1 [00:00<00:00, 55.46it/s]
100%|██████████| 1/1 [00:00<00:00, 54.83it/s]
100%|██████████| 1/1 [00:00<00:00, 32.80it/s]
100%|██████████| 1/1 [00:00<00:00, 50.80it/s]
100%|██████████| 1/1 [00:00<00:00, 58.91it/s]
100%|██████████| 1/1 [00:00<00:00, 55.46it/s]
100%|██████████| 1/1 [00:00<00:00, 58.01it/s]
100%|██████████| 1/1 [00:00<00:00, 55.47it/s]
100%|██████████| 1/1 [00:00<00:00, 57.68it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.3750, Precision: 0.3750, Recall: 0.3667, F1: 0.3651

Fold 4/5


100%|██████████| 1/1 [00:00<00:00, 40.66it/s]
100%|██████████| 1/1 [00:00<00:00, 46.09it/s]
100%|██████████| 1/1 [00:00<00:00, 47.17it/s]
100%|██████████| 1/1 [00:00<00:00, 43.45it/s]
100%|██████████| 1/1 [00:00<00:00, 42.27it/s]
100%|██████████| 1/1 [00:00<00:00, 43.09it/s]
100%|██████████| 1/1 [00:00<00:00, 44.74it/s]
100%|██████████| 1/1 [00:00<00:00, 28.65it/s]
100%|██████████| 1/1 [00:00<00:00, 39.67it/s]
100%|██████████| 1/1 [00:00<00:00, 46.14it/s]
100%|██████████| 1/1 [00:00<00:00, 43.31it/s]
100%|██████████| 1/1 [00:00<00:00, 43.30it/s]
100%|██████████| 1/1 [00:00<00:00, 38.36it/s]
100%|██████████| 1/1 [00:00<00:00, 42.52it/s]
100%|██████████| 1/1 [00:00<00:00, 36.82it/s]
100%|██████████| 1/1 [00:00<00:00, 27.09it/s]
100%|██████████| 1/1 [00:00<00:00, 38.51it/s]
100%|██████████| 1/1 [00:00<00:00, 41.73it/s]
100%|██████████| 1/1 [00:00<00:00, 38.71it/s]
100%|██████████| 1/1 [00:00<00:00, 36.01it/s]
100%|██████████| 1/1 [00:00<00:00, 39.06it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.4286, Precision: 0.4500, Recall: 0.4583, F1: 0.4167

Fold 5/5


100%|██████████| 1/1 [00:00<00:00, 42.64it/s]
100%|██████████| 1/1 [00:00<00:00, 44.49it/s]
100%|██████████| 1/1 [00:00<00:00, 38.75it/s]
100%|██████████| 1/1 [00:00<00:00, 29.50it/s]
100%|██████████| 1/1 [00:00<00:00, 48.95it/s]
100%|██████████| 1/1 [00:00<00:00, 44.43it/s]
100%|██████████| 1/1 [00:00<00:00, 40.72it/s]
100%|██████████| 1/1 [00:00<00:00, 44.12it/s]
100%|██████████| 1/1 [00:00<00:00, 43.85it/s]
100%|██████████| 1/1 [00:00<00:00, 44.20it/s]
100%|██████████| 1/1 [00:00<00:00, 42.28it/s]
100%|██████████| 1/1 [00:00<00:00, 30.10it/s]
100%|██████████| 1/1 [00:00<00:00, 38.98it/s]
100%|██████████| 1/1 [00:00<00:00, 43.75it/s]
100%|██████████| 1/1 [00:00<00:00, 39.54it/s]
100%|██████████| 1/1 [00:00<00:00, 33.73it/s]
100%|██████████| 1/1 [00:00<00:00, 42.64it/s]
100%|██████████| 1/1 [00:00<00:00, 45.40it/s]
100%|██████████| 1/1 [00:00<00:00, 40.42it/s]
100%|██████████| 1/1 [00:00<00:00, 27.66it/s]
100%|██████████| 1/1 [00:00<00:00, 41.32it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.7143, Precision: 0.8000, Recall: 0.7500, F1: 0.7083

5-Fold CV Results:
Mean Accuracy  = 0.5286 ± 0.1251
Mean Precision = 0.5517 ± 0.1500
Mean Recall    = 0.5400 ± 0.1339
Mean F1-Score  = 0.5218 ± 0.1268


## **Hyperparameter tuning: Text Modalities**

In [None]:
from sklearn.model_selection import StratifiedKFold

def objective(trial):
    # Sample hyperparameters
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    #weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    hidden_dim = trial.suggest_categorical("hidden_dim", [256, 512, 1024])
    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64])
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW", "Adamax", "SGD"])
    epochs = trial.suggest_int("epochs", 5, 120)  # Added epochs hyperparameter

    skf_inner = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
    val_accuracies = []

    # Prepare your labels for all_subject_dirs here
    labels = [1 if os.path.basename(d).startswith("0201") else 0 for d in all_subject_dirs]

    for train_idx, val_idx in skf_inner.split(all_subject_dirs, labels):
        train_subjs = [all_subject_dirs[i] for i in train_idx]
        val_subjs = [all_subject_dirs[i] for i in val_idx]

        train_set = TextDataset(train_subjs)
        val_set = TextDataset(val_subjs)

        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padd)
        val_loader = DataLoader(val_set, batch_size=batch_size, collate_fn=collate_fn_padd)

        model = ConvPoolReLUClassifier(input_dim=768, hidden_dim=hidden_dim).to(device)
        criterion = nn.CrossEntropyLoss()

        if optimizer_name == "Adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        elif optimizer_name == "AdamW":
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        elif optimizer_name == "Adamax":
            optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
        elif optimizer_name == "SGD":
            optimizer = torch.optim.SGD(model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unknown optimizer {optimizer_name}")


        trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=epochs)

        # Evaluate validation accuracy for this fold
        trained_model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                output = trained_model(x)
                preds = output.argmax(dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        val_acc = correct / total
        val_accuracies.append(val_acc)

    # Return mean val accuracy over folds as objective metric
    return np.mean(val_accuracies)

In [None]:
study = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(seed=42))
study.optimize(objective, n_trials=35)

print("Best hyperparameters:", study.best_params)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
100%|██████████| 2/2 [00:00<00:00, 57.21it/s]
100%|██████████| 2/2 [00:00<00:00, 40.96it/s]
100%|██████████| 2/2 [00:00<00:00, 56.99it/s]
100%|██████████| 2/2 [00:00<00:00, 51.77it/s]
100%|██████████| 2/2 [00:00<00:00, 56.43it/s]
100%|██████████| 2/2 [00:00<00:00, 53.92it/s]
100%|██████████| 2/2 [00:00<00:00, 43.07it/s]
100%|██████████| 2/2 [00:00<00:00, 45.07it/s]
100%|██████████| 2/2 [00:00<00:00, 62.49it/s]
100%|██████████| 2/2 [00:00<00:00, 41.31it/s]
100%|██████████| 2/2 [00:00<00:00, 58.93it/s]
100%|██████████| 2/2 [00:00<00:00, 49.23it/s]
100%|██████████| 2/2 [00:00<00:00, 39.49it/s]
100%|██████████| 2/2 [00:00<00:00, 51.57it/s]
100%|██████████| 2/2 [00:00<00:00, 39.28it/s]
100%|██████████| 2/2 [00:00<00:00, 57.37it/s]
100%|██████████| 2/2 [00:00<00:00, 56.23it/s]
100%|██████████| 2/2 [00:00<00:00, 53.92it/s]
100%|██████████| 2/2 [00:00<00:00, 54.76it/s]
100%|██████████| 2/2 [00:00<00:00, 57.80it/s]
100%|██████████

Best hyperparameters: {'lr': 0.0003180958563999211, 'hidden_dim': 1024, 'batch_size': 16, 'optimizer': 'Adamax', 'epochs': 55}


In [None]:
best_params = study.best_params

In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
base_dir = "drive/MyDrive/thesis2025/split_dataset_vit3"
split_dirs = ["train", "val", "test"]

all_subject_dirs = []
for split in split_dirs:
    all_subject_dirs.extend(sorted(glob(os.path.join(base_dir, split, "*"))))

labels = [1 if os.path.basename(d).startswith("0201") else 0 for d in all_subject_dirs]
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# ---- Hyperparameters ----

lr = best_params['lr']
#weight_decay = best_params['weight_decay']
hidden_dim = best_params['hidden_dim']
batch_size = best_params['batch_size']
epochs = best_params['epochs']

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx, (train_idx, test_idx) in enumerate(skf.split(all_subject_dirs, labels)):
    print(f"\nFold {fold_idx + 1}/5")
    train_subjs = [all_subject_dirs[i] for i in train_idx]
    test_subjs = [all_subject_dirs[i] for i in test_idx]
    train_subjs, val_subjs = train_test_split(train_subjs, test_size=0.1, random_state=42)

    train_set = TextDataset(train_subjs)
    val_set = TextDataset(val_subjs)
    test_set = TextDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=batch_size, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=768, hidden_dim = hidden_dim).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs = epochs)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")



Fold 1/5


100%|██████████| 1/1 [00:00<00:00, 57.07it/s]
100%|██████████| 1/1 [00:00<00:00, 44.18it/s]
100%|██████████| 1/1 [00:00<00:00, 35.70it/s]
100%|██████████| 1/1 [00:00<00:00, 96.98it/s]
100%|██████████| 1/1 [00:00<00:00, 63.66it/s]
100%|██████████| 1/1 [00:00<00:00, 88.29it/s]
100%|██████████| 1/1 [00:00<00:00, 95.90it/s]
100%|██████████| 1/1 [00:00<00:00, 89.13it/s]
100%|██████████| 1/1 [00:00<00:00, 96.95it/s]
100%|██████████| 1/1 [00:00<00:00, 40.60it/s]
100%|██████████| 1/1 [00:00<00:00, 63.99it/s]
100%|██████████| 1/1 [00:00<00:00, 112.17it/s]
100%|██████████| 1/1 [00:00<00:00, 79.08it/s]
100%|██████████| 1/1 [00:00<00:00, 61.15it/s]
100%|██████████| 1/1 [00:00<00:00, 88.12it/s]
100%|██████████| 1/1 [00:00<00:00, 109.70it/s]
100%|██████████| 1/1 [00:00<00:00, 36.03it/s]
100%|██████████| 1/1 [00:00<00:00, 71.70it/s]
100%|██████████| 1/1 [00:00<00:00, 89.65it/s]
100%|██████████| 1/1 [00:00<00:00, 75.70it/s]
100%|██████████| 1/1 [00:00<00:00, 47.67it/s]
100%|██████████| 1/1 [00:00<00:0

Fold 1 Test Accuracy: 0.7500, Precision: 0.8333, Recall: 0.7500, F1: 0.7333

Fold 2/5


100%|██████████| 1/1 [00:00<00:00, 86.66it/s]
100%|██████████| 1/1 [00:00<00:00, 112.49it/s]
100%|██████████| 1/1 [00:00<00:00, 98.73it/s]
100%|██████████| 1/1 [00:00<00:00, 102.39it/s]
100%|██████████| 1/1 [00:00<00:00, 112.11it/s]
100%|██████████| 1/1 [00:00<00:00, 109.49it/s]
100%|██████████| 1/1 [00:00<00:00, 73.58it/s]
100%|██████████| 1/1 [00:00<00:00, 63.03it/s]
100%|██████████| 1/1 [00:00<00:00, 77.67it/s]
100%|██████████| 1/1 [00:00<00:00, 121.95it/s]
100%|██████████| 1/1 [00:00<00:00, 95.08it/s]
100%|██████████| 1/1 [00:00<00:00, 75.84it/s]
100%|██████████| 1/1 [00:00<00:00, 50.63it/s]
100%|██████████| 1/1 [00:00<00:00, 61.09it/s]
100%|██████████| 1/1 [00:00<00:00, 68.09it/s]
100%|██████████| 1/1 [00:00<00:00, 89.59it/s]
100%|██████████| 1/1 [00:00<00:00, 76.04it/s]
100%|██████████| 1/1 [00:00<00:00, 106.41it/s]
100%|██████████| 1/1 [00:00<00:00, 58.59it/s]
100%|██████████| 1/1 [00:00<00:00, 60.98it/s]
100%|██████████| 1/1 [00:00<00:00, 60.35it/s]
100%|██████████| 1/1 [00:00<

Fold 2 Test Accuracy: 0.5000, Precision: 0.5000, Recall: 0.5000, F1: 0.4667

Fold 3/5


100%|██████████| 1/1 [00:00<00:00, 62.30it/s]
100%|██████████| 1/1 [00:00<00:00, 114.82it/s]
100%|██████████| 1/1 [00:00<00:00, 116.82it/s]
100%|██████████| 1/1 [00:00<00:00, 53.97it/s]
100%|██████████| 1/1 [00:00<00:00, 96.23it/s]
100%|██████████| 1/1 [00:00<00:00, 44.06it/s]
100%|██████████| 1/1 [00:00<00:00, 77.77it/s]
100%|██████████| 1/1 [00:00<00:00, 90.91it/s]
100%|██████████| 1/1 [00:00<00:00, 85.95it/s]
100%|██████████| 1/1 [00:00<00:00, 55.72it/s]
100%|██████████| 1/1 [00:00<00:00, 78.16it/s]
100%|██████████| 1/1 [00:00<00:00, 110.45it/s]
100%|██████████| 1/1 [00:00<00:00, 99.45it/s]
100%|██████████| 1/1 [00:00<00:00, 96.75it/s]
100%|██████████| 1/1 [00:00<00:00, 101.89it/s]
100%|██████████| 1/1 [00:00<00:00, 100.80it/s]
100%|██████████| 1/1 [00:00<00:00, 93.27it/s]
100%|██████████| 1/1 [00:00<00:00, 104.93it/s]
100%|██████████| 1/1 [00:00<00:00, 83.06it/s]
100%|██████████| 1/1 [00:00<00:00, 98.02it/s]
100%|██████████| 1/1 [00:00<00:00, 111.15it/s]
100%|██████████| 1/1 [00:00

Fold 3 Test Accuracy: 0.7500, Precision: 0.8000, Recall: 0.8000, F1: 0.7500

Fold 4/5


100%|██████████| 1/1 [00:00<00:00, 37.90it/s]
100%|██████████| 1/1 [00:00<00:00, 53.65it/s]
100%|██████████| 1/1 [00:00<00:00, 39.20it/s]
100%|██████████| 1/1 [00:00<00:00, 76.59it/s]
100%|██████████| 1/1 [00:00<00:00, 90.86it/s]
100%|██████████| 1/1 [00:00<00:00, 63.38it/s]
100%|██████████| 1/1 [00:00<00:00, 45.88it/s]
100%|██████████| 1/1 [00:00<00:00, 72.90it/s]
100%|██████████| 1/1 [00:00<00:00, 71.59it/s]
100%|██████████| 1/1 [00:00<00:00, 66.23it/s]
100%|██████████| 1/1 [00:00<00:00, 45.03it/s]
100%|██████████| 1/1 [00:00<00:00, 54.62it/s]
100%|██████████| 1/1 [00:00<00:00, 71.77it/s]
100%|██████████| 1/1 [00:00<00:00, 38.90it/s]
100%|██████████| 1/1 [00:00<00:00, 73.66it/s]
100%|██████████| 1/1 [00:00<00:00, 76.11it/s]
100%|██████████| 1/1 [00:00<00:00, 25.49it/s]
100%|██████████| 1/1 [00:00<00:00, 76.50it/s]
100%|██████████| 1/1 [00:00<00:00, 78.84it/s]
100%|██████████| 1/1 [00:00<00:00, 78.39it/s]
100%|██████████| 1/1 [00:00<00:00, 78.83it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.8571, Precision: 0.9000, Recall: 0.8333, F1: 0.8444

Fold 5/5


100%|██████████| 1/1 [00:00<00:00, 42.62it/s]
100%|██████████| 1/1 [00:00<00:00, 94.35it/s]
100%|██████████| 1/1 [00:00<00:00, 60.62it/s]
100%|██████████| 1/1 [00:00<00:00, 63.92it/s]
100%|██████████| 1/1 [00:00<00:00, 47.68it/s]
100%|██████████| 1/1 [00:00<00:00, 53.38it/s]
100%|██████████| 1/1 [00:00<00:00, 44.74it/s]
100%|██████████| 1/1 [00:00<00:00, 59.76it/s]
100%|██████████| 1/1 [00:00<00:00, 28.06it/s]
100%|██████████| 1/1 [00:00<00:00, 31.49it/s]
100%|██████████| 1/1 [00:00<00:00, 76.11it/s]
100%|██████████| 1/1 [00:00<00:00, 25.02it/s]
100%|██████████| 1/1 [00:00<00:00, 71.23it/s]
100%|██████████| 1/1 [00:00<00:00, 80.59it/s]
100%|██████████| 1/1 [00:00<00:00, 60.61it/s]
100%|██████████| 1/1 [00:00<00:00, 87.11it/s]
100%|██████████| 1/1 [00:00<00:00, 69.18it/s]
100%|██████████| 1/1 [00:00<00:00, 75.99it/s]
100%|██████████| 1/1 [00:00<00:00, 67.57it/s]
100%|██████████| 1/1 [00:00<00:00, 61.78it/s]
100%|██████████| 1/1 [00:00<00:00, 24.20it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 1.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000

5-Fold CV Results:
Mean Accuracy  = 0.7714 ± 0.1638
Mean Precision = 0.8067 ± 0.1679
Mean Recall    = 0.7767 ± 0.1618
Mean F1-Score  = 0.7589 ± 0.1741


## **Hyperparameter Tuning: EEG + Audio + Text Modalities**

In [None]:
# ---- Cross-Validation ----
base_dir = "drive/MyDrive/thesis2025/split_dataset_densenet"
split_dirs = ["train", "val", "test"]

all_subject_dirs = []
for split in split_dirs:
    all_subject_dirs.extend(sorted(glob(os.path.join(base_dir, split, "*"))))

In [None]:
from sklearn.model_selection import StratifiedKFold

def objective(trial):
    # Sample hyperparameters
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    #weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    hidden_dim = trial.suggest_categorical("hidden_dim", [256, 512, 1024])
    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64])
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW", "Adamax", "SGD"])
    epochs = trial.suggest_int("epochs", 5, 120)  # Added epochs hyperparameter

    skf_inner = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
    val_accuracies = []

    # Prepare your labels for all_subject_dirs here
    labels = [1 if os.path.basename(d).startswith("0201") else 0 for d in all_subject_dirs]

    for train_idx, val_idx in skf_inner.split(all_subject_dirs, labels):
        train_subjs = [all_subject_dirs[i] for i in train_idx]
        val_subjs = [all_subject_dirs[i] for i in val_idx]

        train_set = EEGAudioTextDataset(train_subjs)
        val_set = EEGAudioTextDataset(val_subjs)

        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padd)
        val_loader = DataLoader(val_set, batch_size=batch_size, collate_fn=collate_fn_padd)

        model = ConvPoolReLUClassifier(input_dim=2816, hidden_dim=hidden_dim).to(device)
        criterion = nn.CrossEntropyLoss()

        if optimizer_name == "Adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        elif optimizer_name == "AdamW":
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        elif optimizer_name == "Adamax":
            optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
        elif optimizer_name == "SGD":
            optimizer = torch.optim.SGD(model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unknown optimizer {optimizer_name}")


        trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=epochs)

        # Evaluate validation accuracy for this fold
        trained_model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                output = trained_model(x)
                preds = output.argmax(dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)
        val_acc = correct / total
        val_accuracies.append(val_acc)

    # Return mean val accuracy over folds as objective metric
    return np.mean(val_accuracies)

In [None]:
study = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(seed=42))
study.optimize(objective, n_trials=35)

print("Best hyperparameters:", study.best_params)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
100%|██████████| 2/2 [00:00<00:00, 15.59it/s]
100%|██████████| 2/2 [00:00<00:00, 18.12it/s]
100%|██████████| 2/2 [00:00<00:00, 18.07it/s]
100%|██████████| 2/2 [00:00<00:00, 15.55it/s]
100%|██████████| 2/2 [00:00<00:00, 18.09it/s]
100%|██████████| 2/2 [00:00<00:00, 16.51it/s]
100%|██████████| 2/2 [00:00<00:00, 17.08it/s]
100%|██████████| 2/2 [00:00<00:00, 17.10it/s]
100%|██████████| 2/2 [00:00<00:00, 16.68it/s]
100%|██████████| 2/2 [00:00<00:00, 18.80it/s]
100%|██████████| 2/2 [00:00<00:00, 15.48it/s]
100%|██████████| 2/2 [00:00<00:00, 16.11it/s]
100%|██████████| 2/2 [00:00<00:00, 18.10it/s]
100%|██████████| 2/2 [00:00<00:00, 16.07it/s]
100%|██████████| 2/2 [00:00<00:00, 16.59it/s]
100%|██████████| 2/2 [00:00<00:00, 18.10it/s]
100%|██████████| 2/2 [00:00<00:00, 16.54it/s]
100%|██████████| 2/2 [00:00<00:00, 17.82it/s]
100%|██████████| 2/2 [00:00<00:00, 16.13it/s]
100%|██████████| 2/2 [00:00<00:00, 15.13it/s]
100%|██████████

Best hyperparameters: {'lr': 3.6283583803549155e-05, 'hidden_dim': 256, 'batch_size': 8, 'optimizer': 'Adamax', 'epochs': 64}


In [None]:
best_params = study.best_params

In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
base_dir = "drive/MyDrive/thesis2025/split_dataset_densenet"
split_dirs = ["train", "val", "test"]

all_subject_dirs = []
for split in split_dirs:
    all_subject_dirs.extend(sorted(glob(os.path.join(base_dir, split, "*"))))

labels = [1 if os.path.basename(d).startswith("0201") else 0 for d in all_subject_dirs]
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx, (train_idx, test_idx) in enumerate(skf.split(all_subject_dirs, labels)):
    print(f"\nFold {fold_idx + 1}/5")
    train_subjs = [all_subject_dirs[i] for i in train_idx]
    test_subjs = [all_subject_dirs[i] for i in test_idx]
    train_subjs, val_subjs = train_test_split(train_subjs, test_size=0.1, random_state=42)

    train_set = EEGAudioTextDataset(train_subjs)
    val_set = EEGAudioTextDataset(val_subjs)
    test_set = EEGAudioTextDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=8, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=8, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=8, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2816, hidden_dim = 256).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=3.6283583803549155e-05)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs = 64)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


Fold 1/5


100%|██████████| 1/1 [00:02<00:00,  2.82s/it]
100%|██████████| 1/1 [00:00<00:00, 18.76it/s]
100%|██████████| 1/1 [00:00<00:00, 20.31it/s]
100%|██████████| 1/1 [00:00<00:00, 22.48it/s]
100%|██████████| 1/1 [00:00<00:00, 20.88it/s]
100%|██████████| 1/1 [00:00<00:00, 15.23it/s]
100%|██████████| 1/1 [00:00<00:00, 21.68it/s]
100%|██████████| 1/1 [00:00<00:00, 19.40it/s]
100%|██████████| 1/1 [00:00<00:00, 17.03it/s]
100%|██████████| 1/1 [00:00<00:00, 20.76it/s]
100%|██████████| 1/1 [00:00<00:00, 20.30it/s]
100%|██████████| 1/1 [00:00<00:00, 22.03it/s]
100%|██████████| 1/1 [00:00<00:00, 15.38it/s]
100%|██████████| 1/1 [00:00<00:00, 20.00it/s]
100%|██████████| 1/1 [00:00<00:00, 21.88it/s]
100%|██████████| 1/1 [00:00<00:00, 22.79it/s]
100%|██████████| 1/1 [00:00<00:00, 17.82it/s]
100%|██████████| 1/1 [00:00<00:00, 23.31it/s]
100%|██████████| 1/1 [00:00<00:00, 16.69it/s]
100%|██████████| 1/1 [00:00<00:00, 22.60it/s]
100%|██████████| 1/1 [00:00<00:00, 20.68it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.7500, Precision: 0.8333, Recall: 0.7500, F1: 0.7333

Fold 2/5


100%|██████████| 1/1 [00:00<00:00, 20.36it/s]
100%|██████████| 1/1 [00:00<00:00, 22.51it/s]
100%|██████████| 1/1 [00:00<00:00, 16.73it/s]
100%|██████████| 1/1 [00:00<00:00, 20.52it/s]
100%|██████████| 1/1 [00:00<00:00, 21.38it/s]
100%|██████████| 1/1 [00:00<00:00, 21.07it/s]
100%|██████████| 1/1 [00:00<00:00, 19.18it/s]
100%|██████████| 1/1 [00:00<00:00, 18.85it/s]
100%|██████████| 1/1 [00:00<00:00, 23.52it/s]
100%|██████████| 1/1 [00:00<00:00, 15.07it/s]
100%|██████████| 1/1 [00:00<00:00, 20.83it/s]
100%|██████████| 1/1 [00:00<00:00, 21.68it/s]
100%|██████████| 1/1 [00:00<00:00, 21.90it/s]
100%|██████████| 1/1 [00:00<00:00, 22.73it/s]
100%|██████████| 1/1 [00:00<00:00, 23.29it/s]
100%|██████████| 1/1 [00:00<00:00, 24.16it/s]
100%|██████████| 1/1 [00:00<00:00, 15.56it/s]
100%|██████████| 1/1 [00:00<00:00, 22.37it/s]
100%|██████████| 1/1 [00:00<00:00, 21.65it/s]
100%|██████████| 1/1 [00:00<00:00, 22.36it/s]
100%|██████████| 1/1 [00:00<00:00, 21.14it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.5000, Precision: 0.5000, Recall: 0.5000, F1: 0.5000

Fold 3/5


100%|██████████| 1/1 [00:00<00:00, 16.28it/s]
100%|██████████| 1/1 [00:00<00:00, 22.99it/s]
100%|██████████| 1/1 [00:00<00:00, 15.91it/s]
100%|██████████| 1/1 [00:00<00:00, 21.09it/s]
100%|██████████| 1/1 [00:00<00:00, 19.30it/s]
100%|██████████| 1/1 [00:00<00:00, 18.45it/s]
100%|██████████| 1/1 [00:00<00:00, 21.14it/s]
100%|██████████| 1/1 [00:00<00:00, 11.04it/s]
100%|██████████| 1/1 [00:00<00:00, 24.33it/s]
100%|██████████| 1/1 [00:00<00:00, 24.18it/s]
100%|██████████| 1/1 [00:00<00:00, 21.45it/s]
100%|██████████| 1/1 [00:00<00:00, 22.19it/s]
100%|██████████| 1/1 [00:00<00:00, 23.85it/s]
100%|██████████| 1/1 [00:00<00:00, 23.60it/s]
100%|██████████| 1/1 [00:00<00:00, 11.05it/s]
100%|██████████| 1/1 [00:00<00:00, 15.57it/s]
100%|██████████| 1/1 [00:00<00:00, 13.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.88it/s]
100%|██████████| 1/1 [00:00<00:00, 14.25it/s]
100%|██████████| 1/1 [00:00<00:00, 15.42it/s]
100%|██████████| 1/1 [00:00<00:00, 17.47it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.5000, Precision: 0.4667, Recall: 0.4667, F1: 0.4667

Fold 4/5


100%|██████████| 1/1 [00:00<00:00, 14.12it/s]
100%|██████████| 1/1 [00:00<00:00, 12.14it/s]
100%|██████████| 1/1 [00:00<00:00, 15.39it/s]
100%|██████████| 1/1 [00:00<00:00, 18.01it/s]
100%|██████████| 1/1 [00:00<00:00, 13.06it/s]
100%|██████████| 1/1 [00:00<00:00, 17.17it/s]
100%|██████████| 1/1 [00:00<00:00, 13.72it/s]
100%|██████████| 1/1 [00:00<00:00, 13.74it/s]
100%|██████████| 1/1 [00:00<00:00, 15.90it/s]
100%|██████████| 1/1 [00:00<00:00, 16.37it/s]
100%|██████████| 1/1 [00:00<00:00, 12.43it/s]
100%|██████████| 1/1 [00:00<00:00, 16.81it/s]
100%|██████████| 1/1 [00:00<00:00, 16.64it/s]
100%|██████████| 1/1 [00:00<00:00, 11.53it/s]
100%|██████████| 1/1 [00:00<00:00, 15.78it/s]
100%|██████████| 1/1 [00:00<00:00, 14.48it/s]
100%|██████████| 1/1 [00:00<00:00, 11.69it/s]
100%|██████████| 1/1 [00:00<00:00, 15.43it/s]
100%|██████████| 1/1 [00:00<00:00, 16.08it/s]
100%|██████████| 1/1 [00:00<00:00, 12.61it/s]
100%|██████████| 1/1 [00:00<00:00, 15.56it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.7143, Precision: 0.8000, Recall: 0.7500, F1: 0.7083

Fold 5/5


100%|██████████| 1/1 [00:00<00:00, 13.34it/s]
100%|██████████| 1/1 [00:00<00:00, 17.33it/s]
100%|██████████| 1/1 [00:00<00:00, 17.09it/s]
100%|██████████| 1/1 [00:00<00:00, 13.23it/s]
100%|██████████| 1/1 [00:00<00:00, 15.38it/s]
100%|██████████| 1/1 [00:00<00:00, 17.36it/s]
100%|██████████| 1/1 [00:00<00:00, 11.79it/s]
100%|██████████| 1/1 [00:00<00:00, 17.36it/s]
100%|██████████| 1/1 [00:00<00:00, 16.38it/s]
100%|██████████| 1/1 [00:00<00:00, 12.68it/s]
100%|██████████| 1/1 [00:00<00:00, 18.04it/s]
100%|██████████| 1/1 [00:00<00:00, 17.05it/s]
100%|██████████| 1/1 [00:00<00:00, 11.66it/s]
100%|██████████| 1/1 [00:00<00:00, 15.20it/s]
100%|██████████| 1/1 [00:00<00:00, 15.92it/s]
100%|██████████| 1/1 [00:00<00:00, 11.60it/s]
100%|██████████| 1/1 [00:00<00:00, 15.52it/s]
100%|██████████| 1/1 [00:00<00:00, 15.10it/s]
100%|██████████| 1/1 [00:00<00:00, 11.17it/s]
100%|██████████| 1/1 [00:00<00:00, 12.28it/s]
100%|██████████| 1/1 [00:00<00:00, 15.78it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.8571, Precision: 0.9000, Recall: 0.8333, F1: 0.8444

5-Fold CV Results:
Mean Accuracy  = 0.6643 ± 0.1421
Mean Precision = 0.7000 ± 0.1801
Mean Recall    = 0.6600 ± 0.1478
Mean F1-Score  = 0.6506 ± 0.1444





# **[Ignored]Hyperparameter tuning trial - nested CV**

In [None]:
base_dir = "drive/MyDrive/thesis2025/split_dataset_densenet"
split_dirs = ["train", "val", "test"]
all_subject_dirs = []

for split in split_dirs:
    all_subject_dirs.extend(sorted(glob(os.path.join(base_dir, split, "*"))))

labels = [1 if os.path.basename(d).startswith("0201") else 0 for d in all_subject_dirs]

In [None]:
all_subject_dirs

['drive/MyDrive/thesis2025/split_dataset_densenet/train/02010005',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010006',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010010',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010011',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010012',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010013',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010015',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010018',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010023',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010024',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010025',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02010036',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02020008',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/02020010',
 'drive/MyDrive/thesis2025/split_dataset_densenet/train/020200

In [None]:
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
import optuna

# Setup
base_dir = "drive/MyDrive/thesis2025/split_dataset_densenet"
split_dirs = ["train", "val", "test"]
all_subject_dirs = []

for split in split_dirs:
    all_subject_dirs.extend(sorted(glob(os.path.join(base_dir, split, "*"))))

labels = [1 if os.path.basename(d).startswith("0201") else 0 for d in all_subject_dirs]

# Outer CV: 5-fold
outer_skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
final_results = []

for outer_fold, (train_val_idx, test_idx) in enumerate(outer_skf.split(all_subject_dirs, labels)):
    print(f"\n===== Outer Fold {outer_fold + 1}/5 =====")

    train_val_dirs = [all_subject_dirs[i] for i in train_val_idx]
    test_dirs = [all_subject_dirs[i] for i in test_idx]
    train_val_labels = [labels[i] for i in train_val_idx]

    # Inner CV for hyperparameter tuning
    def objective(trial):
        lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
        hidden_dim = trial.suggest_categorical("hidden_dim", [256, 512, 1024])
        batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64, 100])
        optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW", "Adamax", "SGD"])
        epochs = trial.suggest_int("epochs", 10, 120)

        inner_skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
        inner_val_scores = []

        for fold_idx, (inner_train_idx, inner_val_idx) in enumerate(tqdm(inner_skf.split(train_val_dirs, train_val_labels),
                                                                  total=inner_skf.get_n_splits(),
                                                                  desc="Inner CV Folds")):
            inner_train_dirs = [train_val_dirs[i] for i in inner_train_idx]
            inner_val_dirs = [train_val_dirs[i] for i in inner_val_idx]

            train_set = EEGAudioTextDataset(inner_train_dirs)
            val_set = EEGAudioTextDataset(inner_val_dirs)
            train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padd)
            val_loader = DataLoader(val_set, batch_size=batch_size, collate_fn=collate_fn_padd)

            model = ConvPoolReLUClassifier(input_dim=2816, hidden_dim=hidden_dim).to(device)
            criterion = nn.CrossEntropyLoss()
            if optimizer_name == "Adam":
                optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            elif optimizer_name == "AdamW":
                optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
            elif optimizer_name == "Adamax":
                optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
            elif optimizer_name == "SGD":
                optimizer = torch.optim.SGD(model.parameters(), lr=lr)

            trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs)

            correct = total = 0
            with torch.no_grad():
                for x, y in val_loader:
                    x, y = x.to(device), y.to(device)
                    preds = trained_model(x).argmax(dim=1)
                    correct += (preds == y).sum().item()
                    total += y.size(0)
            inner_val_scores.append(correct / total)

        return np.mean(inner_val_scores)

    study = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(seed=42))
    study.optimize(objective, n_trials=25)
    best_params = study.best_params

    # Train on full train+val with best hyperparams
    best_batch = best_params["batch_size"]
    best_hidden = best_params["hidden_dim"]
    best_lr = best_params["lr"]
    best_epochs = best_params["epochs"]
    best_opt = best_params["optimizer"]

    train_set = EEGAudioTextDataset(train_val_dirs)
    test_set = EEGAudioTextDataset(test_dirs)
    train_loader = DataLoader(train_set, batch_size=best_batch, shuffle=True, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=best_batch, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2816, hidden_dim=best_hidden).to(device)
    criterion = nn.CrossEntropyLoss()
    if best_opt == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=best_lr)
    elif best_opt == "AdamW":
        optimizer = torch.optim.AdamW(model.parameters(), lr=best_lr)
    elif best_opt == "Adamax":
        optimizer = torch.optim.Adamax(model.parameters(), lr=best_lr)
    elif best_opt == "SGD":
        optimizer = torch.optim.SGD(model.parameters(), lr=best_lr)

    trained_model = train_model(model, train_loader, test_loader, criterion, optimizer, device, epochs=best_epochs)

    # Evaluate
    trained_model.eval()
    correct = total = 0
    preds, labels_true = [], []
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = trained_model(x).argmax(dim=1)
            preds.extend(out.cpu().numpy())
            labels_true.extend(y.cpu().numpy())
            correct += (out == y).sum().item()
            total += y.size(0)

    acc = correct / total
    prec = precision_score(labels_true, preds, average='macro')
    rec = recall_score(labels_true, preds, average='macro')
    f1 = f1_score(labels_true, preds, average='macro')
    print(f"Fold {outer_fold+1} Acc: {acc:.4f}, Prec: {prec:.4f}, Rec: {rec:.4f}, F1: {f1:.4f}")
    final_results.append([acc, prec, rec, f1])

# Final Metrics
final_results = np.array(final_results)
print("\n===== Nested CV Final Results =====")
print(f"Accuracy: {final_results[:,0].mean():.4f} ± {final_results[:,0].std():.4f}")
print(f"Precision: {final_results[:,1].mean():.4f} ± {final_results[:,1].std():.4f}")
print(f"Recall: {final_results[:,2].mean():.4f} ± {final_results[:,2].std():.4f}")
print(f"F1-score: {final_results[:,3].mean():.4f} ± {final_results[:,3].std():.4f}")


[I 2025-06-13 13:04:43,420] A new study created in memory with name: no-name-27f3510b-2697-4998-9ad2-338c346b9608



===== Outer Fold 1/5 =====


[1;30;43mStreaming output truncated to the last 5000 lines.[0m

100%|██████████| 1/1 [00:00<00:00, 10.83it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  9.10it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  8.83it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  8.48it/s]

100%|██████████| 1/1 [00:00<00:00, 10.23it/s]

100%|██████████| 1/1 [00:00<00:00, 10.55it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  7.85it/s]

100%|██████████| 1/1 [00:00<00:00, 11.42it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  9.42it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  8.75it/s]

100%|██████████| 1/1 [00:00<00:00, 11.77it/s]

100%|██████████| 1/1 [00:00<00:00, 10.74it/s]

100%|██████████| 1/1 [00:00<00:00, 10.08it/s]

100%|██████████| 1/1 [00:00<00:00, 10.01it/s]

100%|██████████| 1/1 [00:00<00

Fold 1 Acc: 0.5000, Prec: 0.2500, Rec: 0.5000, F1: 0.3333

===== Outer Fold 2/5 =====


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
100%|██████████| 1/1 [00:00<00:00,  9.01it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  7.48it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  8.49it/s]

100%|██████████| 1/1 [00:00<00:00, 11.45it/s]

100%|██████████| 1/1 [00:00<00:00, 10.77it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  7.87it/s]

100%|██████████| 1/1 [00:00<00:00, 10.19it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  7.58it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  8.25it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  8.88it/s]

100%|██████████| 1/1 [00:00<00:00, 10.86it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  9.29it/s]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  9.45it/