# 04 — Deep Learning Dataset Preparation

This notebook prepares multimodal windowed physiological signals
for deep learning models (CNN + Transformer).

Inputs:
- Windowed signals from Step-2 (.npz)
- Labels aligned per window

Outputs:
- PyTorch Dataset
- PyTorch DataLoaders

No model training is performed in this step.


In [31]:
# ============================================================
# STEP 4 — Deep Learning Dataset Preparation (PATCHED)
# Physiological + Behavioral
# Compatible with Step-5 & Step-6
# ============================================================

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# -----------------------------
# CONFIG
# -----------------------------
DATA_PATH = "features_S2_w60.npz"   # Step-3 output
BATCH_SIZE = 16
TEST_SIZE = 0.2
RANDOM_STATE = 42

# -----------------------------
# Load engineered data
# -----------------------------
data = np.load(DATA_PATH)

EDA  = data["EDA_windows"]              # (N, 60)
BVP  = data["BVP_windows"]
ACC  = data["ACC_windows"]
TEMP = data["TEMP_windows"]
BEH  = data["behavior_features"]        # (N, F)
labels = data["labels"]                 # (N,)

print("Loaded:")
print("  EDA :", EDA.shape)
print("  BEH :", BEH.shape)
print("  Y   :", labels.shape)

# -----------------------------
# Normalization
# -----------------------------
# Physiological → per-window z-score
def normalize_windowwise(x):
    mean = x.mean(axis=1, keepdims=True)
    std  = x.std(axis=1, keepdims=True) + 1e-6
    return (x - mean) / std

EDA  = normalize_windowwise(EDA)
BVP  = normalize_windowwise(BVP)
ACC  = normalize_windowwise(ACC)
TEMP = normalize_windowwise(TEMP)

# Behavioral → global z-score
BEH = (BEH - BEH.mean(axis=0, keepdims=True)) / (
       BEH.std(axis=0, keepdims=True) + 1e-6
)

print("✓ Normalization complete")

# -----------------------------
# PyTorch Dataset
# -----------------------------
class WESADDataset(Dataset):
    def __init__(self, EDA, BVP, ACC, TEMP, BEH, labels):
        self.EDA  = torch.tensor(EDA,  dtype=torch.float32).unsqueeze(-1)
        self.BVP  = torch.tensor(BVP,  dtype=torch.float32).unsqueeze(-1)
        self.ACC  = torch.tensor(ACC,  dtype=torch.float32).unsqueeze(-1)
        self.TEMP = torch.tensor(TEMP, dtype=torch.float32).unsqueeze(-1)

        self.BEH  = torch.tensor(BEH,  dtype=torch.float32)  # (F,)
        self.y    = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return (
            self.EDA[idx],
            self.BVP[idx],
            self.ACC[idx],
            self.TEMP[idx],
            self.BEH[idx],
            self.y[idx]
        )

# -----------------------------
# Train / Test Split
# -----------------------------
indices = np.arange(len(labels))

train_idx, test_idx = train_test_split(
    indices,
    test_size=TEST_SIZE,
    stratify=labels,
    random_state=RANDOM_STATE
)

train_dataset = WESADDataset(
    EDA[train_idx],
    BVP[train_idx],
    ACC[train_idx],
    TEMP[train_idx],
    BEH[train_idx],
    labels[train_idx]
)

test_dataset = WESADDataset(
    EDA[test_idx],
    BVP[test_idx],
    ACC[test_idx],
    TEMP[test_idx],
    BEH[test_idx],
    labels[test_idx]
)

# -----------------------------
# DataLoaders
# -----------------------------
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False
)

# -----------------------------
# Sanity check
# -----------------------------
EDA_b, BVP_b, ACC_b, TEMP_b, BEH_b, y_b = next(iter(train_loader))

print("EDA batch :", EDA_b.shape)
print("BVP batch :", BVP_b.shape)
print("ACC batch :", ACC_b.shape)
print("TEMP batch:", TEMP_b.shape)
print("BEH batch :", BEH_b.shape)
print("Labels    :", y_b.shape)


Loaded:
  EDA : (281, 60)
  BEH : (281, 5)
  Y   : (281,)
✓ Normalization complete
EDA batch : torch.Size([16, 60, 1])
BVP batch : torch.Size([16, 60, 1])
ACC batch : torch.Size([16, 60, 1])
TEMP batch: torch.Size([16, 60, 1])
BEH batch : torch.Size([16, 5])
Labels    : torch.Size([16])
