# PatchTST — Training & Evaluation Walkthrough + Exercises

This notebook is built **from your `train_patchtst.py` script**. We'll go from:
CSV → Dataset → DataLoader → training loop → evaluation → confusion matrix → threshold sweep → session split.

**Goal:** you can explain every step in your meeting and in your thesis defense.



In [None]:
# === Setup ===
import os, sys
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

PROJECT_ROOT = Path(".").resolve()
SRC_DIR = PROJECT_ROOT / "src"
if SRC_DIR.exists():
    sys.path.insert(0, str(SRC_DIR))
else:
    print("WARNING: src/ not found. Update SRC_DIR accordingly.")

from patchtst import PatchTSTClassifier, PatchTSTConfig
from train_patchtst import ForceCSVFolderDataset, USE_COLS, LABEL_RE, metrics_from_cm, confusion_matrix_binary_from_logits

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)


## 1) Understanding the dataset object

`ForceCSVFolderDataset` does:
1. Collect all CSV files recursively
2. Parse label from filename (`_True.csv` or `_False.csv`)
3. Read 6 columns
4. (Optionally) normalize per-sample per-channel
5. Return `(x, y)` where `x: [L, C]` and `y: scalar`

### Exercise 1.1
Inspect the regex and test it on a few filenames.



In [None]:
# TODO: Exercise 1.1
test_names = [
    "FERA2025_08_28_14_46_02_True.csv",
    "FERA2025_08_28_16_09_46_False.csv",
    "something_else.csv",
]

for name in test_names:
    m = LABEL_RE.search(name)
    print(name, "->", m.group(1) if m else None)


## 2) Load real data (if available) + plot

### Exercise 2.1
Set `DATA_ROOT` to your dataset folder (project-relative) and load the dataset.
Then:
- print number of samples
- print first file path
- load first sample and confirm shapes

If you run this on a machine without the dataset, skip and continue.



In [None]:
# TODO: Exercise 2.1
DATA_ROOT = "ForceDataNovo/Old_Fixture"  # change if needed

if Path(DATA_ROOT).exists():
    ds = ForceCSVFolderDataset(DATA_ROOT, seq_len=1000, normalize=True)
    print("n_samples:", len(ds))
    print("first file:", ds.files[0])

    x, y = ds[0]
    print("x.shape:", tuple(x.shape), "y:", float(y))
else:
    print("Dataset not found at:", DATA_ROOT)
    ds = None


### Exercise 2.2 (plot one sample)

Plot the 6 channels vs time for one sample.
- x is `[L, C]`
- time index is `0..999` (500 Hz for 2 seconds)



In [None]:
# TODO: Exercise 2.2
if ds is not None:
    x, y = ds[0]
    x_np = x.numpy()  # [L, C]

    plt.figure()
    for i, col in enumerate(USE_COLS):
        plt.plot(x_np[:, i], label=col)
    plt.title(f"One sample | label={int(y.item())}")
    plt.legend()
    plt.show()


## 3) Session split vs random split (leakage control)

Your folder names look like sessions (dates). A safe evaluation is to hold out entire sessions.

### Exercise 3.1
Count files per group = `p.parent.name`.
Print:
- number of groups
- top groups by count



In [None]:
# TODO: Exercise 3.1
from collections import Counter

if ds is not None:
    groups = [p.parent.name for p in ds.files]
    print("unique groups:", len(set(groups)))
    print("top groups:", Counter(groups).most_common(10))


### Exercise 3.2
Create a session split:
- choose `val_groups` (e.g. last 2 dates)
- build `train_idx` and `val_idx`
- compute and print train/val sizes and class balance



In [None]:
# TODO: Exercise 3.2
if ds is not None:
    val_groups = {"2025_09_08", "2025_09_09"}  # change if needed

    groups = [p.parent.name for p in ds.files]
    train_idx = [i for i, g in enumerate(groups) if g not in val_groups]
    val_idx   = [i for i, g in enumerate(groups) if g in val_groups]

    train_ds = torch.utils.data.Subset(ds, train_idx)
    val_ds   = torch.utils.data.Subset(ds, val_idx)

    def labels_for_indices(indices):
        return [ForceCSVFolderDataset.label_from_name(ds.files[i].name) for i in indices]

    train_labels = labels_for_indices(train_idx)
    val_labels   = labels_for_indices(val_idx)

    print("train size:", len(train_ds), "val size:", len(val_ds))
    print("train pos:", sum(train_labels), "neg:", len(train_labels)-sum(train_labels))
    print("val   pos:", sum(val_labels), "neg:", len(val_labels)-sum(val_labels))


## 4) DataLoader and batch shapes

### Exercise 4.1
Create loaders and inspect one batch:
- expected batch x shape is `[B, L, C]`
- expected batch y shape is `[B]`



In [None]:
# TODO: Exercise 4.1
if ds is not None:
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
    xb, yb = next(iter(train_loader))
    print("xb.shape:", tuple(xb.shape))
    print("yb.shape:", tuple(yb.shape))


## 5) Training step: forward → loss → backward → update

Your script uses:
- `BCEWithLogitsLoss` (binary classification, logits)
- `AdamW`
- optional AMP (autocast + GradScaler)
- optional grad clipping

### Exercise 5.1
Run *one* training step manually and print:
- logits stats (mean/std)
- loss
- grad norm before/after clipping (optional)



In [None]:
# TODO: Exercise 5.1
if ds is not None:
    model_cfg = PatchTSTConfig(
        num_classes=1,
        patch_len=25,
        stride=25,
        d_model=128,
        n_heads=8,
        n_layers=2,
        d_ff=256,
        dropout=0.1,
        channel_independent=True,
        pooling="mean",
        fuse="mean",
    )
    model = PatchTSTClassifier(model_cfg).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
    crit = nn.BCEWithLogitsLoss()

    xb, yb = next(iter(train_loader))
    xb = xb.to(device)
    yb = yb.to(device)

    model.train()
    opt.zero_grad()

    logits = model(xb).squeeze(-1)
    loss = crit(logits, yb)
    loss.backward()

    # print gradient norm
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            total_norm += p.grad.data.norm(2).item() ** 2
    total_norm = total_norm ** 0.5

    opt.step()

    print("logits mean/std:", logits.mean().item(), logits.std().item())
    print("loss:", loss.item())
    print("grad L2 norm:", total_norm)


## 6) Confusion matrix and metrics

Given logits and labels:
- probability = sigmoid(logit)
- prediction = prob >= threshold

### Exercise 6.1
Compute TP, FP, TN, FN for a batch at threshold 0.5 and compute:
- precision, recall, F1, specificity



In [None]:
# TODO: Exercise 6.1
if ds is not None:
    model.eval()
    with torch.no_grad():
        logits = model(xb).squeeze(-1)

    tp, fp, tn, fn = confusion_matrix_binary_from_logits(logits.cpu(), yb.cpu(), thr=0.5)
    precision, recall, f1, specificity = metrics_from_cm(tp, fp, tn, fn)

    print(f"TP={tp} FP={fp} TN={tn} FN={fn}")
    print(f"precision={precision:.3f} recall={recall:.3f} f1={f1:.3f} specificity={specificity:.3f}")


## 7) Threshold sweep (balanced accuracy)

Balanced accuracy = (TPR + TNR)/2 = (recall + specificity)/2

### Exercise 7.1
On the validation set:
1. collect all logits and labels
2. sweep thresholds
3. plot balanced accuracy vs threshold
4. report the best threshold

This reproduces the logic in your script in an interactive way.



In [None]:
# TODO: Exercise 7.1
if ds is not None:
    # collect logits/labels on val set
    val_loader = torch.utils.data.DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

    model.eval()
    all_logits = []
    all_y = []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            logits = model(xb).squeeze(-1).cpu()
            all_logits.append(logits)
            all_y.append(yb.cpu())

    logits = torch.cat(all_logits)
    y = torch.cat(all_y)

    probs = torch.sigmoid(logits)
    thresholds = torch.linspace(0.05, 0.95, steps=19)

    bal_accs = []
    for thr in thresholds:
        pred = (probs >= float(thr)).to(torch.int64)
        y_i = y.to(torch.int64)
        tp = int(((pred==1)&(y_i==1)).sum())
        fp = int(((pred==1)&(y_i==0)).sum())
        tn = int(((pred==0)&(y_i==0)).sum())
        fn = int(((pred==0)&(y_i==1)).sum())
        precision, recall, f1, specificity = metrics_from_cm(tp, fp, tn, fn)
        bal_acc = 0.5*(recall+specificity)
        bal_accs.append(bal_acc)

    # plot
    plt.figure()
    plt.plot([float(t) for t in thresholds], bal_accs, marker="o")
    plt.xlabel("threshold")
    plt.ylabel("balanced accuracy")
    plt.title("Balanced accuracy vs threshold")
    plt.grid(True)
    plt.show()

    best_i = int(np.argmax(bal_accs))
    print("best thr:", float(thresholds[best_i]), "best bal acc:", float(bal_accs[best_i]))


## 8) End-to-end training (short run)

### Exercise 8.1
Train for a few epochs and observe:
- train loss decreases
- val balanced accuracy improves
- best threshold stabilizes as the model gets better

Try:
- different `patch_len` (10, 25, 50)
- overlapping patches (`stride < patch_len`)
- different `d_model` (64 vs 128)



In [None]:
# TODO: Exercise 8.1 (short run)
if ds is not None:
    # fresh model for a short run
    model_cfg = PatchTSTConfig(
        num_classes=1,
        patch_len=25,
        stride=25,
        d_model=128,
        n_heads=8,
        n_layers=4,
        d_ff=256,
        dropout=0.1,
        channel_independent=True,
        pooling="mean",
        fuse="mean",
    )
    model = PatchTSTClassifier(model_cfg).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
    crit = nn.BCEWithLogitsLoss()

    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = torch.utils.data.DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

    def eval_bal_acc(model):
        model.eval()
        all_logits, all_y = [], []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                logits = model(xb).squeeze(-1).cpu()
                all_logits.append(logits)
                all_y.append(yb.cpu())
        logits = torch.cat(all_logits)
        y = torch.cat(all_y)
        probs = torch.sigmoid(logits)

        thresholds = torch.linspace(0.05, 0.95, steps=19)
        best_thr, best_bal = 0.5, -1.0
        for thr in thresholds:
            thr_f = float(thr)
            pred = (probs >= thr_f).to(torch.int64)
            y_i = y.to(torch.int64)
            tp = int(((pred==1)&(y_i==1)).sum())
            fp = int(((pred==1)&(y_i==0)).sum())
            tn = int(((pred==0)&(y_i==0)).sum())
            fn = int(((pred==0)&(y_i==1)).sum())
            precision, recall, f1, specificity = metrics_from_cm(tp, fp, tn, fn)
            bal = 0.5*(recall+specificity)
            if bal > best_bal:
                best_bal, best_thr = bal, thr_f
        return best_bal, best_thr

    for epoch in range(1, 6):
        model.train()
        tr_loss = 0.0
        n = 0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            opt.zero_grad()
            logits = model(xb).squeeze(-1)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            tr_loss += loss.item() * xb.size(0)
            n += xb.size(0)
        tr_loss /= n

        bal, thr = eval_bal_acc(model)
        print(f"epoch {epoch:02d} | train loss {tr_loss:.3f} | val bal acc {bal:.3f} | best thr {thr:.2f}")


## What you should be able to explain after this notebook

- Why `[B, L, C]` becomes `[B, N, P, C]` (Patchify)
- Why we train on logits using BCEWithLogitsLoss
- What a confusion matrix is
- Why we hold out entire sessions (date folders)
- Why threshold affects precision/recall/specificity
- Why balanced accuracy is good when both classes matter

