In [None]:
import time, numpy as np, pandas as pd, matplotlib.pyplot as plt
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from pathlib import Path


CKPT_DIR = Path("./checkpoints"); CKPT_DIR.mkdir(parents=True, exist_ok=True)

CSV_PATH      = "Final_Proceesed_Dataset.csv"
TRIAL_COL     = "TrialID"
PARTIC_COL    = "ParticipantID"
TARGET_X      = "newCopX"
TARGET_Y      = "newCopY"
TIME_COLS     = ["aligned_time", "Timestamp"]

DES_LONG_SEC  = 8.0
DES_SHORT_SEC = 2.0
DES_HOP_SEC   = 0.20
DES_HOR_SEC   = 0.06


CLEAN_LEVEL   = 1
MAD_K         = 3.5
HAMPEL_K      = 5
HAMPEL_T0     = 3.5

INCLUDE_PAST_COP = False
USE_DERIVED_FEATS = False

USE_AUG          = True
AUG_NOISE_STD    = 0.01
AUG_TIMEMASK_P   = 0.25
AUG_TIMEMASK_LEN = 6

RANDOM_SEED   = 42


D_MODEL       = 160
N_HEADS       = 8
N_LAYERS      = 3
DROPOUT       = 0.20
PATCH_OVERLAP = 0.75
SCALES        = (1,2,4)

# Train
BATCH         = 64
EPOCHS        = 100
LR            = 3e-4
WEIGHT_DECAY  = 5e-4
PATIENCE      = 12
CLIP_NORM     = 1.0

# SWA
USE_SWA        = True
SWA_START_FRAC = 0.7
SWA_LR_FACTOR  = 0.5

# Weighted loss
WEIGHT_Q          = 0.70
SPIKE_WEIGHT      = 2.0
AXIS_WEIGHT_FROM_TRAIN_IQR = True

#TTA
USE_TTA      = False
TTA_RUNS     = 5
TTA_NOISE    = 0.01

# Latency
PRINT_LATENCY          = True
LAT_WARMUP_STEPS       = 10
LAT_MEASURE_STEPS      = 50
LAT_WARMUP_SAMPLES_PS  = 20
LAT_MEASURE_SAMPLES_PS = 50

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_seed(s=RANDOM_SEED):
    import random
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.benchmark = True

def find_time_col(df):
    for c in TIME_COLS:
        if c in df.columns:
            try: df[c] = pd.to_datetime(df[c], errors='coerce')
            except: pass
            return c
    return None

def seconds_to_samples(sec, fs): return max(1, int(round(sec*fs)))

def estimate_fs(df, tcol, default=50.0):
    if tcol is None or not np.issubdtype(df[tcol].dtype, np.datetime64): return default
    dts=[]
    for _,g in df.groupby(TRIAL_COL, sort=False):
        tt=g[tcol].astype('int64').to_numpy()/1e9; dt=np.diff(tt)
        if len(dt): dts.append(np.median(dt))
    if len(dts) and np.median(dts)>0: return 1.0/np.median(dts)
    return default



def find_inputs(df):
    prefixes=('hmd_','controller_lefthand_','controller_righthand_','lefthand_','righthand_','hands_')
    cols=[c for c in df.columns if any(c.startswith(p) for p in prefixes)]
    for c in ['hmd_ang_velocity_Y','controller_lefthand_ang_velocity_Y','controller_righthand_ang_velocity_Y']:
        if c in df.columns and c not in cols: cols.append(c)
    return sorted(list(dict.fromkeys(cols)))

def make_diffs_per_trial_df(df, cols):

    grp = df.groupby(TRIAL_COL, sort=False)
    out = {}
    for c in cols:
        out[c+"_d1"] = grp[c].diff().fillna(0.0).astype(np.float32)
        out[c+"_d2"] = grp[c].diff().diff().fillna(0.0).astype(np.float32)
    return pd.DataFrame(out, index=df.index)

def robust_clip_inplace(df, cols, by=None, k=3.5):
    if CLEAN_LEVEL==0 or not cols: return df
    if by and by in df.columns:
        for _, gidx in df.groupby(by, sort=False).groups.items():
            sub = df.loc[gidx, cols]
            med = sub.median()
            mad = (sub - med).abs().median() * 1.4826 + 1e-9
            df.loc[gidx, cols] = sub.clip(med - k*mad, med + k*mad, axis=1)
    else:
        sub = df.loc[:, cols]
        med = sub.median()
        mad = (sub - med).abs().median() * 1.4826 + 1e-9
        df.loc[:, cols] = sub.clip(med - k*mad, med + k*mad, axis=1)
    return df

def _hampel_inplace_np(y: np.ndarray, k=5, t0=3.5):
    n = y.shape[0]
    if k <= 0 or n == 0: return
    kk = max(3, int(k))
    for i in range(kk, n-kk):
        w = y[i-kk:i+kk+1]
        med = np.median(w)
        mad = 1.4826*np.median(np.abs(w - med)) + 1e-9
        if abs(y[i] - med) > t0 * mad:
            y[i] = med

def hampel_targets_per_trial_inplace(df, tgt_cols, k=5, t0=3.5):
    if CLEAN_LEVEL==0 or k<=0: return df
    groups = df.groupby(TRIAL_COL, sort=False).groups
    for _, idx in groups.items():
        for c in tgt_cols:
            y = df.loc[idx, c].to_numpy(copy=True, dtype=np.float32)
            _hampel_inplace_np(y, k=k, t0=t0)
            df.loc[idx, c] = y
    return df


def positional_encoding(n, d, device):
    pe=torch.zeros(n,d,device=device)
    pos=torch.arange(0,n,dtype=torch.float,device=device).unsqueeze(1)
    div=torch.exp(torch.arange(0,d,2,device=device).float()*(-np.log(10000.0)/d))
    pe[:,0::2]=torch.sin(pos*div); pe[:,1::2]=torch.cos(pos*div)
    return pe.unsqueeze(0)

class SqueezeExcite1D(nn.Module):
    def __init__(self, channels: int, r: int = 8):
        super().__init__()
        m = max(1, channels // r)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc   = nn.Sequential(
            nn.Linear(channels, m, bias=True),
            nn.GELU(),
            nn.Linear(m, channels, bias=True),
            nn.Sigmoid()
        )
    def forward(self, x):
        z = x.transpose(1, 2)
        w = self.pool(z).squeeze(-1)
        w = self.fc(w).unsqueeze(-1)
        z = z * w
        return z.transpose(1, 2)

class ConvStem(nn.Module):
    def __init__(self, in_ch, k=7, dropout=0.1):
        super().__init__()
        self.dw=nn.Conv1d(in_ch,in_ch,kernel_size=k,padding=k//2,groups=in_ch)
        self.pw=nn.Conv1d(in_ch,in_ch,kernel_size=1)
        self.act=nn.GELU(); self.bn=nn.BatchNorm1d(in_ch); self.do=nn.Dropout(dropout)
        self.se=SqueezeExcite1D(in_ch, r=8)
    def forward(self, x):
        z=x.transpose(1,2); z=self.dw(z); z=self.pw(z)
        z=self.act(z); z=self.bn(z); z=self.do(z)
        z=z.transpose(1,2); z=self.se(z)
        return z

class TimePooledPatch(nn.Module):
    def __init__(self, seq_len, d_model, scales=(1,2,4), overlap=0.75):
        super().__init__()
        self.seq_len=int(seq_len); self.scales=tuple(scales); self.overlap=float(overlap); self.d_model=int(d_model)
        self.proj=nn.ModuleList([nn.Identity() for _ in self.scales]); self._inited=[False]*len(self.scales)
    def _ensure(self, i, in_feats, dev, dtype):
        if (not self._inited[i]) or (not isinstance(self.proj[i], nn.Linear)) or (self.proj[i].in_features!=in_feats):
            self.proj[i]=nn.Linear(in_feats,self.d_model).to(device=dev,dtype=dtype); self._inited[i]=True
    def forward(self, x):
        B,T,F=x.shape; outs=[]
        for i,s in enumerate(self.scales):
            size=max(4, self.seq_len//s)
            step=max(1, int(round(size*(1.0-self.overlap))))
            if T<size:
                outs.append(torch.zeros(B,0,self.d_model,device=x.device,dtype=x.dtype)); continue
            patches=x.unfold(dimension=1,size=size,step=step)
            pooled=patches.mean(dim=2)
            self._ensure(i, pooled.size(-1), pooled.device, pooled.dtype)
            outs.append(self.proj[i](pooled))
        return torch.cat(outs,dim=1)

class CrossAttention(nn.Module):
    def __init__(self, d_model, heads=8):
        super().__init__(); self.mha=nn.MultiheadAttention(d_model, heads, batch_first=True)
    def forward(self, q, kv): out,_=self.mha(q,kv,kv); return out

class MultiPatchFormerX_Reg(nn.Module):
    def __init__(self, long_win, short_win, in_feats, d_model=160, nheads=8, nlayers=3,
                 dropout=0.2, patch_overlap=0.75, scales=(1,2,4), out_dim=2):
        super().__init__()
        self.stem_long=ConvStem(in_feats,k=7,dropout=dropout*0.5)
        self.stem_short=ConvStem(in_feats,k=5,dropout=dropout*0.5)
        self.patch_long=TimePooledPatch(long_win,d_model,scales,patch_overlap)
        self.patch_short=TimePooledPatch(short_win,d_model,scales,patch_overlap)
        self.norm_long=nn.LayerNorm(d_model); self.norm_short=nn.LayerNorm(d_model)
        self.drop=nn.Dropout(dropout)
        self.cross_l2s=CrossAttention(d_model,nheads)
        self.cross_s2l=CrossAttention(d_model,nheads)
        enc_layer=nn.TransformerEncoderLayer(d_model,nheads,d_model*4,batch_first=True,dropout=dropout,activation='gelu')
        self.encoder=nn.TransformerEncoder(enc_layer,num_layers=nlayers)
        self.head=nn.Sequential(nn.Linear(d_model,128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128,out_dim))
    def forward(self, xs, xl):
        xs=self.stem_short(xs); xl=self.stem_long(xl)
        ps=self.patch_short(xs); pl=self.patch_long(xl)
        ps=self.norm_short(ps+positional_encoding(ps.size(1),ps.size(2),ps.device)); ps=self.drop(ps)
        pl=self.norm_long(pl+positional_encoding(pl.size(1),pl.size(2),pl.device));  pl=self.drop(pl)
        ps2=self.cross_l2s(ps,pl)+ps; pl2=self.cross_s2l(pl,ps)+pl
        z=torch.cat([ps2,pl2],dim=1); z=self.encoder(z)
        rep=z.mean(dim=1)
        return self.head(rep)

def weighted_huber_2axis(pred, y, thresh_vec, wx=1.0, wy=1.0):
    base = nn.functional.smooth_l1_loss(pred, y, reduction='none')
    spike = torch.where(y.abs() >= thresh_vec, SPIKE_WEIGHT, 1.0)
    w_axis = torch.tensor([wx, wy], device=pred.device, dtype=pred.dtype)
    w = spike * w_axis
    w = w / w.mean()
    return (base * w).mean()

def time_mask(x, L=6):
    B,T,F = x.shape
    if T <= L: return x
    s = torch.randint(0, T-L+1, (B,), device=x.device)
    for i in range(B):
        x[i, s[i]:s[i]+L, :] = 0.0
    return x

class DualWinStreamDS(Dataset):

    def __init__(self, df, input_cols, y_cols_norm, trial_col, mask,
                 long_win, short_win, stride, horizon):
        self.df = df; self.input_cols = input_cols; self.y_cols_norm = y_cols_norm
        self.trial_col = trial_col
        self.long_win = int(long_win); self.short_win = int(short_win)
        self.stride = int(stride); self.horizon = int(horizon)
        self.index = []; self.trial_groups = {}
        for t, g in df.groupby(trial_col, sort=False):
            g2 = g[mask.loc[g.index]]
            if len(g2) < self.long_win + self.horizon: continue
            self.trial_groups[str(t)] = g2
            T = len(g2)
            n = 1 + (T - self.long_win) // max(1, self.stride)
            for i in range(n):
                s = i * self.stride
                tgt = s + self.long_win - 1 + self.horizon
                if tgt < T:
                    self.index.append((str(t), s))
    def __len__(self): return len(self.index)
    def __getitem__(self, i):
        tid, s = self.index[i]; g = self.trial_groups[tid]
        sub = g.iloc[s : s + self.long_win]
        Xl = sub[self.input_cols].to_numpy(dtype=np.float32)
        Xs = Xl[-self.short_win:, :]
        tgt_pos = s + self.long_win - 1 + self.horizon
        y_norm = g.iloc[tgt_pos][self.y_cols_norm].to_numpy(dtype=np.float32)
        last_true_norm = sub[self.y_cols_norm].iloc[-1].to_numpy(dtype=np.float32)
        return (torch.from_numpy(Xl), torch.from_numpy(Xs),
                torch.from_numpy(y_norm), torch.from_numpy(last_true_norm),
                tid, int(tgt_pos))

def predict_tta(model, xs, xl, n=5, noise=0.01):
    preds=[]
    with torch.no_grad():
        for _ in range(n):
            xs_n = xs + (torch.randn_like(xs)*noise if noise>0 else 0.0)
            xl_n = xl + (torch.randn_like(xl)*noise if noise>0 else 0.0)
            preds.append(model(xs_n, xl_n))
    return torch.stack(preds,0).mean(0)

def measure_latency(model, loader, device, mode="plain",
                    warmup_steps=10, measure_steps=50,
                    tta_runs=5, tta_noise=0.01):
    import time as _time
    model.eval()
    with torch.no_grad():
        w = 0
        for xl_b, xs_b, y_b, last_b, _t, _ti in loader:
            xl_b=xl_b.to(device); xs_b=xs_b.to(device)
            _ = model(xs_b, xl_b) if mode=="plain" else predict_tta(model, xs_b, xl_b, n=tta_runs, noise=tta_noise)
            w += 1
            if w >= warmup_steps: break
    total_time = 0.0; total_items = 0; steps = 0
    with torch.no_grad():
        for xl_b, xs_b, y_b, last_b, _t, _ti in loader:
            xl_b=xl_b.to(device); xs_b=xs_b.to(device)
            if device.type == 'cuda': torch.cuda.synchronize()
            t0 = _time.perf_counter()
            _ = model(xs_b, xl_b) if mode=="plain" else predict_tta(model, xs_b, xl_b, n=tta_runs, noise=tta_noise)
            if device.type == 'cuda': torch.cuda.synchronize()
            t1 = _time.perf_counter()
            total_time  += (t1 - t0); total_items += xs_b.size(0); steps += 1
            if steps >= measure_steps: break
    if total_items == 0 or total_time == 0: return None, None
    ms_per_item = (total_time / total_items) * 1000.0; items_per_s = total_items / total_time
    return ms_per_item, items_per_s

def measure_latency_per_sample(model, loader, device, mode="plain",
                               warmup_samples=20, measure_samples=50,
                               tta_runs=5, tta_noise=0.01):
    import time as _time
    model.eval(); done = 0
    with torch.no_grad():
        for xl_b, xs_b, y_b, last_b, _t, _ti in loader:
            for i in range(xl_b.size(0)):
                xs = xs_b[i:i+1].to(device); xl = xl_b[i:i+1].to(device)
                _ = model(xs, xl) if mode=="plain" else predict_tta(model, xs, xl, n=tta_runs, noise=tta_noise)
                done += 1
                if done >= warmup_samples: break
            if done >= warmup_samples: break
    total_time = 0.0; total_items = 0; done = 0
    with torch.no_grad():
        for xl_b, xs_b, y_b, last_b, _t, _ti in loader:
            for i in range(xl_b.size(0)):
                xs = xs_b[i:i+1].to(device); xl = xl_b[i:i+1].to(device)
                if device.type == 'cuda': torch.cuda.synchronize()
                t0 = _time.perf_counter()
                _ = model(xs, xl) if mode=="plain" else predict_tta(model, xs, xl, n=tta_runs, noise=tta_noise)
                if device.type == 'cuda': torch.cuda.synchronize()
                t1 = _time.perf_counter()
                total_time += (t1 - t0); total_items += 1; done += 1
                if done >= measure_samples: break
            if done >= measure_samples: break
    if total_items == 0 or total_time == 0.0: return None, None
    ms_per_sample = (total_time / total_items) * 1000.0; samples_per_s = total_items / total_time
    return ms_per_sample, samples_per_s


set_seed()
df = pd.read_csv(CSV_PATH)
num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
df[num_cols] = df[num_cols].astype(np.float32, copy=False)

if TARGET_X not in df.columns or TARGET_Y not in df.columns:
    raise ValueError("Missing newCopX / newCopY in CSV.")
tcol = find_time_col(df)
print(f"[LOAD] rows={len(df)} | trials={df[TRIAL_COL].nunique()}")

if USE_DERIVED_FEATS:
    df = add_derived_feats(df)

base_inputs = find_inputs(df)

numeric_cols=[c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
feat_cols=[c for c in base_inputs if c in numeric_cols]
tgt_cols=[c for c in [TARGET_X,TARGET_Y] if c in numeric_cols]


if CLEAN_LEVEL>0:
    robust_clip_inplace(df, feat_cols+tgt_cols,
                        by=(PARTIC_COL if (PARTIC_COL and PARTIC_COL in df.columns) else None),
                        k=MAD_K)
    if HAMPEL_K>0:
        try:
            hampel_targets_per_trial_inplace(df, tgt_cols, k=HAMPEL_K, t0=HAMPEL_T0)
        except MemoryError:
            print("[WARN] Hampel OOM -> skipping Hampel.")

arr = df[feat_cols+tgt_cols].to_numpy(copy=False)
bad = ~np.isfinite(arr)
if bad.any():
    arr = arr.copy(); arr[bad] = np.nan
    df[feat_cols+tgt_cols] = arr
df = df.dropna(subset=feat_cols+tgt_cols).reset_index(drop=True)
print(f"[CLEAN] rows={len(df)} | trials={df[TRIAL_COL].nunique()} after cleaning")
print(f"[INFO] Using {len(feat_cols)} input features.")

fs = estimate_fs(df, tcol, default=50.0)
long_len  = seconds_to_samples(DES_LONG_SEC,  fs)
short_len = seconds_to_samples(DES_SHORT_SEC, fs)
stride    = max(1, seconds_to_samples(DES_HOP_SEC, fs))
horizon   = seconds_to_samples(DES_HOR_SEC, fs)

unique_parts = sorted(df[PARTIC_COL].unique())
n_splits_cv = 10
print(f"[CV] Using {n_splits_cv}-fold CV, stratified by participant ({len(unique_parts)} participants).")


fold_metrics = []

for fold_idx in range(1, n_splits_cv + 1):
    np.random.seed(RANDOM_SEED + fold_idx)

    train_trials = []; val_trials = []; test_trials = []
    for p in unique_parts:
        their_trials = df[df[PARTIC_COL] == p][TRIAL_COL].unique()
        if len(their_trials) < 7:
            raise ValueError(f"Participant {p} has <7 trials; adjust splitting.")
        np.random.shuffle(their_trials)
        train_trials.extend(their_trials[:5])
        val_trials.append(their_trials[5])
        test_trials.append(their_trials[6])

    print(f"\n FOLD {fold_idx}/{n_splits_cv}")
    print(f"Trials: train={len(train_trials)}, val={len(val_trials)}, test={len(test_trials)}")

    train_mask = df[TRIAL_COL].isin(train_trials)
    val_mask   = df[TRIAL_COL].isin(val_trials)
    test_mask  = df[TRIAL_COL].isin(test_trials)

    df_fold = df.copy(deep=True)


    train_targets = df_fold.loc[train_mask, [TARGET_X, TARGET_Y]]
    gmx = float(np.median(train_targets[TARGET_X])); gmy = float(np.median(train_targets[TARGET_Y]))
    gsx = float(np.subtract(*np.percentile(train_targets[TARGET_X], [75,25]))); gsx = max(gsx, 1e-6)
    gsy = float(np.subtract(*np.percentile(train_targets[TARGET_Y], [75,25]))); gsy = max(gsy, 1e-6)

    df_fold['tx_norm'] = ((df_fold[TARGET_X] - gmx) / gsx).astype(np.float32)
    df_fold['ty_norm'] = ((df_fold[TARGET_Y] - gmy) / gsy).astype(np.float32)
    ycols_norm = ['tx_norm','ty_norm']


    fold_input_cols = list(feat_cols)
    x_scaler = RobustScaler().fit(df_fold.loc[train_mask, fold_input_cols])
    df_fold[fold_input_cols] = x_scaler.transform(df_fold[fold_input_cols])


    train_ds = DualWinStreamDS(df_fold, fold_input_cols, ycols_norm, TRIAL_COL, train_mask,
                               long_len, short_len, stride, horizon)
    val_ds   = DualWinStreamDS(df_fold, fold_input_cols, ycols_norm, TRIAL_COL, val_mask,
                               long_len, short_len, stride, horizon) if val_mask.any() else None
    test_ds  = DualWinStreamDS(df_fold, fold_input_cols, ycols_norm, TRIAL_COL, test_mask,
                               long_len, short_len, stride, horizon)
    print(f"[WIN] train={len(train_ds)}, val={0 if val_ds is None else len(val_ds)}, test={len(test_ds)}")

    pin = (DEVICE.type=='cuda')
    train_loader=DataLoader(train_ds, batch_size=BATCH, shuffle=True,  pin_memory=pin)
    val_loader  =DataLoader(val_ds,   batch_size=BATCH, shuffle=False, pin_memory=pin) if (val_ds and len(val_ds)) else None
    test_loader =DataLoader(test_ds,  batch_size=BATCH, shuffle=False, pin_memory=pin)

    model=MultiPatchFormerX_Reg(long_win=long_len, short_win=short_len, in_feats=len(fold_input_cols),
                                d_model=D_MODEL, nheads=N_HEADS, nlayers=N_LAYERS,
                                dropout=DROPOUT, patch_overlap=PATCH_OVERLAP, scales=SCALES, out_dim=2).to(DEVICE)
    opt=torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=3)

    if USE_SWA:
        from torch.optim.swa_utils import AveragedModel, SWALR
        swa_model = nn.Identity(); swa_start_epoch = int(max(1, EPOCHS * SWA_START_FRAC))
        swa_scheduler = None


    if len(train_ds):
        train_targets_n = df_fold.loc[train_mask, ycols_norm].to_numpy(dtype=np.float32)
        thr_vals = np.quantile(np.abs(train_targets_n), WEIGHT_Q, axis=0)
        spike_thr = torch.tensor(thr_vals, dtype=torch.float32, device=DEVICE)
    else:
        spike_thr = torch.tensor([1.0,1.0], dtype=torch.float32, device=DEVICE)

    wx = 1.0; wy = 1.0
    if AXIS_WEIGHT_FROM_TRAIN_IQR: wy = gsy / (gsx + 1e-6)


    if len(train_ds):
        with torch.no_grad():
            xl_b, xs_b, y_b, last_b, _tid, _tidx = next(iter(train_loader))
            _=model(xs_b.to(DEVICE), xl_b.to(DEVICE))
        print("[DEBUG] Forward sanity check OK.")


    best_val=np.inf; best_state=None; bad=0
    for epoch in range(1, EPOCHS+1):
        if not len(train_ds): break
        model.train(); run=0.0
        for xl_b, xs_b, y_b, last_b, _tid, _tidx in train_loader:
            xl_b=xl_b.to(DEVICE); xs_b=xs_b.to(DEVICE); y_b=y_b.to(DEVICE)
            if USE_AUG:
                if AUG_NOISE_STD>0:
                    xl_b = xl_b + torch.randn_like(xl_b)*AUG_NOISE_STD
                    xs_b = xs_b + torch.randn_like(xs_b)*AUG_NOISE_STD
                if np.random.rand() < AUG_TIMEMASK_P:
                    xl_b = time_mask(xl_b, L=AUG_TIMEMASK_LEN)
                    xs_b = time_mask(xs_b, L=AUG_TIMEMASK_LEN)
            opt.zero_grad(set_to_none=True)
            pred=model(xs_b, xl_b)
            loss=weighted_huber_2axis(pred, y_b, spike_thr, wx=wx, wy=wy)
            loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            opt.step(); run+=float(loss.item())

        if USE_SWA and epoch == swa_start_epoch:
            swa_model = AveragedModel(model)
            swa_scheduler = SWALR(opt, anneal_strategy='cos', anneal_epochs=5, swa_lr=LR * SWA_LR_FACTOR)

        if val_loader and len(val_ds):
            model.eval(); ya=[]; yh=[]
            with torch.no_grad():
                for xl_b, xs_b, y_b, last_b, t_b, t_idx in val_loader:
                    xl_b=xl_b.to(DEVICE); xs_b=xs_b.to(DEVICE); y_b=y_b.to(DEVICE)
                    pred=model(xs_b, xl_b); ya.append(y_b.cpu().numpy()); yh.append(pred.cpu().numpy())
            y_true_n = np.concatenate(ya,0); y_pred_n = np.concatenate(yh,0)
            y_true = np.stack([y_true_n[:,0]*gsx+gmx, y_true_n[:,1]*gsy+gmy], axis=1)
            y_pred = np.stack([y_pred_n[:,0]*gsx+gmx, y_pred_n[:,1]*gsy+gmy], axis=1)
            v_mae  = mean_absolute_error(y_true, y_pred)

            if USE_SWA and epoch >= swa_start_epoch:
                swa_model.update_parameters(model)
                if swa_scheduler is not None: swa_scheduler.step()
            else:
                scheduler.step(v_mae)

            print(f"[F{fold_idx} E{epoch:02d}] train_loss={run/max(1,len(train_loader)):.4f} | val MAE={v_mae:.3f}")
            if v_mae<best_val:
                best_val=v_mae; best_state={k:v.detach().cpu() for k,v in model.state_dict().items()}; bad=0
            else:
                bad+=1
                if bad>=PATIENCE: print("[EARLY STOP]"); break
        else:
            print(f"[F{fold_idx} E{epoch:02d}] train_loss={run/max(1,len(train_loader)):.4f}")
            if USE_SWA and epoch >= swa_start_epoch:
                swa_model.update_parameters(model)
                if swa_scheduler is not None: swa_scheduler.step()

    if best_state is not None and len(train_ds): model.load_state_dict(best_state)
    if USE_SWA and isinstance(swa_model, nn.Module) and not isinstance(swa_model, nn.Identity):
        model = swa_model


    lat_plain_batch = lat_tta_batch = lat_plain_ps = lat_tta_ps = (np.nan, np.nan)
    if PRINT_LATENCY:
        lat_plain_batch = measure_latency(model, test_loader, DEVICE, mode="plain",
                                          warmup_steps=LAT_WARMUP_STEPS, measure_steps=LAT_MEASURE_STEPS,
                                          tta_runs=TTA_RUNS, tta_noise=TTA_NOISE)
        print(f"[LATENCY] Plain forward (batches): {lat_plain_batch[0]:.2f} ms/sample | {lat_plain_batch[1]:.1f} samples/s")
        lat_plain_ps = measure_latency_per_sample(model, test_loader, DEVICE, mode="plain",
                                                  warmup_samples=LAT_WARMUP_SAMPLES_PS, measure_samples=LAT_MEASURE_SAMPLES_PS,
                                                  tta_runs=TTA_RUNS, tta_noise=TTA_NOISE)
        print(f"[LATENCY] Plain per-sample: {lat_plain_ps[0]:.2f} ms/sample | {lat_plain_ps[1]:.1f} samples/s")
        if USE_TTA:
            lat_tta_batch = measure_latency(model, test_loader, DEVICE, mode="tta",
                                            warmup_steps=max(5, LAT_WARMUP_STEPS//2),
                                            measure_steps=max(10, LAT_MEASURE_STEPS//2),
                                            tta_runs=TTA_RUNS, tta_noise=TTA_NOISE)
            print(f"[LATENCY] TTA (batches): {lat_tta_batch[0]:.2f} ms/sample | {lat_tta_batch[1]:.1f} samples/s")
            lat_tta_ps = measure_latency_per_sample(model, test_loader, DEVICE, mode="tta",
                                                    warmup_samples=max(10, LAT_WARMUP_SAMPLES_PS//2),
                                                    measure_samples=max(30, LAT_MEASURE_SAMPLES_PS//2),
                                                    tta_runs=TTA_RUNS, tta_noise=TTA_NOISE)
            print(f"[LATENCY] TTA per-sample: {lat_tta_ps[0]:.2f} ms/sample | {lat_tta_ps[1]:.1f} samples/s")


    model.eval(); ya=[]; yh=[]; ylast=[]; tids=[]; tidxs=[]
    with torch.no_grad():
        for xl_b, xs_b, y_b, last_b, t_b, t_idx in test_loader:
            xl_b=xl_b.to(DEVICE); xs_b=xs_b.to(DEVICE)
            pred = predict_tta(model, xs_b, xl_b, n=TTA_RUNS, noise=TTA_NOISE) if USE_TTA else model(xs_b, xl_b)
            ya.append(y_b.cpu().numpy()); yh.append(pred.cpu().numpy()); ylast.append(last_b.cpu().numpy())
            tids += list(t_b); tidxs += list(t_idx.numpy())

    y_true_n = np.concatenate(ya,0) if ya else np.zeros((0,2),np.float32)
    y_pred_n = np.concatenate(yh,0) if yh else np.zeros((0,2),np.float32)
    y_last_n = np.concatenate(ylast,0) if ylast else np.zeros((0,2),np.float32)


    y_true = np.stack([y_true_n[:,0]*gsx+gmx, y_true_n[:,1]*gsy+gmy], axis=1)
    y_pred = np.stack([y_pred_n[:,0]*gsx+gmx, y_pred_n[:,1]*gsy+gmy], axis=1)
    y_pers = np.stack([y_last_n[:,0]*gsx+gmx, y_last_n[:,1]*gsy+gmy], axis=1)

    # Metrics
    if len(y_true):
        mae_x=mean_absolute_error(y_true[:,0], y_pred[:,0]); mae_y=mean_absolute_error(y_true[:,1], y_pred[:,1])
        mse_x=mean_squared_error(y_true[:,0], y_pred[:,0]);  mse_y=mean_squared_error(y_true[:,1], y_pred[:,1])
        rmse_x=np.sqrt(mse_x); rmse_y=np.sqrt(mse_y)
        r2_x=r2_score(y_true[:,0], y_pred[:,0]); r2_y=r2_score(y_true[:,1], y_pred[:,1])

        mae_x_p=mean_absolute_error(y_true[:,0], y_pers[:,0]); mae_y_p=mean_absolute_error(y_true[:,1], y_pers[:,1])
        mse_x_p=mean_squared_error(y_true[:,0], y_pers[:,0]);  mse_y_p=mean_squared_error(y_true[:,1], y_pers[:,1])
        rmse_x_p=np.sqrt(mse_x_p); rmse_y_p=np.sqrt(mse_y_p)
        r2_x_p=r2_score(y_true[:,0], y_pers[:,0]); r2_y_p=r2_score(y_true[:,1], y_pers[:,1])
    else:
        mae_x=mae_y=mse_x=mse_y=rmse_x=rmse_y=r2_x=r2_y=np.nan
        mae_x_p=mae_y_p=mse_x_p=mse_y_p=rmse_x_p=rmse_y_p=r2_x_p=r2_y_p=np.nan

    print(f"[FOLD {fold_idx}] "
          f"MAE X={mae_x:.3f} Y={mae_y:.3f} | "
          f"MSE X={mse_x:.3f} Y={mse_y:.3f} | "
          f"RMSE X={rmse_x:.3f} Y={rmse_y:.3f} | "
          f"R^2 X={r2_x:.3f} Y={r2_y:.3f} || "
          f"[PERSIST] MAE X={mae_x_p:.3f} Y={mae_y_p:.3f} | "
          f"MSE X={mse_x_p:.3f} Y={mse_y_p:.3f} | "
          f"RMSE X={rmse_x_p:.3f} Y={rmse_y_p:.3f} | "
          f"R^2 X={r2_x_p:.3f} Y={r2_y_p:.3f}")

    fold_metrics.append({
        "mae_x": mae_x, "mae_y": mae_y,
        "mse_x": mse_x, "mse_y": mse_y,
        "rmse_x": rmse_x, "rmse_y": rmse_y,
        "r2_x": r2_x, "r2_y": r2_y,
        "mae_x_p": mae_x_p, "mae_y_p": mae_y_p,
        "mse_x_p": mse_x_p, "mse_y_p": mse_y_p,
        "rmse_x_p": rmse_x_p, "rmse_y_p": rmse_y_p,
        "r2_x_p": r2_x_p, "r2_y_p": r2_y_p,
        "lat_plain_ps_ms": lat_plain_ps[0], "lat_plain_ps_sps": lat_plain_ps[1],
        "lat_tta_ps_ms": lat_tta_ps[0] if USE_TTA else np.nan,
        "lat_tta_ps_sps": lat_tta_ps[1] if USE_TTA else np.nan
    })


    if len(y_true):

        fig, ax = plt.subplots(1, 2, figsize=(11, 4))
        labels = ['X', 'Y']
        for i in range(2):
            ax[i].scatter(y_true[:, i], y_pred[:, i], s=6, alpha=0.5)
            lo = float(min(y_true[:, i].min(), y_pred[:, i].min()))
            hi = float(max(y_true[:, i].max(), y_pred[:, i].max()))
            ax[i].plot([lo, hi], [lo, hi], 'k--', lw=1)
            ax[i].set_xlabel('True'); ax[i].set_ylabel('Pred')
            ax[i].set_title(f'Fold {fold_idx}: newCoP {labels[i]} — True vs Pred')
        plt.tight_layout(); plt.show()


        series = {}
        for i, (tid, ti) in enumerate(zip(tids, tidxs)):
            if tid not in series: series[tid] = {'t': [], 'tx': [], 'px': [], 'ty': [], 'py': []}
            series[tid]['t'].append(ti)
            series[tid]['tx'].append(y_true[i,0]); series[tid]['px'].append(y_pred[i,0])
            series[tid]['ty'].append(y_true[i,1]); series[tid]['py'].append(y_pred[i,1])

        plot_trials = [str(t) for t in test_trials if str(t) in series]
        for tid in plot_trials:
            s = series[tid]
            order = np.argsort(s['t'])
            t_sec = (np.asarray(s['t'])[order] / fs).astype(float)
            tx = np.asarray(s['tx'])[order]; px = np.asarray(s['px'])[order]
            ty = np.asarray(s['ty'])[order]; py = np.asarray(s['py'])[order]

            plt.figure(figsize=(10,4))
            plt.plot(t_sec, tx, label='True newCopX')
            plt.plot(t_sec, px, label='Pred newCopX')
            plt.title(f'Fold {fold_idx} | Trial {tid}: Time Series newCopX')
            plt.xlabel('Time (s)'); plt.ylabel('newCopX'); plt.legend(); plt.tight_layout(); plt.show()

            plt.figure(figsize=(10,4))
            plt.plot(t_sec, ty, label='True newCopY')
            plt.plot(t_sec, py, label='Pred newCopY')
            plt.title(f'Fold {fold_idx} | Trial {tid}: Time Series newCopY')
            plt.xlabel('Time (s)'); plt.ylabel('newCopY'); plt.legend(); plt.tight_layout(); plt.show()


    save_model = MultiPatchFormerX_Reg(
        long_win=long_len, short_win=short_len, in_feats=len(fold_input_cols),
        d_model=D_MODEL, nheads=N_HEADS, nlayers=N_LAYERS,
        dropout=DROPOUT, patch_overlap=PATCH_OVERLAP, scales=SCALES, out_dim=2
    )
    raw_state = model.state_dict()
    clean_state = {k.replace("module.", ""): v.cpu() for k, v in raw_state.items()}
    save_model.load_state_dict(clean_state, strict=False)
    save_model.cpu().eval()
    ckpt = {
        "model_state": save_model.state_dict(),
        "model_cfg": {
            "long_win": long_len, "short_win": short_len, "in_feats": len(fold_input_cols),
            "d_model": D_MODEL, "nheads": N_HEADS, "nlayers": N_LAYERS, "dropout": DROPOUT,
            "patch_overlap": PATCH_OVERLAP, "scales": SCALES, "out_dim": 2
        },
        "input_cols": fold_input_cols,
        "x_scaler": x_scaler,
        "global_target_stats": {"gmx": gmx, "gmy": gmy, "gsx": gsx, "gsy": gsy},
        "fs": fs, "horizon": horizon, "stride": stride,
        "ycols_norm": ycols_norm, "seed": RANDOM_SEED, "fold_idx": fold_idx
    }
    ckpt_path = CKPT_DIR / f"mpf_xreg_fair_fold{fold_idx}.pt"
    torch.save(ckpt, ckpt_path)
    print(f"[SAVE] Checkpoint saved to: {ckpt_path}")


def mean_std(arr):
    arr = np.asarray(arr, float)
    mu = float(np.nanmean(arr))
    finite = np.isfinite(arr)
    sd = float(np.nanstd(arr[finite], ddof=1)) if np.sum(finite) > 1 else float(np.nan)
    return mu, sd

mae_x_vals  = [m["mae_x"]  for m in fold_metrics]
mae_y_vals  = [m["mae_y"]  for m in fold_metrics]
mse_x_vals  = [m["mse_x"]  for m in fold_metrics]
mse_y_vals  = [m["mse_y"]  for m in fold_metrics]
rmse_x_vals = [m["rmse_x"] for m in fold_metrics]
rmse_y_vals = [m["rmse_y"] for m in fold_metrics]
r2_x_vals   = [m["r2_x"]   for m in fold_metrics]
r2_y_vals   = [m["r2_y"]   for m in fold_metrics]

def report(name, vals):
    mu, sd = mean_std(vals); print(f"{name:<22}: {mu:.3f} ± {sd:.3f}")

print("\n=== Cross-Fold Summary (Model) ===")
report("MAE X", mae_x_vals);   report("MAE Y", mae_y_vals)
report("MSE X", mse_x_vals);   report("MSE Y", mse_y_vals)
report("RMSE X", rmse_x_vals); report("RMSE Y", rmse_y_vals)
report("R^2 X", r2_x_vals);    report("R^2 Y", r2_y_vals)


for i, m in enumerate(fold_metrics, 1):
    print(f"[FOLD {i}] Latency per-sample (plain): {m['lat_plain_ps_ms']:.2f} ms/sample | {m['lat_plain_ps_sps']:.1f} samples/s")
    if not np.isnan(m['lat_tta_ps_ms']):
        print(f"[FOLD {i}] Latency per-sample (TTA):   {m['lat_tta_ps_ms']:.2f} ms/sample | {m['lat_tta_ps_sps']:.1f} samples/s")
