In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# **Extracting Features using Vision Transformer**

In [None]:
import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import models, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vit = models.vit_b_16(pretrained=True).to(device)
vit.eval()

feature_extractor = torch.nn.Sequential(
    vit.encoder,
    torch.nn.Identity()
)

vit.heads = torch.nn.Identity()

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:03<00:00, 99.6MB/s]


In [None]:
def extract_features(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        features = vit(input_tensor)
        features = features.view(-1)
    return features.cpu()

In [None]:
splits = ["train", "val", "test"]

In [None]:
source_base = "drive/MyDrive/thesis2025/split_dataset_new"
target_base = "drive/MyDrive/thesis2025/split_dataset_vit3"

In [None]:
for split in splits:
    split_path = os.path.join(source_base, split)
    subject_ids = os.listdir(split_path)

    for subject_id in tqdm(subject_ids, desc=f"Processing {split}"):
        subject_path = os.path.join(split_path, subject_id)
        eeg_dir = os.path.join(subject_path, "eeg_stft_spectrogram2")
        audio_dir = os.path.join(subject_path, "audio_spectrogram")

        eeg_files = sorted([os.path.join(eeg_dir, f) for f in os.listdir(eeg_dir) if f.endswith(".png")])
        audio_files = sorted([os.path.join(audio_dir, f) for f in os.listdir(audio_dir) if f.endswith(".png")])

        eeg_embeddings = []
        for p in eeg_files:
            try:
                eeg_embeddings.append(extract_features(p))
            except Exception as e:
                print(f"Skipping EEG image {p}: {e}")

        audio_embeddings = []
        for p in audio_files:
            try:
                audio_embeddings.append(extract_features(p))
            except Exception as e:
                print(f"Skipping audio image {p}: {e}")

        if not eeg_embeddings or not audio_embeddings:
            print(f"Skipping subject {subject_id} due to insufficient embeddings.")
            continue

        eeg_stack = torch.stack(eeg_embeddings)
        audio_stack = torch.stack(audio_embeddings)

        eeg_len = eeg_stack.size(0)
        audio_len = audio_stack.size(0)

        if audio_len < eeg_len:
            padding = torch.zeros((eeg_len - audio_len, audio_stack.size(1)))
            audio_stack = torch.cat([audio_stack, padding], dim=0)
        elif eeg_len < audio_len:
            padding = torch.zeros((audio_len - eeg_len, eeg_stack.size(1)))
            eeg_stack = torch.cat([eeg_stack, padding], dim=0)

        combined = torch.cat([eeg_stack, audio_stack], dim=1)
        combined_mean = combined.mean(dim=0)

        save_dir = os.path.join(target_base, split, subject_id)
        os.makedirs(save_dir, exist_ok=True)

        np.save(os.path.join(save_dir, "eeg_embedding.npy"), eeg_stack.numpy())
        np.save(os.path.join(save_dir, "audio_embedding.npy"), audio_stack.numpy())


Processing train: 100%|██████████| 26/26 [08:11<00:00, 18.92s/it]
Processing val: 100%|██████████| 6/6 [01:52<00:00, 18.69s/it]
Processing test: 100%|██████████| 6/6 [01:51<00:00, 18.65s/it]


In [None]:
tri = np.load("drive/MyDrive/thesis2025/split_dataset_vit3/train/02010006/eeg_embedding.npy")

In [None]:
tri.shape

(29, 768)

# **1D EEG Signal Extraction**

In [None]:
import torch
import torch.nn as nn

class EEG1DEncoder(nn.Module):
    def __init__(self, hidden_dim=256, embed_dim=512):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2)
        )
        self.lstm = nn.LSTM(input_size=128, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
        self.proj = nn.Linear(hidden_dim * 2, embed_dim)

    def forward(self, x):
        # x: (batch=1, channels=1, time)
        x = self.cnn(x)               # (1, 128, time')
        x = x.permute(0, 2, 1)        # (1, time', 128)
        lstm_out, _ = self.lstm(x)    # (1, time', hidden_dim*2)
        pooled = lstm_out.mean(dim=1) # (1, hidden_dim*2)
        embedding = self.proj(pooled) # (1, embed_dim)
        return embedding.squeeze(0)   # (embed_dim,)


In [None]:
from pathlib import Path
# Setup device and model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EEG1DEncoder(hidden_dim=128, embed_dim=512).to(device)
model.eval()

# Paths
base_input_dir = Path("/content/drive/My Drive/thesis2025/split_dataset")
base_output_dir = Path("/content/drive/My Drive/thesis2025/split_dataset_vit3")

splits = ['train', 'val', 'test']

In [None]:
# Feature extraction loop
with torch.no_grad():
    for split in splits:
        input_split_dir = base_input_dir / split
        output_split_dir = base_output_dir / split

        subject_folders = [f for f in input_split_dir.iterdir() if f.is_dir()]
        print(f"\nProcessing split: {split} ({len(subject_folders)} subjects)")

        for subject_folder in tqdm(subject_folders, desc=f"{split} split"):
            subject_id = subject_folder.name
            input_path = subject_folder / f"{subject_id}_processed.npy"
            if not input_path.exists():
                continue

            # Load EEG data (29 channels x time)
            eeg = np.load(input_path)  # (29, T)

            # Normalize each channel
            eeg = (eeg - eeg.mean(axis=1, keepdims=True)) / (eeg.std(axis=1, keepdims=True) + 1e-6)

            channel_embeddings = []

            # Encode channel by channel
            for ch_idx in range(eeg.shape[0]):
                ch_data = eeg[ch_idx]  # (T,)
                ch_tensor = torch.tensor(ch_data, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)  # (1, 1, T)

                embedding = model(ch_tensor)  # (embed_dim,)
                channel_embeddings.append(embedding.cpu().numpy())

            # Stack embeddings: (29, embed_dim)
            channel_embeddings = np.stack(channel_embeddings)

            # Save embedding
            output_subject_dir = output_split_dir / subject_id
            output_subject_dir.mkdir(parents=True, exist_ok=True)
            np.save(output_subject_dir / "eeg1d_embedding.npy", channel_embeddings)


Processing split: train (26 subjects)


train split: 100%|██████████| 26/26 [01:16<00:00,  2.93s/it]



Processing split: val (6 subjects)


val split: 100%|██████████| 6/6 [00:17<00:00,  2.96s/it]



Processing split: test (6 subjects)


test split: 100%|██████████| 6/6 [00:18<00:00,  3.02s/it]


In [None]:
tri = np.load("drive/MyDrive/thesis2025/split_dataset_vit3/train/02010010/eeg1d_embedding.npy")

In [None]:
tri.shape

(29, 512)

# **Classifier Module**

## **Max Pooling**

In [None]:
import torch
import torch.nn as nn

class ConvPoolReLUClassifier(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=512, num_classes=2):
        super(ConvPoolReLUClassifier, self).__init__()

        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=256, kernel_size=3, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)  # 29 → 14
        self.relu = nn.ReLU()

        self.fc1 = nn.Linear(256 * 14, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2)  # (batch, 2048, 29)
        x = self.conv1(x)      # (batch, 256, 29)
        x = self.pool(x)       # (batch, 256, 14)
        x = self.relu(x)

        x = x.view(x.size(0), -1)  # (batch, 256*14)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from glob import glob

In [None]:
class EEGAudioDataset(Dataset):
    def __init__(self, subject_dirs):
        self.subject_dirs = subject_dirs

    def __len__(self):
        return len(self.subject_dirs)

    def __getitem__(self, idx):
        subject_path = self.subject_dirs[idx]
        subject_id = os.path.basename(subject_path)

        eeg_embedding = np.load(os.path.join(subject_path, "eeg_embedding.npy"))
        audio_embedding = np.load(os.path.join(subject_path, "audio_embedding.npy"))
        raw_eeg_embedding = np.load(os.path.join(subject_path, "eeg1d_embedding.npy"))

        combined_embedding = np.concatenate([eeg_embedding, audio_embedding, raw_eeg_embedding], axis=1)
        label = 1 if subject_id.startswith('0201') else 0

        combined_embedding = torch.tensor(combined_embedding, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        return combined_embedding, label


In [None]:
# ---- Training ----
def train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=70):
    best_model = None
    best_val_acc = 0

    for epoch in range(epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in tqdm(val_loader):
                x, y = x.to(device), y.to(device)
                output = model(x)
                preds = output.argmax(dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)

        val_acc = correct / total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model.state_dict()

    model.load_state_dict(best_model)
    return model


In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence

def collate_fn_padd(batch):
    """
    batch: list of tuples (x, y)
    x: tensor with shape (seq_len, feat_dim)
    y: int label
    """
    sequences = [item[0] for item in batch]  # list of (seq_len, feat_dim) tensors
    labels = torch.tensor([item[1] for item in batch], dtype=torch.long)

    padded_seqs = pad_sequence(sequences, batch_first=True, padding_value=0)

    return padded_seqs, labels


In [None]:
base_dir = "drive/MyDrive/thesis2025/split_dataset_vit3"

# Step 1: Index all subject folders from train, val, test
subject_path_map = {}
for split in ['train', 'val', 'test']:
    split_path = os.path.join(base_dir, split)
    for subj in os.listdir(split_path):
        subject_path_map[subj] = os.path.join(split_path, subj)

In [None]:
import os
import json
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=100, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=100, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=100, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0004)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


100%|██████████| 1/1 [00:10<00:00, 10.21s/it]
100%|██████████| 1/1 [00:00<00:00, 43.59it/s]
100%|██████████| 1/1 [00:00<00:00, 42.44it/s]
100%|██████████| 1/1 [00:00<00:00, 41.62it/s]
100%|██████████| 1/1 [00:00<00:00, 37.51it/s]
100%|██████████| 1/1 [00:00<00:00, 35.76it/s]
100%|██████████| 1/1 [00:00<00:00, 41.96it/s]
100%|██████████| 1/1 [00:00<00:00, 41.13it/s]
100%|██████████| 1/1 [00:00<00:00, 43.26it/s]
100%|██████████| 1/1 [00:00<00:00, 40.53it/s]
100%|██████████| 1/1 [00:00<00:00, 44.55it/s]
100%|██████████| 1/1 [00:00<00:00, 41.14it/s]
100%|██████████| 1/1 [00:00<00:00, 26.34it/s]
100%|██████████| 1/1 [00:00<00:00, 41.38it/s]
100%|██████████| 1/1 [00:00<00:00, 43.12it/s]
100%|██████████| 1/1 [00:00<00:00, 42.08it/s]
100%|██████████| 1/1 [00:00<00:00, 41.22it/s]
100%|██████████| 1/1 [00:00<00:00, 39.43it/s]
100%|██████████| 1/1 [00:00<00:00, 41.20it/s]
100%|██████████| 1/1 [00:00<00:00, 40.46it/s]
100%|██████████| 1/1 [00:00<00:00, 41.81it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.3750, Precision: 0.3667, Recall: 0.3750, F1: 0.3651


100%|██████████| 1/1 [00:00<00:00, 26.17it/s]
100%|██████████| 1/1 [00:00<00:00, 38.97it/s]
100%|██████████| 1/1 [00:00<00:00, 43.68it/s]
100%|██████████| 1/1 [00:00<00:00, 42.70it/s]
100%|██████████| 1/1 [00:00<00:00, 26.95it/s]
100%|██████████| 1/1 [00:00<00:00, 37.00it/s]
100%|██████████| 1/1 [00:00<00:00, 42.51it/s]
100%|██████████| 1/1 [00:00<00:00, 40.11it/s]
100%|██████████| 1/1 [00:00<00:00, 42.54it/s]
100%|██████████| 1/1 [00:00<00:00, 39.20it/s]
100%|██████████| 1/1 [00:00<00:00, 41.77it/s]
100%|██████████| 1/1 [00:00<00:00, 28.09it/s]
100%|██████████| 1/1 [00:00<00:00, 44.10it/s]
100%|██████████| 1/1 [00:00<00:00, 41.96it/s]
100%|██████████| 1/1 [00:00<00:00, 40.12it/s]
100%|██████████| 1/1 [00:00<00:00, 39.80it/s]
100%|██████████| 1/1 [00:00<00:00, 42.33it/s]
100%|██████████| 1/1 [00:00<00:00, 42.01it/s]
100%|██████████| 1/1 [00:00<00:00, 28.19it/s]
100%|██████████| 1/1 [00:00<00:00, 40.78it/s]
100%|██████████| 1/1 [00:00<00:00, 40.10it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.6250, Precision: 0.6333, Recall: 0.6250, F1: 0.6190


100%|██████████| 1/1 [00:00<00:00, 41.95it/s]
100%|██████████| 1/1 [00:00<00:00, 45.89it/s]
100%|██████████| 1/1 [00:00<00:00, 41.55it/s]
100%|██████████| 1/1 [00:00<00:00, 43.20it/s]
100%|██████████| 1/1 [00:00<00:00, 26.21it/s]
100%|██████████| 1/1 [00:00<00:00, 38.71it/s]
100%|██████████| 1/1 [00:00<00:00, 42.02it/s]
100%|██████████| 1/1 [00:00<00:00, 41.39it/s]
100%|██████████| 1/1 [00:00<00:00, 43.68it/s]
100%|██████████| 1/1 [00:00<00:00, 37.59it/s]
100%|██████████| 1/1 [00:00<00:00, 43.53it/s]
100%|██████████| 1/1 [00:00<00:00, 28.36it/s]
100%|██████████| 1/1 [00:00<00:00, 49.04it/s]
100%|██████████| 1/1 [00:00<00:00, 40.25it/s]
100%|██████████| 1/1 [00:00<00:00, 37.81it/s]
100%|██████████| 1/1 [00:00<00:00, 34.02it/s]
100%|██████████| 1/1 [00:00<00:00, 41.57it/s]
100%|██████████| 1/1 [00:00<00:00, 44.60it/s]
100%|██████████| 1/1 [00:00<00:00, 28.34it/s]
100%|██████████| 1/1 [00:00<00:00, 42.00it/s]
100%|██████████| 1/1 [00:00<00:00, 46.02it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.6250, Precision: 0.6250, Recall: 0.6333, F1: 0.6190


100%|██████████| 1/1 [00:00<00:00, 11.99it/s]
100%|██████████| 1/1 [00:00<00:00, 18.67it/s]
100%|██████████| 1/1 [00:00<00:00, 32.21it/s]
100%|██████████| 1/1 [00:00<00:00, 20.59it/s]
100%|██████████| 1/1 [00:00<00:00, 28.02it/s]
100%|██████████| 1/1 [00:00<00:00, 28.45it/s]
100%|██████████| 1/1 [00:00<00:00, 29.42it/s]
100%|██████████| 1/1 [00:00<00:00, 28.72it/s]
100%|██████████| 1/1 [00:00<00:00, 32.54it/s]
100%|██████████| 1/1 [00:00<00:00, 31.85it/s]
100%|██████████| 1/1 [00:00<00:00, 32.48it/s]
100%|██████████| 1/1 [00:00<00:00, 34.51it/s]
100%|██████████| 1/1 [00:00<00:00, 31.55it/s]
100%|██████████| 1/1 [00:00<00:00, 31.92it/s]
100%|██████████| 1/1 [00:00<00:00, 32.06it/s]
100%|██████████| 1/1 [00:00<00:00, 30.03it/s]
100%|██████████| 1/1 [00:00<00:00, 30.36it/s]
100%|██████████| 1/1 [00:00<00:00, 32.08it/s]
100%|██████████| 1/1 [00:00<00:00, 32.71it/s]
100%|██████████| 1/1 [00:00<00:00, 20.79it/s]
100%|██████████| 1/1 [00:00<00:00, 31.36it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.8571, Precision: 0.9000, Recall: 0.8333, F1: 0.8444


100%|██████████| 1/1 [00:00<00:00, 31.73it/s]
100%|██████████| 1/1 [00:00<00:00, 21.98it/s]
100%|██████████| 1/1 [00:00<00:00, 33.32it/s]
100%|██████████| 1/1 [00:00<00:00, 35.08it/s]
100%|██████████| 1/1 [00:00<00:00, 22.44it/s]
100%|██████████| 1/1 [00:00<00:00, 32.02it/s]
100%|██████████| 1/1 [00:00<00:00, 31.94it/s]
100%|██████████| 1/1 [00:00<00:00, 21.27it/s]
100%|██████████| 1/1 [00:00<00:00, 32.96it/s]
100%|██████████| 1/1 [00:00<00:00, 29.58it/s]
100%|██████████| 1/1 [00:00<00:00, 31.27it/s]
100%|██████████| 1/1 [00:00<00:00, 32.48it/s]
100%|██████████| 1/1 [00:00<00:00, 32.07it/s]
100%|██████████| 1/1 [00:00<00:00, 31.82it/s]
100%|██████████| 1/1 [00:00<00:00, 31.60it/s]
100%|██████████| 1/1 [00:00<00:00, 32.29it/s]
100%|██████████| 1/1 [00:00<00:00, 32.85it/s]
100%|██████████| 1/1 [00:00<00:00, 22.62it/s]
100%|██████████| 1/1 [00:00<00:00, 31.11it/s]
100%|██████████| 1/1 [00:00<00:00, 31.09it/s]
100%|██████████| 1/1 [00:00<00:00, 23.00it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.4286, Precision: 0.2500, Recall: 0.3750, F1: 0.3000

5-Fold CV Results:
Mean Accuracy  = 0.5821 ± 0.1708
Mean Precision = 0.5550 ± 0.2274
Mean Recall    = 0.5683 ± 0.1746
Mean F1-Score  = 0.5495 ± 0.1964





## **Adaptive Avg**

In [None]:
import torch
import torch.nn as nn

class ConvPoolReLUClassifier(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=256, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(256, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.pool(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        return self.fc2(x)


In [None]:
import os
import json
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=100, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=100, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=100, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0004)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


100%|██████████| 1/1 [00:00<00:00, 42.71it/s]
100%|██████████| 1/1 [00:00<00:00, 39.80it/s]
100%|██████████| 1/1 [00:00<00:00, 42.25it/s]
100%|██████████| 1/1 [00:00<00:00, 39.65it/s]
100%|██████████| 1/1 [00:00<00:00, 43.15it/s]
100%|██████████| 1/1 [00:00<00:00, 41.96it/s]
100%|██████████| 1/1 [00:00<00:00, 43.13it/s]
100%|██████████| 1/1 [00:00<00:00, 41.89it/s]
100%|██████████| 1/1 [00:00<00:00, 27.62it/s]
100%|██████████| 1/1 [00:00<00:00, 40.75it/s]
100%|██████████| 1/1 [00:00<00:00, 30.56it/s]
100%|██████████| 1/1 [00:00<00:00, 43.37it/s]
100%|██████████| 1/1 [00:00<00:00, 42.14it/s]
100%|██████████| 1/1 [00:00<00:00, 40.31it/s]
100%|██████████| 1/1 [00:00<00:00, 42.26it/s]
100%|██████████| 1/1 [00:00<00:00, 26.23it/s]
100%|██████████| 1/1 [00:00<00:00, 41.55it/s]
100%|██████████| 1/1 [00:00<00:00, 38.50it/s]
100%|██████████| 1/1 [00:00<00:00, 38.09it/s]
100%|██████████| 1/1 [00:00<00:00, 41.34it/s]
100%|██████████| 1/1 [00:00<00:00, 43.89it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.2500, Precision: 0.2500, Recall: 0.2500, F1: 0.2500


100%|██████████| 1/1 [00:00<00:00, 16.99it/s]
100%|██████████| 1/1 [00:00<00:00, 33.75it/s]
100%|██████████| 1/1 [00:00<00:00, 44.45it/s]
100%|██████████| 1/1 [00:00<00:00, 45.76it/s]
100%|██████████| 1/1 [00:00<00:00, 29.09it/s]
100%|██████████| 1/1 [00:00<00:00, 41.67it/s]
100%|██████████| 1/1 [00:00<00:00, 44.67it/s]
100%|██████████| 1/1 [00:00<00:00, 32.20it/s]
100%|██████████| 1/1 [00:00<00:00, 13.00it/s]
100%|██████████| 1/1 [00:00<00:00, 14.46it/s]
100%|██████████| 1/1 [00:00<00:00, 41.34it/s]
100%|██████████| 1/1 [00:00<00:00, 43.27it/s]
100%|██████████| 1/1 [00:00<00:00, 43.38it/s]
100%|██████████| 1/1 [00:00<00:00, 43.42it/s]
100%|██████████| 1/1 [00:00<00:00, 41.15it/s]
100%|██████████| 1/1 [00:00<00:00, 26.09it/s]
100%|██████████| 1/1 [00:00<00:00, 37.88it/s]
100%|██████████| 1/1 [00:00<00:00, 42.21it/s]
100%|██████████| 1/1 [00:00<00:00, 43.17it/s]
100%|██████████| 1/1 [00:00<00:00, 38.51it/s]
100%|██████████| 1/1 [00:00<00:00, 41.99it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.5000, Precision: 0.5000, Recall: 0.5000, F1: 0.4667


100%|██████████| 1/1 [00:00<00:00, 43.14it/s]
100%|██████████| 1/1 [00:00<00:00, 26.80it/s]
100%|██████████| 1/1 [00:00<00:00, 43.69it/s]
100%|██████████| 1/1 [00:00<00:00, 41.07it/s]
100%|██████████| 1/1 [00:00<00:00, 41.23it/s]
100%|██████████| 1/1 [00:00<00:00, 43.39it/s]
100%|██████████| 1/1 [00:00<00:00, 43.53it/s]
100%|██████████| 1/1 [00:00<00:00, 41.91it/s]
100%|██████████| 1/1 [00:00<00:00, 41.16it/s]
100%|██████████| 1/1 [00:00<00:00, 42.38it/s]
100%|██████████| 1/1 [00:00<00:00, 42.77it/s]
100%|██████████| 1/1 [00:00<00:00, 42.54it/s]
100%|██████████| 1/1 [00:00<00:00, 40.58it/s]
100%|██████████| 1/1 [00:00<00:00, 41.80it/s]
100%|██████████| 1/1 [00:00<00:00, 43.13it/s]
100%|██████████| 1/1 [00:00<00:00, 40.99it/s]
100%|██████████| 1/1 [00:00<00:00, 40.46it/s]
100%|██████████| 1/1 [00:00<00:00, 41.18it/s]
100%|██████████| 1/1 [00:00<00:00, 40.81it/s]
100%|██████████| 1/1 [00:00<00:00, 27.65it/s]
100%|██████████| 1/1 [00:00<00:00, 41.13it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.7500, Precision: 0.7333, Recall: 0.7333, F1: 0.7333


100%|██████████| 1/1 [00:00<00:00, 14.00it/s]
100%|██████████| 1/1 [00:00<00:00, 11.78it/s]
100%|██████████| 1/1 [00:00<00:00,  4.31it/s]
100%|██████████| 1/1 [00:00<00:00, 14.01it/s]
100%|██████████| 1/1 [00:00<00:00, 21.24it/s]
100%|██████████| 1/1 [00:00<00:00, 18.92it/s]
100%|██████████| 1/1 [00:00<00:00, 24.56it/s]
100%|██████████| 1/1 [00:00<00:00, 23.26it/s]
100%|██████████| 1/1 [00:00<00:00, 10.80it/s]
100%|██████████| 1/1 [00:00<00:00, 20.65it/s]
100%|██████████| 1/1 [00:00<00:00, 19.31it/s]
100%|██████████| 1/1 [00:00<00:00, 22.23it/s]
100%|██████████| 1/1 [00:00<00:00, 30.21it/s]
100%|██████████| 1/1 [00:00<00:00, 29.19it/s]
100%|██████████| 1/1 [00:00<00:00, 15.34it/s]
100%|██████████| 1/1 [00:00<00:00, 26.67it/s]
100%|██████████| 1/1 [00:00<00:00, 28.62it/s]
100%|██████████| 1/1 [00:00<00:00, 22.22it/s]
100%|██████████| 1/1 [00:00<00:00, 31.57it/s]
100%|██████████| 1/1 [00:00<00:00, 30.59it/s]
100%|██████████| 1/1 [00:00<00:00, 20.90it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.8571, Precision: 0.9000, Recall: 0.8333, F1: 0.8444


100%|██████████| 1/1 [00:00<00:00, 31.21it/s]
100%|██████████| 1/1 [00:00<00:00, 31.71it/s]
100%|██████████| 1/1 [00:00<00:00, 22.34it/s]
100%|██████████| 1/1 [00:00<00:00, 30.87it/s]
100%|██████████| 1/1 [00:00<00:00, 31.92it/s]
100%|██████████| 1/1 [00:00<00:00, 22.02it/s]
100%|██████████| 1/1 [00:00<00:00, 31.55it/s]
100%|██████████| 1/1 [00:00<00:00, 30.68it/s]
100%|██████████| 1/1 [00:00<00:00, 25.48it/s]
100%|██████████| 1/1 [00:00<00:00, 32.20it/s]
100%|██████████| 1/1 [00:00<00:00, 31.71it/s]
100%|██████████| 1/1 [00:00<00:00, 31.31it/s]
100%|██████████| 1/1 [00:00<00:00, 30.85it/s]
100%|██████████| 1/1 [00:00<00:00, 31.78it/s]
100%|██████████| 1/1 [00:00<00:00, 30.32it/s]
100%|██████████| 1/1 [00:00<00:00, 31.40it/s]
100%|██████████| 1/1 [00:00<00:00, 30.53it/s]
100%|██████████| 1/1 [00:00<00:00, 31.25it/s]
100%|██████████| 1/1 [00:00<00:00, 22.06it/s]
100%|██████████| 1/1 [00:00<00:00, 31.63it/s]
100%|██████████| 1/1 [00:00<00:00, 31.10it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.4286, Precision: 0.2500, Recall: 0.3750, F1: 0.3000

5-Fold CV Results:
Mean Accuracy  = 0.5571 ± 0.2197
Mean Precision = 0.5267 ± 0.2592
Mean Recall    = 0.5383 ± 0.2174
Mean F1-Score  = 0.5189 ± 0.2345





## **GlobalAvg Pooling**

In [None]:
import torch
import torch.nn as nn

class ConvPoolReLUClassifier(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=256, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(256, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2)        # (batch, input_dim, seq_len)
        x = self.conv1(x)            # (batch, 256, seq_len)
        x = self.relu(x)
        x = torch.mean(x, dim=2)     # Global average pooling across time dimension
        x = self.relu(self.fc1(x))   # (batch, hidden_dim)
        return self.fc2(x)           # (batch, num_classes)


In [None]:
import os
import json
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=100, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=100, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=100, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0004)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


100%|██████████| 1/1 [00:00<00:00, 41.45it/s]
100%|██████████| 1/1 [00:00<00:00, 28.15it/s]
100%|██████████| 1/1 [00:00<00:00, 41.29it/s]
100%|██████████| 1/1 [00:00<00:00, 41.32it/s]
100%|██████████| 1/1 [00:00<00:00, 45.98it/s]
100%|██████████| 1/1 [00:00<00:00, 26.02it/s]
100%|██████████| 1/1 [00:00<00:00, 24.41it/s]
100%|██████████| 1/1 [00:00<00:00, 45.97it/s]
100%|██████████| 1/1 [00:00<00:00, 41.19it/s]
100%|██████████| 1/1 [00:00<00:00, 39.70it/s]
100%|██████████| 1/1 [00:00<00:00, 45.55it/s]
100%|██████████| 1/1 [00:00<00:00, 28.23it/s]
100%|██████████| 1/1 [00:00<00:00, 26.33it/s]
100%|██████████| 1/1 [00:00<00:00, 38.30it/s]
100%|██████████| 1/1 [00:00<00:00, 33.36it/s]
100%|██████████| 1/1 [00:00<00:00, 38.54it/s]
100%|██████████| 1/1 [00:00<00:00, 41.77it/s]
100%|██████████| 1/1 [00:00<00:00, 43.11it/s]
100%|██████████| 1/1 [00:00<00:00, 45.49it/s]
100%|██████████| 1/1 [00:00<00:00, 26.99it/s]
100%|██████████| 1/1 [00:00<00:00, 42.58it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.2500, Precision: 0.2500, Recall: 0.2500, F1: 0.2500


100%|██████████| 1/1 [00:00<00:00, 41.62it/s]
100%|██████████| 1/1 [00:00<00:00, 37.93it/s]
100%|██████████| 1/1 [00:00<00:00, 41.68it/s]
100%|██████████| 1/1 [00:00<00:00, 42.45it/s]
100%|██████████| 1/1 [00:00<00:00, 40.74it/s]
100%|██████████| 1/1 [00:00<00:00, 26.15it/s]
100%|██████████| 1/1 [00:00<00:00, 40.66it/s]
100%|██████████| 1/1 [00:00<00:00, 40.85it/s]
100%|██████████| 1/1 [00:00<00:00, 44.86it/s]
100%|██████████| 1/1 [00:00<00:00, 37.62it/s]
100%|██████████| 1/1 [00:00<00:00, 41.23it/s]
100%|██████████| 1/1 [00:00<00:00, 41.92it/s]
100%|██████████| 1/1 [00:00<00:00, 45.89it/s]
100%|██████████| 1/1 [00:00<00:00, 37.72it/s]
100%|██████████| 1/1 [00:00<00:00, 40.83it/s]
100%|██████████| 1/1 [00:00<00:00, 41.22it/s]
100%|██████████| 1/1 [00:00<00:00, 27.40it/s]
100%|██████████| 1/1 [00:00<00:00, 37.98it/s]
100%|██████████| 1/1 [00:00<00:00, 42.30it/s]
100%|██████████| 1/1 [00:00<00:00, 38.30it/s]
100%|██████████| 1/1 [00:00<00:00, 39.85it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.6250, Precision: 0.6333, Recall: 0.6250, F1: 0.6190


100%|██████████| 1/1 [00:00<00:00, 40.07it/s]
100%|██████████| 1/1 [00:00<00:00, 39.90it/s]
100%|██████████| 1/1 [00:00<00:00, 27.03it/s]
100%|██████████| 1/1 [00:00<00:00, 42.02it/s]
100%|██████████| 1/1 [00:00<00:00, 42.18it/s]
100%|██████████| 1/1 [00:00<00:00, 35.35it/s]
100%|██████████| 1/1 [00:00<00:00, 27.30it/s]
100%|██████████| 1/1 [00:00<00:00, 41.57it/s]
100%|██████████| 1/1 [00:00<00:00, 40.49it/s]
100%|██████████| 1/1 [00:00<00:00, 36.46it/s]
100%|██████████| 1/1 [00:00<00:00, 40.90it/s]
100%|██████████| 1/1 [00:00<00:00, 41.82it/s]
100%|██████████| 1/1 [00:00<00:00, 42.95it/s]
100%|██████████| 1/1 [00:00<00:00, 27.36it/s]
100%|██████████| 1/1 [00:00<00:00, 42.13it/s]
100%|██████████| 1/1 [00:00<00:00, 42.28it/s]
100%|██████████| 1/1 [00:00<00:00, 42.10it/s]
100%|██████████| 1/1 [00:00<00:00, 37.50it/s]
100%|██████████| 1/1 [00:00<00:00, 40.80it/s]
100%|██████████| 1/1 [00:00<00:00, 39.81it/s]
100%|██████████| 1/1 [00:00<00:00, 25.21it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.7500, Precision: 0.7333, Recall: 0.7333, F1: 0.7333


100%|██████████| 1/1 [00:00<00:00, 31.43it/s]
100%|██████████| 1/1 [00:00<00:00, 30.28it/s]
100%|██████████| 1/1 [00:00<00:00, 22.65it/s]
100%|██████████| 1/1 [00:00<00:00, 32.47it/s]
100%|██████████| 1/1 [00:00<00:00, 31.88it/s]
100%|██████████| 1/1 [00:00<00:00, 22.32it/s]
100%|██████████| 1/1 [00:00<00:00, 31.63it/s]
100%|██████████| 1/1 [00:00<00:00, 29.16it/s]
100%|██████████| 1/1 [00:00<00:00, 23.14it/s]
100%|██████████| 1/1 [00:00<00:00, 27.42it/s]
100%|██████████| 1/1 [00:00<00:00, 30.60it/s]
100%|██████████| 1/1 [00:00<00:00, 32.92it/s]
100%|██████████| 1/1 [00:00<00:00, 31.11it/s]
100%|██████████| 1/1 [00:00<00:00, 33.15it/s]
100%|██████████| 1/1 [00:00<00:00, 35.13it/s]
100%|██████████| 1/1 [00:00<00:00, 29.94it/s]
100%|██████████| 1/1 [00:00<00:00, 26.73it/s]
100%|██████████| 1/1 [00:00<00:00, 28.93it/s]
100%|██████████| 1/1 [00:00<00:00, 28.40it/s]
100%|██████████| 1/1 [00:00<00:00, 27.87it/s]
100%|██████████| 1/1 [00:00<00:00, 31.43it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.8571, Precision: 0.9000, Recall: 0.8333, F1: 0.8444


100%|██████████| 1/1 [00:00<00:00, 33.38it/s]
100%|██████████| 1/1 [00:00<00:00, 31.63it/s]
100%|██████████| 1/1 [00:00<00:00, 23.16it/s]
100%|██████████| 1/1 [00:00<00:00, 33.20it/s]
100%|██████████| 1/1 [00:00<00:00, 32.27it/s]
100%|██████████| 1/1 [00:00<00:00, 31.05it/s]
100%|██████████| 1/1 [00:00<00:00, 30.22it/s]
100%|██████████| 1/1 [00:00<00:00, 32.72it/s]
100%|██████████| 1/1 [00:00<00:00, 32.55it/s]
100%|██████████| 1/1 [00:00<00:00, 31.34it/s]
100%|██████████| 1/1 [00:00<00:00, 29.93it/s]
100%|██████████| 1/1 [00:00<00:00, 33.06it/s]
100%|██████████| 1/1 [00:00<00:00, 32.11it/s]
100%|██████████| 1/1 [00:00<00:00, 31.96it/s]
100%|██████████| 1/1 [00:00<00:00, 33.73it/s]
100%|██████████| 1/1 [00:00<00:00, 23.28it/s]
100%|██████████| 1/1 [00:00<00:00, 32.01it/s]
100%|██████████| 1/1 [00:00<00:00, 32.31it/s]
100%|██████████| 1/1 [00:00<00:00, 20.51it/s]
100%|██████████| 1/1 [00:00<00:00, 32.28it/s]
100%|██████████| 1/1 [00:00<00:00, 33.00it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.4286, Precision: 0.2500, Recall: 0.3750, F1: 0.3000

5-Fold CV Results:
Mean Accuracy  = 0.5821 ± 0.2189
Mean Precision = 0.5533 ± 0.2619
Mean Recall    = 0.5633 ± 0.2187
Mean F1-Score  = 0.5494 ± 0.2356





## **Attention Pooling**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionPooling(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.attn_weights = nn.Linear(input_dim, 1)

    def forward(self, x):
        # x: (batch, seq_len, dim)
        scores = self.attn_weights(x)             # (batch, seq_len, 1)
        weights = F.softmax(scores, dim=1)        # (batch, seq_len, 1)
        pooled = torch.sum(weights * x, dim=1)    # (batch, dim)
        return pooled

class ConvPoolReLUClassifier(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=256, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.attn_pool = AttentionPooling(input_dim=256)
        self.fc1 = nn.Linear(256, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2)                    # (batch, input_dim, seq_len)
        x = self.conv1(x)                        # (batch, 256, seq_len)
        x = self.relu(x)
        x = x.transpose(1, 2)                    # (batch, seq_len, 256)
        x = self.attn_pool(x)                    # (batch, 256)
        x = self.relu(self.fc1(x))               # (batch, hidden_dim)
        return self.fc2(x)                       # (batch, num_classes)


In [None]:
import os
import json
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=100, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=100, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=100, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0004)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


100%|██████████| 1/1 [00:00<00:00, 39.41it/s]
100%|██████████| 1/1 [00:00<00:00, 24.79it/s]
100%|██████████| 1/1 [00:00<00:00, 45.07it/s]
100%|██████████| 1/1 [00:00<00:00, 46.27it/s]
100%|██████████| 1/1 [00:00<00:00, 45.59it/s]
100%|██████████| 1/1 [00:00<00:00, 36.87it/s]
100%|██████████| 1/1 [00:00<00:00, 45.50it/s]
100%|██████████| 1/1 [00:00<00:00, 45.20it/s]
100%|██████████| 1/1 [00:00<00:00, 25.90it/s]
100%|██████████| 1/1 [00:00<00:00, 31.18it/s]
100%|██████████| 1/1 [00:00<00:00, 36.72it/s]
100%|██████████| 1/1 [00:00<00:00, 36.91it/s]
100%|██████████| 1/1 [00:00<00:00, 40.47it/s]
100%|██████████| 1/1 [00:00<00:00, 42.41it/s]
100%|██████████| 1/1 [00:00<00:00, 42.56it/s]
100%|██████████| 1/1 [00:00<00:00, 42.94it/s]
100%|██████████| 1/1 [00:00<00:00, 42.63it/s]
100%|██████████| 1/1 [00:00<00:00, 41.21it/s]
100%|██████████| 1/1 [00:00<00:00, 35.99it/s]
100%|██████████| 1/1 [00:00<00:00, 26.85it/s]
100%|██████████| 1/1 [00:00<00:00, 41.86it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.2500, Precision: 0.2500, Recall: 0.2500, F1: 0.2500


100%|██████████| 1/1 [00:00<00:00, 40.54it/s]
100%|██████████| 1/1 [00:00<00:00, 37.39it/s]
100%|██████████| 1/1 [00:00<00:00, 42.32it/s]
100%|██████████| 1/1 [00:00<00:00, 40.00it/s]
100%|██████████| 1/1 [00:00<00:00, 27.45it/s]
100%|██████████| 1/1 [00:00<00:00, 42.56it/s]
100%|██████████| 1/1 [00:00<00:00, 40.19it/s]
100%|██████████| 1/1 [00:00<00:00, 41.65it/s]
100%|██████████| 1/1 [00:00<00:00, 41.67it/s]
100%|██████████| 1/1 [00:00<00:00, 42.91it/s]
100%|██████████| 1/1 [00:00<00:00, 40.26it/s]
100%|██████████| 1/1 [00:00<00:00, 27.36it/s]
100%|██████████| 1/1 [00:00<00:00, 43.16it/s]
100%|██████████| 1/1 [00:00<00:00, 42.76it/s]
100%|██████████| 1/1 [00:00<00:00, 42.45it/s]
100%|██████████| 1/1 [00:00<00:00, 41.57it/s]
100%|██████████| 1/1 [00:00<00:00, 39.76it/s]
100%|██████████| 1/1 [00:00<00:00, 42.80it/s]
100%|██████████| 1/1 [00:00<00:00, 28.02it/s]
100%|██████████| 1/1 [00:00<00:00, 41.26it/s]
100%|██████████| 1/1 [00:00<00:00, 35.61it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.5000, Precision: 0.5000, Recall: 0.5000, F1: 0.4667


100%|██████████| 1/1 [00:00<00:00, 26.40it/s]
100%|██████████| 1/1 [00:00<00:00, 38.80it/s]
100%|██████████| 1/1 [00:00<00:00, 36.21it/s]
100%|██████████| 1/1 [00:00<00:00, 41.48it/s]
100%|██████████| 1/1 [00:00<00:00, 41.65it/s]
100%|██████████| 1/1 [00:00<00:00, 43.22it/s]
100%|██████████| 1/1 [00:00<00:00, 43.10it/s]
100%|██████████| 1/1 [00:00<00:00, 40.28it/s]
100%|██████████| 1/1 [00:00<00:00, 42.50it/s]
100%|██████████| 1/1 [00:00<00:00, 43.65it/s]
100%|██████████| 1/1 [00:00<00:00, 40.99it/s]
100%|██████████| 1/1 [00:00<00:00, 28.19it/s]
100%|██████████| 1/1 [00:00<00:00, 43.71it/s]
100%|██████████| 1/1 [00:00<00:00, 38.67it/s]
100%|██████████| 1/1 [00:00<00:00, 36.98it/s]
100%|██████████| 1/1 [00:00<00:00, 42.34it/s]
100%|██████████| 1/1 [00:00<00:00, 41.34it/s]
100%|██████████| 1/1 [00:00<00:00, 42.96it/s]
100%|██████████| 1/1 [00:00<00:00, 26.08it/s]
100%|██████████| 1/1 [00:00<00:00, 42.83it/s]
100%|██████████| 1/1 [00:00<00:00, 41.05it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.6250, Precision: 0.6250, Recall: 0.6333, F1: 0.6190


100%|██████████| 1/1 [00:00<00:00, 33.61it/s]
100%|██████████| 1/1 [00:00<00:00, 34.09it/s]
100%|██████████| 1/1 [00:00<00:00, 33.51it/s]
100%|██████████| 1/1 [00:00<00:00, 33.17it/s]
100%|██████████| 1/1 [00:00<00:00, 33.25it/s]
100%|██████████| 1/1 [00:00<00:00, 34.30it/s]
100%|██████████| 1/1 [00:00<00:00, 26.00it/s]
100%|██████████| 1/1 [00:00<00:00, 35.86it/s]
100%|██████████| 1/1 [00:00<00:00, 32.37it/s]
100%|██████████| 1/1 [00:00<00:00, 31.71it/s]
100%|██████████| 1/1 [00:00<00:00, 24.45it/s]
100%|██████████| 1/1 [00:00<00:00, 29.27it/s]
100%|██████████| 1/1 [00:00<00:00, 36.16it/s]
100%|██████████| 1/1 [00:00<00:00, 21.01it/s]
100%|██████████| 1/1 [00:00<00:00, 27.91it/s]
100%|██████████| 1/1 [00:00<00:00, 33.00it/s]
100%|██████████| 1/1 [00:00<00:00, 21.75it/s]
100%|██████████| 1/1 [00:00<00:00, 27.28it/s]
100%|██████████| 1/1 [00:00<00:00, 34.39it/s]
100%|██████████| 1/1 [00:00<00:00, 22.04it/s]
100%|██████████| 1/1 [00:00<00:00, 31.64it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.8571, Precision: 0.9000, Recall: 0.8333, F1: 0.8444


100%|██████████| 1/1 [00:00<00:00, 32.75it/s]
100%|██████████| 1/1 [00:00<00:00, 23.49it/s]
100%|██████████| 1/1 [00:00<00:00, 30.90it/s]
100%|██████████| 1/1 [00:00<00:00, 33.10it/s]
100%|██████████| 1/1 [00:00<00:00, 34.43it/s]
100%|██████████| 1/1 [00:00<00:00, 29.88it/s]
100%|██████████| 1/1 [00:00<00:00, 34.11it/s]
100%|██████████| 1/1 [00:00<00:00, 32.11it/s]
100%|██████████| 1/1 [00:00<00:00, 32.68it/s]
100%|██████████| 1/1 [00:00<00:00, 33.74it/s]
100%|██████████| 1/1 [00:00<00:00, 31.60it/s]
100%|██████████| 1/1 [00:00<00:00, 33.50it/s]
100%|██████████| 1/1 [00:00<00:00, 35.23it/s]
100%|██████████| 1/1 [00:00<00:00, 32.64it/s]
100%|██████████| 1/1 [00:00<00:00, 32.04it/s]
100%|██████████| 1/1 [00:00<00:00, 32.59it/s]
100%|██████████| 1/1 [00:00<00:00, 31.60it/s]
100%|██████████| 1/1 [00:00<00:00, 22.48it/s]
100%|██████████| 1/1 [00:00<00:00, 31.96it/s]
100%|██████████| 1/1 [00:00<00:00, 34.07it/s]
100%|██████████| 1/1 [00:00<00:00, 22.01it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.4286, Precision: 0.2500, Recall: 0.3750, F1: 0.3000

5-Fold CV Results:
Mean Accuracy  = 0.5321 ± 0.2028
Mean Precision = 0.5050 ± 0.2452
Mean Recall    = 0.5183 ± 0.2026
Mean F1-Score  = 0.4960 ± 0.2174





# **Manual Hyperparams Tuning**

In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioDataset(train_subjs)
    val_set = EEGAudioDataset(val_subjs)
    test_set = EEGAudioDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=100, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=100, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=100, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2048).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


100%|██████████| 1/1 [00:00<00:00, 36.56it/s]
100%|██████████| 1/1 [00:00<00:00, 40.61it/s]
100%|██████████| 1/1 [00:00<00:00, 40.08it/s]
100%|██████████| 1/1 [00:00<00:00, 26.27it/s]
100%|██████████| 1/1 [00:00<00:00, 39.05it/s]
100%|██████████| 1/1 [00:00<00:00, 43.42it/s]
100%|██████████| 1/1 [00:00<00:00, 36.33it/s]
100%|██████████| 1/1 [00:00<00:00, 40.18it/s]
100%|██████████| 1/1 [00:00<00:00, 37.71it/s]
100%|██████████| 1/1 [00:00<00:00, 42.79it/s]
100%|██████████| 1/1 [00:00<00:00, 25.23it/s]
100%|██████████| 1/1 [00:00<00:00, 44.48it/s]
100%|██████████| 1/1 [00:00<00:00, 39.40it/s]
100%|██████████| 1/1 [00:00<00:00, 41.68it/s]
100%|██████████| 1/1 [00:00<00:00, 17.51it/s]
100%|██████████| 1/1 [00:00<00:00, 44.64it/s]
100%|██████████| 1/1 [00:00<00:00, 41.36it/s]
100%|██████████| 1/1 [00:00<00:00, 43.79it/s]
100%|██████████| 1/1 [00:00<00:00, 42.08it/s]
100%|██████████| 1/1 [00:00<00:00, 40.71it/s]
100%|██████████| 1/1 [00:00<00:00, 40.52it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.3750, Precision: 0.3667, Recall: 0.3750, F1: 0.3651


100%|██████████| 1/1 [00:00<00:00, 39.77it/s]
100%|██████████| 1/1 [00:00<00:00, 40.09it/s]
100%|██████████| 1/1 [00:00<00:00, 38.32it/s]
100%|██████████| 1/1 [00:00<00:00, 37.93it/s]
100%|██████████| 1/1 [00:00<00:00, 38.77it/s]
100%|██████████| 1/1 [00:00<00:00, 41.33it/s]
100%|██████████| 1/1 [00:00<00:00, 24.23it/s]
100%|██████████| 1/1 [00:00<00:00, 40.34it/s]
100%|██████████| 1/1 [00:00<00:00, 43.59it/s]
100%|██████████| 1/1 [00:00<00:00, 38.40it/s]
100%|██████████| 1/1 [00:00<00:00, 43.74it/s]
100%|██████████| 1/1 [00:00<00:00, 40.53it/s]
100%|██████████| 1/1 [00:00<00:00, 40.76it/s]
100%|██████████| 1/1 [00:00<00:00, 41.50it/s]
100%|██████████| 1/1 [00:00<00:00, 36.27it/s]
100%|██████████| 1/1 [00:00<00:00, 38.75it/s]
100%|██████████| 1/1 [00:00<00:00, 38.64it/s]
100%|██████████| 1/1 [00:00<00:00, 26.91it/s]
100%|██████████| 1/1 [00:00<00:00, 40.72it/s]
100%|██████████| 1/1 [00:00<00:00, 39.30it/s]
100%|██████████| 1/1 [00:00<00:00, 33.72it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.7500, Precision: 0.7500, Recall: 0.7500, F1: 0.7500


100%|██████████| 1/1 [00:00<00:00, 41.35it/s]
100%|██████████| 1/1 [00:00<00:00, 25.98it/s]
100%|██████████| 1/1 [00:00<00:00, 26.07it/s]
100%|██████████| 1/1 [00:00<00:00, 37.30it/s]
100%|██████████| 1/1 [00:00<00:00, 40.69it/s]
100%|██████████| 1/1 [00:00<00:00, 41.02it/s]
100%|██████████| 1/1 [00:00<00:00, 37.63it/s]
100%|██████████| 1/1 [00:00<00:00, 37.59it/s]
100%|██████████| 1/1 [00:00<00:00, 36.90it/s]
100%|██████████| 1/1 [00:00<00:00, 24.49it/s]
100%|██████████| 1/1 [00:00<00:00, 39.91it/s]
100%|██████████| 1/1 [00:00<00:00, 43.74it/s]
100%|██████████| 1/1 [00:00<00:00, 41.57it/s]
100%|██████████| 1/1 [00:00<00:00, 42.08it/s]
100%|██████████| 1/1 [00:00<00:00, 40.51it/s]
100%|██████████| 1/1 [00:00<00:00, 41.58it/s]
100%|██████████| 1/1 [00:00<00:00, 27.30it/s]
100%|██████████| 1/1 [00:00<00:00, 42.27it/s]
100%|██████████| 1/1 [00:00<00:00, 41.23it/s]
100%|██████████| 1/1 [00:00<00:00, 38.44it/s]
100%|██████████| 1/1 [00:00<00:00, 40.20it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.7500, Precision: 0.7333, Recall: 0.7333, F1: 0.7333


100%|██████████| 1/1 [00:00<00:00, 32.69it/s]
100%|██████████| 1/1 [00:00<00:00, 29.71it/s]
100%|██████████| 1/1 [00:00<00:00, 28.81it/s]
100%|██████████| 1/1 [00:00<00:00, 32.40it/s]
100%|██████████| 1/1 [00:00<00:00, 32.30it/s]
100%|██████████| 1/1 [00:00<00:00, 28.72it/s]
100%|██████████| 1/1 [00:00<00:00, 31.57it/s]
100%|██████████| 1/1 [00:00<00:00, 28.08it/s]
100%|██████████| 1/1 [00:00<00:00, 30.90it/s]
100%|██████████| 1/1 [00:00<00:00, 31.62it/s]
100%|██████████| 1/1 [00:00<00:00, 29.90it/s]
100%|██████████| 1/1 [00:00<00:00, 29.03it/s]
100%|██████████| 1/1 [00:00<00:00, 30.24it/s]
100%|██████████| 1/1 [00:00<00:00, 31.03it/s]
100%|██████████| 1/1 [00:00<00:00, 23.13it/s]
100%|██████████| 1/1 [00:00<00:00, 32.18it/s]
100%|██████████| 1/1 [00:00<00:00, 32.64it/s]
100%|██████████| 1/1 [00:00<00:00, 22.21it/s]
100%|██████████| 1/1 [00:00<00:00, 30.09it/s]
100%|██████████| 1/1 [00:00<00:00, 30.15it/s]
100%|██████████| 1/1 [00:00<00:00, 34.01it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.7143, Precision: 0.8333, Recall: 0.6667, F1: 0.6500


100%|██████████| 1/1 [00:00<00:00, 32.83it/s]
100%|██████████| 1/1 [00:00<00:00, 31.36it/s]
100%|██████████| 1/1 [00:00<00:00, 21.55it/s]
100%|██████████| 1/1 [00:00<00:00, 30.74it/s]
100%|██████████| 1/1 [00:00<00:00, 32.24it/s]
100%|██████████| 1/1 [00:00<00:00, 23.96it/s]
100%|██████████| 1/1 [00:00<00:00, 31.69it/s]
100%|██████████| 1/1 [00:00<00:00, 30.15it/s]
100%|██████████| 1/1 [00:00<00:00, 32.84it/s]
100%|██████████| 1/1 [00:00<00:00, 31.40it/s]
100%|██████████| 1/1 [00:00<00:00, 34.16it/s]
100%|██████████| 1/1 [00:00<00:00, 31.48it/s]
100%|██████████| 1/1 [00:00<00:00, 22.74it/s]
100%|██████████| 1/1 [00:00<00:00, 29.70it/s]
100%|██████████| 1/1 [00:00<00:00, 34.90it/s]
100%|██████████| 1/1 [00:00<00:00, 23.09it/s]
100%|██████████| 1/1 [00:00<00:00, 30.50it/s]
100%|██████████| 1/1 [00:00<00:00, 31.59it/s]
100%|██████████| 1/1 [00:00<00:00, 20.67it/s]
100%|██████████| 1/1 [00:00<00:00, 30.75it/s]
100%|██████████| 1/1 [00:00<00:00, 32.49it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.4286, Precision: 0.2500, Recall: 0.3750, F1: 0.3000

5-Fold CV Results:
Mean Accuracy  = 0.6036 ± 0.1661
Mean Precision = 0.5867 ± 0.2327
Mean Recall    = 0.5800 ± 0.1697
Mean F1-Score  = 0.5597 ± 0.1897





# **Incorporating Text Modality**

In [None]:
tri = np.load("drive/MyDrive/thesis2025/dataset_he/02010006/text_embedding.npy")

In [None]:
tri.shape

(29, 768)

In [None]:
from pathlib import Path
import shutil

source_base = Path("drive/MyDrive/thesis2025/dataset_he")
target_base = Path("drive/MyDrive/thesis2025/split_dataset_vit3")

splits = ["train", "val", "test"]

for split in splits:
    split_dir = target_base / split
    subject_dirs = [d for d in split_dir.iterdir() if d.is_dir()]

    print(f"Processing {split} split: {len(subject_dirs)} subjects")

    for subject_dir in subject_dirs:
        subject_id = subject_dir.name
        src_file = source_base / subject_id / "text_embedding.npy"
        dst_file = subject_dir / "text_embedding.npy"

        if src_file.exists():
            shutil.copy(src_file, dst_file)
        else:
            print(f"Warning: {src_file} does not exist!")


Processing train split: 26 subjects
Processing val split: 6 subjects
Processing test split: 6 subjects


In [None]:
class EEGAudioTextDataset(Dataset):
    def __init__(self, subject_dirs):
        self.subject_dirs = subject_dirs

    def __len__(self):
        return len(self.subject_dirs)

    def __getitem__(self, idx):
        subject_path = self.subject_dirs[idx]
        subject_id = os.path.basename(subject_path)

        eeg_embedding = np.load(os.path.join(subject_path, "eeg_embedding.npy"))
        audio_embedding = np.load(os.path.join(subject_path, "audio_embedding.npy"))
        raw_eeg_embedding = np.load(os.path.join(subject_path, "eeg1d_embedding.npy"))
        text_embedding = np.load(os.path.join(subject_path, "text_embedding.npy"))

        combined_embedding = np.concatenate([eeg_embedding, audio_embedding, raw_eeg_embedding, text_embedding], axis=1)
        label = 1 if subject_id.startswith('0201') else 0

        combined_embedding = torch.tensor(combined_embedding, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        return combined_embedding, label

In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
with open("drive/MyDrive/thesis2025/split_dataset_june/fold_assignments.json", "r") as f:
    folds = json.load(f)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx in range(5):
    fold_name = f"fold_{fold_idx + 1}"
    train_ids = folds[fold_name]["train"]
    val_ids   = folds[fold_name]["val"]
    test_ids  = folds[fold_name]["test"]

    train_subjs = [subject_path_map[sid] for sid in train_ids]
    val_subjs   = [subject_path_map[sid] for sid in val_ids]
    test_subjs  = [subject_path_map[sid] for sid in test_ids]

    train_set = EEGAudioTextDataset(train_subjs)
    val_set = EEGAudioTextDataset(val_subjs)
    test_set = EEGAudioTextDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=100, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=100, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=100, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=2816).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


100%|██████████| 1/1 [00:02<00:00,  2.63s/it]
100%|██████████| 1/1 [00:00<00:00, 22.54it/s]
100%|██████████| 1/1 [00:00<00:00, 33.16it/s]
100%|██████████| 1/1 [00:00<00:00, 32.30it/s]
100%|██████████| 1/1 [00:00<00:00, 29.99it/s]
100%|██████████| 1/1 [00:00<00:00, 33.33it/s]
100%|██████████| 1/1 [00:00<00:00, 32.37it/s]
100%|██████████| 1/1 [00:00<00:00, 30.87it/s]
100%|██████████| 1/1 [00:00<00:00, 35.73it/s]
100%|██████████| 1/1 [00:00<00:00, 32.22it/s]
100%|██████████| 1/1 [00:00<00:00, 31.91it/s]
100%|██████████| 1/1 [00:00<00:00, 22.37it/s]
100%|██████████| 1/1 [00:00<00:00, 37.70it/s]
100%|██████████| 1/1 [00:00<00:00, 27.37it/s]
100%|██████████| 1/1 [00:00<00:00, 26.90it/s]
100%|██████████| 1/1 [00:00<00:00, 29.29it/s]
100%|██████████| 1/1 [00:00<00:00, 31.60it/s]
100%|██████████| 1/1 [00:00<00:00, 29.48it/s]
100%|██████████| 1/1 [00:00<00:00, 30.48it/s]
100%|██████████| 1/1 [00:00<00:00, 33.39it/s]
100%|██████████| 1/1 [00:00<00:00, 30.09it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.3750, Precision: 0.3667, Recall: 0.3750, F1: 0.3651


100%|██████████| 1/1 [00:00<00:00, 32.22it/s]
100%|██████████| 1/1 [00:00<00:00, 31.83it/s]
100%|██████████| 1/1 [00:00<00:00, 33.35it/s]
100%|██████████| 1/1 [00:00<00:00, 28.03it/s]
100%|██████████| 1/1 [00:00<00:00, 32.77it/s]
100%|██████████| 1/1 [00:00<00:00, 27.75it/s]
100%|██████████| 1/1 [00:00<00:00, 22.23it/s]
100%|██████████| 1/1 [00:00<00:00, 30.31it/s]
100%|██████████| 1/1 [00:00<00:00, 30.36it/s]
100%|██████████| 1/1 [00:00<00:00, 26.56it/s]
100%|██████████| 1/1 [00:00<00:00, 29.03it/s]
100%|██████████| 1/1 [00:00<00:00, 30.16it/s]
100%|██████████| 1/1 [00:00<00:00, 28.21it/s]
100%|██████████| 1/1 [00:00<00:00, 23.34it/s]
100%|██████████| 1/1 [00:00<00:00, 31.61it/s]
100%|██████████| 1/1 [00:00<00:00, 32.84it/s]
100%|██████████| 1/1 [00:00<00:00, 32.59it/s]
100%|██████████| 1/1 [00:00<00:00, 32.46it/s]
100%|██████████| 1/1 [00:00<00:00, 31.79it/s]
100%|██████████| 1/1 [00:00<00:00, 35.16it/s]
100%|██████████| 1/1 [00:00<00:00, 22.80it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.6250, Precision: 0.6333, Recall: 0.6250, F1: 0.6190


100%|██████████| 1/1 [00:00<00:00, 30.46it/s]
100%|██████████| 1/1 [00:00<00:00, 31.10it/s]
100%|██████████| 1/1 [00:00<00:00, 34.70it/s]
100%|██████████| 1/1 [00:00<00:00, 30.88it/s]
100%|██████████| 1/1 [00:00<00:00, 22.39it/s]
100%|██████████| 1/1 [00:00<00:00, 30.73it/s]
100%|██████████| 1/1 [00:00<00:00, 31.99it/s]
100%|██████████| 1/1 [00:00<00:00, 29.29it/s]
100%|██████████| 1/1 [00:00<00:00, 32.38it/s]
100%|██████████| 1/1 [00:00<00:00, 31.09it/s]
100%|██████████| 1/1 [00:00<00:00, 31.97it/s]
100%|██████████| 1/1 [00:00<00:00, 23.76it/s]
100%|██████████| 1/1 [00:00<00:00, 31.59it/s]
100%|██████████| 1/1 [00:00<00:00, 29.71it/s]
100%|██████████| 1/1 [00:00<00:00, 31.58it/s]
100%|██████████| 1/1 [00:00<00:00, 31.53it/s]
100%|██████████| 1/1 [00:00<00:00, 29.90it/s]
100%|██████████| 1/1 [00:00<00:00, 31.31it/s]
100%|██████████| 1/1 [00:00<00:00, 33.59it/s]
100%|██████████| 1/1 [00:00<00:00, 28.83it/s]
100%|██████████| 1/1 [00:00<00:00, 33.29it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.8750, Precision: 0.9167, Recall: 0.8333, F1: 0.8545


100%|██████████| 1/1 [00:00<00:00, 23.40it/s]
100%|██████████| 1/1 [00:00<00:00, 23.91it/s]
100%|██████████| 1/1 [00:00<00:00, 17.56it/s]
100%|██████████| 1/1 [00:00<00:00, 23.51it/s]
100%|██████████| 1/1 [00:00<00:00, 24.77it/s]
100%|██████████| 1/1 [00:00<00:00, 22.43it/s]
100%|██████████| 1/1 [00:00<00:00,  9.01it/s]
100%|██████████| 1/1 [00:00<00:00, 26.61it/s]
100%|██████████| 1/1 [00:00<00:00, 23.69it/s]
100%|██████████| 1/1 [00:00<00:00, 23.37it/s]
100%|██████████| 1/1 [00:00<00:00, 22.11it/s]
100%|██████████| 1/1 [00:00<00:00, 18.45it/s]
100%|██████████| 1/1 [00:00<00:00, 24.76it/s]
100%|██████████| 1/1 [00:00<00:00, 22.23it/s]
100%|██████████| 1/1 [00:00<00:00, 22.50it/s]
100%|██████████| 1/1 [00:00<00:00, 19.15it/s]
100%|██████████| 1/1 [00:00<00:00, 23.14it/s]
100%|██████████| 1/1 [00:00<00:00, 23.19it/s]
100%|██████████| 1/1 [00:00<00:00, 21.92it/s]
100%|██████████| 1/1 [00:00<00:00, 18.81it/s]
100%|██████████| 1/1 [00:00<00:00, 23.69it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.8571, Precision: 0.9000, Recall: 0.8333, F1: 0.8444


100%|██████████| 1/1 [00:00<00:00, 24.67it/s]
100%|██████████| 1/1 [00:00<00:00, 22.31it/s]
100%|██████████| 1/1 [00:00<00:00, 21.67it/s]
100%|██████████| 1/1 [00:00<00:00, 18.00it/s]
100%|██████████| 1/1 [00:00<00:00, 24.69it/s]
100%|██████████| 1/1 [00:00<00:00, 24.27it/s]
100%|██████████| 1/1 [00:00<00:00, 23.72it/s]
100%|██████████| 1/1 [00:00<00:00, 17.83it/s]
100%|██████████| 1/1 [00:00<00:00, 24.30it/s]
100%|██████████| 1/1 [00:00<00:00, 23.51it/s]
100%|██████████| 1/1 [00:00<00:00, 21.07it/s]
100%|██████████| 1/1 [00:00<00:00, 24.37it/s]
100%|██████████| 1/1 [00:00<00:00, 19.01it/s]
100%|██████████| 1/1 [00:00<00:00, 25.32it/s]
100%|██████████| 1/1 [00:00<00:00, 26.31it/s]
100%|██████████| 1/1 [00:00<00:00, 24.00it/s]
100%|██████████| 1/1 [00:00<00:00, 17.88it/s]
100%|██████████| 1/1 [00:00<00:00, 24.16it/s]
100%|██████████| 1/1 [00:00<00:00, 23.05it/s]
100%|██████████| 1/1 [00:00<00:00, 24.88it/s]
100%|██████████| 1/1 [00:00<00:00, 18.89it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 0.4286, Precision: 0.2500, Recall: 0.3750, F1: 0.3000

5-Fold CV Results:
Mean Accuracy  = 0.6321 ± 0.2084
Mean Precision = 0.6133 ± 0.2711
Mean Recall    = 0.6083 ± 0.2051
Mean F1-Score  = 0.5966 ± 0.2324





# **Let's do Unimodal Text**

In [None]:
class TextDataset(Dataset):
    def __init__(self, subject_dirs):
        self.subject_dirs = subject_dirs

    def __len__(self):
        return len(self.subject_dirs)

    def __getitem__(self, idx):
        subject_path = self.subject_dirs[idx]
        subject_id = os.path.basename(subject_path)

        text_embedding = np.load(os.path.join(subject_path, "text_embedding.npy"))
        label = 1 if subject_id.startswith('0201') else 0

        combined_embedding = torch.tensor(text_embedding, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        return combined_embedding, label

In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
base_dir = "drive/MyDrive/thesis2025/split_dataset_vit3"
split_dirs = ["train", "val", "test"]

all_subject_dirs = []
for split in split_dirs:
    all_subject_dirs.extend(sorted(glob(os.path.join(base_dir, split, "*"))))

labels = [1 if os.path.basename(d).startswith("0201") else 0 for d in all_subject_dirs]
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx, (train_idx, test_idx) in enumerate(skf.split(all_subject_dirs, labels)):
    print(f"\nFold {fold_idx + 1}/5")
    train_subjs = [all_subject_dirs[i] for i in train_idx]
    test_subjs = [all_subject_dirs[i] for i in test_idx]
    train_subjs, val_subjs = train_test_split(train_subjs, test_size=0.1, random_state=42)

    train_set = TextDataset(train_subjs)
    val_set = TextDataset(val_subjs)
    test_set = TextDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=100, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=100, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=100, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=768).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=0.0004, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


Fold 1/5


100%|██████████| 1/1 [00:00<00:00, 62.58it/s]
100%|██████████| 1/1 [00:00<00:00, 100.34it/s]
100%|██████████| 1/1 [00:00<00:00, 80.75it/s]
100%|██████████| 1/1 [00:00<00:00, 82.66it/s]
100%|██████████| 1/1 [00:00<00:00, 83.46it/s]
100%|██████████| 1/1 [00:00<00:00, 97.21it/s]
100%|██████████| 1/1 [00:00<00:00, 101.22it/s]
100%|██████████| 1/1 [00:00<00:00, 103.07it/s]
100%|██████████| 1/1 [00:00<00:00, 101.72it/s]
100%|██████████| 1/1 [00:00<00:00, 40.55it/s]
100%|██████████| 1/1 [00:00<00:00, 104.06it/s]
100%|██████████| 1/1 [00:00<00:00, 99.08it/s]
100%|██████████| 1/1 [00:00<00:00, 72.14it/s]
100%|██████████| 1/1 [00:00<00:00, 95.21it/s]
100%|██████████| 1/1 [00:00<00:00, 94.05it/s]
100%|██████████| 1/1 [00:00<00:00, 99.86it/s]
100%|██████████| 1/1 [00:00<00:00, 80.79it/s]
100%|██████████| 1/1 [00:00<00:00, 101.03it/s]
100%|██████████| 1/1 [00:00<00:00, 74.52it/s]
100%|██████████| 1/1 [00:00<00:00, 97.08it/s]
100%|██████████| 1/1 [00:00<00:00, 97.74it/s]
100%|██████████| 1/1 [00:00<

Fold 1 Test Accuracy: 0.8750, Precision: 0.9000, Recall: 0.8750, F1: 0.8730

Fold 2/5


100%|██████████| 1/1 [00:00<00:00, 104.85it/s]
100%|██████████| 1/1 [00:00<00:00, 110.36it/s]
100%|██████████| 1/1 [00:00<00:00, 101.65it/s]
100%|██████████| 1/1 [00:00<00:00, 102.31it/s]
100%|██████████| 1/1 [00:00<00:00, 113.27it/s]
100%|██████████| 1/1 [00:00<00:00, 97.77it/s]
100%|██████████| 1/1 [00:00<00:00, 103.21it/s]
100%|██████████| 1/1 [00:00<00:00, 96.73it/s]
100%|██████████| 1/1 [00:00<00:00, 92.07it/s]
100%|██████████| 1/1 [00:00<00:00, 112.61it/s]
100%|██████████| 1/1 [00:00<00:00, 94.40it/s]
100%|██████████| 1/1 [00:00<00:00, 86.28it/s]
100%|██████████| 1/1 [00:00<00:00, 102.04it/s]
100%|██████████| 1/1 [00:00<00:00, 104.00it/s]
100%|██████████| 1/1 [00:00<00:00, 93.29it/s]
100%|██████████| 1/1 [00:00<00:00, 108.34it/s]
100%|██████████| 1/1 [00:00<00:00, 98.78it/s]
100%|██████████| 1/1 [00:00<00:00, 93.99it/s]
100%|██████████| 1/1 [00:00<00:00, 103.41it/s]
100%|██████████| 1/1 [00:00<00:00, 107.25it/s]
100%|██████████| 1/1 [00:00<00:00, 43.24it/s]
100%|██████████| 1/1 [

Fold 2 Test Accuracy: 0.6250, Precision: 0.6333, Recall: 0.6250, F1: 0.6190

Fold 3/5


100%|██████████| 1/1 [00:00<00:00, 91.05it/s]
100%|██████████| 1/1 [00:00<00:00, 105.36it/s]
100%|██████████| 1/1 [00:00<00:00, 101.83it/s]
100%|██████████| 1/1 [00:00<00:00, 111.13it/s]
100%|██████████| 1/1 [00:00<00:00, 46.07it/s]
100%|██████████| 1/1 [00:00<00:00, 101.02it/s]
100%|██████████| 1/1 [00:00<00:00, 116.63it/s]
100%|██████████| 1/1 [00:00<00:00, 103.97it/s]
100%|██████████| 1/1 [00:00<00:00, 117.86it/s]
100%|██████████| 1/1 [00:00<00:00, 113.33it/s]
100%|██████████| 1/1 [00:00<00:00, 100.01it/s]
100%|██████████| 1/1 [00:00<00:00, 103.68it/s]
100%|██████████| 1/1 [00:00<00:00, 101.03it/s]
100%|██████████| 1/1 [00:00<00:00, 89.09it/s]
100%|██████████| 1/1 [00:00<00:00, 104.76it/s]
100%|██████████| 1/1 [00:00<00:00, 102.84it/s]
100%|██████████| 1/1 [00:00<00:00, 92.29it/s]
100%|██████████| 1/1 [00:00<00:00, 97.00it/s]
100%|██████████| 1/1 [00:00<00:00, 112.53it/s]
100%|██████████| 1/1 [00:00<00:00, 94.89it/s]
100%|██████████| 1/1 [00:00<00:00, 103.57it/s]
100%|██████████| 1/

Fold 3 Test Accuracy: 0.7500, Precision: 0.8000, Recall: 0.8000, F1: 0.7500

Fold 4/5


100%|██████████| 1/1 [00:00<00:00, 31.98it/s]
100%|██████████| 1/1 [00:00<00:00, 79.63it/s]
100%|██████████| 1/1 [00:00<00:00, 66.85it/s]
100%|██████████| 1/1 [00:00<00:00, 82.62it/s]
100%|██████████| 1/1 [00:00<00:00, 79.60it/s]
100%|██████████| 1/1 [00:00<00:00, 77.55it/s]
100%|██████████| 1/1 [00:00<00:00, 56.24it/s]
100%|██████████| 1/1 [00:00<00:00, 27.19it/s]
100%|██████████| 1/1 [00:00<00:00, 65.42it/s]
100%|██████████| 1/1 [00:00<00:00, 43.32it/s]
100%|██████████| 1/1 [00:00<00:00, 65.05it/s]
100%|██████████| 1/1 [00:00<00:00, 80.57it/s]
100%|██████████| 1/1 [00:00<00:00, 67.86it/s]
100%|██████████| 1/1 [00:00<00:00, 78.96it/s]
100%|██████████| 1/1 [00:00<00:00, 73.70it/s]
100%|██████████| 1/1 [00:00<00:00, 38.72it/s]
100%|██████████| 1/1 [00:00<00:00, 70.56it/s]
100%|██████████| 1/1 [00:00<00:00, 77.37it/s]
100%|██████████| 1/1 [00:00<00:00, 86.42it/s]
100%|██████████| 1/1 [00:00<00:00, 88.41it/s]
100%|██████████| 1/1 [00:00<00:00, 77.56it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.7143, Precision: 0.8333, Recall: 0.6667, F1: 0.6500

Fold 5/5


100%|██████████| 1/1 [00:00<00:00, 77.48it/s]
100%|██████████| 1/1 [00:00<00:00, 83.67it/s]
100%|██████████| 1/1 [00:00<00:00, 81.57it/s]
100%|██████████| 1/1 [00:00<00:00, 87.45it/s]
100%|██████████| 1/1 [00:00<00:00, 65.94it/s]
100%|██████████| 1/1 [00:00<00:00, 74.42it/s]
100%|██████████| 1/1 [00:00<00:00, 82.70it/s]
100%|██████████| 1/1 [00:00<00:00, 78.47it/s]
100%|██████████| 1/1 [00:00<00:00, 70.03it/s]
100%|██████████| 1/1 [00:00<00:00, 90.10it/s]
100%|██████████| 1/1 [00:00<00:00, 84.45it/s]
100%|██████████| 1/1 [00:00<00:00, 78.82it/s]
100%|██████████| 1/1 [00:00<00:00, 83.23it/s]
100%|██████████| 1/1 [00:00<00:00, 83.88it/s]
100%|██████████| 1/1 [00:00<00:00, 81.82it/s]
100%|██████████| 1/1 [00:00<00:00, 88.17it/s]
100%|██████████| 1/1 [00:00<00:00, 77.97it/s]
100%|██████████| 1/1 [00:00<00:00, 84.04it/s]
100%|██████████| 1/1 [00:00<00:00, 85.62it/s]
100%|██████████| 1/1 [00:00<00:00, 76.25it/s]
100%|██████████| 1/1 [00:00<00:00, 71.90it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 1.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000

5-Fold CV Results:
Mean Accuracy  = 0.7929 ± 0.1310
Mean Precision = 0.8333 ± 0.1211
Mean Recall    = 0.7933 ± 0.1370
Mean F1-Score  = 0.7784 ± 0.1419


In [None]:
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
)

set_seed(42)
# ---- Cross-Validation ----
base_dir = "drive/MyDrive/thesis2025/split_dataset_vit3"
split_dirs = ["train", "val", "test"]

all_subject_dirs = []
for split in split_dirs:
    all_subject_dirs.extend(sorted(glob(os.path.join(base_dir, split, "*"))))

labels = [1 if os.path.basename(d).startswith("0201") else 0 for d in all_subject_dirs]
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

results = []
fold_reports = []
fold_conf_matrices = []

for fold_idx, (train_idx, test_idx) in enumerate(skf.split(all_subject_dirs, labels)):
    print(f"\nFold {fold_idx + 1}/5")
    train_subjs = [all_subject_dirs[i] for i in train_idx]
    test_subjs = [all_subject_dirs[i] for i in test_idx]
    train_subjs, val_subjs = train_test_split(train_subjs, test_size=0.1, random_state=42)

    train_set = TextDataset(train_subjs)
    val_set = TextDataset(val_subjs)
    test_set = TextDataset(test_subjs)

    train_loader = DataLoader(train_set, batch_size=100, shuffle=True, collate_fn=collate_fn_padd)
    val_loader = DataLoader(val_set, batch_size=100, collate_fn=collate_fn_padd)
    test_loader = DataLoader(test_set, batch_size=100, collate_fn=collate_fn_padd)

    model = ConvPoolReLUClassifier(input_dim=768).to(device)
    optimizer = optim.Adamax(model.parameters(), lr=0.0004)
    criterion = nn.CrossEntropyLoss()

    trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, device)

    # ---- Evaluation ----
    trained_model.eval()
    correct = total = 0
    fold_preds = []
    fold_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            output = trained_model(x)
            preds = output.argmax(dim=1)

            fold_preds.extend(preds.cpu().numpy())
            fold_labels.extend(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    precision = precision_score(fold_labels, fold_preds, average='macro')
    recall = recall_score(fold_labels, fold_preds, average='macro')
    f1 = f1_score(fold_labels, fold_preds, average='macro')

    print(f"Fold {fold_idx + 1} Test Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    results.append((acc, precision, recall, f1))

    # Store classification report and confusion matrix for later averaging
    report = classification_report(fold_labels, fold_preds, output_dict=True, target_names=["Healthy", "Depressed"])
    conf_matrix = confusion_matrix(fold_labels, fold_preds)

    fold_reports.append(report)
    fold_conf_matrices.append(conf_matrix)

# ---- Final Report ----
results = np.array(results)
mean_acc, mean_prec, mean_rec, mean_f1 = results.mean(axis=0)
std_acc = results[:, 0].std()
std_prec = results[:, 1].std()
std_rec = results[:, 2].std()
std_f1 = results[:, 3].std()

print(f"\n5-Fold CV Results:")
print(f"Mean Accuracy  = {mean_acc:.4f} ± {std_acc:.4f}")
print(f"Mean Precision = {mean_prec:.4f} ± {std_prec:.4f}")
print(f"Mean Recall    = {mean_rec:.4f} ± {std_rec:.4f}")
print(f"Mean F1-Score  = {mean_f1:.4f} ± {std_f1:.4f}")


Fold 1/5


100%|██████████| 1/1 [00:00<00:00,  1.13it/s]
100%|██████████| 1/1 [00:00<00:00, 39.09it/s]
100%|██████████| 1/1 [00:00<00:00, 33.54it/s]
100%|██████████| 1/1 [00:00<00:00, 40.60it/s]
100%|██████████| 1/1 [00:00<00:00, 41.26it/s]
100%|██████████| 1/1 [00:00<00:00, 43.51it/s]
100%|██████████| 1/1 [00:00<00:00, 41.16it/s]
100%|██████████| 1/1 [00:00<00:00, 32.04it/s]
100%|██████████| 1/1 [00:00<00:00, 25.99it/s]
100%|██████████| 1/1 [00:00<00:00, 46.39it/s]
100%|██████████| 1/1 [00:00<00:00, 50.41it/s]
100%|██████████| 1/1 [00:00<00:00, 40.72it/s]
100%|██████████| 1/1 [00:00<00:00, 44.73it/s]
100%|██████████| 1/1 [00:00<00:00, 41.51it/s]
100%|██████████| 1/1 [00:00<00:00, 48.96it/s]
100%|██████████| 1/1 [00:00<00:00, 43.03it/s]
100%|██████████| 1/1 [00:00<00:00, 36.50it/s]
100%|██████████| 1/1 [00:00<00:00, 51.25it/s]
100%|██████████| 1/1 [00:00<00:00, 47.43it/s]
100%|██████████| 1/1 [00:00<00:00, 47.25it/s]
100%|██████████| 1/1 [00:00<00:00, 42.86it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 1 Test Accuracy: 0.8750, Precision: 0.9000, Recall: 0.8750, F1: 0.8730

Fold 2/5


100%|██████████| 1/1 [00:00<00:00, 36.06it/s]
100%|██████████| 1/1 [00:00<00:00, 36.70it/s]
100%|██████████| 1/1 [00:00<00:00, 45.15it/s]
100%|██████████| 1/1 [00:00<00:00, 48.01it/s]
100%|██████████| 1/1 [00:00<00:00, 37.85it/s]
100%|██████████| 1/1 [00:00<00:00, 38.50it/s]
100%|██████████| 1/1 [00:00<00:00, 39.57it/s]
100%|██████████| 1/1 [00:00<00:00, 44.87it/s]
100%|██████████| 1/1 [00:00<00:00, 35.99it/s]
100%|██████████| 1/1 [00:00<00:00, 43.74it/s]
100%|██████████| 1/1 [00:00<00:00, 46.98it/s]
100%|██████████| 1/1 [00:00<00:00, 44.46it/s]
100%|██████████| 1/1 [00:00<00:00, 45.09it/s]
100%|██████████| 1/1 [00:00<00:00, 44.63it/s]
100%|██████████| 1/1 [00:00<00:00, 35.12it/s]
100%|██████████| 1/1 [00:00<00:00, 37.88it/s]
100%|██████████| 1/1 [00:00<00:00, 37.60it/s]
100%|██████████| 1/1 [00:00<00:00, 42.29it/s]
100%|██████████| 1/1 [00:00<00:00, 29.02it/s]
100%|██████████| 1/1 [00:00<00:00, 43.05it/s]
100%|██████████| 1/1 [00:00<00:00, 48.97it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 2 Test Accuracy: 0.6250, Precision: 0.6333, Recall: 0.6250, F1: 0.6190

Fold 3/5


100%|██████████| 1/1 [00:00<00:00, 47.20it/s]
100%|██████████| 1/1 [00:00<00:00, 43.77it/s]
100%|██████████| 1/1 [00:00<00:00, 43.55it/s]
100%|██████████| 1/1 [00:00<00:00, 45.44it/s]
100%|██████████| 1/1 [00:00<00:00, 46.98it/s]
100%|██████████| 1/1 [00:00<00:00, 37.29it/s]
100%|██████████| 1/1 [00:00<00:00, 40.18it/s]
100%|██████████| 1/1 [00:00<00:00, 37.80it/s]
100%|██████████| 1/1 [00:00<00:00, 37.81it/s]
100%|██████████| 1/1 [00:00<00:00, 39.62it/s]
100%|██████████| 1/1 [00:00<00:00, 52.54it/s]
100%|██████████| 1/1 [00:00<00:00, 45.64it/s]
100%|██████████| 1/1 [00:00<00:00, 50.00it/s]
100%|██████████| 1/1 [00:00<00:00, 46.45it/s]
100%|██████████| 1/1 [00:00<00:00, 51.27it/s]
100%|██████████| 1/1 [00:00<00:00, 51.52it/s]
100%|██████████| 1/1 [00:00<00:00, 40.34it/s]
100%|██████████| 1/1 [00:00<00:00, 45.94it/s]
100%|██████████| 1/1 [00:00<00:00, 44.83it/s]
100%|██████████| 1/1 [00:00<00:00, 30.67it/s]
100%|██████████| 1/1 [00:00<00:00, 53.15it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 3 Test Accuracy: 0.7500, Precision: 0.8000, Recall: 0.8000, F1: 0.7500

Fold 4/5


100%|██████████| 1/1 [00:00<00:00, 36.63it/s]
100%|██████████| 1/1 [00:00<00:00, 39.93it/s]
100%|██████████| 1/1 [00:00<00:00, 37.52it/s]
100%|██████████| 1/1 [00:00<00:00, 39.07it/s]
100%|██████████| 1/1 [00:00<00:00, 39.74it/s]
100%|██████████| 1/1 [00:00<00:00, 40.01it/s]
100%|██████████| 1/1 [00:00<00:00, 39.90it/s]
100%|██████████| 1/1 [00:00<00:00, 34.57it/s]
100%|██████████| 1/1 [00:00<00:00, 37.44it/s]
100%|██████████| 1/1 [00:00<00:00, 41.76it/s]
100%|██████████| 1/1 [00:00<00:00, 35.07it/s]
100%|██████████| 1/1 [00:00<00:00, 37.83it/s]
100%|██████████| 1/1 [00:00<00:00, 34.40it/s]
100%|██████████| 1/1 [00:00<00:00, 38.63it/s]
100%|██████████| 1/1 [00:00<00:00, 39.63it/s]
100%|██████████| 1/1 [00:00<00:00, 36.28it/s]
100%|██████████| 1/1 [00:00<00:00, 40.17it/s]
100%|██████████| 1/1 [00:00<00:00, 37.42it/s]
100%|██████████| 1/1 [00:00<00:00, 37.31it/s]
100%|██████████| 1/1 [00:00<00:00, 39.51it/s]
100%|██████████| 1/1 [00:00<00:00, 36.63it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 4 Test Accuracy: 0.7143, Precision: 0.8333, Recall: 0.6667, F1: 0.6500

Fold 5/5


100%|██████████| 1/1 [00:00<00:00, 39.52it/s]
100%|██████████| 1/1 [00:00<00:00, 41.15it/s]
100%|██████████| 1/1 [00:00<00:00, 39.99it/s]
100%|██████████| 1/1 [00:00<00:00, 42.63it/s]
100%|██████████| 1/1 [00:00<00:00, 42.40it/s]
100%|██████████| 1/1 [00:00<00:00, 39.61it/s]
100%|██████████| 1/1 [00:00<00:00, 40.04it/s]
100%|██████████| 1/1 [00:00<00:00, 42.62it/s]
100%|██████████| 1/1 [00:00<00:00, 33.80it/s]
100%|██████████| 1/1 [00:00<00:00, 37.63it/s]
100%|██████████| 1/1 [00:00<00:00, 38.06it/s]
100%|██████████| 1/1 [00:00<00:00, 37.08it/s]
100%|██████████| 1/1 [00:00<00:00, 41.38it/s]
100%|██████████| 1/1 [00:00<00:00, 41.17it/s]
100%|██████████| 1/1 [00:00<00:00, 38.06it/s]
100%|██████████| 1/1 [00:00<00:00, 33.35it/s]
100%|██████████| 1/1 [00:00<00:00, 43.47it/s]
100%|██████████| 1/1 [00:00<00:00, 32.83it/s]
100%|██████████| 1/1 [00:00<00:00, 38.37it/s]
100%|██████████| 1/1 [00:00<00:00, 38.52it/s]
100%|██████████| 1/1 [00:00<00:00, 37.67it/s]
100%|██████████| 1/1 [00:00<00:00,

Fold 5 Test Accuracy: 1.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000

5-Fold CV Results:
Mean Accuracy  = 0.7929 ± 0.1310
Mean Precision = 0.8333 ± 0.1211
Mean Recall    = 0.7933 ± 0.1370
Mean F1-Score  = 0.7784 ± 0.1419



