<a href="https://colab.research.google.com/github/Salehin555/MOE-RUL-Prediction/blob/main/RUL_PREDICTION.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

FD001

In [85]:
# ============================================================
# FINAL Interpretable MOE for C-MAPSS RUL (ONE-CELL Colab)
# Fixes:
#  - Noisy Top-K gating + load balancing + entropy bonus (Sub-Obj 1)
#  - Physics constraints: monotonic + smooth (Sub-Obj 2)
#  - UQ: NLL + MC Dropout + PI calibration (Sub-Obj 3)
#  - Transfer-ready structure (Sub-Obj 4 hooks; single-domain run here)
# ============================================================

# ---------------------------
# USER SETTINGS
# ---------------------------
DATA_DIR = "/content"
PRETRAIN_DOMAINS = ["FD001"]
TARGETS = ["FD001"]
DO_FINETUNE = False
ENSEMBLE_SIZE = 1

# ---------------------------
# Install deps
# ---------------------------
import sys, subprocess, importlib
def _ensure(pkg, import_name=None):
    name = import_name or pkg
    try:
        importlib.import_module(name)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

_ensure("scikit-learn", "sklearn")
_ensure("scipy", "scipy")

# ---------------------------
# Imports
# ---------------------------
import os, math, random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from scipy.stats import norm

# ---------------------------
# Config
# ---------------------------
@dataclass
class CFG:
    window: int = 30
    stride: int = 1
    max_rul: int = 125
    normalize_y: bool = True
    val_ratio_units: float = 0.2
    seed: int = 42

    drop_low_var: bool = True
    low_var_thresh: float = 1e-6

    n_experts: int = 4
    top_k: int = 2
    enc_hidden: int = 192
    head_hidden: int = 192
    dropout: float = 0.08
    gate_dropout: float = 0.05

    gate_noise_max: float = 1.0
    gate_noise_min: float = 0.15
    temp_max: float = 2.0
    temp_min: float = 0.7

    epochs: int = 85
    batch_size: int = 128
    block_len: int = 12
    lr: float = 2e-3
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    early_stop_patience: int = 14

    aux_mse_weight: float = 0.55
    huber_weight: float = 0.15
    huber_delta: float = 0.08  # in normalized y scale

    lambda_mono: float = 0.10
    lambda_smooth: float = 0.03
    lambda_lb: float = 0.25
    lambda_ent: float = 0.02
    lambda_div: float = 0.01
    lambda_dead: float = 0.06
    dead_floor: float = 0.03

    warmup_epochs: int = 12
    ramp_epochs: int = 20

    mc_samples: int = 30
    pi_alpha: float = 0.10  # 90% PI

cfg = CFG()

# ---------------------------
# Utils
# ---------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(cfg.seed)

def ensure_exists(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing file: {path}")

def base_feature_columns() -> List[str]:
    return ["op1", "op2", "op3"] + [f"s{i}" for i in range(1, 22)]

def ramp(ep, warmup, ramp_len):
    if ep <= warmup: return 0.0
    return float(min(1.0, (ep - warmup) / max(1, ramp_len)))

def lerp(a, b, t):
    return a + (b - a) * t

# ---------------------------
# C-MAPSS load
# ---------------------------
def load_cmapss_split(data_dir: str, fd: str):
    train_file = os.path.join(data_dir, f"train_{fd}.txt")
    test_file  = os.path.join(data_dir, f"test_{fd}.txt")
    rul_file   = os.path.join(data_dir, f"RUL_{fd}.txt")
    ensure_exists(train_file); ensure_exists(test_file); ensure_exists(rul_file)

    train_df = pd.read_csv(train_file, sep=r"\s+", header=None)
    test_df  = pd.read_csv(test_file,  sep=r"\s+", header=None)
    rul_df   = pd.read_csv(rul_file,   sep=r"\s+", header=None)

    cols = ["unit","cycle","op1","op2","op3"] + [f"s{i}" for i in range(1,22)]
    train_df.columns = cols
    test_df.columns  = cols
    rul_df.columns   = ["RUL_last"]
    return train_df, test_df, rul_df

def add_rul_train(df: pd.DataFrame, max_rul: int):
    df = df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    df["RUL"] = df.apply(lambda r: max_cycle.loc[r["unit"]] - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

def add_rul_test(test_df: pd.DataFrame, rul_df: pd.DataFrame, max_rul: int):
    df = test_df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    rul_last = rul_df["RUL_last"].values
    mapping = {u: rul_last[u-1] for u in sorted(df["unit"].unique())}
    df["RUL"] = df.apply(lambda r: (max_cycle.loc[r["unit"]] + mapping[r["unit"]]) - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

def make_windows(df, window, stride, feature_cols, target_col="RUL"):
    xs, ys, units, cycles = [], [], [], []
    for unit, g in df.groupby("unit"):
        g = g.sort_values("cycle")
        feats = g[feature_cols].values.astype(np.float32)
        targ  = g[target_col].values.astype(np.float32)
        cyc   = g["cycle"].values.astype(np.int32)

        for end in range(window-1, len(g), stride):
            start = end-window+1
            xs.append(feats[start:end+1])
            ys.append(targ[end])
            units.append(unit)
            cycles.append(cyc[end])
    return np.stack(xs), np.array(ys), np.array(units), np.array(cycles)

def last_window_per_unit(X, y, unit_ids, cycles):
    idx=[]
    for u in np.unique(unit_ids):
        m = unit_ids==u
        i = np.argmax(cycles[m])
        idx.append(np.where(m)[0][i])
    idx=np.array(idx)
    return X[idx], y[idx], unit_ids[idx], cycles[idx]

def split_by_units(unit_ids, val_ratio, seed):
    rng=np.random.default_rng(seed)
    units=np.unique(unit_ids)
    rng.shuffle(units)
    n_val=max(1,int(len(units)*val_ratio))
    val=set(units[:n_val].tolist())
    tr=np.array([i for i,u in enumerate(unit_ids) if u not in val])
    va=np.array([i for i,u in enumerate(unit_ids) if u in val])
    return tr, va

# ---------------------------
# Dataset + contiguous sampler
# ---------------------------
class RULWindowDataset(Dataset):
    def __init__(self, X, y, unit_ids, cycles, domain_ids):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.unit_ids = np.array(unit_ids)
        self.cycles = np.array(cycles)
        self.domain_ids = torch.tensor(domain_ids, dtype=torch.long)

    def __len__(self): return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], int(self.unit_ids[idx]), int(self.cycles[idx]), self.domain_ids[idx]

class ContiguousEngineBatchSampler(Sampler):
    def __init__(self, unit_ids, cycles, batch_size, block_len, shuffle=True):
        self.unit_ids=np.array(unit_ids)
        self.cycles=np.array(cycles)
        self.batch_size=batch_size
        self.block_len=block_len
        self.shuffle=shuffle
        self.units=np.unique(self.unit_ids)
        self.unit_to_sorted={}
        for u in self.units:
            idx=np.where(self.unit_ids==u)[0]
            idx=idx[np.argsort(self.cycles[idx])]
            self.unit_to_sorted[u]=idx
        self.engines_per_batch=max(1,batch_size//block_len)

    def __iter__(self):
        units=self.units.copy()
        if self.shuffle: np.random.shuffle(units)
        batch=[]
        for u in units:
            idx=self.unit_to_sorted[u]
            L=len(idx)
            if L<=self.block_len:
                take=np.random.choice(idx,size=self.block_len,replace=True)
            else:
                s=np.random.randint(0,L-self.block_len)
                take=idx[s:s+self.block_len]
            batch.extend(take.tolist())
            if len(batch)>=self.engines_per_batch*self.block_len:
                yield batch[:self.batch_size]
                batch=[]
        if batch: yield batch

    def __len__(self):
        return math.ceil(len(self.units)/self.engines_per_batch)

# ---------------------------
# Physics losses
# ---------------------------
def physics_losses(pred_mean, unit_ids, cycles, margin=0.0):
    unit_ids=np.array(unit_ids); cycles=np.array(cycles)
    mono_terms=[]; smooth_terms=[]
    for u in np.unique(unit_ids):
        idx=np.where(unit_ids==u)[0]
        if len(idx)<2: continue
        ord_idx=idx[np.argsort(cycles[idx])]
        p=pred_mean[ord_idx]
        mono=F.relu(p[1:]-p[:-1]+margin)
        mono_terms.append(mono.mean())
        if len(ord_idx)>=3:
            second=p[2:]-2*p[1:-1]+p[:-2]
            smooth_terms.append(torch.abs(second).mean())
    mono_loss=torch.stack(mono_terms).mean() if mono_terms else pred_mean.new_tensor(0.)
    smooth_loss=torch.stack(smooth_terms).mean() if smooth_terms else pred_mean.new_tensor(0.)
    return mono_loss, smooth_loss

# ---------------------------
# Gating + load balance
# ---------------------------
class GatingNet(nn.Module):
    def __init__(self, in_dim, n_experts, gate_dropout):
        super().__init__()
        h=max(64,in_dim//2)
        self.net=nn.Sequential(
            nn.Linear(in_dim,h),
            nn.ReLU(),
            nn.Dropout(gate_dropout),
            nn.Linear(h,n_experts),
        )
    def forward(self,x): return self.net(x)

def sparse_topk_softmax(logits, k, temperature):
    logits = logits / max(1e-6, temperature)
    topk_vals, topk_idx = torch.topk(logits, k=k, dim=-1)
    masked = torch.full_like(logits, float("-inf"))
    masked.scatter_(dim=-1, index=topk_idx, src=topk_vals)
    w = F.softmax(masked, dim=-1)
    return w, topk_idx

def switch_load_balance_loss(w, top1_idx):
    B,E=w.shape
    importance=w.sum(dim=0)/(B+1e-8)
    load=torch.bincount(top1_idx,minlength=E).float().to(w.device)/(B+1e-8)
    return E*torch.sum(importance*load)

def dead_expert_penalty(w, floor):
    avg = w.mean(dim=0)
    return F.relu(floor - avg).mean()

# ---------------------------
# UQ + calibration
# ---------------------------
def gaussian_nll(y, mean, log_var):
    return 0.5*(torch.exp(-log_var)*(y-mean)**2 + log_var)

def prediction_interval(mean, var, alpha):
    z = norm.ppf(1-alpha/2)
    std=np.sqrt(np.maximum(var,1e-8))
    return mean - z*std, mean + z*std

def coverage(y, lo, hi):
    return float(np.mean((y>=lo)&(y<=hi)))

# ✅ FIXED predict_mc: safe unpacking for 7-value forward()
@torch.no_grad()
def predict_mc(model, X, n_mc, batch_size=256):
    model.train()
    dl=DataLoader(torch.tensor(X,dtype=torch.float32),batch_size=batch_size,shuffle=False)
    means_all=[]; vars_all=[]
    for xb in dl:
        xb=xb.to(DEVICE)
        mc_m=[]; mc_v=[]
        for _ in range(n_mc):
            out = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=True)
            mean, log_var = out[0], out[1]
            mc_m.append(mean)
            mc_v.append(torch.exp(log_var))
        mc_m=torch.stack(mc_m,dim=0)
        mc_v=torch.stack(mc_v,dim=0)
        mean_pred=mc_m.mean(dim=0)
        epistemic=mc_m.var(dim=0,unbiased=False)
        aleatoric=mc_v.mean(dim=0)
        total=aleatoric+epistemic
        means_all.append(mean_pred.cpu().numpy())
        vars_all.append(total.cpu().numpy())
    return np.concatenate(means_all), np.concatenate(vars_all)

def calibrate_scale_min_width(model, X_val, y_val_raw, alpha, n_mc, max_over=0.01):
    mean, var = predict_mc(model, X_val, n_mc=n_mc)
    if cfg.normalize_y:
        mean = mean * cfg.max_rul
        var  = var  * (cfg.max_rul**2)
    std = np.sqrt(np.maximum(var,1e-8))
    z = norm.ppf(1-alpha/2)
    target = 1-alpha
    scales = np.linspace(0.6, 3.0, 121)
    best=None
    for s in scales:
        lo = mean - z*s*std
        hi = mean + z*s*std
        cov = np.mean((y_val_raw>=lo)&(y_val_raw<=hi))
        if cov >= (target - max_over):
            best=s
            break
    return float(best if best is not None else scales[-1])

# ---------------------------
# Model
# ---------------------------
class SharedEncoder(nn.Module):
    def __init__(self, n_features, hidden, dropout):
        super().__init__()
        self.conv1=nn.Conv1d(n_features,96,3,padding=1)
        self.conv2=nn.Conv1d(96,96,3,padding=1)
        self.bn1=nn.BatchNorm1d(96)
        self.bn2=nn.BatchNorm1d(96)
        self.gru=nn.GRU(96,hidden,batch_first=True)
        self.drop=nn.Dropout(dropout)
    def forward(self,x):
        x=x.transpose(1,2)
        x=F.relu(self.bn1(self.conv1(x)))
        x=F.relu(self.bn2(self.conv2(x)))
        x=x.transpose(1,2)
        _,h=self.gru(x)
        return self.drop(h[-1])

class ExpertHead(nn.Module):
    def __init__(self, in_dim, hidden, base_dropout, expert_id):
        super().__init__()
        width = hidden if expert_id%2==0 else max(64,hidden//2)
        drop  = min(0.30, base_dropout + 0.03*expert_id)
        self.emb = nn.Parameter(torch.randn(16)*0.02)
        self.net = nn.Sequential(
            nn.Linear(in_dim+16, width), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(width, max(32,width//2)), nn.ReLU(), nn.Dropout(drop),
        )
        out=max(32,width//2)
        self.mu=nn.Linear(out,1)
        self.logv=nn.Linear(out,1)
    def forward(self,z):
        B=z.size(0)
        e=self.emb.unsqueeze(0).expand(B,-1)
        h=self.net(torch.cat([z,e],dim=-1))
        mu=self.mu(h).squeeze(-1)
        logv=self.logv(h).squeeze(-1).clamp(-9,4)
        return mu, logv

class InterpretableMoE(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.enc = SharedEncoder(n_features,cfg.enc_hidden,cfg.dropout)
        self.experts = nn.ModuleList([ExpertHead(cfg.enc_hidden,cfg.head_hidden,cfg.dropout,i) for i in range(cfg.n_experts)])
        self.gate = GatingNet(cfg.enc_hidden+9, cfg.n_experts, cfg.gate_dropout)

    def forward(self, x, *, use_topk, noise_std, temperature, train):
        z = self.enc(x)
        op = x[:,:, :3]
        gate_in = torch.cat([z, op[:,-1,:], op.mean(dim=1), op.std(dim=1)], dim=-1)

        logits = self.gate(gate_in)
        if train and noise_std>0:
            logits = logits + torch.randn_like(logits)*noise_std

        if use_topk:
            w, topk_idx = sparse_topk_softmax(logits, cfg.top_k, temperature)
            top1 = topk_idx[:,0]
        else:
            w = F.softmax(logits/max(1e-6,temperature), dim=-1)
            top1 = torch.argmax(w, dim=-1)

        mus=[]; vars_=[]
        for ex in self.experts:
            mu, logv = ex(z)
            mus.append(mu); vars_.append(torch.exp(logv))
        mus=torch.stack(mus,dim=-1)
        vars_=torch.stack(vars_,dim=-1)

        mean=torch.sum(w*mus,dim=-1)
        second=torch.sum(w*(vars_+mus**2),dim=-1)
        var=(second-mean**2).clamp_min(1e-6)
        log_var=torch.log(var)

        lb = switch_load_balance_loss(w, top1)
        ent = -(w*torch.log(w+1e-8)).sum(dim=-1).mean()
        dead = dead_expert_penalty(w, cfg.dead_floor)

        div=0.0
        E=mus.shape[-1]
        for i in range(E):
            for j in range(i+1,E):
                a=mus[:,i]-mus[:,i].mean()
                b=mus[:,j]-mus[:,j].mean()
                div += torch.abs(F.cosine_similarity(a.unsqueeze(-1), b.unsqueeze(-1), dim=-1)).mean()
        div = div / max(1,(E*(E-1)//2))

        return mean, log_var, w, lb, ent, div, dead

# ---------------------------
# Prepare domain
# ---------------------------
def prepare_domain(fd, scaler=None, kept_cols=None):
    tr, te, rul = load_cmapss_split(DATA_DIR, fd)
    tr=add_rul_train(tr,cfg.max_rul)
    te=add_rul_test(te,rul,cfg.max_rul)

    all_cols = base_feature_columns()
    if kept_cols is None:
        kept_cols = all_cols
        if cfg.drop_low_var:
            v = tr[all_cols].var(axis=0).values
            kept_cols = [c for c,vv in zip(all_cols,v) if vv>cfg.low_var_thresh]
            for c in ["op1","op2","op3"]:
                if c not in kept_cols:
                    kept_cols = ["op1","op2","op3"] + [x for x in kept_cols if x not in ["op1","op2","op3"]]

    X_tr_all, y_tr_all, u_tr_all, c_tr_all = make_windows(tr, cfg.window, cfg.stride, kept_cols)
    X_te_all, y_te_all, u_te_all, c_te_all = make_windows(te, cfg.window, cfg.stride, kept_cols)
    X_te_last, y_te_last, u_te_last, c_te_last = last_window_per_unit(X_te_all, y_te_all, u_te_all, c_te_all)

    if scaler is None:
        scaler = StandardScaler()
        scaler.fit(X_tr_all.reshape(-1, X_tr_all.shape[-1]))

    X_tr_all = scaler.transform(X_tr_all.reshape(-1,X_tr_all.shape[-1])).reshape(X_tr_all.shape)
    X_te_last = scaler.transform(X_te_last.reshape(-1,X_te_last.shape[-1])).reshape(X_te_last.shape)

    tr_idx, va_idx = split_by_units(u_tr_all, cfg.val_ratio_units, cfg.seed)
    X_tr, y_tr, u_tr, c_tr = X_tr_all[tr_idx], y_tr_all[tr_idx], u_tr_all[tr_idx], c_tr_all[tr_idx]
    X_va, y_va, u_va, c_va = X_tr_all[va_idx], y_tr_all[va_idx], u_tr_all[va_idx], c_tr_all[va_idx]

    y_te_raw = y_te_last.copy()

    if cfg.normalize_y:
        y_tr = y_tr/cfg.max_rul
        y_va = y_va/cfg.max_rul
        y_te = y_te_last/cfg.max_rul
    else:
        y_te = y_te_last

    return {
        "scaler": scaler, "kept_cols": kept_cols,
        "train": (X_tr,y_tr,u_tr,c_tr),
        "val": (X_va,y_va,u_va,c_va),
        "test_last": (X_te_last,y_te,u_te_last,c_te_last),
        "test_last_y_raw": y_te_raw
    }

# ---------------------------
# Interpretability
# ---------------------------
@torch.no_grad()
def gate_stats(model, X, n_show=1024):
    model.eval()
    xb=torch.tensor(X[:n_show],dtype=torch.float32).to(DEVICE)
    out = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=False)
    w = out[2].cpu().numpy()
    print("Gate usage (avg weights):", np.round(w.mean(axis=0),4))

# ---------------------------
# Training
# ---------------------------
def train_on_domain(domain):
    X_tr,y_tr,u_tr,c_tr = domain["train"]
    X_va,y_va,u_va,c_va = domain["val"]

    model = InterpretableMoE(n_features=X_tr.shape[-1]).to(DEVICE)

    ds_tr = RULWindowDataset(X_tr,y_tr,u_tr,c_tr,domain_ids=np.zeros(len(X_tr),dtype=int))
    ds_va = RULWindowDataset(X_va,y_va,u_va,c_va,domain_ids=np.zeros(len(X_va),dtype=int))

    sampler = ContiguousEngineBatchSampler(ds_tr.unit_ids, ds_tr.cycles, cfg.batch_size, cfg.block_len, shuffle=True)
    dl_tr = DataLoader(ds_tr, batch_sampler=sampler)
    dl_va = DataLoader(ds_va, batch_size=cfg.batch_size, shuffle=False)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)

    best_val=float("inf"); best_state=None; bad=0

    for ep in range(1,cfg.epochs+1):
        s = ramp(ep, cfg.warmup_epochs, cfg.ramp_epochs)
        use_topk = (ep > cfg.warmup_epochs)
        noise_std = lerp(cfg.gate_noise_max, cfg.gate_noise_min, s)
        temperature = lerp(cfg.temp_max, cfg.temp_min, s)

        mono_w   = cfg.lambda_mono*s
        smooth_w = cfg.lambda_smooth*s
        lb_w     = cfg.lambda_lb*s
        ent_w    = cfg.lambda_ent*s
        div_w    = cfg.lambda_div*s
        dead_w   = cfg.lambda_dead*s

        model.train()
        tr_loss=0.0
        for xb,yb,ub,cb,db in dl_tr:
            xb=xb.to(DEVICE); yb=yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)

            mean, logv, w, lb, ent, div, dead = model(
                xb, use_topk=use_topk, noise_std=noise_std, temperature=temperature, train=True
            )

            nll = gaussian_nll(yb, mean, logv).mean()
            mse = F.mse_loss(mean, yb)
            hub = F.huber_loss(mean, yb, delta=cfg.huber_delta)
            mono, smooth = physics_losses(mean, ub, cb)

            if ep <= cfg.warmup_epochs:
                loss = 1.0*mse + 0.10*nll + cfg.huber_weight*hub
            else:
                loss = nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub

            loss = loss + mono_w*mono + smooth_w*smooth + lb_w*lb - ent_w*ent + div_w*div + dead_w*dead

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()
            tr_loss += float(loss.detach().cpu())

        sched.step()

        # val loss + quick metrics
        model.eval()
        va_loss=0.0
        with torch.no_grad():
            for xb,yb,ub,cb,db in dl_va:
                xb=xb.to(DEVICE); yb=yb.to(DEVICE)
                mean, logv, *_ = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=False)
                nll = gaussian_nll(yb, mean, logv).mean()
                mse = F.mse_loss(mean, yb)
                hub = F.huber_loss(mean, yb, delta=cfg.huber_delta)
                va_loss += float((nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub).cpu())
        va_loss /= max(1,len(dl_va))

        yhat_val, _ = predict_mc(model, X_va, n_mc=1)
        if cfg.normalize_y:
            yhat_val = yhat_val*cfg.max_rul
            y_val_raw = y_va*cfg.max_rul
        else:
            y_val_raw = y_va
        r2 = r2_score(y_val_raw, yhat_val)
        rmse = math.sqrt(mean_squared_error(y_val_raw, yhat_val))
        print(f"[Epoch {ep:02d}] s={s:.2f} topk={int(use_topk)} val={va_loss:.4f} | R2={r2:.4f} RMSE={rmse:.3f}")

        if va_loss + 1e-6 < best_val:
            best_val=va_loss
            best_state={k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            bad=0
        else:
            bad+=1
            if bad>=cfg.early_stop_patience:
                print(f"Early stopping (best val={best_val:.4f})")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model

# ---------------------------
# RUN
# ---------------------------
for fd in set(PRETRAIN_DOMAINS + TARGETS):
    for f in [f"train_{fd}.txt", f"test_{fd}.txt", f"RUL_{fd}.txt"]:
        ensure_exists(os.path.join(DATA_DIR, f))

domains={}
scaler=None
kept=None
for fd in PRETRAIN_DOMAINS:
    domains[fd]=prepare_domain(fd, scaler=scaler, kept_cols=kept)
    scaler=domains[fd]["scaler"]
    kept=domains[fd]["kept_cols"]
for fd in TARGETS:
    if fd not in domains:
        domains[fd]=prepare_domain(fd, scaler=scaler, kept_cols=kept)

models=[]
for mi in range(ENSEMBLE_SIZE):
    print("\n==============================")
    print(f" Training model {mi+1}/{ENSEMBLE_SIZE}")
    print("==============================")
    set_seed(cfg.seed + mi)
    models.append(train_on_domain(domains[PRETRAIN_DOMAINS[0]]))

for fd in TARGETS:
    print("\n==============================")
    print(f" Target {fd} | LAST window per unit")
    print("==============================")

    X_te, y_te, *_ = domains[fd]["test_last"]
    y_te_raw = domains[fd]["test_last_y_raw"]
    X_va, y_va, *_ = domains[fd]["val"]
    y_va_raw = (y_va*cfg.max_rul) if cfg.normalize_y else y_va

    mean, var = predict_mc(models[0], X_te, n_mc=cfg.mc_samples)
    if cfg.normalize_y:
        mean_raw = mean*cfg.max_rul
        var_raw  = var*(cfg.max_rul**2)
    else:
        mean_raw = mean; var_raw = var

    scale = calibrate_scale_min_width(models[0], X_va, y_va_raw, cfg.pi_alpha, cfg.mc_samples, max_over=0.01)
    var_raw = (scale**2)*var_raw

    r2 = r2_score(y_te_raw, mean_raw)
    rmse = math.sqrt(mean_squared_error(y_te_raw, mean_raw))
    mae = mean_absolute_error(y_te_raw, mean_raw)

    lo, hi = prediction_interval(mean_raw, var_raw, cfg.pi_alpha)
    cov = coverage(y_te_raw, lo, hi)

    print(f"Point: R2={r2:.4f} RMSE={rmse:.3f} MAE={mae:.3f}")
    print(f"UQ: {int((1-cfg.pi_alpha)*100)}% PI coverage={cov:.3f} | mean width={(hi-lo).mean():.3f} | cal_scale={scale:.3f}")

    gate_stats(models[0], X_te)

print("\n✅ Done.")



 Training model 1/1
[Epoch 01] s=0.00 topk=0 val=0.1561 | R2=0.5270 RMSE=28.475
[Epoch 02] s=0.00 topk=0 val=-0.1708 | R2=0.6182 RMSE=25.583
[Epoch 03] s=0.00 topk=0 val=-0.3969 | R2=0.6275 RMSE=25.270
[Epoch 04] s=0.00 topk=0 val=-0.7160 | R2=0.7017 RMSE=22.614
[Epoch 05] s=0.00 topk=0 val=-0.8247 | R2=0.5785 RMSE=26.882
[Epoch 06] s=0.00 topk=0 val=-0.7446 | R2=0.4484 RMSE=30.750
[Epoch 07] s=0.00 topk=0 val=-1.2552 | R2=0.7263 RMSE=21.662
[Epoch 08] s=0.00 topk=0 val=-1.3979 | R2=0.7389 RMSE=21.156
[Epoch 09] s=0.00 topk=0 val=1.6415 | R2=-0.0068 RMSE=41.544
[Epoch 10] s=0.00 topk=0 val=-1.1568 | R2=0.6992 RMSE=22.709
[Epoch 11] s=0.00 topk=0 val=-1.0664 | R2=0.5933 RMSE=26.403
[Epoch 12] s=0.00 topk=0 val=-1.1788 | R2=0.7430 RMSE=20.990
[Epoch 13] s=0.05 topk=1 val=-1.4030 | R2=0.7206 RMSE=21.886
[Epoch 14] s=0.10 topk=1 val=-0.3117 | R2=-0.1598 RMSE=44.588
[Epoch 15] s=0.15 topk=1 val=-1.3494 | R2=0.7115 RMSE=22.237
[Epoch 16] s=0.20 topk=1 val=-1.3317 | R2=0.6758 RMSE=23.575
[Ep

FD002

In [95]:
# ============================================================
# FINAL Interpretable MOE for C-MAPSS RUL (ONE-CELL Colab)
# Fixes:
#  - Noisy Top-K gating + load balancing + entropy bonus (Sub-Obj 1)
#  - Physics constraints: monotonic + smooth (Sub-Obj 2)
#  - UQ: NLL + MC Dropout + PI calibration (Sub-Obj 3)
#  - Transfer-ready structure (Sub-Obj 4 hooks; single-domain run here)
# ============================================================

# ---------------------------
# USER SETTINGS
# ---------------------------
DATA_DIR = "/content"
PRETRAIN_DOMAINS = ["FD002"]
TARGETS = ["FD002"]
DO_FINETUNE = False
ENSEMBLE_SIZE = 1

# ---------------------------
# Install deps
# ---------------------------
import sys, subprocess, importlib
def _ensure(pkg, import_name=None):
    name = import_name or pkg
    try:
        importlib.import_module(name)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

_ensure("scikit-learn", "sklearn")
_ensure("scipy", "scipy")

# ---------------------------
# Imports
# ---------------------------
import os, math, random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from scipy.stats import norm

# ---------------------------
# Config
# ---------------------------
@dataclass
class CFG:
    window: int = 30
    stride: int = 1
    max_rul: int = 125
    normalize_y: bool = True
    val_ratio_units: float = 0.2
    seed: int = 42

    drop_low_var: bool = True
    low_var_thresh: float = 1e-6

    n_experts: int = 4
    top_k: int = 2
    enc_hidden: int = 192
    head_hidden: int = 192
    dropout: float = 0.08
    gate_dropout: float = 0.05

    gate_noise_max: float = 1.0
    gate_noise_min: float = 0.15
    temp_max: float = 2.0
    temp_min: float = 0.7

    epochs: int = 85
    batch_size: int = 128
    block_len: int = 12
    lr: float = 2e-3
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    early_stop_patience: int = 14

    aux_mse_weight: float = 0.55
    huber_weight: float = 0.15
    huber_delta: float = 0.08  # in normalized y scale

    lambda_mono: float = 0.10
    lambda_smooth: float = 0.03
    lambda_lb: float = 0.25
    lambda_ent: float = 0.02
    lambda_div: float = 0.01
    lambda_dead: float = 0.06
    dead_floor: float = 0.03

    warmup_epochs: int = 12
    ramp_epochs: int = 20

    mc_samples: int = 30
    pi_alpha: float = 0.10  # 90% PI

cfg = CFG()

# ---------------------------
# Utils
# ---------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(cfg.seed)

def ensure_exists(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing file: {path}")

def base_feature_columns() -> List[str]:
    return ["op1", "op2", "op3"] + [f"s{i}" for i in range(1, 22)]

def ramp(ep, warmup, ramp_len):
    if ep <= warmup: return 0.0
    return float(min(1.0, (ep - warmup) / max(1, ramp_len)))

def lerp(a, b, t):
    return a + (b - a) * t

# ---------------------------
# C-MAPSS load
# ---------------------------
def load_cmapss_split(data_dir: str, fd: str):
    train_file = os.path.join(data_dir, f"train_{fd}.txt")
    test_file  = os.path.join(data_dir, f"test_{fd}.txt")
    rul_file   = os.path.join(data_dir, f"RUL_{fd}.txt")
    ensure_exists(train_file); ensure_exists(test_file); ensure_exists(rul_file)

    train_df = pd.read_csv(train_file, sep=r"\s+", header=None)
    test_df  = pd.read_csv(test_file,  sep=r"\s+", header=None)
    rul_df   = pd.read_csv(rul_file,   sep=r"\s+", header=None)

    cols = ["unit","cycle","op1","op2","op3"] + [f"s{i}" for i in range(1,22)]
    train_df.columns = cols
    test_df.columns  = cols
    rul_df.columns   = ["RUL_last"]
    return train_df, test_df, rul_df

def add_rul_train(df: pd.DataFrame, max_rul: int):
    df = df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    df["RUL"] = df.apply(lambda r: max_cycle.loc[r["unit"]] - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

def add_rul_test(test_df: pd.DataFrame, rul_df: pd.DataFrame, max_rul: int):
    df = test_df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    rul_last = rul_df["RUL_last"].values
    mapping = {u: rul_last[u-1] for u in sorted(df["unit"].unique())}
    df["RUL"] = df.apply(lambda r: (max_cycle.loc[r["unit"]] + mapping[r["unit"]]) - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

def make_windows(df, window, stride, feature_cols, target_col="RUL"):
    xs, ys, units, cycles = [], [], [], []
    for unit, g in df.groupby("unit"):
        g = g.sort_values("cycle")
        feats = g[feature_cols].values.astype(np.float32)
        targ  = g[target_col].values.astype(np.float32)
        cyc   = g["cycle"].values.astype(np.int32)

        for end in range(window-1, len(g), stride):
            start = end-window+1
            xs.append(feats[start:end+1])
            ys.append(targ[end])
            units.append(unit)
            cycles.append(cyc[end])
    return np.stack(xs), np.array(ys), np.array(units), np.array(cycles)

def last_window_per_unit(X, y, unit_ids, cycles):
    idx=[]
    for u in np.unique(unit_ids):
        m = unit_ids==u
        i = np.argmax(cycles[m])
        idx.append(np.where(m)[0][i])
    idx=np.array(idx)
    return X[idx], y[idx], unit_ids[idx], cycles[idx]

def split_by_units(unit_ids, val_ratio, seed):
    rng=np.random.default_rng(seed)
    units=np.unique(unit_ids)
    rng.shuffle(units)
    n_val=max(1,int(len(units)*val_ratio))
    val=set(units[:n_val].tolist())
    tr=np.array([i for i,u in enumerate(unit_ids) if u not in val])
    va=np.array([i for i,u in enumerate(unit_ids) if u in val])
    return tr, va

# ---------------------------
# Dataset + contiguous sampler
# ---------------------------
class RULWindowDataset(Dataset):
    def __init__(self, X, y, unit_ids, cycles, domain_ids):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.unit_ids = np.array(unit_ids)
        self.cycles = np.array(cycles)
        self.domain_ids = torch.tensor(domain_ids, dtype=torch.long)

    def __len__(self): return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], int(self.unit_ids[idx]), int(self.cycles[idx]), self.domain_ids[idx]

class ContiguousEngineBatchSampler(Sampler):
    def __init__(self, unit_ids, cycles, batch_size, block_len, shuffle=True):
        self.unit_ids=np.array(unit_ids)
        self.cycles=np.array(cycles)
        self.batch_size=batch_size
        self.block_len=block_len
        self.shuffle=shuffle
        self.units=np.unique(self.unit_ids)
        self.unit_to_sorted={}
        for u in self.units:
            idx=np.where(self.unit_ids==u)[0]
            idx=idx[np.argsort(self.cycles[idx])]
            self.unit_to_sorted[u]=idx
        self.engines_per_batch=max(1,batch_size//block_len)

    def __iter__(self):
        units=self.units.copy()
        if self.shuffle: np.random.shuffle(units)
        batch=[]
        for u in units:
            idx=self.unit_to_sorted[u]
            L=len(idx)
            if L<=self.block_len:
                take=np.random.choice(idx,size=self.block_len,replace=True)
            else:
                s=np.random.randint(0,L-self.block_len)
                take=idx[s:s+self.block_len]
            batch.extend(take.tolist())
            if len(batch)>=self.engines_per_batch*self.block_len:
                yield batch[:self.batch_size]
                batch=[]
        if batch: yield batch

    def __len__(self):
        return math.ceil(len(self.units)/self.engines_per_batch)

# ---------------------------
# Physics losses
# ---------------------------
def physics_losses(pred_mean, unit_ids, cycles, margin=0.0):
    unit_ids=np.array(unit_ids); cycles=np.array(cycles)
    mono_terms=[]; smooth_terms=[]
    for u in np.unique(unit_ids):
        idx=np.where(unit_ids==u)[0]
        if len(idx)<2: continue
        ord_idx=idx[np.argsort(cycles[idx])]
        p=pred_mean[ord_idx]
        mono=F.relu(p[1:]-p[:-1]+margin)
        mono_terms.append(mono.mean())
        if len(ord_idx)>=3:
            second=p[2:]-2*p[1:-1]+p[:-2]
            smooth_terms.append(torch.abs(second).mean())
    mono_loss=torch.stack(mono_terms).mean() if mono_terms else pred_mean.new_tensor(0.)
    smooth_loss=torch.stack(smooth_terms).mean() if smooth_terms else pred_mean.new_tensor(0.)
    return mono_loss, smooth_loss

# ---------------------------
# Gating + load balance
# ---------------------------
class GatingNet(nn.Module):
    def __init__(self, in_dim, n_experts, gate_dropout):
        super().__init__()
        h=max(64,in_dim//2)
        self.net=nn.Sequential(
            nn.Linear(in_dim,h),
            nn.ReLU(),
            nn.Dropout(gate_dropout),
            nn.Linear(h,n_experts),
        )
    def forward(self,x): return self.net(x)

def sparse_topk_softmax(logits, k, temperature):
    logits = logits / max(1e-6, temperature)
    topk_vals, topk_idx = torch.topk(logits, k=k, dim=-1)
    masked = torch.full_like(logits, float("-inf"))
    masked.scatter_(dim=-1, index=topk_idx, src=topk_vals)
    w = F.softmax(masked, dim=-1)
    return w, topk_idx

def switch_load_balance_loss(w, top1_idx):
    B,E=w.shape
    importance=w.sum(dim=0)/(B+1e-8)
    load=torch.bincount(top1_idx,minlength=E).float().to(w.device)/(B+1e-8)
    return E*torch.sum(importance*load)

def dead_expert_penalty(w, floor):
    avg = w.mean(dim=0)
    return F.relu(floor - avg).mean()

# ---------------------------
# UQ + calibration
# ---------------------------
def gaussian_nll(y, mean, log_var):
    return 0.5*(torch.exp(-log_var)*(y-mean)**2 + log_var)

def prediction_interval(mean, var, alpha):
    z = norm.ppf(1-alpha/2)
    std=np.sqrt(np.maximum(var,1e-8))
    return mean - z*std, mean + z*std

def coverage(y, lo, hi):
    return float(np.mean((y>=lo)&(y<=hi)))

# ✅ FIXED predict_mc: safe unpacking for 7-value forward()
@torch.no_grad()
def predict_mc(model, X, n_mc, batch_size=256):
    model.train()
    dl=DataLoader(torch.tensor(X,dtype=torch.float32),batch_size=batch_size,shuffle=False)
    means_all=[]; vars_all=[]
    for xb in dl:
        xb=xb.to(DEVICE)
        mc_m=[]; mc_v=[]
        for _ in range(n_mc):
            out = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=True)
            mean, log_var = out[0], out[1]
            mc_m.append(mean)
            mc_v.append(torch.exp(log_var))
        mc_m=torch.stack(mc_m,dim=0)
        mc_v=torch.stack(mc_v,dim=0)
        mean_pred=mc_m.mean(dim=0)
        epistemic=mc_m.var(dim=0,unbiased=False)
        aleatoric=mc_v.mean(dim=0)
        total=aleatoric+epistemic
        means_all.append(mean_pred.cpu().numpy())
        vars_all.append(total.cpu().numpy())
    return np.concatenate(means_all), np.concatenate(vars_all)

def calibrate_scale_min_width(model, X_val, y_val_raw, alpha, n_mc, max_over=0.01):
    mean, var = predict_mc(model, X_val, n_mc=n_mc)
    if cfg.normalize_y:
        mean = mean * cfg.max_rul
        var  = var  * (cfg.max_rul**2)
    std = np.sqrt(np.maximum(var,1e-8))
    z = norm.ppf(1-alpha/2)
    target = 1-alpha
    scales = np.linspace(0.6, 3.0, 121)
    best=None
    for s in scales:
        lo = mean - z*s*std
        hi = mean + z*s*std
        cov = np.mean((y_val_raw>=lo)&(y_val_raw<=hi))
        if cov >= (target - max_over):
            best=s
            break
    return float(best if best is not None else scales[-1])

# ---------------------------
# Model
# ---------------------------
class SharedEncoder(nn.Module):
    def __init__(self, n_features, hidden, dropout):
        super().__init__()
        self.conv1=nn.Conv1d(n_features,96,3,padding=1)
        self.conv2=nn.Conv1d(96,96,3,padding=1)
        self.bn1=nn.BatchNorm1d(96)
        self.bn2=nn.BatchNorm1d(96)
        self.gru=nn.GRU(96,hidden,batch_first=True)
        self.drop=nn.Dropout(dropout)
    def forward(self,x):
        x=x.transpose(1,2)
        x=F.relu(self.bn1(self.conv1(x)))
        x=F.relu(self.bn2(self.conv2(x)))
        x=x.transpose(1,2)
        _,h=self.gru(x)
        return self.drop(h[-1])

class ExpertHead(nn.Module):
    def __init__(self, in_dim, hidden, base_dropout, expert_id):
        super().__init__()
        width = hidden if expert_id%2==0 else max(64,hidden//2)
        drop  = min(0.30, base_dropout + 0.03*expert_id)
        self.emb = nn.Parameter(torch.randn(16)*0.02)
        self.net = nn.Sequential(
            nn.Linear(in_dim+16, width), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(width, max(32,width//2)), nn.ReLU(), nn.Dropout(drop),
        )
        out=max(32,width//2)
        self.mu=nn.Linear(out,1)
        self.logv=nn.Linear(out,1)
    def forward(self,z):
        B=z.size(0)
        e=self.emb.unsqueeze(0).expand(B,-1)
        h=self.net(torch.cat([z,e],dim=-1))
        mu=self.mu(h).squeeze(-1)
        logv=self.logv(h).squeeze(-1).clamp(-9,4)
        return mu, logv

class InterpretableMoE(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.enc = SharedEncoder(n_features,cfg.enc_hidden,cfg.dropout)
        self.experts = nn.ModuleList([ExpertHead(cfg.enc_hidden,cfg.head_hidden,cfg.dropout,i) for i in range(cfg.n_experts)])
        self.gate = GatingNet(cfg.enc_hidden+9, cfg.n_experts, cfg.gate_dropout)

    def forward(self, x, *, use_topk, noise_std, temperature, train):
        z = self.enc(x)
        op = x[:,:, :3]
        gate_in = torch.cat([z, op[:,-1,:], op.mean(dim=1), op.std(dim=1)], dim=-1)

        logits = self.gate(gate_in)
        if train and noise_std>0:
            logits = logits + torch.randn_like(logits)*noise_std

        if use_topk:
            w, topk_idx = sparse_topk_softmax(logits, cfg.top_k, temperature)
            top1 = topk_idx[:,0]
        else:
            w = F.softmax(logits/max(1e-6,temperature), dim=-1)
            top1 = torch.argmax(w, dim=-1)

        mus=[]; vars_=[]
        for ex in self.experts:
            mu, logv = ex(z)
            mus.append(mu); vars_.append(torch.exp(logv))
        mus=torch.stack(mus,dim=-1)
        vars_=torch.stack(vars_,dim=-1)

        mean=torch.sum(w*mus,dim=-1)
        second=torch.sum(w*(vars_+mus**2),dim=-1)
        var=(second-mean**2).clamp_min(1e-6)
        log_var=torch.log(var)

        lb = switch_load_balance_loss(w, top1)
        ent = -(w*torch.log(w+1e-8)).sum(dim=-1).mean()
        dead = dead_expert_penalty(w, cfg.dead_floor)

        div=0.0
        E=mus.shape[-1]
        for i in range(E):
            for j in range(i+1,E):
                a=mus[:,i]-mus[:,i].mean()
                b=mus[:,j]-mus[:,j].mean()
                div += torch.abs(F.cosine_similarity(a.unsqueeze(-1), b.unsqueeze(-1), dim=-1)).mean()
        div = div / max(1,(E*(E-1)//2))

        return mean, log_var, w, lb, ent, div, dead

# ---------------------------
# Prepare domain
# ---------------------------
def prepare_domain(fd, scaler=None, kept_cols=None):
    tr, te, rul = load_cmapss_split(DATA_DIR, fd)
    tr=add_rul_train(tr,cfg.max_rul)
    te=add_rul_test(te,rul,cfg.max_rul)

    all_cols = base_feature_columns()
    if kept_cols is None:
        kept_cols = all_cols
        if cfg.drop_low_var:
            v = tr[all_cols].var(axis=0).values
            kept_cols = [c for c,vv in zip(all_cols,v) if vv>cfg.low_var_thresh]
            for c in ["op1","op2","op3"]:
                if c not in kept_cols:
                    kept_cols = ["op1","op2","op3"] + [x for x in kept_cols if x not in ["op1","op2","op3"]]

    X_tr_all, y_tr_all, u_tr_all, c_tr_all = make_windows(tr, cfg.window, cfg.stride, kept_cols)
    X_te_all, y_te_all, u_te_all, c_te_all = make_windows(te, cfg.window, cfg.stride, kept_cols)
    X_te_last, y_te_last, u_te_last, c_te_last = last_window_per_unit(X_te_all, y_te_all, u_te_all, c_te_all)

    if scaler is None:
        scaler = StandardScaler()
        scaler.fit(X_tr_all.reshape(-1, X_tr_all.shape[-1]))

    X_tr_all = scaler.transform(X_tr_all.reshape(-1,X_tr_all.shape[-1])).reshape(X_tr_all.shape)
    X_te_last = scaler.transform(X_te_last.reshape(-1,X_te_last.shape[-1])).reshape(X_te_last.shape)

    tr_idx, va_idx = split_by_units(u_tr_all, cfg.val_ratio_units, cfg.seed)
    X_tr, y_tr, u_tr, c_tr = X_tr_all[tr_idx], y_tr_all[tr_idx], u_tr_all[tr_idx], c_tr_all[tr_idx]
    X_va, y_va, u_va, c_va = X_tr_all[va_idx], y_tr_all[va_idx], u_tr_all[va_idx], c_tr_all[va_idx]

    y_te_raw = y_te_last.copy()

    if cfg.normalize_y:
        y_tr = y_tr/cfg.max_rul
        y_va = y_va/cfg.max_rul
        y_te = y_te_last/cfg.max_rul
    else:
        y_te = y_te_last

    return {
        "scaler": scaler, "kept_cols": kept_cols,
        "train": (X_tr,y_tr,u_tr,c_tr),
        "val": (X_va,y_va,u_va,c_va),
        "test_last": (X_te_last,y_te,u_te_last,c_te_last),
        "test_last_y_raw": y_te_raw
    }

# ---------------------------
# Interpretability
# ---------------------------
@torch.no_grad()
def gate_stats(model, X, n_show=1024):
    model.eval()
    xb=torch.tensor(X[:n_show],dtype=torch.float32).to(DEVICE)
    out = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=False)
    w = out[2].cpu().numpy()
    print("Gate usage (avg weights):", np.round(w.mean(axis=0),4))

# ---------------------------
# Training
# ---------------------------
def train_on_domain(domain):
    X_tr,y_tr,u_tr,c_tr = domain["train"]
    X_va,y_va,u_va,c_va = domain["val"]

    model = InterpretableMoE(n_features=X_tr.shape[-1]).to(DEVICE)

    ds_tr = RULWindowDataset(X_tr,y_tr,u_tr,c_tr,domain_ids=np.zeros(len(X_tr),dtype=int))
    ds_va = RULWindowDataset(X_va,y_va,u_va,c_va,domain_ids=np.zeros(len(X_va),dtype=int))

    sampler = ContiguousEngineBatchSampler(ds_tr.unit_ids, ds_tr.cycles, cfg.batch_size, cfg.block_len, shuffle=True)
    dl_tr = DataLoader(ds_tr, batch_sampler=sampler)
    dl_va = DataLoader(ds_va, batch_size=cfg.batch_size, shuffle=False)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)

    best_val=float("inf"); best_state=None; bad=0

    for ep in range(1,cfg.epochs+1):
        s = ramp(ep, cfg.warmup_epochs, cfg.ramp_epochs)
        use_topk = (ep > cfg.warmup_epochs)
        noise_std = lerp(cfg.gate_noise_max, cfg.gate_noise_min, s)
        temperature = lerp(cfg.temp_max, cfg.temp_min, s)

        mono_w   = cfg.lambda_mono*s
        smooth_w = cfg.lambda_smooth*s
        lb_w     = cfg.lambda_lb*s
        ent_w    = cfg.lambda_ent*s
        div_w    = cfg.lambda_div*s
        dead_w   = cfg.lambda_dead*s

        model.train()
        tr_loss=0.0
        for xb,yb,ub,cb,db in dl_tr:
            xb=xb.to(DEVICE); yb=yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)

            mean, logv, w, lb, ent, div, dead = model(
                xb, use_topk=use_topk, noise_std=noise_std, temperature=temperature, train=True
            )

            nll = gaussian_nll(yb, mean, logv).mean()
            mse = F.mse_loss(mean, yb)
            hub = F.huber_loss(mean, yb, delta=cfg.huber_delta)
            mono, smooth = physics_losses(mean, ub, cb)

            if ep <= cfg.warmup_epochs:
                loss = 1.0*mse + 0.10*nll + cfg.huber_weight*hub
            else:
                loss = nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub

            loss = loss + mono_w*mono + smooth_w*smooth + lb_w*lb - ent_w*ent + div_w*div + dead_w*dead

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()
            tr_loss += float(loss.detach().cpu())

        sched.step()

        # val loss + quick metrics
        model.eval()
        va_loss=0.0
        with torch.no_grad():
            for xb,yb,ub,cb,db in dl_va:
                xb=xb.to(DEVICE); yb=yb.to(DEVICE)
                mean, logv, *_ = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=False)
                nll = gaussian_nll(yb, mean, logv).mean()
                mse = F.mse_loss(mean, yb)
                hub = F.huber_loss(mean, yb, delta=cfg.huber_delta)
                va_loss += float((nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub).cpu())
        va_loss /= max(1,len(dl_va))

        yhat_val, _ = predict_mc(model, X_va, n_mc=1)
        if cfg.normalize_y:
            yhat_val = yhat_val*cfg.max_rul
            y_val_raw = y_va*cfg.max_rul
        else:
            y_val_raw = y_va
        r2 = r2_score(y_val_raw, yhat_val)
        rmse = math.sqrt(mean_squared_error(y_val_raw, yhat_val))
        print(f"[Epoch {ep:02d}] s={s:.2f} topk={int(use_topk)} val={va_loss:.4f} | R2={r2:.4f} RMSE={rmse:.3f}")

        if va_loss + 1e-6 < best_val:
            best_val=va_loss
            best_state={k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            bad=0
        else:
            bad+=1
            if bad>=cfg.early_stop_patience:
                print(f"Early stopping (best val={best_val:.4f})")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model

# ---------------------------
# RUN
# ---------------------------
for fd in set(PRETRAIN_DOMAINS + TARGETS):
    for f in [f"train_{fd}.txt", f"test_{fd}.txt", f"RUL_{fd}.txt"]:
        ensure_exists(os.path.join(DATA_DIR, f))

domains={}
scaler=None
kept=None
for fd in PRETRAIN_DOMAINS:
    domains[fd]=prepare_domain(fd, scaler=scaler, kept_cols=kept)
    scaler=domains[fd]["scaler"]
    kept=domains[fd]["kept_cols"]
for fd in TARGETS:
    if fd not in domains:
        domains[fd]=prepare_domain(fd, scaler=scaler, kept_cols=kept)

models=[]
for mi in range(ENSEMBLE_SIZE):
    print("\n==============================")
    print(f" Training model {mi+1}/{ENSEMBLE_SIZE}")
    print("==============================")
    set_seed(cfg.seed + mi)
    models.append(train_on_domain(domains[PRETRAIN_DOMAINS[0]]))

for fd in TARGETS:
    print("\n==============================")
    print(f" Target {fd} | LAST window per unit")
    print("==============================")

    X_te, y_te, *_ = domains[fd]["test_last"]
    y_te_raw = domains[fd]["test_last_y_raw"]
    X_va, y_va, *_ = domains[fd]["val"]
    y_va_raw = (y_va*cfg.max_rul) if cfg.normalize_y else y_va

    mean, var = predict_mc(models[0], X_te, n_mc=cfg.mc_samples)
    if cfg.normalize_y:
        mean_raw = mean*cfg.max_rul
        var_raw  = var*(cfg.max_rul**2)
    else:
        mean_raw = mean; var_raw = var

    scale = calibrate_scale_min_width(models[0], X_va, y_va_raw, cfg.pi_alpha, cfg.mc_samples, max_over=0.01)
    var_raw = (scale**2)*var_raw

    r2 = r2_score(y_te_raw, mean_raw)
    rmse = math.sqrt(mean_squared_error(y_te_raw, mean_raw))
    mae = mean_absolute_error(y_te_raw, mean_raw)

    lo, hi = prediction_interval(mean_raw, var_raw, cfg.pi_alpha)
    cov = coverage(y_te_raw, lo, hi)

    print(f"Point: R2={r2:.4f} RMSE={rmse:.3f} MAE={mae:.3f}")
    print(f"UQ: {int((1-cfg.pi_alpha)*100)}% PI coverage={cov:.3f} | mean width={(hi-lo).mean():.3f} | cal_scale={scale:.3f}")

    gate_stats(models[0], X_te)

print("\n✅ Done.")



 Training model 1/1
[Epoch 01] s=0.00 topk=0 val=-0.2407 | R2=-0.0040 RMSE=41.841
[Epoch 02] s=0.00 topk=0 val=-0.5685 | R2=0.0405 RMSE=40.904
[Epoch 03] s=0.00 topk=0 val=-0.8230 | R2=0.5002 RMSE=29.521
[Epoch 04] s=0.00 topk=0 val=-0.8968 | R2=0.5992 RMSE=26.435
[Epoch 05] s=0.00 topk=0 val=-0.8805 | R2=0.5406 RMSE=28.303
[Epoch 06] s=0.00 topk=0 val=-0.9473 | R2=0.5696 RMSE=27.395
[Epoch 07] s=0.00 topk=0 val=-1.1210 | R2=0.6504 RMSE=24.692
[Epoch 08] s=0.00 topk=0 val=-1.1902 | R2=0.7201 RMSE=22.093
[Epoch 09] s=0.00 topk=0 val=-0.9997 | R2=0.4966 RMSE=29.627
[Epoch 10] s=0.00 topk=0 val=-1.1631 | R2=0.6726 RMSE=23.893
[Epoch 11] s=0.00 topk=0 val=-0.8876 | R2=0.4254 RMSE=31.653
[Epoch 12] s=0.00 topk=0 val=-0.9271 | R2=0.4042 RMSE=32.233
[Epoch 13] s=0.05 topk=1 val=-0.8902 | R2=0.6009 RMSE=26.379
[Epoch 14] s=0.10 topk=1 val=-0.5136 | R2=0.7117 RMSE=22.419
[Epoch 15] s=0.15 topk=1 val=-1.1794 | R2=0.6934 RMSE=23.123
[Epoch 16] s=0.20 topk=1 val=-1.1896 | R2=0.7095 RMSE=22.508
[E

FD003

In [96]:
# ============================================================
# FINAL Interpretable MOE for C-MAPSS RUL (ONE-CELL Colab)
# Fixes:
#  - Noisy Top-K gating + load balancing + entropy bonus (Sub-Obj 1)
#  - Physics constraints: monotonic + smooth (Sub-Obj 2)
#  - UQ: NLL + MC Dropout + PI calibration (Sub-Obj 3)
#  - Transfer-ready structure (Sub-Obj 4 hooks; single-domain run here)
# ============================================================

# ---------------------------
# USER SETTINGS
# ---------------------------
DATA_DIR = "/content"
PRETRAIN_DOMAINS = ["FD003"]
TARGETS = ["FD003"]
DO_FINETUNE = False
ENSEMBLE_SIZE = 1

# ---------------------------
# Install deps
# ---------------------------
import sys, subprocess, importlib
def _ensure(pkg, import_name=None):
    name = import_name or pkg
    try:
        importlib.import_module(name)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

_ensure("scikit-learn", "sklearn")
_ensure("scipy", "scipy")

# ---------------------------
# Imports
# ---------------------------
import os, math, random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from scipy.stats import norm

# ---------------------------
# Config
# ---------------------------
@dataclass
class CFG:
    window: int = 30
    stride: int = 1
    max_rul: int = 125
    normalize_y: bool = True
    val_ratio_units: float = 0.2
    seed: int = 42

    drop_low_var: bool = True
    low_var_thresh: float = 1e-6

    n_experts: int = 4
    top_k: int = 2
    enc_hidden: int = 192
    head_hidden: int = 192
    dropout: float = 0.08
    gate_dropout: float = 0.05

    gate_noise_max: float = 1.0
    gate_noise_min: float = 0.15
    temp_max: float = 2.0
    temp_min: float = 0.7

    epochs: int = 85
    batch_size: int = 128
    block_len: int = 12
    lr: float = 2e-3
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    early_stop_patience: int = 14

    aux_mse_weight: float = 0.55
    huber_weight: float = 0.15
    huber_delta: float = 0.08  # in normalized y scale

    lambda_mono: float = 0.10
    lambda_smooth: float = 0.03
    lambda_lb: float = 0.25
    lambda_ent: float = 0.02
    lambda_div: float = 0.01
    lambda_dead: float = 0.06
    dead_floor: float = 0.03

    warmup_epochs: int = 12
    ramp_epochs: int = 20

    mc_samples: int = 30
    pi_alpha: float = 0.10  # 90% PI

cfg = CFG()

# ---------------------------
# Utils
# ---------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(cfg.seed)

def ensure_exists(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing file: {path}")

def base_feature_columns() -> List[str]:
    return ["op1", "op2", "op3"] + [f"s{i}" for i in range(1, 22)]

def ramp(ep, warmup, ramp_len):
    if ep <= warmup: return 0.0
    return float(min(1.0, (ep - warmup) / max(1, ramp_len)))

def lerp(a, b, t):
    return a + (b - a) * t

# ---------------------------
# C-MAPSS load
# ---------------------------
def load_cmapss_split(data_dir: str, fd: str):
    train_file = os.path.join(data_dir, f"train_{fd}.txt")
    test_file  = os.path.join(data_dir, f"test_{fd}.txt")
    rul_file   = os.path.join(data_dir, f"RUL_{fd}.txt")
    ensure_exists(train_file); ensure_exists(test_file); ensure_exists(rul_file)

    train_df = pd.read_csv(train_file, sep=r"\s+", header=None)
    test_df  = pd.read_csv(test_file,  sep=r"\s+", header=None)
    rul_df   = pd.read_csv(rul_file,   sep=r"\s+", header=None)

    cols = ["unit","cycle","op1","op2","op3"] + [f"s{i}" for i in range(1,22)]
    train_df.columns = cols
    test_df.columns  = cols
    rul_df.columns   = ["RUL_last"]
    return train_df, test_df, rul_df

def add_rul_train(df: pd.DataFrame, max_rul: int):
    df = df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    df["RUL"] = df.apply(lambda r: max_cycle.loc[r["unit"]] - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

def add_rul_test(test_df: pd.DataFrame, rul_df: pd.DataFrame, max_rul: int):
    df = test_df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    rul_last = rul_df["RUL_last"].values
    mapping = {u: rul_last[u-1] for u in sorted(df["unit"].unique())}
    df["RUL"] = df.apply(lambda r: (max_cycle.loc[r["unit"]] + mapping[r["unit"]]) - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

def make_windows(df, window, stride, feature_cols, target_col="RUL"):
    xs, ys, units, cycles = [], [], [], []
    for unit, g in df.groupby("unit"):
        g = g.sort_values("cycle")
        feats = g[feature_cols].values.astype(np.float32)
        targ  = g[target_col].values.astype(np.float32)
        cyc   = g["cycle"].values.astype(np.int32)

        for end in range(window-1, len(g), stride):
            start = end-window+1
            xs.append(feats[start:end+1])
            ys.append(targ[end])
            units.append(unit)
            cycles.append(cyc[end])
    return np.stack(xs), np.array(ys), np.array(units), np.array(cycles)

def last_window_per_unit(X, y, unit_ids, cycles):
    idx=[]
    for u in np.unique(unit_ids):
        m = unit_ids==u
        i = np.argmax(cycles[m])
        idx.append(np.where(m)[0][i])
    idx=np.array(idx)
    return X[idx], y[idx], unit_ids[idx], cycles[idx]

def split_by_units(unit_ids, val_ratio, seed):
    rng=np.random.default_rng(seed)
    units=np.unique(unit_ids)
    rng.shuffle(units)
    n_val=max(1,int(len(units)*val_ratio))
    val=set(units[:n_val].tolist())
    tr=np.array([i for i,u in enumerate(unit_ids) if u not in val])
    va=np.array([i for i,u in enumerate(unit_ids) if u in val])
    return tr, va

# ---------------------------
# Dataset + contiguous sampler
# ---------------------------
class RULWindowDataset(Dataset):
    def __init__(self, X, y, unit_ids, cycles, domain_ids):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.unit_ids = np.array(unit_ids)
        self.cycles = np.array(cycles)
        self.domain_ids = torch.tensor(domain_ids, dtype=torch.long)

    def __len__(self): return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], int(self.unit_ids[idx]), int(self.cycles[idx]), self.domain_ids[idx]

class ContiguousEngineBatchSampler(Sampler):
    def __init__(self, unit_ids, cycles, batch_size, block_len, shuffle=True):
        self.unit_ids=np.array(unit_ids)
        self.cycles=np.array(cycles)
        self.batch_size=batch_size
        self.block_len=block_len
        self.shuffle=shuffle
        self.units=np.unique(self.unit_ids)
        self.unit_to_sorted={}
        for u in self.units:
            idx=np.where(self.unit_ids==u)[0]
            idx=idx[np.argsort(self.cycles[idx])]
            self.unit_to_sorted[u]=idx
        self.engines_per_batch=max(1,batch_size//block_len)

    def __iter__(self):
        units=self.units.copy()
        if self.shuffle: np.random.shuffle(units)
        batch=[]
        for u in units:
            idx=self.unit_to_sorted[u]
            L=len(idx)
            if L<=self.block_len:
                take=np.random.choice(idx,size=self.block_len,replace=True)
            else:
                s=np.random.randint(0,L-self.block_len)
                take=idx[s:s+self.block_len]
            batch.extend(take.tolist())
            if len(batch)>=self.engines_per_batch*self.block_len:
                yield batch[:self.batch_size]
                batch=[]
        if batch: yield batch

    def __len__(self):
        return math.ceil(len(self.units)/self.engines_per_batch)

# ---------------------------
# Physics losses
# ---------------------------
def physics_losses(pred_mean, unit_ids, cycles, margin=0.0):
    unit_ids=np.array(unit_ids); cycles=np.array(cycles)
    mono_terms=[]; smooth_terms=[]
    for u in np.unique(unit_ids):
        idx=np.where(unit_ids==u)[0]
        if len(idx)<2: continue
        ord_idx=idx[np.argsort(cycles[idx])]
        p=pred_mean[ord_idx]
        mono=F.relu(p[1:]-p[:-1]+margin)
        mono_terms.append(mono.mean())
        if len(ord_idx)>=3:
            second=p[2:]-2*p[1:-1]+p[:-2]
            smooth_terms.append(torch.abs(second).mean())
    mono_loss=torch.stack(mono_terms).mean() if mono_terms else pred_mean.new_tensor(0.)
    smooth_loss=torch.stack(smooth_terms).mean() if smooth_terms else pred_mean.new_tensor(0.)
    return mono_loss, smooth_loss

# ---------------------------
# Gating + load balance
# ---------------------------
class GatingNet(nn.Module):
    def __init__(self, in_dim, n_experts, gate_dropout):
        super().__init__()
        h=max(64,in_dim//2)
        self.net=nn.Sequential(
            nn.Linear(in_dim,h),
            nn.ReLU(),
            nn.Dropout(gate_dropout),
            nn.Linear(h,n_experts),
        )
    def forward(self,x): return self.net(x)

def sparse_topk_softmax(logits, k, temperature):
    logits = logits / max(1e-6, temperature)
    topk_vals, topk_idx = torch.topk(logits, k=k, dim=-1)
    masked = torch.full_like(logits, float("-inf"))
    masked.scatter_(dim=-1, index=topk_idx, src=topk_vals)
    w = F.softmax(masked, dim=-1)
    return w, topk_idx

def switch_load_balance_loss(w, top1_idx):
    B,E=w.shape
    importance=w.sum(dim=0)/(B+1e-8)
    load=torch.bincount(top1_idx,minlength=E).float().to(w.device)/(B+1e-8)
    return E*torch.sum(importance*load)

def dead_expert_penalty(w, floor):
    avg = w.mean(dim=0)
    return F.relu(floor - avg).mean()

# ---------------------------
# UQ + calibration
# ---------------------------
def gaussian_nll(y, mean, log_var):
    return 0.5*(torch.exp(-log_var)*(y-mean)**2 + log_var)

def prediction_interval(mean, var, alpha):
    z = norm.ppf(1-alpha/2)
    std=np.sqrt(np.maximum(var,1e-8))
    return mean - z*std, mean + z*std

def coverage(y, lo, hi):
    return float(np.mean((y>=lo)&(y<=hi)))

# ✅ FIXED predict_mc: safe unpacking for 7-value forward()
@torch.no_grad()
def predict_mc(model, X, n_mc, batch_size=256):
    model.train()
    dl=DataLoader(torch.tensor(X,dtype=torch.float32),batch_size=batch_size,shuffle=False)
    means_all=[]; vars_all=[]
    for xb in dl:
        xb=xb.to(DEVICE)
        mc_m=[]; mc_v=[]
        for _ in range(n_mc):
            out = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=True)
            mean, log_var = out[0], out[1]
            mc_m.append(mean)
            mc_v.append(torch.exp(log_var))
        mc_m=torch.stack(mc_m,dim=0)
        mc_v=torch.stack(mc_v,dim=0)
        mean_pred=mc_m.mean(dim=0)
        epistemic=mc_m.var(dim=0,unbiased=False)
        aleatoric=mc_v.mean(dim=0)
        total=aleatoric+epistemic
        means_all.append(mean_pred.cpu().numpy())
        vars_all.append(total.cpu().numpy())
    return np.concatenate(means_all), np.concatenate(vars_all)

def calibrate_scale_min_width(model, X_val, y_val_raw, alpha, n_mc, max_over=0.01):
    mean, var = predict_mc(model, X_val, n_mc=n_mc)
    if cfg.normalize_y:
        mean = mean * cfg.max_rul
        var  = var  * (cfg.max_rul**2)
    std = np.sqrt(np.maximum(var,1e-8))
    z = norm.ppf(1-alpha/2)
    target = 1-alpha
    scales = np.linspace(0.6, 3.0, 121)
    best=None
    for s in scales:
        lo = mean - z*s*std
        hi = mean + z*s*std
        cov = np.mean((y_val_raw>=lo)&(y_val_raw<=hi))
        if cov >= (target - max_over):
            best=s
            break
    return float(best if best is not None else scales[-1])

# ---------------------------
# Model
# ---------------------------
class SharedEncoder(nn.Module):
    def __init__(self, n_features, hidden, dropout):
        super().__init__()
        self.conv1=nn.Conv1d(n_features,96,3,padding=1)
        self.conv2=nn.Conv1d(96,96,3,padding=1)
        self.bn1=nn.BatchNorm1d(96)
        self.bn2=nn.BatchNorm1d(96)
        self.gru=nn.GRU(96,hidden,batch_first=True)
        self.drop=nn.Dropout(dropout)
    def forward(self,x):
        x=x.transpose(1,2)
        x=F.relu(self.bn1(self.conv1(x)))
        x=F.relu(self.bn2(self.conv2(x)))
        x=x.transpose(1,2)
        _,h=self.gru(x)
        return self.drop(h[-1])

class ExpertHead(nn.Module):
    def __init__(self, in_dim, hidden, base_dropout, expert_id):
        super().__init__()
        width = hidden if expert_id%2==0 else max(64,hidden//2)
        drop  = min(0.30, base_dropout + 0.03*expert_id)
        self.emb = nn.Parameter(torch.randn(16)*0.02)
        self.net = nn.Sequential(
            nn.Linear(in_dim+16, width), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(width, max(32,width//2)), nn.ReLU(), nn.Dropout(drop),
        )
        out=max(32,width//2)
        self.mu=nn.Linear(out,1)
        self.logv=nn.Linear(out,1)
    def forward(self,z):
        B=z.size(0)
        e=self.emb.unsqueeze(0).expand(B,-1)
        h=self.net(torch.cat([z,e],dim=-1))
        mu=self.mu(h).squeeze(-1)
        logv=self.logv(h).squeeze(-1).clamp(-9,4)
        return mu, logv

class InterpretableMoE(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.enc = SharedEncoder(n_features,cfg.enc_hidden,cfg.dropout)
        self.experts = nn.ModuleList([ExpertHead(cfg.enc_hidden,cfg.head_hidden,cfg.dropout,i) for i in range(cfg.n_experts)])
        self.gate = GatingNet(cfg.enc_hidden+9, cfg.n_experts, cfg.gate_dropout)

    def forward(self, x, *, use_topk, noise_std, temperature, train):
        z = self.enc(x)
        op = x[:,:, :3]
        gate_in = torch.cat([z, op[:,-1,:], op.mean(dim=1), op.std(dim=1)], dim=-1)

        logits = self.gate(gate_in)
        if train and noise_std>0:
            logits = logits + torch.randn_like(logits)*noise_std

        if use_topk:
            w, topk_idx = sparse_topk_softmax(logits, cfg.top_k, temperature)
            top1 = topk_idx[:,0]
        else:
            w = F.softmax(logits/max(1e-6,temperature), dim=-1)
            top1 = torch.argmax(w, dim=-1)

        mus=[]; vars_=[]
        for ex in self.experts:
            mu, logv = ex(z)
            mus.append(mu); vars_.append(torch.exp(logv))
        mus=torch.stack(mus,dim=-1)
        vars_=torch.stack(vars_,dim=-1)

        mean=torch.sum(w*mus,dim=-1)
        second=torch.sum(w*(vars_+mus**2),dim=-1)
        var=(second-mean**2).clamp_min(1e-6)
        log_var=torch.log(var)

        lb = switch_load_balance_loss(w, top1)
        ent = -(w*torch.log(w+1e-8)).sum(dim=-1).mean()
        dead = dead_expert_penalty(w, cfg.dead_floor)

        div=0.0
        E=mus.shape[-1]
        for i in range(E):
            for j in range(i+1,E):
                a=mus[:,i]-mus[:,i].mean()
                b=mus[:,j]-mus[:,j].mean()
                div += torch.abs(F.cosine_similarity(a.unsqueeze(-1), b.unsqueeze(-1), dim=-1)).mean()
        div = div / max(1,(E*(E-1)//2))

        return mean, log_var, w, lb, ent, div, dead

# ---------------------------
# Prepare domain
# ---------------------------
def prepare_domain(fd, scaler=None, kept_cols=None):
    tr, te, rul = load_cmapss_split(DATA_DIR, fd)
    tr=add_rul_train(tr,cfg.max_rul)
    te=add_rul_test(te,rul,cfg.max_rul)

    all_cols = base_feature_columns()
    if kept_cols is None:
        kept_cols = all_cols
        if cfg.drop_low_var:
            v = tr[all_cols].var(axis=0).values
            kept_cols = [c for c,vv in zip(all_cols,v) if vv>cfg.low_var_thresh]
            for c in ["op1","op2","op3"]:
                if c not in kept_cols:
                    kept_cols = ["op1","op2","op3"] + [x for x in kept_cols if x not in ["op1","op2","op3"]]

    X_tr_all, y_tr_all, u_tr_all, c_tr_all = make_windows(tr, cfg.window, cfg.stride, kept_cols)
    X_te_all, y_te_all, u_te_all, c_te_all = make_windows(te, cfg.window, cfg.stride, kept_cols)
    X_te_last, y_te_last, u_te_last, c_te_last = last_window_per_unit(X_te_all, y_te_all, u_te_all, c_te_all)

    if scaler is None:
        scaler = StandardScaler()
        scaler.fit(X_tr_all.reshape(-1, X_tr_all.shape[-1]))

    X_tr_all = scaler.transform(X_tr_all.reshape(-1,X_tr_all.shape[-1])).reshape(X_tr_all.shape)
    X_te_last = scaler.transform(X_te_last.reshape(-1,X_te_last.shape[-1])).reshape(X_te_last.shape)

    tr_idx, va_idx = split_by_units(u_tr_all, cfg.val_ratio_units, cfg.seed)
    X_tr, y_tr, u_tr, c_tr = X_tr_all[tr_idx], y_tr_all[tr_idx], u_tr_all[tr_idx], c_tr_all[tr_idx]
    X_va, y_va, u_va, c_va = X_tr_all[va_idx], y_tr_all[va_idx], u_tr_all[va_idx], c_tr_all[va_idx]

    y_te_raw = y_te_last.copy()

    if cfg.normalize_y:
        y_tr = y_tr/cfg.max_rul
        y_va = y_va/cfg.max_rul
        y_te = y_te_last/cfg.max_rul
    else:
        y_te = y_te_last

    return {
        "scaler": scaler, "kept_cols": kept_cols,
        "train": (X_tr,y_tr,u_tr,c_tr),
        "val": (X_va,y_va,u_va,c_va),
        "test_last": (X_te_last,y_te,u_te_last,c_te_last),
        "test_last_y_raw": y_te_raw
    }

# ---------------------------
# Interpretability
# ---------------------------
@torch.no_grad()
def gate_stats(model, X, n_show=1024):
    model.eval()
    xb=torch.tensor(X[:n_show],dtype=torch.float32).to(DEVICE)
    out = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=False)
    w = out[2].cpu().numpy()
    print("Gate usage (avg weights):", np.round(w.mean(axis=0),4))

# ---------------------------
# Training
# ---------------------------
def train_on_domain(domain):
    X_tr,y_tr,u_tr,c_tr = domain["train"]
    X_va,y_va,u_va,c_va = domain["val"]

    model = InterpretableMoE(n_features=X_tr.shape[-1]).to(DEVICE)

    ds_tr = RULWindowDataset(X_tr,y_tr,u_tr,c_tr,domain_ids=np.zeros(len(X_tr),dtype=int))
    ds_va = RULWindowDataset(X_va,y_va,u_va,c_va,domain_ids=np.zeros(len(X_va),dtype=int))

    sampler = ContiguousEngineBatchSampler(ds_tr.unit_ids, ds_tr.cycles, cfg.batch_size, cfg.block_len, shuffle=True)
    dl_tr = DataLoader(ds_tr, batch_sampler=sampler)
    dl_va = DataLoader(ds_va, batch_size=cfg.batch_size, shuffle=False)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)

    best_val=float("inf"); best_state=None; bad=0

    for ep in range(1,cfg.epochs+1):
        s = ramp(ep, cfg.warmup_epochs, cfg.ramp_epochs)
        use_topk = (ep > cfg.warmup_epochs)
        noise_std = lerp(cfg.gate_noise_max, cfg.gate_noise_min, s)
        temperature = lerp(cfg.temp_max, cfg.temp_min, s)

        mono_w   = cfg.lambda_mono*s
        smooth_w = cfg.lambda_smooth*s
        lb_w     = cfg.lambda_lb*s
        ent_w    = cfg.lambda_ent*s
        div_w    = cfg.lambda_div*s
        dead_w   = cfg.lambda_dead*s

        model.train()
        tr_loss=0.0
        for xb,yb,ub,cb,db in dl_tr:
            xb=xb.to(DEVICE); yb=yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)

            mean, logv, w, lb, ent, div, dead = model(
                xb, use_topk=use_topk, noise_std=noise_std, temperature=temperature, train=True
            )

            nll = gaussian_nll(yb, mean, logv).mean()
            mse = F.mse_loss(mean, yb)
            hub = F.huber_loss(mean, yb, delta=cfg.huber_delta)
            mono, smooth = physics_losses(mean, ub, cb)

            if ep <= cfg.warmup_epochs:
                loss = 1.0*mse + 0.10*nll + cfg.huber_weight*hub
            else:
                loss = nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub

            loss = loss + mono_w*mono + smooth_w*smooth + lb_w*lb - ent_w*ent + div_w*div + dead_w*dead

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()
            tr_loss += float(loss.detach().cpu())

        sched.step()

        # val loss + quick metrics
        model.eval()
        va_loss=0.0
        with torch.no_grad():
            for xb,yb,ub,cb,db in dl_va:
                xb=xb.to(DEVICE); yb=yb.to(DEVICE)
                mean, logv, *_ = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=False)
                nll = gaussian_nll(yb, mean, logv).mean()
                mse = F.mse_loss(mean, yb)
                hub = F.huber_loss(mean, yb, delta=cfg.huber_delta)
                va_loss += float((nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub).cpu())
        va_loss /= max(1,len(dl_va))

        yhat_val, _ = predict_mc(model, X_va, n_mc=1)
        if cfg.normalize_y:
            yhat_val = yhat_val*cfg.max_rul
            y_val_raw = y_va*cfg.max_rul
        else:
            y_val_raw = y_va
        r2 = r2_score(y_val_raw, yhat_val)
        rmse = math.sqrt(mean_squared_error(y_val_raw, yhat_val))
        print(f"[Epoch {ep:02d}] s={s:.2f} topk={int(use_topk)} val={va_loss:.4f} | R2={r2:.4f} RMSE={rmse:.3f}")

        if va_loss + 1e-6 < best_val:
            best_val=va_loss
            best_state={k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            bad=0
        else:
            bad+=1
            if bad>=cfg.early_stop_patience:
                print(f"Early stopping (best val={best_val:.4f})")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model

# ---------------------------
# RUN
# ---------------------------
for fd in set(PRETRAIN_DOMAINS + TARGETS):
    for f in [f"train_{fd}.txt", f"test_{fd}.txt", f"RUL_{fd}.txt"]:
        ensure_exists(os.path.join(DATA_DIR, f))

domains={}
scaler=None
kept=None
for fd in PRETRAIN_DOMAINS:
    domains[fd]=prepare_domain(fd, scaler=scaler, kept_cols=kept)
    scaler=domains[fd]["scaler"]
    kept=domains[fd]["kept_cols"]
for fd in TARGETS:
    if fd not in domains:
        domains[fd]=prepare_domain(fd, scaler=scaler, kept_cols=kept)

models=[]
for mi in range(ENSEMBLE_SIZE):
    print("\n==============================")
    print(f" Training model {mi+1}/{ENSEMBLE_SIZE}")
    print("==============================")
    set_seed(cfg.seed + mi)
    models.append(train_on_domain(domains[PRETRAIN_DOMAINS[0]]))

for fd in TARGETS:
    print("\n==============================")
    print(f" Target {fd} | LAST window per unit")
    print("==============================")

    X_te, y_te, *_ = domains[fd]["test_last"]
    y_te_raw = domains[fd]["test_last_y_raw"]
    X_va, y_va, *_ = domains[fd]["val"]
    y_va_raw = (y_va*cfg.max_rul) if cfg.normalize_y else y_va

    mean, var = predict_mc(models[0], X_te, n_mc=cfg.mc_samples)
    if cfg.normalize_y:
        mean_raw = mean*cfg.max_rul
        var_raw  = var*(cfg.max_rul**2)
    else:
        mean_raw = mean; var_raw = var

    scale = calibrate_scale_min_width(models[0], X_va, y_va_raw, cfg.pi_alpha, cfg.mc_samples, max_over=0.01)
    var_raw = (scale**2)*var_raw

    r2 = r2_score(y_te_raw, mean_raw)
    rmse = math.sqrt(mean_squared_error(y_te_raw, mean_raw))
    mae = mean_absolute_error(y_te_raw, mean_raw)

    lo, hi = prediction_interval(mean_raw, var_raw, cfg.pi_alpha)
    cov = coverage(y_te_raw, lo, hi)

    print(f"Point: R2={r2:.4f} RMSE={rmse:.3f} MAE={mae:.3f}")
    print(f"UQ: {int((1-cfg.pi_alpha)*100)}% PI coverage={cov:.3f} | mean width={(hi-lo).mean():.3f} | cal_scale={scale:.3f}")

    gate_stats(models[0], X_te)

print("\n✅ Done.")



 Training model 1/1
[Epoch 01] s=0.00 topk=0 val=0.0894 | R2=0.6722 RMSE=23.885
[Epoch 02] s=0.00 topk=0 val=-0.3108 | R2=0.5821 RMSE=26.968
[Epoch 03] s=0.00 topk=0 val=-0.6049 | R2=0.6585 RMSE=24.381
[Epoch 04] s=0.00 topk=0 val=-0.9182 | R2=0.6446 RMSE=24.870
[Epoch 05] s=0.00 topk=0 val=-1.3225 | R2=0.6947 RMSE=23.051
[Epoch 06] s=0.00 topk=0 val=-0.8848 | R2=0.6259 RMSE=25.517
[Epoch 07] s=0.00 topk=0 val=-1.2841 | R2=0.7014 RMSE=22.798
[Epoch 08] s=0.00 topk=0 val=-1.0305 | R2=0.7097 RMSE=22.478
[Epoch 09] s=0.00 topk=0 val=-1.4801 | R2=0.7613 RMSE=20.381
[Epoch 10] s=0.00 topk=0 val=-1.5606 | R2=0.7627 RMSE=20.325
[Epoch 11] s=0.00 topk=0 val=-1.3519 | R2=0.7336 RMSE=21.531
[Epoch 12] s=0.00 topk=0 val=-1.1461 | R2=0.6607 RMSE=24.303
[Epoch 13] s=0.05 topk=1 val=-1.5128 | R2=0.8100 RMSE=18.187
[Epoch 14] s=0.10 topk=1 val=-1.5152 | R2=0.7922 RMSE=19.018
[Epoch 15] s=0.15 topk=1 val=-1.0954 | R2=0.6798 RMSE=23.606
[Epoch 16] s=0.20 topk=1 val=-1.1003 | R2=0.6635 RMSE=24.200
[Epo

FD004

In [1]:
# ============================================================
# FINAL Interpretable MOE for C-MAPSS RUL (ONE-CELL Colab)
# Fixes:
#  - Noisy Top-K gating + load balancing + entropy bonus (Sub-Obj 1)
#  - Physics constraints: monotonic + smooth (Sub-Obj 2)
#  - UQ: NLL + MC Dropout + PI calibration (Sub-Obj 3)
#  - Transfer-ready structure (Sub-Obj 4 hooks; single-domain run here)
# ============================================================

# ---------------------------
# USER SETTINGS
# ---------------------------
DATA_DIR = "/content"
PRETRAIN_DOMAINS = ["FD004"]
TARGETS = ["FD004"]
DO_FINETUNE = False
ENSEMBLE_SIZE = 1

# ---------------------------
# Install deps
# ---------------------------
import sys, subprocess, importlib
def _ensure(pkg, import_name=None):
    name = import_name or pkg
    try:
        importlib.import_module(name)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

_ensure("scikit-learn", "sklearn")
_ensure("scipy", "scipy")

# ---------------------------
# Imports
# ---------------------------
import os, math, random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from scipy.stats import norm

# ---------------------------
# Config
# ---------------------------
@dataclass
class CFG:
    window: int = 30
    stride: int = 1
    max_rul: int = 125
    normalize_y: bool = True
    val_ratio_units: float = 0.2
    seed: int = 42

    drop_low_var: bool = True
    low_var_thresh: float = 1e-6

    n_experts: int = 4
    top_k: int = 2
    enc_hidden: int = 192
    head_hidden: int = 192
    dropout: float = 0.08
    gate_dropout: float = 0.05

    gate_noise_max: float = 1.0
    gate_noise_min: float = 0.15
    temp_max: float = 2.0
    temp_min: float = 0.7

    epochs: int = 85
    batch_size: int = 128
    block_len: int = 12
    lr: float = 2e-3
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    early_stop_patience: int = 14

    aux_mse_weight: float = 0.55
    huber_weight: float = 0.15
    huber_delta: float = 0.08  # in normalized y scale

    lambda_mono: float = 0.10
    lambda_smooth: float = 0.03
    lambda_lb: float = 0.25
    lambda_ent: float = 0.02
    lambda_div: float = 0.01
    lambda_dead: float = 0.06
    dead_floor: float = 0.03

    warmup_epochs: int = 12
    ramp_epochs: int = 20

    mc_samples: int = 30
    pi_alpha: float = 0.10  # 90% PI

cfg = CFG()

# ---------------------------
# Utils
# ---------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(cfg.seed)

def ensure_exists(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing file: {path}")

def base_feature_columns() -> List[str]:
    return ["op1", "op2", "op3"] + [f"s{i}" for i in range(1, 22)]

def ramp(ep, warmup, ramp_len):
    if ep <= warmup: return 0.0
    return float(min(1.0, (ep - warmup) / max(1, ramp_len)))

def lerp(a, b, t):
    return a + (b - a) * t

# ---------------------------
# C-MAPSS load
# ---------------------------
def load_cmapss_split(data_dir: str, fd: str):
    train_file = os.path.join(data_dir, f"train_{fd}.txt")
    test_file  = os.path.join(data_dir, f"test_{fd}.txt")
    rul_file   = os.path.join(data_dir, f"RUL_{fd}.txt")
    ensure_exists(train_file); ensure_exists(test_file); ensure_exists(rul_file)

    train_df = pd.read_csv(train_file, sep=r"\s+", header=None)
    test_df  = pd.read_csv(test_file,  sep=r"\s+", header=None)
    rul_df   = pd.read_csv(rul_file,   sep=r"\s+", header=None)

    cols = ["unit","cycle","op1","op2","op3"] + [f"s{i}" for i in range(1,22)]
    train_df.columns = cols
    test_df.columns  = cols
    rul_df.columns   = ["RUL_last"]
    return train_df, test_df, rul_df

def add_rul_train(df: pd.DataFrame, max_rul: int):
    df = df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    df["RUL"] = df.apply(lambda r: max_cycle.loc[r["unit"]] - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

def add_rul_test(test_df: pd.DataFrame, rul_df: pd.DataFrame, max_rul: int):
    df = test_df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    rul_last = rul_df["RUL_last"].values
    mapping = {u: rul_last[u-1] for u in sorted(df["unit"].unique())}
    df["RUL"] = df.apply(lambda r: (max_cycle.loc[r["unit"]] + mapping[r["unit"]]) - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

def make_windows(df, window, stride, feature_cols, target_col="RUL"):
    xs, ys, units, cycles = [], [], [], []
    for unit, g in df.groupby("unit"):
        g = g.sort_values("cycle")
        feats = g[feature_cols].values.astype(np.float32)
        targ  = g[target_col].values.astype(np.float32)
        cyc   = g["cycle"].values.astype(np.int32)

        for end in range(window-1, len(g), stride):
            start = end-window+1
            xs.append(feats[start:end+1])
            ys.append(targ[end])
            units.append(unit)
            cycles.append(cyc[end])
    return np.stack(xs), np.array(ys), np.array(units), np.array(cycles)

def last_window_per_unit(X, y, unit_ids, cycles):
    idx=[]
    for u in np.unique(unit_ids):
        m = unit_ids==u
        i = np.argmax(cycles[m])
        idx.append(np.where(m)[0][i])
    idx=np.array(idx)
    return X[idx], y[idx], unit_ids[idx], cycles[idx]

def split_by_units(unit_ids, val_ratio, seed):
    rng=np.random.default_rng(seed)
    units=np.unique(unit_ids)
    rng.shuffle(units)
    n_val=max(1,int(len(units)*val_ratio))
    val=set(units[:n_val].tolist())
    tr=np.array([i for i,u in enumerate(unit_ids) if u not in val])
    va=np.array([i for i,u in enumerate(unit_ids) if u in val])
    return tr, va

# ---------------------------
# Dataset + contiguous sampler
# ---------------------------
class RULWindowDataset(Dataset):
    def __init__(self, X, y, unit_ids, cycles, domain_ids):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.unit_ids = np.array(unit_ids)
        self.cycles = np.array(cycles)
        self.domain_ids = torch.tensor(domain_ids, dtype=torch.long)

    def __len__(self): return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], int(self.unit_ids[idx]), int(self.cycles[idx]), self.domain_ids[idx]

class ContiguousEngineBatchSampler(Sampler):
    def __init__(self, unit_ids, cycles, batch_size, block_len, shuffle=True):
        self.unit_ids=np.array(unit_ids)
        self.cycles=np.array(cycles)
        self.batch_size=batch_size
        self.block_len=block_len
        self.shuffle=shuffle
        self.units=np.unique(self.unit_ids)
        self.unit_to_sorted={}
        for u in self.units:
            idx=np.where(self.unit_ids==u)[0]
            idx=idx[np.argsort(self.cycles[idx])]
            self.unit_to_sorted[u]=idx
        self.engines_per_batch=max(1,batch_size//block_len)

    def __iter__(self):
        units=self.units.copy()
        if self.shuffle: np.random.shuffle(units)
        batch=[]
        for u in units:
            idx=self.unit_to_sorted[u]
            L=len(idx)
            if L<=self.block_len:
                take=np.random.choice(idx,size=self.block_len,replace=True)
            else:
                s=np.random.randint(0,L-self.block_len)
                take=idx[s:s+self.block_len]
            batch.extend(take.tolist())
            if len(batch)>=self.engines_per_batch*self.block_len:
                yield batch[:self.batch_size]
                batch=[]
        if batch: yield batch

    def __len__(self):
        return math.ceil(len(self.units)/self.engines_per_batch)

# ---------------------------
# Physics losses
# ---------------------------
def physics_losses(pred_mean, unit_ids, cycles, margin=0.0):
    unit_ids=np.array(unit_ids); cycles=np.array(cycles)
    mono_terms=[]; smooth_terms=[]
    for u in np.unique(unit_ids):
        idx=np.where(unit_ids==u)[0]
        if len(idx)<2: continue
        ord_idx=idx[np.argsort(cycles[idx])]
        p=pred_mean[ord_idx]
        mono=F.relu(p[1:]-p[:-1]+margin)
        mono_terms.append(mono.mean())
        if len(ord_idx)>=3:
            second=p[2:]-2*p[1:-1]+p[:-2]
            smooth_terms.append(torch.abs(second).mean())
    mono_loss=torch.stack(mono_terms).mean() if mono_terms else pred_mean.new_tensor(0.)
    smooth_loss=torch.stack(smooth_terms).mean() if smooth_terms else pred_mean.new_tensor(0.)
    return mono_loss, smooth_loss

# ---------------------------
# Gating + load balance
# ---------------------------
class GatingNet(nn.Module):
    def __init__(self, in_dim, n_experts, gate_dropout):
        super().__init__()
        h=max(64,in_dim//2)
        self.net=nn.Sequential(
            nn.Linear(in_dim,h),
            nn.ReLU(),
            nn.Dropout(gate_dropout),
            nn.Linear(h,n_experts),
        )
    def forward(self,x): return self.net(x)

def sparse_topk_softmax(logits, k, temperature):
    logits = logits / max(1e-6, temperature)
    topk_vals, topk_idx = torch.topk(logits, k=k, dim=-1)
    masked = torch.full_like(logits, float("-inf"))
    masked.scatter_(dim=-1, index=topk_idx, src=topk_vals)
    w = F.softmax(masked, dim=-1)
    return w, topk_idx

def switch_load_balance_loss(w, top1_idx):
    B,E=w.shape
    importance=w.sum(dim=0)/(B+1e-8)
    load=torch.bincount(top1_idx,minlength=E).float().to(w.device)/(B+1e-8)
    return E*torch.sum(importance*load)

def dead_expert_penalty(w, floor):
    avg = w.mean(dim=0)
    return F.relu(floor - avg).mean()

# ---------------------------
# UQ + calibration
# ---------------------------
def gaussian_nll(y, mean, log_var):
    return 0.5*(torch.exp(-log_var)*(y-mean)**2 + log_var)

def prediction_interval(mean, var, alpha):
    z = norm.ppf(1-alpha/2)
    std=np.sqrt(np.maximum(var,1e-8))
    return mean - z*std, mean + z*std

def coverage(y, lo, hi):
    return float(np.mean((y>=lo)&(y<=hi)))

# ✅ FIXED predict_mc: safe unpacking for 7-value forward()
@torch.no_grad()
def predict_mc(model, X, n_mc, batch_size=256):
    model.train()
    dl=DataLoader(torch.tensor(X,dtype=torch.float32),batch_size=batch_size,shuffle=False)
    means_all=[]; vars_all=[]
    for xb in dl:
        xb=xb.to(DEVICE)
        mc_m=[]; mc_v=[]
        for _ in range(n_mc):
            out = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=True)
            mean, log_var = out[0], out[1]
            mc_m.append(mean)
            mc_v.append(torch.exp(log_var))
        mc_m=torch.stack(mc_m,dim=0)
        mc_v=torch.stack(mc_v,dim=0)
        mean_pred=mc_m.mean(dim=0)
        epistemic=mc_m.var(dim=0,unbiased=False)
        aleatoric=mc_v.mean(dim=0)
        total=aleatoric+epistemic
        means_all.append(mean_pred.cpu().numpy())
        vars_all.append(total.cpu().numpy())
    return np.concatenate(means_all), np.concatenate(vars_all)

def calibrate_scale_min_width(model, X_val, y_val_raw, alpha, n_mc, max_over=0.01):
    mean, var = predict_mc(model, X_val, n_mc=n_mc)
    if cfg.normalize_y:
        mean = mean * cfg.max_rul
        var  = var  * (cfg.max_rul**2)
    std = np.sqrt(np.maximum(var,1e-8))
    z = norm.ppf(1-alpha/2)
    target = 1-alpha
    scales = np.linspace(0.6, 3.0, 121)
    best=None
    for s in scales:
        lo = mean - z*s*std
        hi = mean + z*s*std
        cov = np.mean((y_val_raw>=lo)&(y_val_raw<=hi))
        if cov >= (target - max_over):
            best=s
            break
    return float(best if best is not None else scales[-1])

# ---------------------------
# Model
# ---------------------------
class SharedEncoder(nn.Module):
    def __init__(self, n_features, hidden, dropout):
        super().__init__()
        self.conv1=nn.Conv1d(n_features,96,3,padding=1)
        self.conv2=nn.Conv1d(96,96,3,padding=1)
        self.bn1=nn.BatchNorm1d(96)
        self.bn2=nn.BatchNorm1d(96)
        self.gru=nn.GRU(96,hidden,batch_first=True)
        self.drop=nn.Dropout(dropout)
    def forward(self,x):
        x=x.transpose(1,2)
        x=F.relu(self.bn1(self.conv1(x)))
        x=F.relu(self.bn2(self.conv2(x)))
        x=x.transpose(1,2)
        _,h=self.gru(x)
        return self.drop(h[-1])

class ExpertHead(nn.Module):
    def __init__(self, in_dim, hidden, base_dropout, expert_id):
        super().__init__()
        width = hidden if expert_id%2==0 else max(64,hidden//2)
        drop  = min(0.30, base_dropout + 0.03*expert_id)
        self.emb = nn.Parameter(torch.randn(16)*0.02)
        self.net = nn.Sequential(
            nn.Linear(in_dim+16, width), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(width, max(32,width//2)), nn.ReLU(), nn.Dropout(drop),
        )
        out=max(32,width//2)
        self.mu=nn.Linear(out,1)
        self.logv=nn.Linear(out,1)
    def forward(self,z):
        B=z.size(0)
        e=self.emb.unsqueeze(0).expand(B,-1)
        h=self.net(torch.cat([z,e],dim=-1))
        mu=self.mu(h).squeeze(-1)
        logv=self.logv(h).squeeze(-1).clamp(-9,4)
        return mu, logv

class InterpretableMoE(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.enc = SharedEncoder(n_features,cfg.enc_hidden,cfg.dropout)
        self.experts = nn.ModuleList([ExpertHead(cfg.enc_hidden,cfg.head_hidden,cfg.dropout,i) for i in range(cfg.n_experts)])
        self.gate = GatingNet(cfg.enc_hidden+9, cfg.n_experts, cfg.gate_dropout)

    def forward(self, x, *, use_topk, noise_std, temperature, train):
        z = self.enc(x)
        op = x[:,:, :3]
        gate_in = torch.cat([z, op[:,-1,:], op.mean(dim=1), op.std(dim=1)], dim=-1)

        logits = self.gate(gate_in)
        if train and noise_std>0:
            logits = logits + torch.randn_like(logits)*noise_std

        if use_topk:
            w, topk_idx = sparse_topk_softmax(logits, cfg.top_k, temperature)
            top1 = topk_idx[:,0]
        else:
            w = F.softmax(logits/max(1e-6,temperature), dim=-1)
            top1 = torch.argmax(w, dim=-1)

        mus=[]; vars_=[]
        for ex in self.experts:
            mu, logv = ex(z)
            mus.append(mu); vars_.append(torch.exp(logv))
        mus=torch.stack(mus,dim=-1)
        vars_=torch.stack(vars_,dim=-1)

        mean=torch.sum(w*mus,dim=-1)
        second=torch.sum(w*(vars_+mus**2),dim=-1)
        var=(second-mean**2).clamp_min(1e-6)
        log_var=torch.log(var)

        lb = switch_load_balance_loss(w, top1)
        ent = -(w*torch.log(w+1e-8)).sum(dim=-1).mean()
        dead = dead_expert_penalty(w, cfg.dead_floor)

        div=0.0
        E=mus.shape[-1]
        for i in range(E):
            for j in range(i+1,E):
                a=mus[:,i]-mus[:,i].mean()
                b=mus[:,j]-mus[:,j].mean()
                div += torch.abs(F.cosine_similarity(a.unsqueeze(-1), b.unsqueeze(-1), dim=-1)).mean()
        div = div / max(1,(E*(E-1)//2))

        return mean, log_var, w, lb, ent, div, dead

# ---------------------------
# Prepare domain
# ---------------------------
def prepare_domain(fd, scaler=None, kept_cols=None):
    tr, te, rul = load_cmapss_split(DATA_DIR, fd)
    tr=add_rul_train(tr,cfg.max_rul)
    te=add_rul_test(te,rul,cfg.max_rul)

    all_cols = base_feature_columns()
    if kept_cols is None:
        kept_cols = all_cols
        if cfg.drop_low_var:
            v = tr[all_cols].var(axis=0).values
            kept_cols = [c for c,vv in zip(all_cols,v) if vv>cfg.low_var_thresh]
            for c in ["op1","op2","op3"]:
                if c not in kept_cols:
                    kept_cols = ["op1","op2","op3"] + [x for x in kept_cols if x not in ["op1","op2","op3"]]

    X_tr_all, y_tr_all, u_tr_all, c_tr_all = make_windows(tr, cfg.window, cfg.stride, kept_cols)
    X_te_all, y_te_all, u_te_all, c_te_all = make_windows(te, cfg.window, cfg.stride, kept_cols)
    X_te_last, y_te_last, u_te_last, c_te_last = last_window_per_unit(X_te_all, y_te_all, u_te_all, c_te_all)

    if scaler is None:
        scaler = StandardScaler()
        scaler.fit(X_tr_all.reshape(-1, X_tr_all.shape[-1]))

    X_tr_all = scaler.transform(X_tr_all.reshape(-1,X_tr_all.shape[-1])).reshape(X_tr_all.shape)
    X_te_last = scaler.transform(X_te_last.reshape(-1,X_te_last.shape[-1])).reshape(X_te_last.shape)

    tr_idx, va_idx = split_by_units(u_tr_all, cfg.val_ratio_units, cfg.seed)
    X_tr, y_tr, u_tr, c_tr = X_tr_all[tr_idx], y_tr_all[tr_idx], u_tr_all[tr_idx], c_tr_all[tr_idx]
    X_va, y_va, u_va, c_va = X_tr_all[va_idx], y_tr_all[va_idx], u_tr_all[va_idx], c_tr_all[va_idx]

    y_te_raw = y_te_last.copy()

    if cfg.normalize_y:
        y_tr = y_tr/cfg.max_rul
        y_va = y_va/cfg.max_rul
        y_te = y_te_last/cfg.max_rul
    else:
        y_te = y_te_last

    return {
        "scaler": scaler, "kept_cols": kept_cols,
        "train": (X_tr,y_tr,u_tr,c_tr),
        "val": (X_va,y_va,u_va,c_va),
        "test_last": (X_te_last,y_te,u_te_last,c_te_last),
        "test_last_y_raw": y_te_raw
    }

# ---------------------------
# Interpretability
# ---------------------------
@torch.no_grad()
def gate_stats(model, X, n_show=1024):
    model.eval()
    xb=torch.tensor(X[:n_show],dtype=torch.float32).to(DEVICE)
    out = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=False)
    w = out[2].cpu().numpy()
    print("Gate usage (avg weights):", np.round(w.mean(axis=0),4))

# ---------------------------
# Training
# ---------------------------
def train_on_domain(domain):
    X_tr,y_tr,u_tr,c_tr = domain["train"]
    X_va,y_va,u_va,c_va = domain["val"]

    model = InterpretableMoE(n_features=X_tr.shape[-1]).to(DEVICE)

    ds_tr = RULWindowDataset(X_tr,y_tr,u_tr,c_tr,domain_ids=np.zeros(len(X_tr),dtype=int))
    ds_va = RULWindowDataset(X_va,y_va,u_va,c_va,domain_ids=np.zeros(len(X_va),dtype=int))

    sampler = ContiguousEngineBatchSampler(ds_tr.unit_ids, ds_tr.cycles, cfg.batch_size, cfg.block_len, shuffle=True)
    dl_tr = DataLoader(ds_tr, batch_sampler=sampler)
    dl_va = DataLoader(ds_va, batch_size=cfg.batch_size, shuffle=False)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)

    best_val=float("inf"); best_state=None; bad=0

    for ep in range(1,cfg.epochs+1):
        s = ramp(ep, cfg.warmup_epochs, cfg.ramp_epochs)
        use_topk = (ep > cfg.warmup_epochs)
        noise_std = lerp(cfg.gate_noise_max, cfg.gate_noise_min, s)
        temperature = lerp(cfg.temp_max, cfg.temp_min, s)

        mono_w   = cfg.lambda_mono*s
        smooth_w = cfg.lambda_smooth*s
        lb_w     = cfg.lambda_lb*s
        ent_w    = cfg.lambda_ent*s
        div_w    = cfg.lambda_div*s
        dead_w   = cfg.lambda_dead*s

        model.train()
        tr_loss=0.0
        for xb,yb,ub,cb,db in dl_tr:
            xb=xb.to(DEVICE); yb=yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)

            mean, logv, w, lb, ent, div, dead = model(
                xb, use_topk=use_topk, noise_std=noise_std, temperature=temperature, train=True
            )

            nll = gaussian_nll(yb, mean, logv).mean()
            mse = F.mse_loss(mean, yb)
            hub = F.huber_loss(mean, yb, delta=cfg.huber_delta)
            mono, smooth = physics_losses(mean, ub, cb)

            if ep <= cfg.warmup_epochs:
                loss = 1.0*mse + 0.10*nll + cfg.huber_weight*hub
            else:
                loss = nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub

            loss = loss + mono_w*mono + smooth_w*smooth + lb_w*lb - ent_w*ent + div_w*div + dead_w*dead

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()
            tr_loss += float(loss.detach().cpu())

        sched.step()

        # val loss + quick metrics
        model.eval()
        va_loss=0.0
        with torch.no_grad():
            for xb,yb,ub,cb,db in dl_va:
                xb=xb.to(DEVICE); yb=yb.to(DEVICE)
                mean, logv, *_ = model(xb, use_topk=True, noise_std=0.0, temperature=1.0, train=False)
                nll = gaussian_nll(yb, mean, logv).mean()
                mse = F.mse_loss(mean, yb)
                hub = F.huber_loss(mean, yb, delta=cfg.huber_delta)
                va_loss += float((nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub).cpu())
        va_loss /= max(1,len(dl_va))

        yhat_val, _ = predict_mc(model, X_va, n_mc=1)
        if cfg.normalize_y:
            yhat_val = yhat_val*cfg.max_rul
            y_val_raw = y_va*cfg.max_rul
        else:
            y_val_raw = y_va
        r2 = r2_score(y_val_raw, yhat_val)
        rmse = math.sqrt(mean_squared_error(y_val_raw, yhat_val))
        print(f"[Epoch {ep:02d}] s={s:.2f} topk={int(use_topk)} val={va_loss:.4f} | R2={r2:.4f} RMSE={rmse:.3f}")

        if va_loss + 1e-6 < best_val:
            best_val=va_loss
            best_state={k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            bad=0
        else:
            bad+=1
            if bad>=cfg.early_stop_patience:
                print(f"Early stopping (best val={best_val:.4f})")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model

# ---------------------------
# RUN
# ---------------------------
for fd in set(PRETRAIN_DOMAINS + TARGETS):
    for f in [f"train_{fd}.txt", f"test_{fd}.txt", f"RUL_{fd}.txt"]:
        ensure_exists(os.path.join(DATA_DIR, f))

domains={}
scaler=None
kept=None
for fd in PRETRAIN_DOMAINS:
    domains[fd]=prepare_domain(fd, scaler=scaler, kept_cols=kept)
    scaler=domains[fd]["scaler"]
    kept=domains[fd]["kept_cols"]
for fd in TARGETS:
    if fd not in domains:
        domains[fd]=prepare_domain(fd, scaler=scaler, kept_cols=kept)

models=[]
for mi in range(ENSEMBLE_SIZE):
    print("\n==============================")
    print(f" Training model {mi+1}/{ENSEMBLE_SIZE}")
    print("==============================")
    set_seed(cfg.seed + mi)
    models.append(train_on_domain(domains[PRETRAIN_DOMAINS[0]]))

for fd in TARGETS:
    print("\n==============================")
    print(f" Target {fd} | LAST window per unit")
    print("==============================")

    X_te, y_te, *_ = domains[fd]["test_last"]
    y_te_raw = domains[fd]["test_last_y_raw"]
    X_va, y_va, *_ = domains[fd]["val"]
    y_va_raw = (y_va*cfg.max_rul) if cfg.normalize_y else y_va

    mean, var = predict_mc(models[0], X_te, n_mc=cfg.mc_samples)
    if cfg.normalize_y:
        mean_raw = mean*cfg.max_rul
        var_raw  = var*(cfg.max_rul**2)
    else:
        mean_raw = mean; var_raw = var

    scale = calibrate_scale_min_width(models[0], X_va, y_va_raw, cfg.pi_alpha, cfg.mc_samples, max_over=0.01)
    var_raw = (scale**2)*var_raw

    r2 = r2_score(y_te_raw, mean_raw)
    rmse = math.sqrt(mean_squared_error(y_te_raw, mean_raw))
    mae = mean_absolute_error(y_te_raw, mean_raw)

    lo, hi = prediction_interval(mean_raw, var_raw, cfg.pi_alpha)
    cov = coverage(y_te_raw, lo, hi)

    print(f"Point: R2={r2:.4f} RMSE={rmse:.3f} MAE={mae:.3f}")
    print(f"UQ: {int((1-cfg.pi_alpha)*100)}% PI coverage={cov:.3f} | mean width={(hi-lo).mean():.3f} | cal_scale={scale:.3f}")

    gate_stats(models[0], X_te)

print("\n✅ Done.")



 Training model 1/1
[Epoch 01] s=0.00 topk=0 val=-0.1370 | R2=-0.0879 RMSE=43.112
[Epoch 02] s=0.00 topk=0 val=-0.4181 | R2=-0.0584 RMSE=42.523
[Epoch 03] s=0.00 topk=0 val=-0.6923 | R2=0.5067 RMSE=29.030
[Epoch 04] s=0.00 topk=0 val=-0.7007 | R2=0.6014 RMSE=26.095
[Epoch 05] s=0.00 topk=0 val=-0.7130 | R2=0.3685 RMSE=32.848
[Epoch 06] s=0.00 topk=0 val=-0.6488 | R2=0.3043 RMSE=34.476
[Epoch 07] s=0.00 topk=0 val=-1.0896 | R2=0.5807 RMSE=26.764
[Epoch 08] s=0.00 topk=0 val=-1.0222 | R2=0.5436 RMSE=27.926
[Epoch 09] s=0.00 topk=0 val=-0.6560 | R2=0.1965 RMSE=37.052
[Epoch 10] s=0.00 topk=0 val=-0.9359 | R2=0.4450 RMSE=30.793
[Epoch 11] s=0.00 topk=0 val=-0.4969 | R2=0.2563 RMSE=35.645
[Epoch 12] s=0.00 topk=0 val=-1.0868 | R2=0.5786 RMSE=26.831
[Epoch 13] s=0.05 topk=1 val=-0.9888 | R2=0.4901 RMSE=29.515
[Epoch 14] s=0.10 topk=1 val=-1.0477 | R2=0.5623 RMSE=27.345
[Epoch 15] s=0.15 topk=1 val=-1.0026 | R2=0.5436 RMSE=27.926
[Epoch 16] s=0.20 topk=1 val=-1.1605 | R2=0.6136 RMSE=25.694
[

In [2]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [4]:
# ============================================================
# FINAL Generalized Interpretable MoE for C-MAPSS RUL (ONE CELL)
#  - One joint model for FD001..FD004 (best generalization)
#  - Global scaler on all domains (no FD001 anchoring)
#  - Domain-balanced contiguous batching
#  - Regime embedding from (op1, op2, op3)
#  - GroupNorm + FiLM conditioning for domain/regime shift
#  - MoE Top-K gating + load balance + entropy
#  - Physics constraints: monotonic + smooth
#  - UQ: Gaussian NLL + MC Dropout + PI calibration
# ============================================================

# ---------------------------
# USER SETTINGS
# ---------------------------
DATA_DIR = "/content"   # change if needed
FDS = ["FD001","FD002","FD003","FD004"]
ENSEMBLE_SIZE = 1       # keep 1 for "single best model"

# ---------------------------
# Install deps
# ---------------------------
import sys, subprocess, importlib
def _ensure(pkg, import_name=None):
    name = import_name or pkg
    try:
        importlib.import_module(name)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

_ensure("scikit-learn", "sklearn")
_ensure("scipy", "scipy")

# ---------------------------
# Imports
# ---------------------------
import os, math, random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from scipy.stats import norm

# ---------------------------
# Config
# ---------------------------
@dataclass
class CFG:
    # windows
    window: int = 30
    stride: int = 1
    max_rul: int = 125
    normalize_y: bool = True
    val_ratio_units: float = 0.2
    seed: int = 42

    # model
    n_experts: int = 6
    top_k: int = 2
    enc_hidden: int = 224
    head_hidden: int = 224
    dropout: float = 0.12
    gate_dropout: float = 0.08

    # gating schedule
    gate_noise_max: float = 1.0
    gate_noise_min: float = 0.10
    temp_max: float = 2.0
    temp_min: float = 0.7

    # training
    epochs: int = 110
    batch_size: int = 128
    block_len: int = 12
    lr: float = 2e-3
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    early_stop_patience: int = 16

    # losses
    aux_mse_weight: float = 0.50
    huber_weight: float = 0.15
    huber_delta: float = 0.08  # normalized y scale

    # physics
    lambda_mono: float = 0.10
    lambda_smooth: float = 0.03

    # moe regularizers
    lambda_lb: float = 0.30
    lambda_ent: float = 0.02
    lambda_div: float = 0.01
    lambda_dead: float = 0.06
    dead_floor: float = 0.03

    # ramp
    warmup_epochs: int = 12
    ramp_epochs: int = 20

    # RUL focus (helps FD004)
    rul_focus: float = 0.7   # 0..1; higher => focus on low RUL region

    # UQ
    mc_samples: int = 30
    pi_alpha: float = 0.10  # 90% PI

cfg = CFG()

# ---------------------------
# Utils
# ---------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(cfg.seed)

def ensure_exists(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing file: {path}")

def base_feature_columns() -> List[str]:
    return ["op1", "op2", "op3"] + [f"s{i}" for i in range(1, 22)]

def ramp(ep, warmup, ramp_len):
    if ep <= warmup: return 0.0
    return float(min(1.0, (ep - warmup) / max(1, ramp_len)))

def lerp(a, b, t):
    return a + (b - a) * t

# ---------------------------
# C-MAPSS load
# ---------------------------
def load_cmapss_split(data_dir: str, fd: str):
    train_file = os.path.join(data_dir, f"train_{fd}.txt")
    test_file  = os.path.join(data_dir, f"test_{fd}.txt")
    rul_file   = os.path.join(data_dir, f"RUL_{fd}.txt")
    ensure_exists(train_file); ensure_exists(test_file); ensure_exists(rul_file)

    train_df = pd.read_csv(train_file, sep=r"\s+", header=None)
    test_df  = pd.read_csv(test_file,  sep=r"\s+", header=None)
    rul_df   = pd.read_csv(rul_file,   sep=r"\s+", header=None)

    cols = ["unit","cycle","op1","op2","op3"] + [f"s{i}" for i in range(1,22)]
    train_df.columns = cols
    test_df.columns  = cols
    rul_df.columns   = ["RUL_last"]
    return train_df, test_df, rul_df

def add_rul_train(df: pd.DataFrame, max_rul: int):
    df = df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    df["RUL"] = df.apply(lambda r: max_cycle.loc[r["unit"]] - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

def add_rul_test(test_df: pd.DataFrame, rul_df: pd.DataFrame, max_rul: int):
    df = test_df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    rul_last = rul_df["RUL_last"].values
    mapping = {u: rul_last[u-1] for u in sorted(df["unit"].unique())}
    df["RUL"] = df.apply(lambda r: (max_cycle.loc[r["unit"]] + mapping[r["unit"]]) - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

# ---------------------------
# Regime key/id (from ops)
# ---------------------------
def compute_regime_key(df: pd.DataFrame):
    # robust discrete key across FD001..FD004:
    # op1: round to nearest int (FD001 ~0 -> 0, FD002/4 -> 10/20/35/42)
    # op2: round op2*100 (0.25->25, 0.70->70, 0.84->84, FD001 ~0 -> 0)
    # op3: int (60, 100)
    a = np.round(df["op1"].values).astype(int)
    b = np.round(df["op2"].values * 100).astype(int)
    c = df["op3"].values.astype(int)
    return list(zip(a,b,c))

def build_regime_vocab(all_train_dfs: Dict[str,pd.DataFrame]):
    keys = []
    for fd, df in all_train_dfs.items():
        keys.extend(compute_regime_key(df))
    uniq = sorted(set(keys))
    key2id = {k:i for i,k in enumerate(uniq)}
    return key2id

def attach_regime_id(df: pd.DataFrame, key2id: Dict[Tuple[int,int,int],int]):
    df = df.copy()
    keys = compute_regime_key(df)
    # if unseen keys appear in test, add them (rare but safe)
    for k in keys:
        if k not in key2id:
            key2id[k] = len(key2id)
    df["regime_id"] = [key2id[k] for k in keys]
    return df, key2id

# ---------------------------
# Windowing
# ---------------------------
def make_windows(df, window, stride, feature_cols, target_col="RUL"):
    xs, ys, units, cycles, regimes = [], [], [], [], []
    for unit, g in df.groupby("unit"):
        g = g.sort_values("cycle")
        feats = g[feature_cols].values.astype(np.float32)
        targ  = g[target_col].values.astype(np.float32)
        cyc   = g["cycle"].values.astype(np.int32)
        reg   = g["regime_id"].values.astype(np.int64)

        for end in range(window-1, len(g), stride):
            start = end-window+1
            xs.append(feats[start:end+1])
            ys.append(targ[end])
            units.append(unit)
            cycles.append(cyc[end])
            regimes.append(reg[end])  # last-step regime
    return np.stack(xs), np.array(ys), np.array(units), np.array(cycles), np.array(regimes)

def last_window_per_unit(X, y, unit_ids, cycles, regimes):
    idx=[]
    for u in np.unique(unit_ids):
        m = unit_ids==u
        i = np.argmax(cycles[m])
        idx.append(np.where(m)[0][i])
    idx=np.array(idx)
    return X[idx], y[idx], unit_ids[idx], cycles[idx], regimes[idx]

def split_by_units(unit_ids, val_ratio, seed):
    rng=np.random.default_rng(seed)
    units=np.unique(unit_ids)
    rng.shuffle(units)
    n_val=max(1,int(len(units)*val_ratio))
    val=set(units[:n_val].tolist())
    tr=np.array([i for i,u in enumerate(unit_ids) if u not in val])
    va=np.array([i for i,u in enumerate(unit_ids) if u in val])
    return tr, va

# ---------------------------
# Dataset + domain-balanced contiguous sampler
# ---------------------------
class RULWindowDataset(Dataset):
    def __init__(self, X, y, unit_ids, cycles, domain_ids, regime_ids):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.unit_ids = np.array(unit_ids)
        self.cycles = np.array(cycles)
        self.domain_ids = torch.tensor(domain_ids, dtype=torch.long)
        self.regime_ids = torch.tensor(regime_ids, dtype=torch.long)

    def __len__(self): return len(self.X)
    def __getitem__(self, idx):
        return (self.X[idx],
                self.y[idx],
                int(self.unit_ids[idx]),
                int(self.cycles[idx]),
                self.domain_ids[idx],
                self.regime_ids[idx])

class DomainBalancedContiguousSampler(Sampler):
    """
    Each batch: pick engines_per_batch engines total, balanced across domains.
    For each engine: sample a contiguous block_len segment (like your original sampler).
    """
    def __init__(self, unit_ids, cycles, domain_ids, batch_size, block_len, n_domains, shuffle=True):
        self.unit_ids=np.array(unit_ids)
        self.cycles=np.array(cycles)
        self.domain_ids=np.array(domain_ids)
        self.batch_size=batch_size
        self.block_len=block_len
        self.shuffle=shuffle
        self.n_domains=n_domains

        self.engines_per_batch=max(1,batch_size//block_len)

        # build engine index lists per domain: (domain, unit) -> sorted sample indices
        self.domain_units = {d: [] for d in range(n_domains)}
        self.engine_to_sorted = {}  # (d,u)-> idx array sorted by cycle

        for d in range(n_domains):
            mask_d = self.domain_ids == d
            units_d = np.unique(self.unit_ids[mask_d])
            self.domain_units[d] = units_d.tolist()
            for u in units_d:
                idx = np.where(mask_d & (self.unit_ids==u))[0]
                idx = idx[np.argsort(self.cycles[idx])]
                self.engine_to_sorted[(d,u)] = idx

        # handle domains with few units
        self.domain_units = {d: np.array(v, dtype=int) for d,v in self.domain_units.items()}

    def __iter__(self):
        # prepare per-domain engine order
        dom_orders = {}
        dom_ptr = {}
        for d in range(self.n_domains):
            units = self.domain_units[d].copy()
            if len(units)==0:
                continue
            if self.shuffle:
                np.random.shuffle(units)
            dom_orders[d] = units
            dom_ptr[d] = 0

        # how many engines per domain each batch
        base = self.engines_per_batch // self.n_domains
        rem  = self.engines_per_batch % self.n_domains
        engines_per_dom = {d: base + (1 if d < rem else 0) for d in range(self.n_domains)}

        # generate batches
        batch=[]
        while True:
            any_added = False
            for d in range(self.n_domains):
                need = engines_per_dom[d]
                if need <= 0: continue
                if d not in dom_orders or len(dom_orders[d])==0:
                    continue
                for _ in range(need):
                    # cycle through engines if we run out
                    if dom_ptr[d] >= len(dom_orders[d]):
                        dom_ptr[d] = 0
                        if self.shuffle:
                            np.random.shuffle(dom_orders[d])
                    u = int(dom_orders[d][dom_ptr[d]])
                    dom_ptr[d] += 1

                    idx = self.engine_to_sorted[(d,u)]
                    L = len(idx)
                    if L <= self.block_len:
                        take = np.random.choice(idx, size=self.block_len, replace=True)
                    else:
                        s = np.random.randint(0, L-self.block_len)
                        take = idx[s:s+self.block_len]
                    batch.extend(take.tolist())
                    any_added = True

                    if len(batch) >= self.batch_size:
                        yield batch[:self.batch_size]
                        batch=[]

            if not any_added:
                break

        if batch:
            yield batch

    def __len__(self):
        # rough estimate
        n_eng = sum(len(self.domain_units[d]) for d in range(self.n_domains))
        if n_eng == 0: return 0
        return max(1, math.ceil(n_eng / max(1,self.engines_per_batch)))

# ---------------------------
# Physics losses
# ---------------------------
def physics_losses(pred_mean, unit_ids, cycles, margin=0.0):
    unit_ids=np.array(unit_ids); cycles=np.array(cycles)
    mono_terms=[]; smooth_terms=[]
    for u in np.unique(unit_ids):
        idx=np.where(unit_ids==u)[0]
        if len(idx)<2: continue
        ord_idx=idx[np.argsort(cycles[idx])]
        p=pred_mean[ord_idx]
        mono=F.relu(p[1:]-p[:-1]+margin)   # RUL should decrease over time
        mono_terms.append(mono.mean())
        if len(ord_idx)>=3:
            second=p[2:]-2*p[1:-1]+p[:-2]
            smooth_terms.append(torch.abs(second).mean())
    mono_loss=torch.stack(mono_terms).mean() if mono_terms else pred_mean.new_tensor(0.)
    smooth_loss=torch.stack(smooth_terms).mean() if smooth_terms else pred_mean.new_tensor(0.)
    return mono_loss, smooth_loss

# ---------------------------
# Gating + load balance
# ---------------------------
class GatingNet(nn.Module):
    def __init__(self, in_dim, n_experts, gate_dropout):
        super().__init__()
        h=max(96,in_dim//2)
        self.net=nn.Sequential(
            nn.Linear(in_dim,h),
            nn.SiLU(),
            nn.Dropout(gate_dropout),
            nn.Linear(h,n_experts),
        )
    def forward(self,x): return self.net(x)

def sparse_topk_softmax(logits, k, temperature):
    logits = logits / max(1e-6, temperature)
    topk_vals, topk_idx = torch.topk(logits, k=k, dim=-1)
    masked = torch.full_like(logits, float("-inf"))
    masked.scatter_(dim=-1, index=topk_idx, src=topk_vals)
    w = F.softmax(masked, dim=-1)
    return w, topk_idx

def switch_load_balance_loss(w, top1_idx):
    B,E=w.shape
    importance=w.sum(dim=0)/(B+1e-8)
    load=torch.bincount(top1_idx,minlength=E).float().to(w.device)/(B+1e-8)
    return E*torch.sum(importance*load)

def dead_expert_penalty(w, floor):
    avg = w.mean(dim=0)
    return F.relu(floor - avg).mean()

# ---------------------------
# UQ + calibration
# ---------------------------
def gaussian_nll(y, mean, log_var):
    return 0.5*(torch.exp(-log_var)*(y-mean)**2 + log_var)

def prediction_interval(mean, var, alpha):
    z = norm.ppf(1-alpha/2)
    std=np.sqrt(np.maximum(var,1e-8))
    return mean - z*std, mean + z*std

def coverage(y, lo, hi):
    return float(np.mean((y>=lo)&(y<=hi)))

@torch.no_grad()
def predict_mc(model, X, domain_ids, regime_ids, n_mc, batch_size=256):
    model.train()  # enable dropout for MC
    dl=DataLoader(
        list(zip(torch.tensor(X,dtype=torch.float32),
                 torch.tensor(domain_ids,dtype=torch.long),
                 torch.tensor(regime_ids,dtype=torch.long))),
        batch_size=batch_size, shuffle=False
    )
    means_all=[]; vars_all=[]
    for xb, db, rb in dl:
        xb=xb.to(DEVICE); db=db.to(DEVICE); rb=rb.to(DEVICE)
        mc_m=[]; mc_v=[]
        for _ in range(n_mc):
            out = model(xb, db, rb, use_topk=True, noise_std=0.0, temperature=1.0, train=True)
            mean, log_var = out[0], out[1]
            mc_m.append(mean)
            mc_v.append(torch.exp(log_var))
        mc_m=torch.stack(mc_m,dim=0)
        mc_v=torch.stack(mc_v,dim=0)
        mean_pred=mc_m.mean(dim=0)
        epistemic=mc_m.var(dim=0,unbiased=False)
        aleatoric=mc_v.mean(dim=0)
        total=aleatoric+epistemic
        means_all.append(mean_pred.cpu().numpy())
        vars_all.append(total.cpu().numpy())
    return np.concatenate(means_all), np.concatenate(vars_all)

def calibrate_scale_min_width(model, X_val, d_val, r_val, y_val_raw, alpha, n_mc, max_over=0.01):
    mean, var = predict_mc(model, X_val, d_val, r_val, n_mc=n_mc)
    if cfg.normalize_y:
        mean = mean * cfg.max_rul
        var  = var  * (cfg.max_rul**2)
    std = np.sqrt(np.maximum(var,1e-8))
    z = norm.ppf(1-alpha/2)
    target = 1-alpha
    scales = np.linspace(0.6, 3.0, 121)
    best=None
    for s in scales:
        lo = mean - z*s*std
        hi = mean + z*s*std
        cov = np.mean((y_val_raw>=lo)&(y_val_raw<=hi))
        if cov >= (target - max_over):
            best=s
            break
    return float(best if best is not None else scales[-1])

# ---------------------------
# Model (GroupNorm + FiLM conditioning + Regime/Domain emb)
# ---------------------------
class SharedEncoder(nn.Module):
    def __init__(self, n_features, hidden, dropout):
        super().__init__()
        self.conv1=nn.Conv1d(n_features,96,3,padding=1)
        self.conv2=nn.Conv1d(96,96,3,padding=1)
        self.gn1=nn.GroupNorm(8,96)   # more domain-robust than BatchNorm
        self.gn2=nn.GroupNorm(8,96)
        self.gru=nn.GRU(96,hidden,batch_first=True)
        self.drop=nn.Dropout(dropout)
    def forward(self,x):
        x=x.transpose(1,2)
        x=F.silu(self.gn1(self.conv1(x)))
        x=F.silu(self.gn2(self.conv2(x)))
        x=x.transpose(1,2)
        _,h=self.gru(x)
        return self.drop(h[-1])

class FiLMConditioner(nn.Module):
    def __init__(self, z_dim, op_dim, dom_emb_dim, reg_emb_dim):
        super().__init__()
        in_dim = op_dim + dom_emb_dim + reg_emb_dim
        h = max(96, z_dim//2)
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, h),
            nn.SiLU(),
            nn.Linear(h, 2*z_dim)
        )
    def forward(self, op_stats, dom_emb, reg_emb):
        x = torch.cat([op_stats, dom_emb, reg_emb], dim=-1)
        gb = self.mlp(x)
        gamma, beta = gb.chunk(2, dim=-1)
        # stable scaling
        gamma = 0.1 * torch.tanh(gamma)
        beta  = 0.1 * torch.tanh(beta)
        return gamma, beta

class ExpertHead(nn.Module):
    def __init__(self, in_dim, hidden, base_dropout, expert_id, extra_dim):
        super().__init__()
        width = hidden if expert_id%2==0 else max(96,hidden//2)
        drop  = min(0.35, base_dropout + 0.03*expert_id)
        self.expert_emb = nn.Parameter(torch.randn(16)*0.02)
        self.net = nn.Sequential(
            nn.Linear(in_dim + extra_dim + 16, width), nn.SiLU(), nn.Dropout(drop),
            nn.Linear(width, max(64,width//2)), nn.SiLU(), nn.Dropout(drop),
        )
        out=max(64,width//2)
        self.mu=nn.Linear(out,1)
        self.logv=nn.Linear(out,1)
    def forward(self, z, extra):
        B=z.size(0)
        e=self.expert_emb.unsqueeze(0).expand(B,-1)
        h=self.net(torch.cat([z, extra, e], dim=-1))
        mu=self.mu(h).squeeze(-1)
        logv=self.logv(h).squeeze(-1).clamp(-9,4)
        return mu, logv

class InterpretableMoE(nn.Module):
    def __init__(self, n_features, n_domains, n_regimes,
                 dom_emb_dim=8, reg_emb_dim=8):
        super().__init__()
        self.enc = SharedEncoder(n_features, cfg.enc_hidden, cfg.dropout)
        self.domain_emb = nn.Embedding(n_domains, dom_emb_dim)
        self.regime_emb = nn.Embedding(n_regimes, reg_emb_dim)

        self.film = FiLMConditioner(cfg.enc_hidden, op_dim=9, dom_emb_dim=dom_emb_dim, reg_emb_dim=reg_emb_dim)

        extra_dim = dom_emb_dim + reg_emb_dim
        self.experts = nn.ModuleList([
            ExpertHead(cfg.enc_hidden, cfg.head_hidden, cfg.dropout, i, extra_dim=extra_dim)
            for i in range(cfg.n_experts)
        ])

        gate_in_dim = cfg.enc_hidden + 9 + extra_dim
        self.gate = GatingNet(gate_in_dim, cfg.n_experts, cfg.gate_dropout)

    def forward(self, x, domain_id, regime_id, *, use_topk, noise_std, temperature, train):
        z = self.enc(x)

        # op stats (last, mean, std)
        op = x[:, :, :3]
        op_stats = torch.cat([op[:,-1,:], op.mean(dim=1), op.std(dim=1)], dim=-1)  # (B,9)

        dom_e = self.domain_emb(domain_id)
        reg_e = self.regime_emb(regime_id)

        gamma, beta = self.film(op_stats, dom_e, reg_e)
        zc = z * (1.0 + gamma) + beta

        extra = torch.cat([dom_e, reg_e], dim=-1)
        gate_in = torch.cat([zc, op_stats, extra], dim=-1)

        logits = self.gate(gate_in)
        if train and noise_std>0:
            logits = logits + torch.randn_like(logits)*noise_std

        if use_topk:
            w, topk_idx = sparse_topk_softmax(logits, cfg.top_k, temperature)
            top1 = topk_idx[:,0]
        else:
            w = F.softmax(logits/max(1e-6,temperature), dim=-1)
            top1 = torch.argmax(w, dim=-1)

        mus=[]; vars_=[]
        for ex in self.experts:
            mu, logv = ex(zc, extra)
            mus.append(mu); vars_.append(torch.exp(logv))
        mus=torch.stack(mus,dim=-1)
        vars_=torch.stack(vars_,dim=-1)

        mean=torch.sum(w*mus,dim=-1)
        second=torch.sum(w*(vars_+mus**2),dim=-1)
        var=(second-mean**2).clamp_min(1e-6)
        log_var=torch.log(var)

        lb = switch_load_balance_loss(w, top1)
        ent = -(w*torch.log(w+1e-8)).sum(dim=-1).mean()
        dead = dead_expert_penalty(w, cfg.dead_floor)

        # expert diversity (light)
        div=0.0
        E=mus.shape[-1]
        for i in range(E):
            for j in range(i+1,E):
                a=mus[:,i]-mus[:,i].mean()
                b=mus[:,j]-mus[:,j].mean()
                div += torch.abs(F.cosine_similarity(a.unsqueeze(-1), b.unsqueeze(-1), dim=-1)).mean()
        div = div / max(1,(E*(E-1)//2))

        return mean, log_var, w, lb, ent, div, dead

# ---------------------------
# Prepare ALL domains (global scaler, global regime vocab)
# ---------------------------
def prepare_all_domains(data_dir: str, fds: List[str]):
    # load all train dfs first (for global regime vocab)
    train_dfs = {}
    raw_splits = {}
    for fd in fds:
        tr, te, rul = load_cmapss_split(data_dir, fd)
        tr = add_rul_train(tr, cfg.max_rul)
        te = add_rul_test(te, rul, cfg.max_rul)
        train_dfs[fd] = tr
        raw_splits[fd] = (tr, te)

    key2id = build_regime_vocab(train_dfs)
    domains = {}

    feature_cols = base_feature_columns()  # keep all sensors always

    # attach regime id for each domain
    for di, fd in enumerate(fds):
        tr, te = raw_splits[fd]
        tr, key2id = attach_regime_id(tr, key2id)
        te, key2id = attach_regime_id(te, key2id)

        X_tr_all, y_tr_all, u_tr_all, c_tr_all, r_tr_all = make_windows(tr, cfg.window, cfg.stride, feature_cols)
        X_te_all, y_te_all, u_te_all, c_te_all, r_te_all = make_windows(te, cfg.window, cfg.stride, feature_cols)
        X_te_last, y_te_last, u_te_last, c_te_last, r_te_last = last_window_per_unit(
            X_te_all, y_te_all, u_te_all, c_te_all, r_te_all
        )

        domains[fd] = {
            "domain_id": di,
            "feature_cols": feature_cols,
            "train_all": (X_tr_all, y_tr_all, u_tr_all, c_tr_all, r_tr_all),
            "test_last_raw": (X_te_last, y_te_last, u_te_last, c_te_last, r_te_last)
        }

    # global scaler on ALL TRAIN windows of ALL domains
    scaler = StandardScaler()
    X_stack = []
    for fd in fds:
        X_tr_all = domains[fd]["train_all"][0]
        X_stack.append(X_tr_all.reshape(-1, X_tr_all.shape[-1]))
    X_stack = np.concatenate(X_stack, axis=0)
    scaler.fit(X_stack)

    # apply scaler + split into train/val by units per domain
    for fd in fds:
        X_tr_all, y_tr_all, u_tr_all, c_tr_all, r_tr_all = domains[fd]["train_all"]
        X_te_last, y_te_last, u_te_last, c_te_last, r_te_last = domains[fd]["test_last_raw"]

        X_tr_all = scaler.transform(X_tr_all.reshape(-1, X_tr_all.shape[-1])).reshape(X_tr_all.shape)
        X_te_last = scaler.transform(X_te_last.reshape(-1, X_te_last.shape[-1])).reshape(X_te_last.shape)

        tr_idx, va_idx = split_by_units(u_tr_all, cfg.val_ratio_units, cfg.seed)

        X_tr, y_tr, u_tr, c_tr, r_tr = X_tr_all[tr_idx], y_tr_all[tr_idx], u_tr_all[tr_idx], c_tr_all[tr_idx], r_tr_all[tr_idx]
        X_va, y_va, u_va, c_va, r_va = X_tr_all[va_idx], y_tr_all[va_idx], u_tr_all[va_idx], c_tr_all[va_idx], r_tr_all[va_idx]

        y_te_raw = y_te_last.copy()

        if cfg.normalize_y:
            y_tr = y_tr / cfg.max_rul
            y_va = y_va / cfg.max_rul
            y_te = y_te_last / cfg.max_rul
        else:
            y_te = y_te_last

        d_tr = np.full(len(X_tr), domains[fd]["domain_id"], dtype=np.int64)
        d_va = np.full(len(X_va), domains[fd]["domain_id"], dtype=np.int64)
        d_te = np.full(len(X_te_last), domains[fd]["domain_id"], dtype=np.int64)

        domains[fd].update({
            "scaler": scaler,
            "n_regimes": len(key2id),
            "train": (X_tr, y_tr, u_tr, c_tr, d_tr, r_tr),
            "val": (X_va, y_va, u_va, c_va, d_va, r_va),
            "test_last": (X_te_last, y_te, u_te_last, c_te_last, d_te, r_te_last),
            "test_last_y_raw": y_te_raw
        })

    return domains, scaler, len(key2id)

# ---------------------------
# Interpretability
# ---------------------------
@torch.no_grad()
def gate_stats(model, X, d, r, n_show=2048):
    model.eval()
    xb=torch.tensor(X[:n_show],dtype=torch.float32).to(DEVICE)
    db=torch.tensor(d[:n_show],dtype=torch.long).to(DEVICE)
    rb=torch.tensor(r[:n_show],dtype=torch.long).to(DEVICE)
    out = model(xb, db, rb, use_topk=True, noise_std=0.0, temperature=1.0, train=False)
    w = out[2].cpu().numpy()
    print("Gate usage (avg weights):", np.round(w.mean(axis=0),4))

# ---------------------------
# Training (ONE joint model)
# ---------------------------
def train_joint(domains: Dict[str,dict], n_domains: int, n_regimes: int):
    # build combined train set
    X_tr=[]; y_tr=[]; u_tr=[]; c_tr=[]; d_tr=[]; r_tr=[]
    X_va_by_fd={}; y_va_by_fd={}; d_va_by_fd={}; r_va_by_fd={}; u_va_by_fd={}; c_va_by_fd={}

    for fd, dom in domains.items():
        X,y,u,c,d,r = dom["train"]
        X_tr.append(X); y_tr.append(y); u_tr.append(u); c_tr.append(c); d_tr.append(d); r_tr.append(r)

        Xv,yv,uv,cv,dv,rv = dom["val"]
        X_va_by_fd[fd]=Xv; y_va_by_fd[fd]=yv; d_va_by_fd[fd]=dv; r_va_by_fd[fd]=rv
        u_va_by_fd[fd]=uv; c_va_by_fd[fd]=cv

    X_tr=np.concatenate(X_tr); y_tr=np.concatenate(y_tr)
    u_tr=np.concatenate(u_tr); c_tr=np.concatenate(c_tr)
    d_tr=np.concatenate(d_tr); r_tr=np.concatenate(r_tr)

    model = InterpretableMoE(n_features=X_tr.shape[-1], n_domains=n_domains, n_regimes=n_regimes).to(DEVICE)

    ds_tr = RULWindowDataset(X_tr,y_tr,u_tr,c_tr,domain_ids=d_tr, regime_ids=r_tr)

    sampler = DomainBalancedContiguousSampler(
        ds_tr.unit_ids, ds_tr.cycles, ds_tr.domain_ids.numpy(),
        cfg.batch_size, cfg.block_len, n_domains=n_domains, shuffle=True
    )
    dl_tr = DataLoader(ds_tr, batch_sampler=sampler)

    # val loaders per domain
    val_loaders={}
    for fd in domains.keys():
        Xv=y_va_by_fd[fd]*0  # dummy init
        ds = RULWindowDataset(
            X_va_by_fd[fd], y_va_by_fd[fd],
            u_va_by_fd[fd], c_va_by_fd[fd],
            domain_ids=d_va_by_fd[fd], regime_ids=r_va_by_fd[fd]
        )
        val_loaders[fd]=DataLoader(ds, batch_size=cfg.batch_size, shuffle=False)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)

    best_val=float("inf"); best_state=None; bad=0

    for ep in range(1, cfg.epochs+1):
        s = ramp(ep, cfg.warmup_epochs, cfg.ramp_epochs)
        use_topk = (ep > cfg.warmup_epochs)
        noise_std = lerp(cfg.gate_noise_max, cfg.gate_noise_min, s)
        temperature = lerp(cfg.temp_max, cfg.temp_min, s)

        mono_w   = cfg.lambda_mono*s
        smooth_w = cfg.lambda_smooth*s
        lb_w     = cfg.lambda_lb*s
        ent_w    = cfg.lambda_ent*s
        div_w    = cfg.lambda_div*s
        dead_w   = cfg.lambda_dead*s

        model.train()
        tr_loss=0.0

        for xb,yb,ub,cb,db,rb in dl_tr:
            xb=xb.to(DEVICE); yb=yb.to(DEVICE)
            db=db.to(DEVICE); rb=rb.to(DEVICE)

            opt.zero_grad(set_to_none=True)

            mean, logv, w, lb, ent, div, dead = model(
                xb, db, rb, use_topk=use_topk, noise_std=noise_std, temperature=temperature, train=True
            )

            # emphasize low RUL region (helps FD004)
            # yb normalized in [0,1]; low RUL => small yb
            weight = 1.0 + cfg.rul_focus * (1.0 - yb)
            nll = (gaussian_nll(yb, mean, logv) * weight).mean()
            mse = (F.mse_loss(mean, yb, reduction="none") * weight).mean()
            hub = (F.huber_loss(mean, yb, delta=cfg.huber_delta, reduction="none") * weight).mean()

            mono, smooth = physics_losses(mean, ub, cb)

            if ep <= cfg.warmup_epochs:
                loss = 1.0*mse + 0.10*nll + cfg.huber_weight*hub
            else:
                loss = nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub

            loss = loss + mono_w*mono + smooth_w*smooth + lb_w*lb - ent_w*ent + div_w*div + dead_w*dead

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()
            tr_loss += float(loss.detach().cpu())

        sched.step()

        # ---- validation: average across domains (key for generalization)
        model.eval()
        val_total=0.0
        val_rmse={}
        with torch.no_grad():
            for fd, dl_va in val_loaders.items():
                va_loss=0.0
                y_true=[]; y_hat=[]
                for xb,yb,ub,cb,db,rb in dl_va:
                    xb=xb.to(DEVICE); yb=yb.to(DEVICE)
                    db=db.to(DEVICE); rb=rb.to(DEVICE)
                    mean, logv, *_ = model(xb, db, rb, use_topk=True, noise_std=0.0, temperature=1.0, train=False)

                    # same weighted objective for selection
                    weight = 1.0 + cfg.rul_focus * (1.0 - yb)
                    nll = (gaussian_nll(yb, mean, logv) * weight).mean()
                    mse = (F.mse_loss(mean, yb, reduction="none") * weight).mean()
                    hub = (F.huber_loss(mean, yb, delta=cfg.huber_delta, reduction="none") * weight).mean()
                    va_loss += float((nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub).cpu())

                    y_true.append(yb.detach().cpu().numpy())
                    y_hat.append(mean.detach().cpu().numpy())

                va_loss /= max(1,len(dl_va))
                val_total += va_loss

                yt = np.concatenate(y_true)
                yh = np.concatenate(y_hat)
                if cfg.normalize_y:
                    yt = yt*cfg.max_rul
                    yh = yh*cfg.max_rul
                val_rmse[fd] = math.sqrt(mean_squared_error(yt, yh))

        val_total /= max(1, len(val_loaders))
        rmse_str = " | ".join([f"{fd}:{val_rmse[fd]:.2f}" for fd in FDS])
        print(f"[Epoch {ep:03d}] s={s:.2f} topk={int(use_topk)} val_avg={val_total:.4f} | valRMSE {rmse_str}")

        if val_total + 1e-6 < best_val:
            best_val=val_total
            best_state={k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            bad=0
        else:
            bad+=1
            if bad>=cfg.early_stop_patience:
                print(f"Early stopping (best avg val={best_val:.4f})")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model

# ---------------------------
# EVALUATION
# ---------------------------
def eval_domain(model, domains, fd):
    X_te, y_te, u_te, c_te, d_te, r_te = domains[fd]["test_last"]
    y_te_raw = domains[fd]["test_last_y_raw"]

    X_va, y_va, u_va, c_va, d_va, r_va = domains[fd]["val"]
    y_va_raw = (y_va*cfg.max_rul) if cfg.normalize_y else y_va

    mean, var = predict_mc(model, X_te, d_te, r_te, n_mc=cfg.mc_samples)
    if cfg.normalize_y:
        mean_raw = mean*cfg.max_rul
        var_raw  = var*(cfg.max_rul**2)
    else:
        mean_raw = mean; var_raw = var

    # PI calibration per-domain (optional; doesn't change point accuracy)
    scale = calibrate_scale_min_width(model, X_va, d_va, r_va, y_va_raw, cfg.pi_alpha, cfg.mc_samples, max_over=0.01)
    var_raw = (scale**2)*var_raw

    r2 = r2_score(y_te_raw, mean_raw)
    rmse = math.sqrt(mean_squared_error(y_te_raw, mean_raw))
    mae = mean_absolute_error(y_te_raw, mean_raw)

    lo, hi = prediction_interval(mean_raw, var_raw, cfg.pi_alpha)
    cov = coverage(y_te_raw, lo, hi)

    print(f"Point: R2={r2:.4f} RMSE={rmse:.3f} MAE={mae:.3f}")
    print(f"UQ: {int((1-cfg.pi_alpha)*100)}% PI coverage={cov:.3f} | mean width={(hi-lo).mean():.3f} | cal_scale={scale:.3f}")
    gate_stats(model, X_te, d_te, r_te)

# ---------------------------
# RUN
# ---------------------------
# auto-fallback for local testing (optional)
if not os.path.exists(DATA_DIR):
    DATA_DIR = "/mnt/data"

# file checks
missing=[]
for fd in FDS:
    for f in [f"train_{fd}.txt", f"test_{fd}.txt", f"RUL_{fd}.txt"]:
        if not os.path.exists(os.path.join(DATA_DIR, f)):
            missing.append(os.path.join(DATA_DIR,f))
if missing:
    raise FileNotFoundError("Missing required C-MAPSS files:\n" + "\n".join(missing))

domains, scaler, n_regimes = prepare_all_domains(DATA_DIR, FDS)
n_domains = len(FDS)

models=[]
for mi in range(ENSEMBLE_SIZE):
    print("\n==============================")
    print(f" Training JOINT model {mi+1}/{ENSEMBLE_SIZE}")
    print("==============================")
    set_seed(cfg.seed + mi)
    models.append(train_joint(domains, n_domains=n_domains, n_regimes=n_regimes))

model = models[0]

for fd in FDS:
    print("\n==============================")
    print(f" Target {fd} | LAST window per unit")
    print("==============================")
    eval_domain(model, domains, fd)

print("\n✅ Done. One generalized model trained on FD001..FD004.")



 Training JOINT model 1/1


KeyboardInterrupt: 

In [5]:
# ============================================================
# END-TO-END: Baseline vs MoE Ablation Table for C-MAPSS (One Cell)
# Outputs: RMSE table for FD001..FD004 + Avg RMSE
# Variants:
#  1) Single Expert
#  2) Dense MoE (no top-k)
#  3) MoE + Top-K
#  4) MoE + Top-K + Physics
#  5) Full (Top-K + Physics + UQ + Calib)
# ============================================================

# ---------------------------
# USER SETTINGS
# ---------------------------
DATA_DIR = "/content"          # change if needed
FDS = ["FD001","FD002","FD003","FD004"]
ENSEMBLE_SIZE = 1             # keep 1 for table fairness
FAST_DEV_RUN = False          # True = fewer epochs for quick check
PRINT_UQ_FOR_FULL = True      # prints PI coverage/width for Full variant

# ---------------------------
# Install deps
# ---------------------------
import sys, subprocess, importlib
def _ensure(pkg, import_name=None):
    name = import_name or pkg
    try:
        importlib.import_module(name)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

_ensure("scikit-learn", "sklearn")
_ensure("scipy", "scipy")

# ---------------------------
# Imports
# ---------------------------
import os, math, random
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from scipy.stats import norm

# ---------------------------
# Config
# ---------------------------
@dataclass
class CFG:
    window: int = 30
    stride: int = 1
    max_rul: int = 125
    normalize_y: bool = True
    val_ratio_units: float = 0.2
    seed: int = 42

    # feature filter (leave as-is; it only drops near-constant columns)
    drop_low_var: bool = True
    low_var_thresh: float = 1e-6

    # MoE
    n_experts: int = 4
    top_k: int = 2
    enc_hidden: int = 192
    head_hidden: int = 192
    dropout: float = 0.08
    gate_dropout: float = 0.05

    # gating schedule
    gate_noise_max: float = 1.0
    gate_noise_min: float = 0.15
    temp_max: float = 2.0
    temp_min: float = 0.7

    # training
    epochs: int = 85
    batch_size: int = 128
    block_len: int = 12
    lr: float = 2e-3
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    early_stop_patience: int = 14

    # loss mix
    aux_mse_weight: float = 0.55
    huber_weight: float = 0.15
    huber_delta: float = 0.08  # normalized y scale

    # physics (used only for physics variants)
    lambda_mono: float = 0.10
    lambda_smooth: float = 0.03

    # MoE regularizers (used only for MoE variants)
    lambda_lb: float = 0.25
    lambda_ent: float = 0.02
    lambda_div: float = 0.01
    lambda_dead: float = 0.06
    dead_floor: float = 0.03

    # ramp
    warmup_epochs: int = 12
    ramp_epochs: int = 20

    # UQ
    mc_samples: int = 30
    pi_alpha: float = 0.10  # 90% PI

cfg = CFG()
if FAST_DEV_RUN:
    cfg.epochs = 12
    cfg.warmup_epochs = 3
    cfg.ramp_epochs = 4
    cfg.early_stop_patience = 4
    cfg.mc_samples = 10

# ---------------------------
# Utils
# ---------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(cfg.seed)

def ensure_exists(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing file: {path}")

def base_feature_columns() -> List[str]:
    return ["op1", "op2", "op3"] + [f"s{i}" for i in range(1, 22)]

def ramp(ep, warmup, ramp_len):
    if ep <= warmup: return 0.0
    return float(min(1.0, (ep - warmup) / max(1, ramp_len)))

def lerp(a, b, t):
    return a + (b - a) * t

# ---------------------------
# C-MAPSS load + labels
# ---------------------------
def load_cmapss_split(data_dir: str, fd: str):
    train_file = os.path.join(data_dir, f"train_{fd}.txt")
    test_file  = os.path.join(data_dir, f"test_{fd}.txt")
    rul_file   = os.path.join(data_dir, f"RUL_{fd}.txt")
    ensure_exists(train_file); ensure_exists(test_file); ensure_exists(rul_file)

    train_df = pd.read_csv(train_file, sep=r"\s+", header=None)
    test_df  = pd.read_csv(test_file,  sep=r"\s+", header=None)
    rul_df   = pd.read_csv(rul_file,   sep=r"\s+", header=None)

    cols = ["unit","cycle","op1","op2","op3"] + [f"s{i}" for i in range(1,22)]
    train_df.columns = cols
    test_df.columns  = cols
    rul_df.columns   = ["RUL_last"]
    return train_df, test_df, rul_df

def add_rul_train(df: pd.DataFrame, max_rul: int):
    df = df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    df["RUL"] = df.apply(lambda r: max_cycle.loc[r["unit"]] - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

def add_rul_test(test_df: pd.DataFrame, rul_df: pd.DataFrame, max_rul: int):
    df = test_df.copy()
    max_cycle = df.groupby("unit")["cycle"].max()
    rul_last = rul_df["RUL_last"].values
    mapping = {u: rul_last[u-1] for u in sorted(df["unit"].unique())}
    df["RUL"] = df.apply(lambda r: (max_cycle.loc[r["unit"]] + mapping[r["unit"]]) - r["cycle"], axis=1)
    df["RUL"] = df["RUL"].clip(upper=max_rul)
    return df

# ---------------------------
# Windowing
# ---------------------------
def make_windows(df, window, stride, feature_cols, target_col="RUL"):
    xs, ys, units, cycles = [], [], [], []
    for unit, g in df.groupby("unit"):
        g = g.sort_values("cycle")
        feats = g[feature_cols].values.astype(np.float32)
        targ  = g[target_col].values.astype(np.float32)
        cyc   = g["cycle"].values.astype(np.int32)

        for end in range(window-1, len(g), stride):
            start = end-window+1
            xs.append(feats[start:end+1])
            ys.append(targ[end])
            units.append(unit)
            cycles.append(cyc[end])
    return np.stack(xs), np.array(ys), np.array(units), np.array(cycles)

def last_window_per_unit(X, y, unit_ids, cycles):
    idx=[]
    for u in np.unique(unit_ids):
        m = unit_ids==u
        i = np.argmax(cycles[m])
        idx.append(np.where(m)[0][i])
    idx=np.array(idx)
    return X[idx], y[idx], unit_ids[idx], cycles[idx]

def split_by_units(unit_ids, val_ratio, seed):
    rng=np.random.default_rng(seed)
    units=np.unique(unit_ids)
    rng.shuffle(units)
    n_val=max(1,int(len(units)*val_ratio))
    val=set(units[:n_val].tolist())
    tr=np.array([i for i,u in enumerate(unit_ids) if u not in val])
    va=np.array([i for i,u in enumerate(unit_ids) if u in val])
    return tr, va

# ---------------------------
# Dataset + contiguous sampler
# ---------------------------
class RULWindowDataset(Dataset):
    def __init__(self, X, y, unit_ids, cycles):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.unit_ids = np.array(unit_ids)
        self.cycles = np.array(cycles)

    def __len__(self): return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], int(self.unit_ids[idx]), int(self.cycles[idx])

class ContiguousEngineBatchSampler(Sampler):
    def __init__(self, unit_ids, cycles, batch_size, block_len, shuffle=True):
        self.unit_ids=np.array(unit_ids)
        self.cycles=np.array(cycles)
        self.batch_size=batch_size
        self.block_len=block_len
        self.shuffle=shuffle
        self.units=np.unique(self.unit_ids)
        self.unit_to_sorted={}
        for u in self.units:
            idx=np.where(self.unit_ids==u)[0]
            idx=idx[np.argsort(self.cycles[idx])]
            self.unit_to_sorted[u]=idx
        self.engines_per_batch=max(1,batch_size//block_len)

    def __iter__(self):
        units=self.units.copy()
        if self.shuffle: np.random.shuffle(units)
        batch=[]
        for u in units:
            idx=self.unit_to_sorted[u]
            L=len(idx)
            if L<=self.block_len:
                take=np.random.choice(idx,size=self.block_len,replace=True)
            else:
                s=np.random.randint(0,L-self.block_len)
                take=idx[s:s+self.block_len]
            batch.extend(take.tolist())
            if len(batch)>=self.engines_per_batch*self.block_len:
                yield batch[:self.batch_size]
                batch=[]
        if batch: yield batch

    def __len__(self):
        return math.ceil(len(self.units)/self.engines_per_batch)

# ---------------------------
# Physics losses
# ---------------------------
def physics_losses(pred_mean, unit_ids, cycles, margin=0.0):
    unit_ids=np.array(unit_ids); cycles=np.array(cycles)
    mono_terms=[]; smooth_terms=[]
    for u in np.unique(unit_ids):
        idx=np.where(unit_ids==u)[0]
        if len(idx)<2: continue
        ord_idx=idx[np.argsort(cycles[idx])]
        p=pred_mean[ord_idx]
        mono=F.relu(p[1:]-p[:-1]+margin)
        mono_terms.append(mono.mean())
        if len(ord_idx)>=3:
            second=p[2:]-2*p[1:-1]+p[:-2]
            smooth_terms.append(torch.abs(second).mean())
    mono_loss=torch.stack(mono_terms).mean() if mono_terms else pred_mean.new_tensor(0.)
    smooth_loss=torch.stack(smooth_terms).mean() if smooth_terms else pred_mean.new_tensor(0.)
    return mono_loss, smooth_loss

# ---------------------------
# MoE utilities
# ---------------------------
class GatingNet(nn.Module):
    def __init__(self, in_dim, n_experts, gate_dropout):
        super().__init__()
        h=max(64,in_dim//2)
        self.net=nn.Sequential(
            nn.Linear(in_dim,h),
            nn.ReLU(),
            nn.Dropout(gate_dropout),
            nn.Linear(h,n_experts),
        )
    def forward(self,x): return self.net(x)

def sparse_topk_softmax(logits, k, temperature):
    logits = logits / max(1e-6, temperature)
    topk_vals, topk_idx = torch.topk(logits, k=k, dim=-1)
    masked = torch.full_like(logits, float("-inf"))
    masked.scatter_(dim=-1, index=topk_idx, src=topk_vals)
    w = F.softmax(masked, dim=-1)
    return w, topk_idx

def switch_load_balance_loss(w, top1_idx):
    B,E=w.shape
    importance=w.sum(dim=0)/(B+1e-8)
    load=torch.bincount(top1_idx,minlength=E).float().to(w.device)/(B+1e-8)
    return E*torch.sum(importance*load)

def dead_expert_penalty(w, floor):
    avg = w.mean(dim=0)
    return F.relu(floor - avg).mean()

# ---------------------------
# UQ + calibration
# ---------------------------
def gaussian_nll(y, mean, log_var):
    return 0.5*(torch.exp(-log_var)*(y-mean)**2 + log_var)

def prediction_interval(mean, var, alpha):
    z = norm.ppf(1-alpha/2)
    std=np.sqrt(np.maximum(var,1e-8))
    return mean - z*std, mean + z*std

def coverage(y, lo, hi):
    return float(np.mean((y>=lo)&(y<=hi)))

@torch.no_grad()
def predict_mc_mean_var(model, X, n_mc, batch_size=256, use_topk=True):
    model.train()  # enable dropout
    dl=DataLoader(torch.tensor(X,dtype=torch.float32), batch_size=batch_size, shuffle=False)
    means_all=[]; vars_all=[]
    for xb in dl:
        xb=xb.to(DEVICE)
        mc_m=[]; mc_v=[]
        for _ in range(n_mc):
            mean, logv, *_ = model(xb, use_topk=use_topk, noise_std=0.0, temperature=1.0, train=True)
            mc_m.append(mean)
            mc_v.append(torch.exp(logv))
        mc_m=torch.stack(mc_m,dim=0)
        mc_v=torch.stack(mc_v,dim=0)
        mean_pred=mc_m.mean(dim=0)
        epistemic=mc_m.var(dim=0,unbiased=False)
        aleatoric=mc_v.mean(dim=0)
        total=aleatoric+epistemic
        means_all.append(mean_pred.cpu().numpy())
        vars_all.append(total.cpu().numpy())
    return np.concatenate(means_all), np.concatenate(vars_all)

def calibrate_scale_min_width(model, X_val, y_val_raw, alpha, n_mc, max_over=0.01, use_topk=True):
    mean, var = predict_mc_mean_var(model, X_val, n_mc=n_mc, use_topk=use_topk)
    if cfg.normalize_y:
        mean = mean * cfg.max_rul
        var  = var  * (cfg.max_rul**2)
    std = np.sqrt(np.maximum(var,1e-8))
    z = norm.ppf(1-alpha/2)
    target = 1-alpha
    scales = np.linspace(0.6, 3.0, 121)
    best=None
    for s in scales:
        lo = mean - z*s*std
        hi = mean + z*s*std
        cov = np.mean((y_val_raw>=lo)&(y_val_raw<=hi))
        if cov >= (target - max_over):
            best=s
            break
    return float(best if best is not None else scales[-1])

# ---------------------------
# Models
# ---------------------------
class SharedEncoder(nn.Module):
    def __init__(self, n_features, hidden, dropout):
        super().__init__()
        self.conv1=nn.Conv1d(n_features,96,3,padding=1)
        self.conv2=nn.Conv1d(96,96,3,padding=1)
        self.bn1=nn.BatchNorm1d(96)
        self.bn2=nn.BatchNorm1d(96)
        self.gru=nn.GRU(96,hidden,batch_first=True)
        self.drop=nn.Dropout(dropout)
    def forward(self,x):
        x=x.transpose(1,2)
        x=F.relu(self.bn1(self.conv1(x)))
        x=F.relu(self.bn2(self.conv2(x)))
        x=x.transpose(1,2)
        _,h=self.gru(x)
        return self.drop(h[-1])

class ExpertHead(nn.Module):
    def __init__(self, in_dim, hidden, base_dropout, expert_id):
        super().__init__()
        width = hidden if expert_id%2==0 else max(64,hidden//2)
        drop  = min(0.30, base_dropout + 0.03*expert_id)
        self.emb = nn.Parameter(torch.randn(16)*0.02)
        self.net = nn.Sequential(
            nn.Linear(in_dim+16, width), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(width, max(32,width//2)), nn.ReLU(), nn.Dropout(drop),
        )
        out=max(32,width//2)
        self.mu=nn.Linear(out,1)
        self.logv=nn.Linear(out,1)
    def forward(self,z):
        B=z.size(0)
        e=self.emb.unsqueeze(0).expand(B,-1)
        h=self.net(torch.cat([z,e],dim=-1))
        mu=self.mu(h).squeeze(-1)
        logv=self.logv(h).squeeze(-1).clamp(-9,4)
        return mu, logv

class InterpretableMoE(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.enc = SharedEncoder(n_features,cfg.enc_hidden,cfg.dropout)
        self.experts = nn.ModuleList([ExpertHead(cfg.enc_hidden,cfg.head_hidden,cfg.dropout,i)
                                      for i in range(cfg.n_experts)])
        self.gate = GatingNet(cfg.enc_hidden+9, cfg.n_experts, cfg.gate_dropout)

    def forward(self, x, *, use_topk, noise_std, temperature, train):
        z = self.enc(x)
        op = x[:,:, :3]
        gate_in = torch.cat([z, op[:,-1,:], op.mean(dim=1), op.std(dim=1)], dim=-1)

        logits = self.gate(gate_in)
        if train and noise_std>0:
            logits = logits + torch.randn_like(logits)*noise_std

        if use_topk:
            w, topk_idx = sparse_topk_softmax(logits, cfg.top_k, temperature)
            top1 = topk_idx[:,0]
        else:
            w = F.softmax(logits/max(1e-6,temperature), dim=-1)
            top1 = torch.argmax(w, dim=-1)

        mus=[]; vars_=[]
        for ex in self.experts:
            mu, logv = ex(z)
            mus.append(mu); vars_.append(torch.exp(logv))
        mus=torch.stack(mus,dim=-1)
        vars_=torch.stack(vars_,dim=-1)

        mean=torch.sum(w*mus,dim=-1)
        second=torch.sum(w*(vars_+mus**2),dim=-1)
        var=(second-mean**2).clamp_min(1e-6)
        log_var=torch.log(var)

        lb = switch_load_balance_loss(w, top1)
        ent = -(w*torch.log(w+1e-8)).sum(dim=-1).mean()
        dead = dead_expert_penalty(w, cfg.dead_floor)

        div=0.0
        E=mus.shape[-1]
        for i in range(E):
            for j in range(i+1,E):
                a=mus[:,i]-mus[:,i].mean()
                b=mus[:,j]-mus[:,j].mean()
                div += torch.abs(F.cosine_similarity(a.unsqueeze(-1), b.unsqueeze(-1), dim=-1)).mean()
        div = div / max(1,(E*(E-1)//2))

        return mean, log_var, w, lb, ent, div, dead

class SingleExpertBaseline(nn.Module):
    """
    Baseline: same encoder, one probabilistic head (mu, log_var).
    To keep training code unified, it returns 7 values like MoE but with zeros.
    """
    def __init__(self, n_features):
        super().__init__()
        self.enc = SharedEncoder(n_features,cfg.enc_hidden,cfg.dropout)
        h = cfg.head_hidden
        self.net = nn.Sequential(
            nn.Linear(cfg.enc_hidden, h), nn.ReLU(), nn.Dropout(cfg.dropout),
            nn.Linear(h, max(32,h//2)), nn.ReLU(), nn.Dropout(cfg.dropout),
        )
        out = max(32,h//2)
        self.mu = nn.Linear(out, 1)
        self.logv = nn.Linear(out, 1)

    def forward(self, x, *, use_topk, noise_std, temperature, train):
        z = self.enc(x)
        h = self.net(z)
        mu = self.mu(h).squeeze(-1)
        logv = self.logv(h).squeeze(-1).clamp(-9,4)

        # placeholders
        w = torch.zeros((x.size(0), cfg.n_experts), device=x.device)
        lb = mu.new_tensor(0.0)
        ent = mu.new_tensor(0.0)
        div = mu.new_tensor(0.0)
        dead = mu.new_tensor(0.0)
        return mu, logv, w, lb, ent, div, dead

# ---------------------------
# Domain preparation (single FD)
# ---------------------------
def prepare_domain(fd: str, scaler=None, kept_cols=None):
    tr, te, rul = load_cmapss_split(DATA_DIR, fd)
    tr = add_rul_train(tr, cfg.max_rul)
    te = add_rul_test(te, rul, cfg.max_rul)

    all_cols = base_feature_columns()
    if kept_cols is None:
        kept_cols = all_cols
        if cfg.drop_low_var:
            v = tr[all_cols].var(axis=0).values
            kept_cols = [c for c,vv in zip(all_cols,v) if vv > cfg.low_var_thresh]
            # always keep operating settings
            for c in ["op1","op2","op3"]:
                if c not in kept_cols:
                    kept_cols = ["op1","op2","op3"] + [x for x in kept_cols if x not in ["op1","op2","op3"]]

    X_tr_all, y_tr_all, u_tr_all, c_tr_all = make_windows(tr, cfg.window, cfg.stride, kept_cols)
    X_te_all, y_te_all, u_te_all, c_te_all = make_windows(te, cfg.window, cfg.stride, kept_cols)
    X_te_last, y_te_last, u_te_last, c_te_last = last_window_per_unit(X_te_all, y_te_all, u_te_all, c_te_all)

    if scaler is None:
        scaler = StandardScaler()
        scaler.fit(X_tr_all.reshape(-1, X_tr_all.shape[-1]))

    X_tr_all = scaler.transform(X_tr_all.reshape(-1,X_tr_all.shape[-1])).reshape(X_tr_all.shape)
    X_te_last = scaler.transform(X_te_last.reshape(-1,X_te_last.shape[-1])).reshape(X_te_last.shape)

    tr_idx, va_idx = split_by_units(u_tr_all, cfg.val_ratio_units, cfg.seed)
    X_tr, y_tr, u_tr, c_tr = X_tr_all[tr_idx], y_tr_all[tr_idx], u_tr_all[tr_idx], c_tr_all[tr_idx]
    X_va, y_va, u_va, c_va = X_tr_all[va_idx], y_tr_all[va_idx], u_tr_all[va_idx], c_tr_all[va_idx]

    y_te_raw = y_te_last.copy()

    if cfg.normalize_y:
        y_tr = y_tr/cfg.max_rul
        y_va = y_va/cfg.max_rul
        y_te = y_te_last/cfg.max_rul
    else:
        y_te = y_te_last

    return {
        "scaler": scaler, "kept_cols": kept_cols,
        "train": (X_tr,y_tr,u_tr,c_tr),
        "val": (X_va,y_va,u_va,c_va),
        "test_last": (X_te_last,y_te,u_te_last,c_te_last),
        "test_last_y_raw": y_te_raw
    }

# ---------------------------
# Variant specs
# ---------------------------
VARIANTS = [
    {"name":"Single Expert",                       "kind":"single", "routing":"na",    "physics":False, "full_uq":False},
    {"name":"Dense MoE (no top-k)",               "kind":"moe",    "routing":"dense", "physics":False, "full_uq":False},
    {"name":"MoE + Top-K",                        "kind":"moe",    "routing":"topk",  "physics":False, "full_uq":False},
    {"name":"MoE + Top-K + Physics",              "kind":"moe",    "routing":"topk",  "physics":True,  "full_uq":False},
    {"name":"Full (Top-K + Physics + UQ + Calib)","kind":"moe",    "routing":"topk",  "physics":True,  "full_uq":True},
]

# ---------------------------
# Training one model on one FD for one variant
# ---------------------------
def train_variant_on_domain(domain, variant):
    X_tr,y_tr,u_tr,c_tr = domain["train"]
    X_va,y_va,u_va,c_va = domain["val"]

    if variant["kind"] == "single":
        model = SingleExpertBaseline(n_features=X_tr.shape[-1]).to(DEVICE)
    else:
        model = InterpretableMoE(n_features=X_tr.shape[-1]).to(DEVICE)

    ds_tr = RULWindowDataset(X_tr,y_tr,u_tr,c_tr)
    ds_va = RULWindowDataset(X_va,y_va,u_va,c_va)

    sampler = ContiguousEngineBatchSampler(ds_tr.unit_ids, ds_tr.cycles, cfg.batch_size, cfg.block_len, shuffle=True)
    dl_tr = DataLoader(ds_tr, batch_sampler=sampler)
    dl_va = DataLoader(ds_va, batch_size=cfg.batch_size, shuffle=False)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)

    best_val=float("inf"); best_state=None; bad=0

    for ep in range(1, cfg.epochs+1):
        s = ramp(ep, cfg.warmup_epochs, cfg.ramp_epochs)

        # routing control
        if variant["routing"] == "dense" or variant["kind"] == "single":
            use_topk = False
        else:
            use_topk = (ep > cfg.warmup_epochs)

        # only meaningful for topk MoE
        noise_std = lerp(cfg.gate_noise_max, cfg.gate_noise_min, s) if (variant["routing"]=="topk" and variant["kind"]=="moe") else 0.0
        temperature = lerp(cfg.temp_max, cfg.temp_min, s) if (variant["routing"]=="topk" and variant["kind"]=="moe") else 1.0

        # weights
        mono_w   = (cfg.lambda_mono*s) if variant["physics"] else 0.0
        smooth_w = (cfg.lambda_smooth*s) if variant["physics"] else 0.0

        lb_w   = (cfg.lambda_lb*s)   if variant["kind"]=="moe" else 0.0
        ent_w  = (cfg.lambda_ent*s)  if variant["kind"]=="moe" else 0.0
        div_w  = (cfg.lambda_div*s)  if variant["kind"]=="moe" else 0.0
        dead_w = (cfg.lambda_dead*s) if variant["kind"]=="moe" else 0.0

        model.train()
        for xb,yb,ub,cb in dl_tr:
            xb=xb.to(DEVICE); yb=yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)

            mean, logv, w, lb, ent, div, dead = model(
                xb, use_topk=use_topk, noise_std=noise_std, temperature=temperature, train=True
            )

            nll = gaussian_nll(yb, mean, logv).mean()
            mse = F.mse_loss(mean, yb)
            hub = F.huber_loss(mean, yb, delta=cfg.huber_delta)

            mono, smooth = physics_losses(mean, ub, cb) if variant["physics"] else (mean.new_tensor(0.0), mean.new_tensor(0.0))

            if ep <= cfg.warmup_epochs:
                loss = 1.0*mse + 0.10*nll + cfg.huber_weight*hub
            else:
                loss = nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub

            # add-ons
            loss = loss + mono_w*mono + smooth_w*smooth
            loss = loss + lb_w*lb - ent_w*ent + div_w*div + dead_w*dead

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()

        sched.step()

        # ---- validation objective for early stopping
        model.eval()
        va_loss=0.0
        with torch.no_grad():
            for xb,yb,ub,cb in dl_va:
                xb=xb.to(DEVICE); yb=yb.to(DEVICE)
                mean, logv, *_ = model(xb, use_topk=use_topk, noise_std=0.0, temperature=1.0, train=False)
                nll = gaussian_nll(yb, mean, logv).mean()
                mse = F.mse_loss(mean, yb)
                hub = F.huber_loss(mean, yb, delta=cfg.huber_delta)
                va_loss += float((nll + cfg.aux_mse_weight*mse + cfg.huber_weight*hub).cpu())
        va_loss /= max(1,len(dl_va))

        if va_loss + 1e-6 < best_val:
            best_val=va_loss
            best_state={k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
            bad=0
        else:
            bad+=1
            if bad>=cfg.early_stop_patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model

# ---------------------------
# Evaluation on LAST window per unit (RMSE)
# ---------------------------
def eval_rmse_last(model, domain, variant):
    X_te, y_te, *_ = domain["test_last"]
    y_te_raw = domain["test_last_y_raw"]

    # point prediction choice:
    # - for Full: use MC-dropout mean (matches UQ methodology)
    # - otherwise: deterministic mean (faster + fair)
    if variant["full_uq"]:
        # for Full, use_topk=True if moe-topk; else False
        use_topk = (variant["kind"]=="moe" and variant["routing"]=="topk")
        mean, var = predict_mc_mean_var(model, X_te, n_mc=cfg.mc_samples, use_topk=use_topk)
        mean_raw = mean*cfg.max_rul if cfg.normalize_y else mean
    else:
        model.eval()
        xb=torch.tensor(X_te,dtype=torch.float32).to(DEVICE)
        use_topk = (variant["kind"]=="moe" and variant["routing"]=="topk")
        with torch.no_grad():
            mean, logv, *_ = model(xb, use_topk=use_topk, noise_std=0.0, temperature=1.0, train=False)
        mean = mean.detach().cpu().numpy()
        mean_raw = mean*cfg.max_rul if cfg.normalize_y else mean

    rmse = math.sqrt(mean_squared_error(y_te_raw, mean_raw))
    return rmse

def eval_uq_full(model, domain, variant):
    if not variant["full_uq"]:
        return None

    X_te, _, *_ = domain["test_last"]
    y_te_raw = domain["test_last_y_raw"]
    X_va, y_va, *_ = domain["val"]
    y_va_raw = (y_va*cfg.max_rul) if cfg.normalize_y else y_va

    use_topk = (variant["kind"]=="moe" and variant["routing"]=="topk")

    mean, var = predict_mc_mean_var(model, X_te, n_mc=cfg.mc_samples, use_topk=use_topk)
    if cfg.normalize_y:
        mean_raw = mean*cfg.max_rul
        var_raw  = var*(cfg.max_rul**2)
    else:
        mean_raw = mean; var_raw = var

    scale = calibrate_scale_min_width(model, X_va, y_va_raw, cfg.pi_alpha, cfg.mc_samples, use_topk=use_topk)
    var_raw = (scale**2)*var_raw

    lo, hi = prediction_interval(mean_raw, var_raw, cfg.pi_alpha)
    cov = coverage(y_te_raw, lo, hi)
    width = float((hi-lo).mean())
    return {"coverage": float(cov), "width": width, "scale": float(scale)}

# ---------------------------
# RUN: prepare domains + train all variants + build table
# ---------------------------
# auto fallback if running outside Colab
if not os.path.exists(DATA_DIR):
    DATA_DIR = "/mnt/data"

# file checks
missing=[]
for fd in FDS:
    for f in [f"train_{fd}.txt", f"test_{fd}.txt", f"RUL_{fd}.txt"]:
        if not os.path.exists(os.path.join(DATA_DIR, f)):
            missing.append(os.path.join(DATA_DIR,f))
if missing:
    raise FileNotFoundError("Missing required C-MAPSS files:\n" + "\n".join(missing))

results = {v["name"]: {} for v in VARIANTS}
uq_results = {v["name"]: {} for v in VARIANTS}

for fd in FDS:
    # prepare once per fd (shared scaler/kept_cols across variants for fairness)
    domain = prepare_domain(fd, scaler=None, kept_cols=None)

    for v in VARIANTS:
        for mi in range(ENSEMBLE_SIZE):
            set_seed(cfg.seed + mi)

            model = train_variant_on_domain(domain, v)
            rmse = eval_rmse_last(model, domain, v)

            # store best if ensemble>1 (here 1)
            key = f"m{mi}"
            # keep single value
            results[v["name"]][fd] = rmse

            if v["full_uq"]:
                uq = eval_uq_full(model, domain, v)
                uq_results[v["name"]][fd] = uq

# build table dataframe
rows=[]
for v in VARIANTS:
    name=v["name"]
    rmses=[results[name][fd] for fd in FDS]
    rows.append({
        "Model": name,
        "FD001 RMSE": rmses[0],
        "FD002 RMSE": rmses[1],
        "FD003 RMSE": rmses[2],
        "FD004 RMSE": rmses[3],
        "Avg RMSE": float(np.mean(rmses)),
    })

df = pd.DataFrame(rows)

# pretty print markdown table
def fmt(x): return f"{x:.3f}"
print("\n==============================")
print(" Ablation RMSE Table (LAST window per unit)")
print("==============================")
print("| Model | FD001 RMSE | FD002 RMSE | FD003 RMSE | FD004 RMSE | Avg RMSE |")
print("|---|---:|---:|---:|---:|---:|")
for _,r in df.iterrows():
    print(f"| {r['Model']} | {fmt(r['FD001 RMSE'])} | {fmt(r['FD002 RMSE'])} | {fmt(r['FD003 RMSE'])} | {fmt(r['FD004 RMSE'])} | {fmt(r['Avg RMSE'])} |")

# optional UQ report for Full
if PRINT_UQ_FOR_FULL:
    for v in VARIANTS:
        if not v["full_uq"]:
            continue
        print("\n==============================")
        print(f" UQ report: {v['name']}")
        print("==============================")
        for fd in FDS:
            uq = uq_results[v["name"]].get(fd, None)
            if uq is None:
                continue
            print(f"{fd}: PI{int((1-cfg.pi_alpha)*100)}% coverage={uq['coverage']:.3f} | width={uq['width']:.3f} | cal_scale={uq['scale']:.3f}")

print("\n✅ Done.")


KeyboardInterrupt: 