In [2]:
!pip install torch

Collecting torch
  Downloading torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl.metadata (28 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)


In [1]:
#!/usr/bin/env python3
# ============================================================
# Many-to-Many LSTM for Wake Detection (windowed, vote-merge)
# - Windows: L=150, stride=5
# - Per-timestep labels & loss
# - Reconstruct full-length probabilities by averaging votes
# - Event-level IoU evaluation (TP/FP/FN/Precision/Recall/F1)
# - Plots: red = ground-truth wake, green = predicted wake
# ============================================================

import os, glob, random, math, json
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from math import sqrt

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import precision_recall_fscore_support

In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

[INFO] Using device: cuda


In [9]:
BASE_DATA_DIR = "processed_ts"   # expects train/valid/test/*.csv
WINDOW_SIZE   = 300
STRIDE        = 5
TAIL_MIN_KEEP = WINDOW_SIZE // 2   # if tail >= 75, pad to full; else drop
USE_TAIL      = True               # enable tail handling

LR            = 1e-3
BATCH_SIZE    = 256
EPOCHS        = 30
POS_WEIGHT    = 1.5                # increase to bias recall
WEIGHT_DECAY  = 1e-4
GRAD_CLIP     = 1.0
EARLY_STOP    = 8                  # patience (epochs)
THR_MASK      = 0.5                # per-sample probability threshold for mask
IOU_EVENT_THR = 0.5                # IoU threshold for event matching
MERGE_GAP_S   = 6.0                # seconds: merge predicted intervals if gap <= this
MIN_DUR_S     = 0.5                # seconds: drop predicted intervals shorter than this

In [10]:
# -------------------------
# 2) Data loading & helpers
# -------------------------
def scan_splits(root):
    splits = {}
    for split in ("train","valid","test"):
        d = os.path.join(root, split)
        files = glob.glob(os.path.join(d, "*.csv")) if os.path.isdir(d) else []
        files.sort()
        splits[split] = files
    return splits

def load_df(fp):
    df = pd.read_csv(fp)
    df = df.sort_values("t_s").reset_index(drop=True)
    return df

def build_windows_many2many(df, window_size, stride, scaler=None, fit_scaler=False,
                            use_tail=True, tail_min_keep=0):
    """
    Returns:
      X: (N,1,L) float32 scaled z_m
      Y: (N,L)   float32 in {0,1}
      M: (N,L)   float32 mask (1 = valid, 0 = padded)
      starts: list of start indices in original series
    """
    z = df["z_m"].to_numpy(dtype=np.float32).reshape(-1,1)
    y = df["wake_label"].to_numpy(dtype=np.int64)
    N = len(z)

    # scale
    if fit_scaler:
        scaler = StandardScaler()
        z_scaled = scaler.fit_transform(z)
    else:
        z_scaled = scaler.transform(z)

    Xs, Ys, Ms, starts = [], [], [], []
    i = 0
    while i + window_size <= N:
        Xs.append(z_scaled[i:i+window_size])          # (L,1)
        Ys.append(y[i:i+window_size])                 # (L,)
        Ms.append(np.ones(window_size, dtype=np.float32))
        starts.append(i)
        i += stride

    if use_tail and i < N:
        tail = N - i
        if tail >= tail_min_keep:
            x_tail = z_scaled[i:N]
            y_tail = y[i:N]
            pad_needed = window_size - tail
            pad_x = np.repeat(x_tail[-1:], pad_needed, axis=0)
            pad_y = np.zeros(pad_needed, dtype=np.int64)
            mask = np.concatenate([np.ones(tail, dtype=np.float32),
                                   np.zeros(pad_needed, dtype=np.float32)])
            Xs.append(np.vstack([x_tail, pad_x]))
            Ys.append(np.concatenate([y_tail, pad_y]))
            Ms.append(mask)
            starts.append(i)
        # else: drop tail

    if not Xs:
        return None, None, None, [], scaler

    X = np.stack(Xs, axis=0).transpose(0,2,1).astype(np.float32)  # (N,1,L)
    Y = np.stack(Ys, axis=0).astype(np.float32)                    # (N,L)
    M = np.stack(Ms, axis=0).astype(np.float32)                    # (N,L)
    return X, Y, M, starts, scaler

def streaming_mean_std_over_train(train_files):
    """One-pass mean/std of z_m over TRAIN files (Welford)."""
    n_total = 0
    mean = 0.0
    M2 = 0.0
    for fp in train_files:
        df = load_df(fp)
        z = df["z_m"].to_numpy(dtype=np.float64)
        for x in z:
            n_total += 1
            delta = x - mean
            mean += delta / n_total
            M2 += delta * (x - mean)
    if n_total < 2:
        return float(mean), 1.0
    var = M2 / (n_total - 1)
    std = sqrt(max(var, 1e-12))
    return float(mean), float(std)

def build_window_index(files, window_size, stride, use_tail=True, tail_min_keep=0):
    """
    Returns list of (file_path, start, end, valid_len).
    If a tail window is included, valid_len < window_size.
    """
    index = []
    for fp in files:
        df = load_df(fp)
        N = len(df)
        i = 0
        while i + window_size <= N:
            index.append((fp, i, i + window_size, window_size))
            i += stride
        if use_tail and i < N:
            tail = N - i
            if tail >= tail_min_keep:
                index.append((fp, i, N, tail))
    return index

class LazyWindowDataset(torch.utils.data.Dataset):
    """
    Loads one window on demand; scales with global (train) mean/std.
    Returns: x (1,L), y (L), m (L), fp, s, e
    """
    def __init__(self, index, mean, std, window_size):
        self.index = index
        self.mean = float(mean)
        self.std = float(std if std > 0 else 1.0)
        self.window_size = int(window_size)

    def __len__(self): return len(self.index)

    def __getitem__(self, i):
        fp, s, e, valid_len = self.index[i]
        df = load_df(fp)
        z = df["z_m"].to_numpy(dtype=np.float32)
        y = df["wake_label"].to_numpy(dtype=np.float32)
        seg_z = z[s:e]
        seg_y = y[s:e]
        L = len(seg_z)
        if L < self.window_size:
            pad = self.window_size - L
            seg_z = np.concatenate([seg_z, np.repeat(seg_z[-1], pad).astype(np.float32)])
            seg_y = np.concatenate([seg_y, np.zeros(pad, dtype=np.float32)])
        m = np.zeros(self.window_size, dtype=np.float32); m[:valid_len] = 1.0
        seg_z = (seg_z - self.mean) / self.std
        x = torch.from_numpy(seg_z[None, :]).float()
        y = torch.from_numpy(seg_y).float()
        m = torch.from_numpy(m).float()
        return x, y, m, fp, s, e


def build_dataset(split_files, window_size, stride, scaler=None, fit_scaler=False,
                  use_tail=True, tail_min_keep=0):
    X_list, Y_list, M_list, idx_triplets = [], [], [], []
    for fp in split_files:
        df = load_df(fp)
        res = build_windows_many2many(df, window_size, stride, scaler, fit_scaler,
                                      use_tail, tail_min_keep)
        X, Y, M, starts, scaler = res
        if X is None: 
            continue
        X_list.append(X)
        Y_list.append(Y)
        M_list.append(M)
        for s in starts:
            e = min(s + window_size, len(df))
            idx_triplets.append((fp, int(s), int(e)))  # map window back to file
    if not X_list:
        return None, None, None, [], scaler
    X = np.concatenate(X_list, axis=0)
    Y = np.concatenate(Y_list, axis=0)
    M = np.concatenate(M_list, axis=0)
    return X, Y, M, idx_triplets, scaler

In [11]:
# -------------------------
# 3) Model (per-timestep output)
# -------------------------
class LSTMWakeSeq(nn.Module):
    def __init__(self, input_size=1, hidden=96, num_layers=2, bidir=True, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size, hidden_size=hidden, num_layers=num_layers,
            batch_first=True, bidirectional=bidir, dropout=dropout if num_layers>1 else 0.0
        )
        out_dim = hidden * (2 if bidir else 1)
        self.head = nn.Linear(out_dim, 1)

    def forward(self, x):        # x: (B,1,L)
        x = x.transpose(1,2)     # (B,L,1)
        h, _ = self.lstm(x)      # (B,L,H*)
        logits = self.head(h).squeeze(-1)   # (B,L)
        return logits


In [12]:
# -------------------------
# 4) Training/Eval utilities
# -------------------------
@torch.no_grad()
def compute_pos_weight_from_loader(loader):
    pos = 0.0; neg = 0.0
    for xb, yb, mb, *_ in loader:
        y = yb.numpy(); m = mb.numpy()
        y = y[m > 0.5]
        pos += (y > 0.5).sum()
        neg += (y <= 0.5).sum()
    if pos == 0: return 1.0
    return max(1.0, float(neg / pos))

def train_one_epoch(model, loader, criterion, optimizer, grad_clip=1.0):
    model.train()
    losses = []
    for xb, yb, mb, *_ in loader:
        xb = xb.to(device)              # (B,1,L)
        yb = yb.to(device)              # (B,L)
        mb = mb.to(device)              # (B,L)
        logits = model(xb)              # (B,L)
        loss_mat = criterion(logits, yb)    # (B,L)
        loss = (loss_mat * mb).sum() / (mb.sum() + 1e-8)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        losses.append(loss.item())
    return float(np.mean(losses))

@torch.no_grad()
def infer_probs(model, loader):
    """Forward pass to collect per-timestep probabilities and (fp, s, e) mapping."""
    model.eval()
    probs_all = []
    idx_triplets = []  # (fp, s, e) per window, aligned with probs_all rows
    for xb, yb, mb, fp, s, e in loader:
        xb = xb.to(device)
        p = torch.sigmoid(model(xb)).cpu().numpy()   # (B,L)
        probs_all.append(p)
        for i in range(len(fp)):
            idx_triplets.append((fp[i], int(s[i]), int(e[i])))
    return np.concatenate(probs_all, axis=0), idx_triplets


def mask_to_intervals(mask, t_s):
    """Binary per-sample mask -> list of (start_time, end_time) seconds (closed-open)."""
    z = mask.astype(np.int8)
    dz = np.diff(np.pad(z, (1,1)))
    starts = np.where(dz == 1)[0]
    ends   = np.where(dz == -1)[0]
    return [(float(t_s[s]), float(t_s[e-1])) for s, e in zip(starts, ends)]

def merge_and_filter(intervals, merge_gap_s=MERGE_GAP_S, min_dur_s=MIN_DUR_S):
    if not intervals: 
        return []
    intervals = sorted(intervals)
    out = []
    cs, ce = intervals[0]
    for s2, e2 in intervals[1:]:
        if (s2 - ce) <= merge_gap_s:
            ce = max(ce, e2)
        else:
            if (ce - cs) >= min_dur_s:
                out.append((cs, ce))
            cs, ce = s2, e2
    if (ce - cs) >= min_dur_s:
        out.append((cs, ce))
    return out

def get_ground_truth_wakes(df):
    w = df["wake_label"].to_numpy().astype(int)
    t = df["t_s"].to_numpy()
    return merge_and_filter(mask_to_intervals(w, t), 0.0, 0.0)  # no merge/filter on GT

def iou_interval(a, b):
    s1,e1 = a; s2,e2 = b
    inter = max(0.0, min(e1, e2) - max(s1, s2))
    union = max(e1,e2) - min(s1,s2)
    return 0.0 if union <= 0 else inter/union

def evaluate_events_iou(gt_intervals, pred_intervals, iou_thr=IOU_EVENT_THR):
    tp, fp, fn = 0, 0, 0
    matched = set()
    for p in pred_intervals:
        best_iou, best_idx = 0.0, -1
        for i, g in enumerate(gt_intervals):
            if i in matched: 
                continue
            iou = iou_interval(g, p)
            if iou > best_iou:
                best_iou, best_idx = iou, i
        if best_iou >= iou_thr and best_idx != -1:
            tp += 1
            matched.add(best_idx)
        else:
            fp += 1
    fn = len(gt_intervals) - len(matched)
    prec = tp/(tp+fp) if (tp+fp)>0 else 0.0
    rec  = tp/(tp+fn) if (tp+fn)>0 else 0.0
    f1   = 2*prec*rec/(prec+rec) if (prec+rec)>0 else 0.0
    return tp, fp, fn, prec, rec, f1

def reconstruct_probs_for_files(dataset_splits, idx_triplets, probs_windows):
    """
    For each file: average-vote the per-timestep probs from overlapping windows.
    Returns dict: fp -> prob_full (length = len(df))
    """
    # Group windows by file
    by_file = defaultdict(list)
    for (fp, s, e), p in zip(idx_triplets, probs_windows):
        by_file[fp].append((s, e, p))

    out = {}
    for fp, items in by_file.items():
        df = load_df(fp)
        N = len(df)
        sum_probs = np.zeros(N, dtype=np.float32)
        counts    = np.zeros(N, dtype=np.int32)
        for s, e, pw in items:
            L = e - s
            sum_probs[s:e] += pw[:L]
            counts[s:e]    += 1
        counts[counts == 0] = 1
        out[fp] = sum_probs / counts
    return out

In [13]:
# -------------------------
# 5) Build datasets (streaming)
# -------------------------
dataset_splits = scan_splits(BASE_DATA_DIR)
train_files = dataset_splits["train"]
valid_files = dataset_splits["valid"]
test_files  = dataset_splits["test"]

print("[INFO] computing streaming mean/std over TRAIN files...")
z_mean, z_std = streaming_mean_std_over_train(train_files)
print(f"[INFO] scaler: mean={z_mean:.6f} std={z_std:.6f}")

train_index = build_window_index(train_files, WINDOW_SIZE, STRIDE, USE_TAIL, TAIL_MIN_KEEP)
valid_index = build_window_index(valid_files, WINDOW_SIZE, STRIDE, USE_TAIL, TAIL_MIN_KEEP)
test_index  = build_window_index(test_files,  WINDOW_SIZE, STRIDE, USE_TAIL, TAIL_MIN_KEEP)
print(f"[IDX] train={len(train_index)} | valid={len(valid_index)} | test={len(test_index)}")

# Dataloaders
train_loader = DataLoader(
    LazyWindowDataset(train_index, z_mean, z_std, WINDOW_SIZE),
    batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    LazyWindowDataset(valid_index, z_mean, z_std, WINDOW_SIZE),
    batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True
) if len(valid_index) else None
test_loader = DataLoader(
    LazyWindowDataset(test_index, z_mean, z_std, WINDOW_SIZE),
    batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True
) if len(test_index) else None


[INFO] computing streaming mean/std over TRAIN files...
[INFO] scaler: mean=0.000002 std=0.032041
[IDX] train=5302646 | valid=1835017 | test=3302900


In [None]:
# -------------------------
# 6) Train
# -------------------------
model = LSTMWakeSeq().to(device)
pw_val = max(POS_WEIGHT, compute_pos_weight_from_loader(train_loader))
pos_weight = torch.tensor(pw_val, dtype=torch.float32, device=device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
print(f"[INFO] pos_weight used = {pos_weight.item():.3f}")

# --- train loop with early stopping on val loss ---
best_val = float("inf")
best_state = None
no_imp = 0

for epoch in range(1, EPOCHS + 1):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, grad_clip=GRAD_CLIP)
    
    if val_loader is not None:
        model.eval()
        val_losses = []
        with torch.no_grad():
            for xb, yb, mb, *_ in val_loader:
                xb = xb.to(device); yb = yb.to(device); mb = mb.to(device)
                logits = model(xb)                     # (B,L)
                loss_mat = criterion(logits, yb)       # (B,L)
                vloss = (loss_mat * mb).sum() / (mb.sum() + 1e-8)
                val_losses.append(vloss.item())
        val_loss = float(np.mean(val_losses))
        scheduler.step(val_loss)

        print(f"Epoch {epoch:03d} | TrainLoss {train_loss:.4f} | ValLoss {val_loss:.4f}")

        if val_loss < best_val - 1e-6:
            best_val = val_loss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            no_imp = 0
        else:
            no_imp += 1
            if no_imp >= EARLY_STOP:
                print(f"[INFO] Early stop at epoch {epoch} (best ValLoss={best_val:.4f})")
                break
    else:
        print(f"Epoch {epoch:03d} | TrainLoss {train_loss:.4f}")

# restore best weights
if best_state is not None:
    model.load_state_dict({k: v.to(device) for k, v in best_state.items()})

[INFO] pos_weight used = 2.137


In [None]:
# -------------------------
# 7) Inference: per-window probs, reconstruction to per-file probs
# -------------------------
p_tr, _, _ = infer_probs(model, train_loader)
p_va, _, _ = (infer_probs(model, val_loader) if val_loader is not None else (None,None,None))
p_te, _, _ = (infer_probs(model, test_loader) if test_loader is not None else (None,None,None))

# Reconstruct per-file probabilities (average vote)
recon_tr = reconstruct_probs_for_files(dataset_splits, idx_tr, p_tr)
recon_va = reconstruct_probs_for_files(dataset_splits, idx_va, p_va) if p_va is not None else {}
recon_te = reconstruct_probs_for_files(dataset_splits, idx_te, p_te) if p_te is not None else {}


In [None]:
# -------------------------
# 8) Event-level IoU evaluation per split
# -------------------------
def eval_split(recon_dict, file_list, thr=THR_MASK):
    total = dict(tp=0, fp=0, fn=0)
    for fp in file_list:
        df = load_df(fp)
        prob = recon_dict.get(fp, None)
        if prob is None:
            continue
        mask = (prob >= thr).astype(np.int8)

        # post-process predicted
        t_s = df["t_s"].to_numpy()
        pred_intervals = merge_and_filter(mask_to_intervals(mask, t_s),
                                          merge_gap_s=MERGE_GAP_S, min_dur_s=MIN_DUR_S)
        gt_intervals = get_ground_truth_wakes(df)

        tp, fp_, fn, prec, rec, f1 = evaluate_events_iou(gt_intervals, pred_intervals, IOU_EVENT_THR)
        total["tp"] += tp
        total["fp"] += fp_
        total["fn"] += fn

    tp, fp_, fn = total["tp"], total["fp"], total["fn"]
    prec = tp/(tp+fp_) if (tp+fp_)>0 else 0.0
    rec  = tp/(tp+fn ) if (tp+fn )>0 else 0.0
    f1   = 2*prec*rec/(prec+rec) if (prec+rec)>0 else 0.0
    return dict(TP=tp, FP=fp_, FN=fn, Precision=prec, Recall=rec, F1=f1)

print("\n===== Event-level IoU metrics (IoU>=%.2f, thr=%.2f) =====" % (IOU_EVENT_THR, THR_MASK))
train_evt = eval_split(recon_tr, dataset_splits["train"])
print("[Train]", train_evt)
if val_loader is not None:
    valid_evt = eval_split(recon_va, dataset_splits["valid"])
    print("[Valid]", valid_evt)
if test_loader is not None:
    test_evt  = eval_split(recon_te, dataset_splits["test"])
    print("[Test ]", test_evt)

In [None]:
# -------------------------
# 9) Plots (test set): red = GT wakes, green = Pred wakes
# -------------------------
def plot_file(fp, prob_full, thr=THR_MASK, nmax=3):
    df = load_df(fp)
    t = df["t_s"].to_numpy()
    z = df["z_m"].to_numpy()
    gt = get_ground_truth_wakes(df)

    pred_mask = (prob_full >= thr).astype(np.int8)
    pred_intv = merge_and_filter(mask_to_intervals(pred_mask, t), MERGE_GAP_S, MIN_DUR_S)

    fig, ax = plt.subplots(1,1, figsize=(12,4))
    ax.plot(t, z, lw=1.0)
    # plot GT (red)
    for s,e in gt:
        ax.axvspan(s, e, color="red", alpha=0.25)
    # plot Pred (green)
    for s,e in pred_intv:
        ax.axvspan(s, e, color="green", alpha=0.25)

    ax.set_title(os.path.basename(fp))
    ax.set_xlabel("time (s)")
    ax.set_ylabel("z_m")
    plt.tight_layout()
    plt.show()

# show a few random test files
to_plot = random.sample(dataset_splits["test"], k=min(3, len(dataset_splits["test"])))
for fp in to_plot:
    prob = recon_te.get(fp, None)
    if prob is not None:
        plot_file(fp, prob, THR_MASK)