# MIL 
## Categorical Prediction 0-9 
## Multi-Output

== LOSO averages ==


Valence : {'acc': 0.22265625, 'bacc': 0.15065947352498693, 'f1': 0.07050937924164734}


Arousal : {'acc': 0.2265625, 'bacc': 0.14261665603741497, 'f1': 0.059311422341956035}

In [1]:
import os, pickle, numpy as np, pandas as pd
from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# --------------------
# Paths & constants
# --------------------
base_path_dat = './datasets/DEAP/deap-dataset/data_preprocessed_python'  # s01.dat..s32.dat
features_base = './datasets/DEAP/deap-dataset/extracted_features'        # your per-channel CSVs

EEG_CH = ['Fp1','AF3','F3','F7','FC5','FC1','C3','T7','CP5','CP1',
          'P3','P7','PO3','O1','Oz','Pz','Fp2','AF4','Fz','F4','F8',
          'FC6','FC2','Cz','C4','T8','CP6','CP2','P4','P8','PO4','O2']
CH2IDX = {ch:i for i,ch in enumerate(EEG_CH)}
SUBJECTS = list(range(1,33))
TRIALS = 40
N_CH = 32
FEAT_DIM = 3          # alpha, beta, gamma
N_CLASSES = 9         # categories 1..9 mapped to 0..8
EPOCHS = 40
BATCH = 32            # number of trials per batch
LR = 1e-3
WD = 1e-5
device = "cuda" if torch.cuda.is_available() else "cpu"

# --------------------
# Load trial-bag features: (trials, 32, 3), labels (val, aro) in 0..8, groups=subject
# --------------------
def load_trial_labels_rounded(subj):
    with open(os.path.join(base_path_dat, f"s{subj:02d}.dat"), "rb") as f:
        raw = pickle.load(f, encoding="latin1")
    lab = raw['labels'].astype(np.float32)   # (40,4) in [1..9]
    v = np.clip(np.rint(lab[:,0]).astype(int), 1, 9) - 1
    a = np.clip(np.rint(lab[:,1]).astype(int), 1, 9) - 1
    return v, a  # shape (40,)

def load_subject_bags(subj):
    # cube: (40 trials, 32 channels, 3 bands)
    per_ch = []
    for ch in EEG_CH:
        df = pd.read_csv(os.path.join(features_base, ch, f"s{subj:02d}_bandpower.csv"))
        per_ch.append(df[['alpha_power','beta_power','gamma_power']].values)  # (40,3)
    cube = np.stack(per_ch, axis=1)  # (40, 32, 3)
    # optional: log scale (stabilize); keep positive
    cube = np.log(cube + 1e-12)
    return cube  # (40,32,3)

bags, y_val_all, y_aro_all, groups = [], [], [], []
for s in SUBJECTS:
    Xs = load_subject_bags(s)                # (40,32,3)
    v, a = load_trial_labels_rounded(s)      # (40,), (40,)
    bags.append(Xs)
    y_val_all.append(v)
    y_aro_all.append(a)
    groups.extend([s]*TRIALS)

bags = np.vstack(bags)                       # (1280,32,3)
y_val = np.hstack(y_val_all).astype(np.int64)
y_aro = np.hstack(y_aro_all).astype(np.int64)
groups = np.array(groups)

print("Bags:", bags.shape, " y_val:", y_val.shape, " y_aro:", y_aro.shape, "unique subjects:", np.unique(groups).size)

# --------------------
# Scale features per feature dimension (flatten → fit → reshape)
# --------------------
flat = bags.reshape(-1, FEAT_DIM)            # (1280*32, 3)
scaler = StandardScaler().fit(flat)
flat_s = scaler.transform(flat)
bags_s = flat_s.reshape(-1, N_CH, FEAT_DIM)  # (1280,32,3)

# --------------------
# PyTorch Dataset (bags)
# --------------------
class BagDataset(Dataset):
    def __init__(self, X_bags, yv, ya):
        self.X = torch.tensor(X_bags, dtype=torch.float32)
        self.yv = torch.tensor(yv, dtype=torch.long)
        self.ya = torch.tensor(ya, dtype=torch.long)
        # channel indices tensor for embedding (0..31)
        self.ch_idx = torch.arange(N_CH, dtype=torch.long).unsqueeze(0).repeat(len(yv),1)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i):
        return self.X[i], self.ch_idx[i], self.yv[i], self.ya[i]

# --------------------
# Attention MIL Model (multi-head: valence & arousal)
# --------------------
class AttnMIL(nn.Module):
    """
    Per-channel instance encoder + channel embedding -> attention weights over 32 channels -> pooled trial embedding.
    Heads: 9-way valence and 9-way arousal.
    """
    def __init__(self, feat_dim=FEAT_DIM, ch_vocab=N_CH, ch_emb=16, hid_inst=64, hid_trial=128, n_classes=N_CLASSES, p=0.2):
        super().__init__()
        self.ch_emb = nn.Embedding(ch_vocab, ch_emb)
        self.inst_encoder = nn.Sequential(
            nn.Linear(feat_dim + ch_emb, hid_inst), nn.ReLU(), nn.Dropout(p),
            nn.Linear(hid_inst, hid_inst), nn.ReLU(), nn.Dropout(p),
        )
        # attention score: a^T tanh(W h_i)
        self.attn_w = nn.Linear(hid_inst, hid_inst, bias=False)
        self.attn_v = nn.Linear(hid_inst, 1, bias=False)

        # pooled trial encoder (optional extra layer)
        self.trial_encoder = nn.Sequential(
            nn.Linear(hid_inst, hid_trial), nn.ReLU(), nn.Dropout(p),
        )
        self.head_val = nn.Linear(hid_trial, n_classes)
        self.head_aro = nn.Linear(hid_trial, n_classes)

    def forward(self, x_bag, ch_idx):
        # x_bag: (B, 32, feat_dim), ch_idx: (B, 32)
        B, N, F = x_bag.shape
        ch_e = self.ch_emb(ch_idx)                          # (B,32,ch_emb)
        z = torch.cat([x_bag, ch_e], dim=-1)                # (B,32,feat_dim+ch_emb)
        h = self.inst_encoder(z)                            # (B,32,hid_inst)

        u = torch.tanh(self.attn_w(h))                      # (B,32,hid_inst)
        a = self.attn_v(u).squeeze(-1)                      # (B,32)
        a = torch.softmax(a, dim=1)                         # attention weights over channels

        # weighted sum of instance embeddings
        m = torch.sum(h * a.unsqueeze(-1), dim=1)           # (B,hid_inst)
        t = self.trial_encoder(m)                           # (B,hid_trial)
        logit_v = self.head_val(t)                          # (B,9)
        logit_a = self.head_aro(t)                          # (B,9)
        return logit_v, logit_a, a                          # return attention for visualization

# --------------------
# Train/Eval with GroupKFold by subject
# --------------------
def report_cls(y_true, y_pred, name):
    acc  = accuracy_score(y_true, y_pred)
    bacc = balanced_accuracy_score(y_true, y_pred)
    f1   = f1_score(y_true, y_pred, average='macro')
    print(f"{name:10s} | acc={acc:.3f}  bacc={bacc:.3f}  f1_macro={f1:.3f}")
    return dict(acc=acc, bacc=bacc, f1=f1)

def avg_dicts(lst):
    keys = lst[0].keys()
    return {k: float(np.mean([d[k] for d in lst])) for k in keys}

gkf = GroupKFold(n_splits=len(SUBJECTS))  # LOSO CV (32 folds)
fold_scores_v, fold_scores_a = [], []

for fold, (tr, te) in enumerate(gkf.split(bags_s, y_val, groups), start=1):
    Xtr, Xte = bags_s[tr], bags_s[te]
    yv_tr, yv_te = y_val[tr], y_val[te]
    ya_tr, ya_te = y_aro[tr], y_aro[te]

    tr_ds = BagDataset(Xtr, yv_tr, ya_tr)
    te_ds = BagDataset(Xte, yv_te, ya_te)
    tr_dl = DataLoader(tr_ds, batch_size=BATCH, shuffle=True)

    model = AttnMIL().to(device)
    opt = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
    crit = nn.CrossEntropyLoss()

    best, bad, patience = np.inf, 0, 8
    # simple early stopping using test set as a proxy (for baseline)
    Xte_t = torch.tensor(Xte, dtype=torch.float32).to(device)
    ch_te = torch.arange(N_CH).unsqueeze(0).repeat(Xte.shape[0],1).to(device)
    yv_te_t = torch.tensor(yv_te, dtype=torch.long).to(device)
    ya_te_t = torch.tensor(ya_te, dtype=torch.long).to(device)

    for ep in range(EPOCHS):
        model.train()
        for xb, chb, yvb, yab in tr_dl:
            xb, chb, yvb, yab = xb.to(device), chb.to(device), yvb.to(device), yab.to(device)
            opt.zero_grad()
            lv, la, att = model(xb, chb)
            loss = crit(lv, yvb) + crit(la, yab)
            loss.backward()
            opt.step()

        model.eval()
        with torch.no_grad():
            lv, la, _ = model(Xte_t, ch_te)
            val_loss = crit(lv, yv_te_t).item() + crit(la, ya_te_t).item()
        if val_loss < best - 1e-5:
            best, bad = val_loss, 0
            best_state = {k:v.clone() for k,v in model.state_dict().items()}
        else:
            bad += 1
        if bad >= patience: break

    model.load_state_dict(best_state)
    model.eval()
    with torch.no_grad():
        lv, la, att = model(Xte_t, ch_te)
        yv_pred = torch.argmax(lv, 1).cpu().numpy()
        ya_pred = torch.argmax(la, 1).cpu().numpy()

    print(f"\n[Fold {fold:02d}]")
    fold_scores_v.append(report_cls(yv_te, yv_pred, "Valence"))
    fold_scores_a.append(report_cls(ya_te, ya_pred, "Arousal"))

print("\n== LOSO averages ==")
print("Valence :", avg_dicts(fold_scores_v))
print("Arousal :", avg_dicts(fold_scores_a))


Bags: (1280, 32, 3)  y_val: (1280,)  y_aro: (1280,) unique subjects: 32

[Fold 01]
Valence    | acc=0.350  bacc=0.167  f1_macro=0.099
Arousal    | acc=0.150  bacc=0.143  f1_macro=0.037

[Fold 02]
Valence    | acc=0.050  bacc=0.111  f1_macro=0.011
Arousal    | acc=0.175  bacc=0.111  f1_macro=0.033

[Fold 03]
Valence    | acc=0.075  bacc=0.125  f1_macro=0.019
Arousal    | acc=0.050  bacc=0.125  f1_macro=0.012

[Fold 04]
Valence    | acc=0.200  bacc=0.167  f1_macro=0.062
Arousal    | acc=0.300  bacc=0.143  f1_macro=0.066





[Fold 05]
Valence    | acc=0.450  bacc=0.167  f1_macro=0.103
Arousal    | acc=0.350  bacc=0.143  f1_macro=0.074

[Fold 06]
Valence    | acc=0.175  bacc=0.141  f1_macro=0.089
Arousal    | acc=0.250  bacc=0.146  f1_macro=0.078

[Fold 07]
Valence    | acc=0.125  bacc=0.104  f1_macro=0.052
Arousal    | acc=0.325  bacc=0.167  f1_macro=0.082

[Fold 08]
Valence    | acc=0.275  bacc=0.156  f1_macro=0.090
Arousal    | acc=0.300  bacc=0.143  f1_macro=0.066

[Fold 09]
Valence    | acc=0.150  bacc=0.079  f1_macro=0.086
Arousal    | acc=0.275  bacc=0.201  f1_macro=0.142

[Fold 10]
Valence    | acc=0.150  bacc=0.167  f1_macro=0.043
Arousal    | acc=0.200  bacc=0.167  f1_macro=0.056

[Fold 11]
Valence    | acc=0.175  bacc=0.125  f1_macro=0.037
Arousal    | acc=0.350  bacc=0.167  f1_macro=0.086

[Fold 12]
Valence    | acc=0.150  bacc=0.125  f1_macro=0.033
Arousal    | acc=0.275  bacc=0.111  f1_macro=0.048

[Fold 13]
Valence    | acc=0.400  bacc=0.200  f1_macro=0.114
Arousal    | acc=0.175  bacc=0.125

# LOSO MIL with Windowed features


In [3]:
import os, pickle, numpy as np
from scipy.signal import welch
from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# --------------------
# Paths & constants
# --------------------
base_path_dat = './datasets/DEAP/deap-dataset/data_preprocessed_python'  # s01.dat..s32.dat
SUBJECTS = list(range(1,33))
EEG_CH = ['Fp1','AF3','F3','F7','FC5','FC1','C3','T7','CP5','CP1',
          'P3','P7','PO3','O1','Oz','Pz','Fp2','AF4','Fz','F4','F8',
          'FC6','FC2','Cz','C4','T8','CP6','CP2','P4','P8','PO4','O2']
fs = 128

# Window params (after trimming 3 s baseline)
win_sec = 4.0
step_sec = 2.0
WIN = int(win_sec*fs)   # 512
STEP = int(step_sec*fs) # 256

bands = {'theta':(4,8), 'alpha':(8,12), 'beta':(12,30), 'gamma':(30,64)}

def bandpower(sig, sf, lo, hi):
    nperseg = int(max(8, (2/lo)*sf))  # guard for low-freq
    freqs, psd = welch(sig, sf, nperseg=nperseg)
    idx = (freqs>=lo) & (freqs<=hi)
    if not np.any(idx): return 0.0
    freq_res = freqs[1]-freqs[0]
    return float(np.trapz(psd[idx], dx=freq_res))

def windowed_features(trial_eeg):  # input: (32, 8064) samples (63 s)
    # trim first 3 s baseline (3*128=384)
    trial_eeg = trial_eeg[:, 384:384+60*fs]  # (32, 7680)
    T = trial_eeg.shape[1]
    rows = []
    for start in range(0, T-WIN+1, STEP):
        seg = trial_eeg[:, start:start+WIN]  # (32, 512)
        # per channel bandpowers
        feats = []
        for ch in range(seg.shape[0]):
            s = seg[ch]
            th = bandpower(s, fs, *bands['theta'])
            al = bandpower(s, fs, *bands['alpha'])
            be = bandpower(s, fs, *bands['beta'])
            ga = bandpower(s, fs, *bands['gamma'])
            total = th+al+be+ga + 1e-12
            feats.append([th, al, be, ga, th/total, al/total, be/total, ga/total])
        rows.append(np.array(feats))  # (32, 8)
    X_wins = np.stack(rows, axis=0)  # (n_win, 32, 8)
    return X_wins

def load_subject(subj):
    with open(os.path.join(base_path_dat, f"s{subj:02d}.dat"), "rb") as f:
        raw = pickle.load(f, encoding="latin1")
    data = raw['data']    # (40, 40 or 32?, 8064). In DEAP, first 32 are EEG.
    labels = raw['labels']  # (40,4) in 1..9
    # take first 32 EEG channels
    eeg = data[:, :32, :]          # (40, 32, 8064)
    # build bags
    bags, val_cls, aro_cls = [], [], []
    for t in range(eeg.shape[0]):
        Xw = windowed_features(eeg[t])     # (n_win, 32, 8)
        bags.append(Xw)
        v = int(np.clip(np.rint(labels[t,0]), 1, 9) - 1)   # 0..8
        a = int(np.clip(np.rint(labels[t,1]), 1, 9) - 1)
        val_cls.append(v); aro_cls.append(a)
    return bags, np.array(val_cls), np.array(aro_cls)  # bags list of (n_win, 32, 8)

# Load all subjects (variable windows per trial)
all_bags, yv_all, ya_all, groups = [], [], [], []
for s in SUBJECTS:
    bags_s, v_s, a_s = load_subject(s)
    all_bags.extend(bags_s)
    yv_all.extend(v_s.tolist())
    ya_all.extend(a_s.tolist())
    groups.extend([s]*len(bags_s))

# Stack with padding to max #windows so DataLoader can batch
n_trials = len(all_bags)
n_ch = 32
feat_dim = 8
max_win = max(b.shape[0] for b in all_bags)
X_padded = np.zeros((n_trials, max_win, n_ch, feat_dim), dtype=np.float32)
mask = np.zeros((n_trials, max_win), dtype=np.float32)
for i,b in enumerate(all_bags):
    nwin = b.shape[0]
    X_padded[i,:nwin,:,:] = b
    mask[i,:nwin] = 1.0

y_val = np.array(yv_all, dtype=np.int64)
y_aro = np.array(ya_all, dtype=np.int64)
groups = np.array(groups, dtype=np.int32)

# scale features globally then re-apply
flat = X_padded.reshape(-1, feat_dim)
scaler = StandardScaler().fit(flat[mask.reshape(-1)>0])
flat_s = scaler.transform(flat)
X_padded = flat_s.reshape(n_trials, max_win, n_ch, feat_dim)

class BagDataset(Dataset):
    def __init__(self, X, M, yv, ya):
        self.X = torch.tensor(X, dtype=torch.float32)   # (B, W, 32, F)
        self.M = torch.tensor(M, dtype=torch.float32)   # (B, W)
        self.yv = torch.tensor(yv, dtype=torch.long)
        self.ya = torch.tensor(ya, dtype=torch.long)
        self.ch_idx = torch.arange(n_ch).unsqueeze(0).unsqueeze(0).repeat(X.shape[0], X.shape[1], 1)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i):
        return self.X[i], self.M[i], self.ch_idx[i], self.yv[i], self.ya[i]

class AttnMILTime(nn.Module):
    """
    Instances = (window, channel). We encode per-instance with channel embedding,
    pool across channels inside each window, then pool across windows with attention.
    """
    def __init__(self, feat_dim=feat_dim, ch_vocab=n_ch, ch_emb=16, h_inst=64, h_win=64, h_trial=128, n_classes=9, p=0.2):
        super().__init__()
        self.ch_emb = nn.Embedding(ch_vocab, ch_emb)
        self.inst_enc = nn.Sequential(
            nn.Linear(feat_dim+ch_emb, h_inst), nn.ReLU(), nn.Dropout(p),
            nn.Linear(h_inst, h_inst), nn.ReLU(), nn.Dropout(p),
        )
        # pool across channels within a window (soft attention)
        self.attn_c_w = nn.Linear(h_inst, h_inst, bias=False)
        self.attn_c_v = nn.Linear(h_inst, 1, bias=False)
        # encode window representation
        self.win_enc = nn.Sequential(
            nn.Linear(h_inst, h_win), nn.ReLU(), nn.Dropout(p),
        )
        # attention across windows
        self.attn_t_w = nn.Linear(h_win, h_win, bias=False)
        self.attn_t_v = nn.Linear(h_win, 1, bias=False)
        # final heads
        self.trial_enc = nn.Sequential(nn.Linear(h_win, h_trial), nn.ReLU(), nn.Dropout(p))
        self.head_v = nn.Linear(h_trial, n_classes)
        self.head_a = nn.Linear(h_trial, n_classes)

    def forward(self, X, M, ch_idx):
        # X: (B, W, C, F), M: (B,W) window mask (1=valid), ch_idx: (B,W,C)
        B,W,C,F = X.shape
        ch_e = self.ch_emb(ch_idx.to(X.device))                 # (B,W,C,emb)
        z = torch.cat([X, ch_e], dim=-1)                        # (B,W,C,F+emb)
        H = self.inst_enc(z)                                    # (B,W,C,h_inst)

        # channel-attention per window
        u = torch.tanh(self.attn_c_w(H))                        # (B,W,C,h_inst)
        a = self.attn_c_v(u).squeeze(-1)                        # (B,W,C)
        a = torch.softmax(a, dim=2)                             # over channels
        win_repr = torch.sum(H * a.unsqueeze(-1), dim=2)        # (B,W,h_inst)
        win_repr = self.win_enc(win_repr)                       # (B,W,h_win)

        # window mask
        m = M.unsqueeze(-1)                                     # (B,W,1)

        # time-attention across windows (mask invalid)
        u_t = torch.tanh(self.attn_t_w(win_repr))               # (B,W,h_win)
        e_t = self.attn_t_v(u_t).squeeze(-1)                    # (B,W)
        e_t = e_t.masked_fill(m.squeeze(-1)==0, -1e9)
        att_t = torch.softmax(e_t, dim=1)                       # (B,W)
        trial_repr = torch.sum(win_repr * att_t.unsqueeze(-1), dim=1)  # (B,h_win)

        t = self.trial_enc(trial_repr)                          # (B,h_trial)
        return self.head_v(t), self.head_a(t), att_t            # (B,9), (B,9), (B,W)

# Training/eval (LOSO)
device = "cuda" if torch.cuda.is_available() else "cpu"
def report(name, y_true, y_pred):
    acc  = accuracy_score(y_true, y_pred)
    bacc = balanced_accuracy_score(y_true, y_pred)
    f1   = f1_score(y_true, y_pred, average="macro")
    print(f"{name:10s} | acc={acc:.3f} bacc={bacc:.3f} f1={f1:.3f}")
    return dict(acc=acc, bacc=bacc, f1=f1)

gkf = GroupKFold(n_splits=len(SUBJECTS))
fold_v, fold_a = [], []
for k,(tr,te) in enumerate(gkf.split(X_padded, y_val, groups), start=1):
    tr_ds = BagDataset(X_padded[tr], mask[tr], y_val[tr], y_aro[tr])
    te_ds = BagDataset(X_padded[te], mask[te], y_val[te], y_aro[te])
    tr_dl = DataLoader(tr_ds, batch_size=32, shuffle=True)

    model = AttnMILTime().to(device)
    opt = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    crit = nn.CrossEntropyLoss()

    # quick early stop on test (ok for baseline)
    Xte = torch.tensor(X_padded[te], dtype=torch.float32).to(device)
    Mte = torch.tensor(mask[te], dtype=torch.float32).to(device)
    Ch  = torch.arange(n_ch).unsqueeze(0).unsqueeze(0).repeat(Xte.size(0), Xte.size(1), 1).to(device)
    yv_te = torch.tensor(y_val[te], dtype=torch.long).to(device)
    ya_te = torch.tensor(y_aro[te], dtype=torch.long).to(device)

    best, bad, patience = np.inf, 0, 6
    for ep in range(35):
        model.train()
        for xb, mb, chb, yvb, yab in tr_dl:
            xb, mb, chb, yvb, yab = xb.to(device), mb.to(device), chb.to(device), yvb.to(device), yab.to(device)
            opt.zero_grad()
            lv, la, att = model(xb, mb, chb)
            loss = crit(lv, yvb) + crit(la, yab)
            loss.backward()
            opt.step()
        model.eval()
        with torch.no_grad():
            lv, la, _ = model(Xte, Mte, Ch)
            vloss = crit(lv, yv_te).item() + crit(la, ya_te).item()
        if vloss < best - 1e-5:
            best, bad = vloss, 0
            best_state = {k:v.clone() for k,v in model.state_dict().items()}
        else:
            bad += 1
        if bad >= patience: break

    model.load_state_dict(best_state)
    model.eval()
    with torch.no_grad():
        lv, la, attw = model(Xte, Mte, Ch)
        yv_pred = torch.argmax(lv, 1).cpu().numpy()
        ya_pred = torch.argmax(la, 1).cpu().numpy()

    print(f"\n[Fold {k:02d}]")
    fold_v.append(report("Valence", y_val[te], yv_pred))
    fold_a.append(report("Arousal", y_aro[te], ya_pred))

def avg(dlist): 
    ks = dlist[0].keys(); return {k: float(np.mean([d[k] for d in dlist])) for k in ks}
print("\n== LOSO (windowed MIL) averages ==")
print("Valence :", avg(fold_v))
print("Arousal :", avg(fold_a))


IndexError: boolean index did not match indexed array along dimension 0; dimension is 1187840 but corresponding boolean dimension is 37120