In [None]:
import json, numpy as np, pandas as pd
import os, gc, numpy as np, pandas as pd, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import statsmodels.api as sm
from statsmodels.stats.sandwich_covariance import cov_hac

In [None]:
#colab libares
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
#these are the seeds theat they used and we keep constant for reproduction
seeds=[3,154295,583240,321536,0, 865466, 260937, 32, 23549, 469686]

In [None]:
MODE               = 'FULL'      
XZ_PATH            = '/content/drive/MyDrive/Dissertation/data/Xz_float32.npy'   #from the outputs of the data_cleaning_and_pca file
META_PATH          = '/content/drive/MyDrive/Dissertation/data/meta.parquet'     #from the outputs of the data_cleaning_and_pca file
TARGET_COL         = 'r_dh_no_tc' #or  'r_dh_tc'
BATCH_SIZE         = 8192
LR                 = 1e-3
WEIGHT_DECAY       = 1e-5
DROPOUT_P          = 0.4
N_EPOCHS_INIT      = 12          
N_EPOCHS_PER_YEAR  = 6           
SEEDS              = seeds   
ENSEMBLE_N         = len(SEEDS)        
DEVICE             = 'cuda' if torch.cuda.is_available() else 'cpu'
PORTS_N            = 10           #deciles
LAGS_NW            = 6          
SAVE_DIR           = None         
ADD_MASK_CHANNEL   = False
GRAD_CLIP_NORM     = 5.0
DISPLAY_IN_PCT     = True        
PRE_STANDARDIZED_INPUT = True     

def _require_cols(df, cols):
    miss = [c for c in cols if c not in df.columns]
    if miss: raise ValueError(f"no columns in meta: {miss}")

def _ensure_channel_dim(X):
    X = np.asarray(X)
    if X.ndim == 3:
        X = X[:, None, :, :]
    elif X.ndim == 4 and X.shape[1] == 1:
        pass
    else:
        raise ValueError(f"Xz should be (N,H,W) or (N,1,H,W).it is {X.shape}")
    return np.ascontiguousarray(X.astype(np.float32))

def _load_Xz(path):
    if path is None:
        if 'Xz' not in globals():
            raise RuntimeError("Xz not in ram or XZ_PATH=None.")
        return _ensure_channel_dim(globals()['Xz'])
    ext = os.path.splitext(path)[1].lower()
    if ext == '.npy':
        X = np.load(path, allow_pickle=False)
    elif ext == '.npz':
        npz = np.load(path, allow_pickle=False)
        X = npz['Xz'] if 'Xz' in npz.files else npz[npz.files[0]]
    else:
        raise ValueError("XZ_PATH should be .npy or .npz")
    return _ensure_channel_dim(X)

def _load_meta(path):
    if path is None:
        if 'meta' not in globals():
            raise RuntimeError("Meta not in ram or META_PATH=None.")
        return globals()['meta'].copy()
    ext = os.path.splitext(path)[1].lower()
    if ext == '.parquet':
        df = pd.read_parquet(path)
    elif ext == '.csv':
        df = pd.read_csv(path)
    else:
        raise ValueError("META_PATH should be  .parquet or .csv")
    return df

#standardization if not already
def compute_pixel_stats(X, idx):
    Xi = X[idx].astype(np.float32)
    Xi = np.where(np.isfinite(Xi), Xi, np.nan)
    mu = np.nanmean(Xi, axis=0)
    sd = np.nanstd(Xi, axis=0, ddof=0)
    sd = np.where((~np.isfinite(sd)) | (sd < 1e-8), 1.0, sd)
    mu = np.where(~np.isfinite(mu), 0.0, mu)
    return mu.astype(np.float32), sd.astype(np.float32)

def standardize_and_stack(X, mu, sd, mask=None, add_mask_channel=False):
    """
    IF mu=0 & sd=1 then the transfromation has no impact 
    """
    X = X.astype(np.float32)
    Xf = np.where(np.isfinite(X), X, np.nan)
    Z  = (Xf - mu) / sd
    Z  = np.where(np.isfinite(Z), Z, 0.0)
    if add_mask_channel:
        if mask is None:
            mask = np.isfinite(X).astype(np.float32)
        Z = np.concatenate([Z, mask.astype(np.float32)], axis=1)  # (N,2,H,W)
    return np.ascontiguousarray(Z.astype(np.float32))

class SurfaceDataset(Dataset):
    def __init__(self, X, y):
        self.X = np.ascontiguousarray(X.astype(np.float32))
        self.y = np.asarray(y, dtype=np.float32)
    def __len__(self): return self.y.shape[0]
    def __getitem__(self, k):
        return torch.from_numpy(self.X[k]), torch.tensor(self.y[k], dtype=torch.float32)
    
class HoflerCNN(nn.Module):
    """
    4 conv blocks: (16,32,64,128), 3x3, LeakyReLU, BatchNorm.
    MaxPool 2x2 at 3 first blocks, Global Avg Pool, Dropout.
    """
    def __init__(self, in_ch=1, p_drop=0.5):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, 16, kernel_size=3, padding=1)
        self.bn1   = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2   = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3   = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64,128, kernel_size=3, padding=1)
        self.bn4   = nn.BatchNorm2d(128)
        self.act   = nn.LeakyReLU(0.1, inplace=True)
        self.pool  = nn.MaxPool2d(2,2)
        self.drop  = nn.Dropout(p_drop)
        self.out   = nn.Linear(128, 1)
        self._init_weights()
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
    def forward(self, x):
        x = self.pool(self.act(self.bn1(self.conv1(x))))  # 10×18 → 5×9
        x = self.pool(self.act(self.bn2(self.conv2(x))))  # 5×9 → 2×4
        x = self.pool(self.act(self.bn3(self.conv3(x))))  # 2×4 → 1×2
        x = self.act(self.bn4(self.conv4(x)))             # 1×2
        x = F.adaptive_avg_pool2d(x, (1,1))
        x = torch.flatten(x, 1)
        x = self.drop(x)
        x = self.out(x).squeeze(1)
        return x

#train and predicitons
def make_loader(X, y, bs=BATCH_SIZE, shuffle=True):
    return DataLoader(SurfaceDataset(X, y),
                      batch_size=bs, shuffle=shuffle,
                      num_workers=8, pin_memory=(DEVICE=='cuda'))

def train_epochs(model, loader, nepochs, opt=None, device=DEVICE):
    model.to(device); model.train()
    if opt is None:
        opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    for ep in range(1, nepochs+1):
        losses = []
        for xb, yb in loader:
            xb = xb.to(device, non_blocking=True); yb = yb.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            pred = model(xb).float()
            loss = F.mse_loss(pred, yb)
            if not torch.isfinite(loss):
                fin_x = float(torch.isfinite(xb).float().mean().item())
                fin_y = float(torch.isfinite(yb).float().mean().item())
                raise RuntimeError(f"NaN/Inf loss. finite_ratio(x)={fin_x:.3f}, finite_ratio(y)={fin_y:.3f}")
            loss.backward()
            if GRAD_CLIP_NORM is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
            opt.step()
            losses.append(loss.item())
        tqdm.write(f"  epoch {ep}/{nepochs}  train MSE={np.mean(losses):.6f}")
    return opt

@torch.no_grad()
def predict(model, X, device=DEVICE, bs=8192):
    model.eval(); model.to(device)
    out = np.empty(X.shape[0], dtype=np.float32)
    for i in range(0, X.shape[0], bs):
        xb = torch.from_numpy(X[i:i+bs]).to(device)
        out[i:i+bs] = model(xb).float().cpu().numpy()
    return out



def newey_west_mean_t(y: pd.Series, lags: int = 6):
    y = pd.Series(y).astype(float).dropna()
    if y.empty: return np.nan, np.nan
    X = np.ones((len(y),1))
    res = sm.OLS(y.values, X).fit()
    V = cov_hac(res, nlags=lags)
    mu = float(res.params[0])
    se = float(np.sqrt(V[0,0])) if (V is not None and V[0,0]>0) else np.nan
    t  = mu/se if (se is not None and se>0) else np.nan
    return mu, t

def sharpe_ann(y: pd.Series):
    y = pd.Series(y).astype(float).dropna()
    s = y.std(ddof=1)
    return (y.mean()/s)*np.sqrt(12.0) if s>0 else np.nan

#rank-based assignment 
def _assign_ports_per_month_rank(s: pd.Series, n: int) -> pd.Series:
    x = pd.Series(s)
    mask = x.notna(); xn = x[mask]
    k = min(n, int(xn.nunique()))
    if (k < 2) or (len(xn) < 2):
        return pd.Series(pd.array([pd.NA]*len(x), dtype="Int64"), index=x.index)
    r = xn.rank(method='first')
    nobs = len(xn)
    port = ((r - 1) * k / nobs).astype(int) + 1
    port = port.clip(1, k).astype("Int64")
    out = pd.Series(pd.array([pd.NA]*len(x), dtype="Int64"), index=x.index)
    out.loc[xn.index] = port
    return out

def _monthly_port_avgs(d, weight_col, pred_col, ret_col):
    if weight_col is None:
        g = (d.groupby(['month','port'], observed=True)
               .agg(pred=(pred_col,'mean'), real=(ret_col,'mean'))
               .reset_index())
    else:
        grp = d[['month','port', pred_col, ret_col, weight_col]].dropna()
        grp = grp[grp[weight_col] > 0]
        for col in (pred_col, ret_col):
            grp[f'w_{col}'] = grp[col]*grp[weight_col]
        g = (grp.groupby(['month','port'], observed=True)
                 .agg(pred=('w_'+pred_col,'sum'),
                      real=('w_'+ret_col,'sum'),
                      w=(weight_col,'sum'))
                 .reset_index())
        g['pred'] = np.where(g['w']>0, g['pred']/g['w'], np.nan)
        g['real'] = np.where(g['w']>0, g['real']/g['w'], np.nan)
        g = g.drop(columns='w')
    return g

def cnn_decile_table(meta, pred_col, ret_col, n_ports=10, weighting='oiw', lags=LAGS_NW):
    d = meta.copy()
    _require_cols(d, ['month', pred_col, ret_col])
    if 'doi' not in d.columns: d['doi'] = 1.0
    d = d.dropna(subset=['month', pred_col, ret_col])
    d['month'] = pd.to_datetime(d['month'])
    d['port'] = d.groupby('month', observed=True)[pred_col].transform(
                  lambda s: _assign_ports_per_month_rank(s, n_ports)
               )
    d = d.dropna(subset=['port'])
    d['port'] = d['port'].astype('Int64')

    weight_col = None if weighting.lower()=='ew' else 'doi'
    g = _monthly_port_avgs(d, weight_col, pred_col, ret_col) 

    maxp = d.groupby('month', observed=True)['port'].max().rename('maxp')
    gj   = g.merge(maxp, on='month', how='left')

    #time series
    ts_pred = {}
    ts_real = {}
    for k in range(1, n_ports+1):
        ts_pred[k] = g[g['port']==k].set_index('month')['pred'].sort_index()
        ts_real[k] = g[g['port']==k].set_index('month')['real'].sort_index()

    #Low/High per month
    low_pred_ts  = g[g['port']==1].set_index('month')['pred'].sort_index()
    low_real_ts  = g[g['port']==1].set_index('month')['real'].sort_index()
    high_pred_ts = gj[gj['port']==gj['maxp']].set_index('month')['pred'].sort_index()
    high_real_ts = gj[gj['port']==gj['maxp']].set_index('month')['real'].sort_index()

    #HML time series
    hml_pred_ts = (high_pred_ts - low_pred_ts).dropna()
    hml_real_ts = (high_real_ts - low_real_ts).dropna()

    rows = []
    labels = (['Low'] + list(range(2, n_ports)) + ['High'])
    for lbl in labels:
        if lbl == 'Low':
            pr, rr = low_pred_ts, low_real_ts
        elif lbl == 'High':
            pr, rr = high_pred_ts, high_real_ts
        else:
            k = int(lbl)
            pr, rr = ts_pred.get(k, pd.Series(dtype=float)), ts_real.get(k, pd.Series(dtype=float))
        mu_pred = float(pd.Series(pr).dropna().mean()) if len(pr)>0 else np.nan
        mu_real = float(pd.Series(rr).dropna().mean()) if len(rr)>0 else np.nan
        tstat   = newey_west_mean_t(rr, lags=lags)[1]
        sr      = sharpe_ann(rr)
        rows.append({'Bucket': lbl, 'Pred.': mu_pred, 'Real.': mu_real, 't': tstat, 'SR': sr})

    #HML row
    mu_pred_hml = float(hml_pred_ts.mean()) if not hml_pred_ts.empty else np.nan
    mu_real_hml = float(hml_real_ts.mean()) if not hml_real_ts.empty else np.nan
    t_hml       = newey_west_mean_t(hml_real_ts, lags=lags)[1]
    sr_hml      = sharpe_ann(hml_real_ts)
    rows.append({'Bucket': 'HML', 'Pred.': mu_pred_hml, 'Real.': mu_real_hml, 't': t_hml, 'SR': sr_hml})

    tbl = pd.DataFrame(rows).set_index('Bucket')
    return tbl, {'HML': hml_real_ts}


def main():
    X = _load_Xz(XZ_PATH)             
    df = _load_meta(META_PATH).copy()
    _require_cols(df, ['month', TARGET_COL])
    if 'doi' not in df.columns: df['doi'] = 1.0
    df['month'] = pd.to_datetime(df['month'])
    assert len(df) == X.shape[0], f"Not correct amount of lines at len(meta)={len(df)} vs N in Xz={X.shape[0]}"

    finite_ratio_X = float(np.isfinite(X).mean())
    finite_ratio_y = float(np.isfinite(df[TARGET_COL].to_numpy()).mean())
    pix_nan_rate = (~np.isfinite(X)).mean(axis=0)  
    print(f"Xz finite ratio: {finite_ratio_X:.4f} | Target finite ratio: {finite_ratio_y:.4f}")
    print("Pixel NaN rate [min/mean/max]:",
          float(pix_nan_rate.min()), float(pix_nan_rate.mean()), float(pix_nan_rate.max()))

    # 3) Define months & OOS start (min 7y)
    months = np.sort(df['month'].unique())
    min_month = pd.to_datetime(df['month'].min())
    first_oos_month = min_month + pd.DateOffset(years=7)
    if first_oos_month > months.max():
        raise RuntimeError("Not enough data for 7 years training")

    #fine tuning
    years_oos = sorted(pd.Series(months[months >= first_oos_month]).dt.year.unique())
    if MODE.upper() == 'FAST' and len(years_oos) > 2:
        years_oos = years_oos[-2:]  # smoke test δύο χρόνια

    print(f"Mode={MODE} | OOS years: {len(years_oos)} | Xz shape: {X.shape} | meta shape: {df.shape}")

    IN_CH = 1 + (1 if ADD_MASK_CHANNEL else 0)
    MASK = np.isfinite(X).astype(np.float32) if ADD_MASK_CHANNEL else None

    if PRE_STANDARDIZED_INPUT:
        mu = np.zeros(X.shape[1:], dtype=np.float32)   
        sd = np.ones (X.shape[1:], dtype=np.float32)   
        print("PRE_STANDARDIZED_INPUT=True:identity transform.")
    else:
        init_idx = np.where(df['month'].lt(first_oos_month).values)[0]
        mu, sd = compute_pixel_stats(X, init_idx)

    rb_hat = np.full(len(df), np.nan, dtype=np.float32)
    models, opts = {}, {}
    for s in SEEDS:
        torch.manual_seed(s); np.random.seed(s)
        m = HoflerCNN(in_ch=IN_CH, p_drop=DROPOUT_P)
        o = torch.optim.Adam(m.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        models[s], opts[s] = m, o

    def make_tr_loader(idx):
        X_tr = standardize_and_stack(X[idx], mu, sd,
                                     mask=None if MASK is None else MASK[idx],
                                     add_mask_channel=ADD_MASK_CHANNEL)
        y_tr = df.loc[idx, TARGET_COL].to_numpy(np.float32)
        good = np.isfinite(y_tr)
        if (~good).sum() > 0:
            print(f"Warning: βρέθηκαν {(~good).sum()} μη-περατά targets — αγνοούνται στο training.")
        return make_loader(X_tr[good], y_tr[good], bs=BATCH_SIZE, shuffle=True)

    #annual walk-forward
    for i, Y in enumerate(tqdm(years_oos, desc="OOS years")):
        start_Y = pd.Timestamp(year=Y, month=1, day=1)
        start_next = pd.Timestamp(year=Y+1, month=1, day=1)

        tr_idx = np.where(df['month'].lt(start_Y).values)[0]
        te_mask = (df['month'] >= start_Y) & (df['month'] < start_next) & (df['month'] >= first_oos_month)
        te_idx  = np.where(te_mask.values)[0]
        if len(tr_idx)==0 or len(te_idx)==0:
            continue
        tr_loader = make_tr_loader(tr_idx)
        ne = N_EPOCHS_INIT if i==0 else N_EPOCHS_PER_YEAR
        print(f"\n=== Year {Y} | train_n={len(tr_idx):,} | test_n={len(te_idx):,} | epochs={ne} ===")
        oos_preds = []
        for s in SEEDS:
            opts[s] = train_epochs(models[s], tr_loader, nepochs=ne, opt=opts[s], device=DEVICE)
            X_te = standardize_and_stack(X[te_idx], mu, sd,
                                         mask=None if MASK is None else MASK[te_idx],
                                         add_mask_channel=ADD_MASK_CHANNEL)
            oos_preds.append(predict(models[s], X_te, device=DEVICE))
        rb_hat[te_idx] = np.mean(np.vstack(oos_preds), axis=0)
        del tr_loader; gc.collect()
        if DEVICE=='cuda':
            torch.cuda.empty_cache()

    #Collect OOS predictions
    meta_eval = df.copy()
    meta_eval['rb_cnn'] = rb_hat
    nnz = int(np.isfinite(meta_eval['rb_cnn']).sum())
    print(f"\nOOS non-NaN preds: {nnz} / {len(meta_eval)}")
    if nnz == 0:
        raise RuntimeError("no predictions")

    #tables
    print(f"\nBuilding decile tables (N={PORTS_N}) ...")
    tbl_oiw, hml_ts_oiw = cnn_decile_table(meta_eval, 'rb_cnn', TARGET_COL, n_ports=PORTS_N, weighting='oiw', lags=LAGS_NW)
    tbl_ew,  hml_ts_ew  = cnn_decile_table(meta_eval, 'rb_cnn', TARGET_COL, n_ports=PORTS_N, weighting='ew',  lags=LAGS_NW)

    def _fmt(tbl):
        out = tbl.copy()
        if DISPLAY_IN_PCT:
            for c in ['Pred.','Real.']:
                out[c] = 100.0 * out[c]
        return out

    print("\n=== (a) Dollar open interest weighted ===")
    print(_fmt(tbl_oiw).round({'Pred.':2,'Real.':2,'t':2,'SR':2}).to_string())

    print("\n=== (b) Equal weighted ===")
    print(_fmt(tbl_ew).round({'Pred.':2,'Real.':2,'t':2,'SR':2}).to_string())

    #saving
    if SAVE_DIR:
        os.makedirs(SAVE_DIR, exist_ok=True)
        meta_eval.to_parquet(os.path.join(SAVE_DIR, 'meta_eval_with_rb_cnn.parquet'))
        tbl_oiw.to_csv(os.path.join(SAVE_DIR, 'cnn_table_oiw.csv'))
        tbl_ew.to_csv(os.path.join(SAVE_DIR, 'cnn_table_ew.csv'))
        hml_ts_oiw['HML'].to_csv(os.path.join(SAVE_DIR, 'hml_oiw_series.csv'))
        hml_ts_ew['HML'].to_csv(os.path.join(SAVE_DIR, 'hml_ew_series.csv'))
        print(f"\nSaved outputs to: {os.path.abspath(SAVE_DIR)}")

    return meta_eval, tbl_oiw, tbl_ew

if __name__ == "__main__":
    _ = main()

In [None]:
import os, time, gc, math, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from torchdiffeq import odeint
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import copy
import statsmodels.api as sm
from statsmodels.stats.sandwich_covariance import cov_hac
import shap

In [None]:



class CFG:
    MODE               = 'FULL'             
    XZ_PATH            = '/content/drive/MyDrive/Dissertation/data/Xz_float32.npy' #same as before at the cnn
    META_PATH          = '/content/drive/MyDrive/Dissertation/data/meta.parquet'
    TARGET_COL         = 'r_dh_no_tc'
    PRE_STANDARDIZED_INPUT = True           
    ADD_MASK_CHANNEL   = False             
    DISPLAY_IN_PCT     = True

    DEVICE             = 'cuda' if torch.cuda.is_available() else 'cpu'
    AMP_DTYPE          = torch.bfloat16     
    USE_SCALER         = False              
    CHANNELS_LAST      = True               
    TORCH_COMPILE      = False              

    BATCH_SIZE         = 8192
    LR                 = 1e-3
    WEIGHT_DECAY       = 1e-5
    GRAD_CLIP_NORM     = 5.0
    DROPOUT_P          = 0.4

    EPOCHS_INIT        = 12                
    EPOCHS_PER_YEAR    = 6                  
    SEEDS              = seeds             

    #NODE
    DZ                 = 128
    VF_WIDTH           = 256
    ODE_METHOD         = 'dopri5'
    ODE_RTOL           = 1e-3
    ODE_ATOL           = 1e-4
    ODE_MAX_STEPS      = 64
    ADJOINT            = False


    LOSS_ALPHA_MSE     = 0.75

    PORTS_N            = 10
    LAGS_NW            = 6


    OPTION_CHAR_COLS   = ['spread','delta','gamma','vega','theta','iv','odp']
    OPTION_CHAR_PCT    = ['spread']        
    HOF_ROUND          = 2
    OPTION_CHAR_SCALERS = {
        'gamma': 50.0,   
        'vega' : 0.01,   
        'theta': 0.02,  

    }

    #SHAP
    SHAP_ENABLE            = True
    SHAP_BG_SAMPLES        = 64
    SHAP_EXPLAIN_SAMPLES   = 64
    SHAP_SAVE_DIR          = './shap_out'
    SHAP_SHOW_FIGS        = True  
    SHAP_SAVE_FIGS        = True   
    SHAP_SAVE_STATS       = False  
    SHAP_STATS_TOPK       = 10     

    #LIME
    LIME_ENABLE            = True
    LIME_SAVE_DIR          = './lime_out'
    LIME_NUM_SAMPLES       = 1500     
    LIME_NUM_FEATURES      = 12        
    LIME_SEGMENTATION_MODE = 'grid'   
    LIME_SLIC_SEGMENTS     = 100       
    LIME_BATCH_SIZE        = 2048      
    LIME_SHOW_FIGS = True
    LIME_SAVE_FIGS = True



def _ensure_channel_dim(X):
    X = np.asarray(X, dtype=np.float32)
    if X.ndim == 3:  #(N,H,W)
        X = X[:, None, :, :]
    if not (X.ndim == 4 and X.shape[1] in (1,2)):
        raise ValueError(f"Xz must be (N,1,H,W) or (N,2,H,W). Got: {X.shape}")
    return np.ascontiguousarray(X)

def load_Xz(path):
    ext = os.path.splitext(path)[1].lower()
    if ext == '.npy':
        X = np.load(path, allow_pickle=False)
    elif ext == '.npz':
        npz = np.load(path, allow_pickle=False)
        X = npz['Xz'] if 'Xz' in npz.files else npz[npz.files[0]]
    else:
        raise ValueError("XZ_PATH should be .npy or .npz")
    return _ensure_channel_dim(X)

def load_meta(path):
    ext = os.path.splitext(path)[1].lower()
    if ext == '.parquet':
        df = pd.read_parquet(path)
    elif ext == '.csv':
        df = pd.read_csv(path)
    else:
        raise ValueError("META_PATH should be .parquet or .csv")
    return df

def identity_stack(X, add_mask=False):
    X = np.asarray(X, dtype=np.float32)
    Xf = np.where(np.isfinite(X), X, 0.0)
    if add_mask:
        mask = np.isfinite(X).astype(np.float32)
        Xf = np.concatenate([Xf, mask], axis=1)
    return np.ascontiguousarray(Xf)

class SurfDS(Dataset):
    def __init__(self, X, y):
        self.X = np.ascontiguousarray(X, dtype=np.float32)  #(N,C,H,W)
        self.y = np.asarray(y, dtype=np.float32)
    def __len__(self): return self.y.shape[0]
    def __getitem__(self, i):
        return torch.from_numpy(self.X[i]), torch.tensor(self.y[i], dtype=torch.float32)

def make_loader(X, y, bs, shuffle=True, device=CFG.DEVICE):
    return DataLoader(SurfDS(X,y), batch_size=bs, shuffle=shuffle,
                      num_workers=0, pin_memory=(device=='cuda'), drop_last=False)

#CNN->NODE
class Encoder(nn.Module):
    """Höfler-style 4 conv blocks + GAP + dropout -> 128-dim."""
    def __init__(self, in_ch=1, p_drop=0.35):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, 16, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64,128, 3, padding=1)
        self.bn4   = nn.BatchNorm2d(128)
        self.pool  = nn.MaxPool2d(2,2)
        self.act   = nn.LeakyReLU(0.1, inplace=True) 
        self.drop  = nn.Dropout(p_drop)
        self.out   = nn.Linear(128, 128)
        self._init()
    def _init(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, x):
        x = self.pool(self.act(self.bn1(self.conv1(x))))
        x = self.pool(self.act(self.bn2(self.conv2(x))))
        x = self.pool(self.act(self.bn3(self.conv3(x))))
        x = self.act(self.bn4(self.conv4(x)))
        x = F.adaptive_avg_pool2d(x, (1,1))
        x = torch.flatten(x, 1)
        x = self.drop(x)
        return self.out(x)  #(B,128)

class VecField(nn.Module):
    def __init__(self, dz=CFG.DZ, width=CFG.VF_WIDTH):
        super().__init__()
        self.nfe = 0
        self.mlp = nn.Sequential(
            nn.Linear(dz+1, width), nn.SiLU(),
            nn.Linear(width, width), nn.SiLU(),
            nn.Linear(width, dz)
        )
        for m in self.mlp:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, s, z):
        self.nfe += 1
        if z.ndim == 1: z = z.unsqueeze(0)
        s_in = torch.full((z.shape[0],1), float(s), device=z.device, dtype=z.dtype)
        x = torch.cat([z, s_in], dim=1)
        return self.mlp(x)
    def reset_nfe(self): self.nfe = 0
#NODE
class NODEBlock(nn.Module):
    def __init__(self, dz=CFG.DZ, method=CFG.ODE_METHOD, rtol=CFG.ODE_RTOL, atol=CFG.ODE_ATOL, max_steps=CFG.ODE_MAX_STEPS):
        super().__init__()
        self.func = VecField(dz=dz, width=CFG.VF_WIDTH)
        self.method = method
        self.rtol   = rtol
        self.atol   = atol
        self.max_steps = max_steps
        self.register_buffer('tspan', torch.tensor([0.0, 1.0], dtype=torch.float32))
    def forward(self, z0):
        z0_fp32 = z0.float()
        self.func.reset_nfe()
        zT = odeint(self.func, z0_fp32, self.tspan,
                    method=self.method, rtol=self.rtol, atol=self.atol,
                    options={'max_num_steps': int(self.max_steps)})
        z1 = zT[-1]
        return z1.to(z0.dtype), self.func.nfe

class CNN_NODE(nn.Module):
    def __init__(self, in_ch=1, dz=CFG.DZ, p_drop=CFG.DROPOUT_P):
        super().__init__()
        self.enc   = Encoder(in_ch=in_ch, p_drop=p_drop)
        self.lin_z0= nn.Linear(128, dz)
        self.node  = NODEBlock(dz=dz)
        self.head  = nn.Linear(dz, 1)
        nn.init.xavier_uniform_(self.lin_z0.weight); nn.init.zeros_(self.lin_z0.bias)
        nn.init.xavier_uniform_(self.head.weight);  nn.init.zeros_(self.head.bias)
    def forward(self, x):
        h  = self.enc(x)
        z0 = self.lin_z0(h)
        z1, nfe = self.node(z0)
        y  = self.head(z1).squeeze(1)
        return y, nfe

#LOSS and Metrics
@torch.no_grad()
def batch_corr(yhat, y):
    yhat = yhat.float(); y = y.float()
    yhat = yhat - yhat.mean(); y = y - y.mean()
    num = (yhat*y).sum()
    den = torch.sqrt((yhat*yhat).sum() * (y*y).sum() + 1e-12)
    return (num/den).item() if den > 0 else 0.0

def corr_loss(yhat, y):
    yhat = yhat.float(); y = y.float()
    yhat = yhat - yhat.mean()
    y    = y - y.mean()
    num = (yhat*y).sum()
    den = torch.sqrt((yhat*yhat).sum() * (y*y).sum() + 1e-12)
    c = num/(den+1e-12)
    return 1.0 - c

#Train-predict
def train_epochs(model, loader, nepochs, opt=None, device=CFG.DEVICE):
    model.to(device)
    if CFG.CHANNELS_LAST:
        model = model.to(memory_format=torch.channels_last)
    model.train()
    if opt is None:
        opt = torch.optim.AdamW(model.parameters(), lr=CFG.LR, weight_decay=CFG.WEIGHT_DECAY)
    scaler = torch.cuda.amp.GradScaler(enabled=(CFG.USE_SCALER and device=='cuda'))

    for ep in range(1, nepochs+1):
        t0 = time.time()
        loss_meter, corr_meter, nfe_meter, nb = 0.0, 0.0, 0.0, 0
        steps = 10
        for xb, yb in loader:
            steps += 2
            if xb.ndim == 3: xb = xb.unsqueeze(0)
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()
            if CFG.CHANNELS_LAST and xb.ndim == 4:
                xb = xb.contiguous(memory_format=torch.channels_last)

            opt.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type='cuda', dtype=CFG.AMP_DTYPE, enabled=(device=='cuda')):
                yhat, nfe = model(xb)
                mse  = F.mse_loss(yhat.float(), yb)
                cl   = corr_loss(yhat.float(), yb)
                loss = CFG.LOSS_ALPHA_MSE*mse + (1.0-CFG.LOSS_ALPHA_MSE)*cl

            if scaler.is_enabled():
                scaler.scale(loss).backward()
                if CFG.GRAD_CLIP_NORM is not None:
                    scaler.unscale_(opt)
                    nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP_NORM)
                scaler.step(opt); scaler.update()
            else:
                loss.backward()
                if CFG.GRAD_CLIP_NORM is not None:
                    nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP_NORM)
                opt.step()

            loss_meter += loss.item() * yb.size(0)
            corr_meter += batch_corr(yhat.detach(), yb) * yb.size(0)
            nfe_meter  += float(nfe) * yb.size(0)
            nb         += yb.size(0)

        t1 = time.time()
        print(f"  epoch {ep}/{nepochs}  loss={loss_meter/nb:.6f}  corr≈{corr_meter/nb:.3f}  "
              f"time={t1-t0:.1f}s  steps={steps}  NFE≈{nfe_meter/nb:.1f}")
    return opt, model

@torch.no_grad()
def predict(model, X, device=CFG.DEVICE, bs=8192):
    model.eval(); model.to(device)
    out = np.empty(X.shape[0], dtype=np.float32)
    for i in range(0, X.shape[0], bs):
        xb = torch.from_numpy(X[i:i+bs]).to(device, non_blocking=True)
        if CFG.CHANNELS_LAST and xb.ndim == 4:
            xb = xb.contiguous(memory_format=torch.channels_last)
        yhat, _ = model(xb)
        out[i:i+bs] = yhat.float().detach().cpu().numpy()
    return out


def newey_west_mean_t(y, lags=6):
    y = pd.Series(y).astype(float).dropna()
    if y.empty: return np.nan, np.nan
    X = np.ones((len(y),1))
    res = sm.OLS(y.values, X).fit()
    V = cov_hac(res, nlags=lags)
    mu = float(res.params[0])
    se = float(np.sqrt(V[0,0])) if V is not None and V[0,0]>0 else np.nan
    return mu, (mu/se if (se and se>0) else np.nan)

def sharpe_ann(y):
    y = pd.Series(y).astype(float).dropna()
    s = y.std(ddof=1)
    return (y.mean()/s)*np.sqrt(12.0) if s>0 else np.nan

def _assign_ports_per_month_rank(s, n):
    x = pd.Series(s); mask = x.notna(); xn = x[mask]
    k = min(n, int(xn.nunique()))
    if (k<2) or (len(xn)<2):
        return pd.Series(pd.array([pd.NA]*len(x), dtype="Int64"), index=x.index)
    r = xn.rank(method='first')
    nobs = len(xn)
    port = ((r-1)*k/nobs).astype(int)+1
    port = port.clip(1,k).astype("Int64")
    out = pd.Series(pd.array([pd.NA]*len(x), dtype="Int64"), index=x.index)
    out.loc[xn.index] = port
    return out

def _monthly_port_avgs(d, weight_col, pred_col, ret_col):
    if weight_col is None:
        g = (d.groupby(['month','port'], observed=True)
              .agg(pred=(pred_col,'mean'), real=(ret_col,'mean'))
              .reset_index())
    else:
        grp = d[['month','port',pred_col,ret_col,weight_col]].dropna()
        grp = grp[grp[weight_col]>0]
        for col in (pred_col, ret_col):
            grp[f'w_{col}'] = grp[col]*grp[weight_col]
        g = (grp.groupby(['month','port'], observed=True)
                .agg(pred=('w_'+pred_col,'sum'),
                    real=('w_'+ret_col,'sum'),
                    w=(weight_col,'sum'))
                .reset_index())
        g['pred'] = np.where(g['w']>0, g['pred']/g['w'], np.nan)
        g['real'] = np.where(g['w']>0, g['real']/g['w'], np.nan)
        g = g.drop(columns='w')
    return g

def decile_table(meta, pred_col, ret_col, n_ports=10, weighting='oiw', lags=CFG.LAGS_NW, display_pct=CFG.DISPLAY_IN_PCT):
    d = meta.copy()
    if 'doi' not in d.columns: d['doi'] = 1.0
    d = d.dropna(subset=['month', pred_col, ret_col])
    d['month'] = pd.to_datetime(d['month'])
    d['port'] = d.groupby('month', observed=True)[pred_col].transform(lambda s: _assign_ports_per_month_rank(s, n_ports))
    d = d.dropna(subset=['port']); d['port'] = d['port'].astype('Int64')
    weight_col = None if weighting.lower()=='ew' else 'doi'
    g = _monthly_port_avgs(d, weight_col, pred_col, ret_col)
    maxp = d.groupby('month', observed=True)['port'].max().rename('maxp')
    gj   = g.merge(maxp, on='month', how='left')
    def _ts(k, col):
        return g[g['port']==k].set_index('month')[col].sort_index()
    low_p  = _ts(1,'pred');  low_r  = _ts(1,'real')
    high_p = gj[gj['port']==gj['maxp']].set_index('month')['pred'].sort_index()
    high_r = gj[gj['port']==gj['maxp']].set_index('month')['real'].sort_index()

    rows = []
    labels = (['Low'] + list(range(2,n_ports)) + ['High'])
    for lbl in labels:
        if lbl=='Low': pr, rr = low_p, low_r
        elif lbl=='High': pr, rr = high_p, high_r
        else: pr, rr = _ts(int(lbl),'pred'), _ts(int(lbl),'real')
        mu_pred = float(pd.Series(pr).dropna().mean()) if len(pr)>0 else np.nan
        mu_real = float(pd.Series(rr).dropna().mean()) if len(rr)>0 else np.nan
        tstat   = newey_west_mean_t(rr, lags=lags)[1]
        sr      = sharpe_ann(rr)
        rows.append({'Bucket': lbl, 'Pred.': mu_pred, 'Real.': mu_real, 't': tstat, 'SR': sr})

    hml_pred = (high_p - low_p).dropna(); hml_real = (high_r - low_r).dropna()
    rows.append({'Bucket':'HML',
                'Pred.': float(hml_pred.mean()) if not hml_pred.empty else np.nan,
                'Real.': float(hml_real.mean()) if not hml_real.empty else np.nan,
                't': newey_west_mean_t(hml_real, lags=lags)[1],
                'SR': sharpe_ann(hml_real)})
    tbl = pd.DataFrame(rows).set_index('Bucket')
    if display_pct:
        for c in ['Pred.','Real.']: tbl[c] = 100.0*tbl[c]
    return tbl

#Option Characteristics
def _monthly_port_wavg_for_chars(df, port_col, weight_col, char_cols):
    keep = ['month', port_col] + char_cols + ([weight_col] if weight_col else [])
    d = df[keep].dropna(subset=['month', port_col])
    d['month'] = pd.to_datetime(d['month'])
    if weight_col is None:
        g = d.groupby(['month', port_col], observed=True)[char_cols].mean().reset_index()
    else:
        parts = []
        for c in char_cols:
            tmp = d[['month', port_col, c, weight_col]].dropna()
            tmp = tmp[tmp[weight_col]>0]
            tmp['wx'] = tmp[c]*tmp[weight_col]
            g = (tmp.groupby(['month', port_col], observed=True)
                    .agg(wx=('wx','sum'), w=(weight_col,'sum')).reset_index())
            g[c] = np.where(g['w']>0, g['wx']/g['w'], np.nan)
            parts.append(g[['month', port_col, c]])
        g = parts[0]
        for k in range(1, len(parts)): g = g.merge(parts[k], on=['month',port_col], how='outer')
    return g

def hofler_option_chars_table(meta, pred_col, char_cols=None, n_ports=10,
                              weighting='oiw', lags=CFG.LAGS_NW,
                              pct_cols=None, rnd=CFG.HOF_ROUND,
                              scalers=None):
    d = meta.copy()
    if 'doi' not in d.columns: d['doi'] = 1.0
    if char_cols is None: char_cols = CFG.OPTION_CHAR_COLS
    char_cols = [c for c in char_cols if c in d.columns]
    if len(char_cols)==0:
        raise ValueError("no columns option characteristics at meta.")
    pct_cols = [] if pct_cols is None else [c for c in pct_cols if c in char_cols]
    scalers  = {} if scalers  is None else {k:v for k,v in scalers.items() if k in char_cols}

    d = d.dropna(subset=['month', pred_col])
    d['month'] = pd.to_datetime(d['month'])
    d['port'] = d.groupby('month', observed=True)[pred_col].transform(lambda s: _assign_ports_per_month_rank(s, n_ports))
    d = d.dropna(subset=['port']); d['port'] = d['port'].astype('Int64')

    weight_col = None if weighting.lower()=='ew' else 'doi'
    gp = _monthly_port_wavg_for_chars(d, 'port', weight_col, char_cols)
    maxp = d.groupby('month', observed=True)['port'].max().rename('maxp')
    gp = gp.merge(maxp, on='month', how='left')

    labels = (['Low'] + list(range(2, n_ports)) + ['High'])
    out = {}
    for lab in labels:
        sel = (gp['port']==1) if lab=='Low' else ((gp['port']==gp['maxp']) if lab=='High' else (gp['port']==int(lab)))
        m = gp[sel].drop(columns=['port','maxp']).set_index('month').sort_index()
        out[lab] = m[char_cols].mean(axis=0, skipna=True).to_dict()

    low_ts  = gp[gp['port']==1].set_index('month').sort_index()[char_cols]
    high_ts = gp[gp['port']==gp['maxp']].set_index('month').sort_index()[char_cols]
    idx = low_ts.index.intersection(high_ts.index)
    hml_ts = (high_ts.loc[idx] - low_ts.loc[idx])

    hml_mean, tstats = {}, {}
    for c in char_cols:
        mu, tstat = newey_west_mean_t(hml_ts[c], lags=lags)
        hml_mean[c] = float(mu);  tstats[c] = float(tstat)

    rows = [{'Bucket': lab, **out[lab]} for lab in labels]
    rows.append({'Bucket':'HML', **hml_mean})
    rows.append({'Bucket':'t',   **tstats})
    tbl = pd.DataFrame(rows).set_index('Bucket')

    #display scaling
    base_rows = [r for r in tbl.index if r!='t']
    for col, sc in scalers.items():
        tbl.loc[base_rows, col] = tbl.loc[base_rows, col] * float(sc)

    for c in pct_cols:
        tbl.loc[base_rows, c] = 100.0 * tbl.loc[base_rows, c]
    def _format_tbl(t):
        t2 = t.copy()
        t2.loc[base_rows, :] = t2.loc[base_rows, :].round(rnd)
        t2 = t2.astype(object)  
        t2.loc['t', :] = t.loc['t', :].apply(lambda x: f"({x:.2f})" if pd.notna(x) else "(.)").astype(object)
        return t2
    return tbl, _format_tbl(tbl)


#shap analysis
class EnsembleWrapper(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = nn.ModuleList(models)
    def forward(self, x):
        preds = []
        for m in self.models:
            y, _ = m(x)
            preds.append(y.unsqueeze(1))
        yavg = torch.mean(torch.stack(preds, dim=0), dim=0)
        return yavg

#plots
def _decile_split_and_plot(sv_arr, X_ex_cl, f_nhwc, save_dir, show=True, save=True):
    import numpy as np, os, matplotlib.pyplot as plt
    from IPython.display import display as ipy_display, Image as IPImage
    preds = f_nhwc(X_ex_cl).ravel()
    qL, qH = np.percentile(preds, [10,90])
    L = np.where(preds<=qL)[0]; Hidx = np.where(preds>=qH)[0]
    Lmap = sv_arr[L].mean(axis=0).sum(axis=-1)
    Hmap = sv_arr[Hidx].mean(axis=0).sum(axis=-1)
    Dmap = Hmap - Lmap
    vmax = np.percentile(np.abs(Dmap), 99)+1e-12
    fig, axs = plt.subplots(1,3,figsize=(14,4))
    for ax, M, ttl in zip(axs, [Lmap,Hmap,Dmap], ["Signed SHAP – Low decile",
                                                  "Signed SHAP – High decile",
                                                  "High − Low"]):
        im = ax.imshow(M, aspect='auto', cmap='RdBu_r', vmin=-vmax, vmax=vmax, interpolation='nearest')
        ax.set_title(ttl); fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        ax.set_xlabel("maturity bins"); ax.set_ylabel("moneyness bins")
    out = os.path.join(save_dir, "shap_decile_split.png")
    if save: fig.savefig(out, dpi=180, bbox_inches='tight')
    if show: ipy_display(IPImage(filename=out)) if save else plt.show()
    plt.close(fig)

def _faithfulness_occlusion_report(sv_arr, X_ex_cl, f_nhwc, save_dir, show=True, save=True):
    import numpy as np, os, matplotlib.pyplot as plt
    from IPython.display import display as ipy_display, Image as IPImage
    mass = np.abs(sv_arr).sum(axis=(0,3))   # (H,W)
    thr = np.quantile(mass, 0.95)           # top-5%
    roi = mass >= thr
    X_mask = X_ex_cl.copy()
    X_mask[..., roi] = 0.0                 
    p0 = f_nhwc(X_ex_cl).ravel()
    p1 = f_nhwc(X_mask).ravel()
    drop = float(np.mean(np.abs(p1 - p0)))
    print(f"Faithfulness: mean |Δprediction| after masking top-5% cells = {drop:.3e}")
    
    fig, ax = plt.subplots(1,1,figsize=(5,4))
    im = ax.imshow(roi.astype(float), aspect='auto', cmap='Greys')
    ax.set_title("Masked ROI (top-5% mass cells)")
    out = os.path.join(save_dir, "shap_masked_roi.png")
    if save: fig.savefig(out, dpi=180, bbox_inches='tight')
    if show: ipy_display(IPImage(filename=out)) if save else plt.show()
    plt.close(fig)

def _roi_newey_west_share(sv_arr, ex_idx, meta, rows=None, cols=None, lags=6):
    import numpy as np, pandas as pd
    from math import isfinite
    H, W = sv_arr.shape[1], sv_arr.shape[2]
    R = np.zeros((H,W), dtype=bool)
    if rows is None: rows = range(H)
    if cols is None: cols = range(W)
    R[np.ix_(list(rows), list(cols))] = True
    absv = np.abs(sv_arr)                   #(N,H,W,C)
    num = absv[:, R, :].sum(axis=(1,2))     
    num = num.sum(axis=-1) if num.ndim==2 else num
    den = absv.sum(axis=(1,2,3)) + 1e-12    
    share = (num/den)                       
    m = pd.DataFrame({"month": pd.to_datetime(meta.iloc[ex_idx]['month']).values, "s": share})
    r = m.groupby("month")["s"].mean().dropna()
    # 
    mu, t = newey_west_mean_t(r.values, lags=lags)
    print(f"ROI share mean={mu:.3%}  NW t={t:.2f}   (rows={list(rows)}, cols={list(cols)})")
    return mu, t




def clone_model_no_inplace(src_model, in_ch, dz, p_drop): 
    dst = CNN_NODE(in_ch=in_ch, dz=dz, p_drop=p_drop)
    dst.load_state_dict(copy.deepcopy(src_model.state_dict()))
    for mod in dst.modules():
        if isinstance(mod, (nn.ReLU, nn.LeakyReLU)):
            mod.inplace = False
    return dst


def run_shap_analysis(Xz, meta, models_dict, first_oos, device=CFG.DEVICE):

    from IPython.display import display as ipy_display, Image as IPImage

    outdir = getattr(CFG, "SHAP_SAVE_DIR", "./shap_out")
    os.makedirs(outdir, exist_ok=True)

    SHOW         = getattr(CFG, "SHAP_SHOW_FIGS", True)
    SAVE         = getattr(CFG, "SHAP_SAVE_FIGS", True)
    SAVE_STATS   = getattr(CFG, "SHAP_SAVE_STATS", False)
    TOPK         = int(getattr(CFG, "SHAP_STATS_TOPK", 10))
    MAX_EVALS    = getattr(CFG, "SHAP_MAX_EVALS", None)    
    # extra toggles
    DO_DECILES   = getattr(CFG, "SHAP_DO_DECILES", True)
    DO_FAITHFUL  = getattr(CFG, "SHAP_DO_FAITHFULNESS", True)
    DO_ROI_NW    = getattr(CFG, "SHAP_DO_ROI_NW", True)

    def _nhwc_predictor(x_nhwc: np.ndarray) -> np.ndarray:
        x = np.array(x_nhwc, dtype=np.float32)
        if x.ndim == 3: x = x[None, ...]
        x_nchw = np.transpose(x, (0,3,1,2))
        xb = torch.from_numpy(x_nchw)
        if getattr(CFG, "CHANNELS_LAST", True) and xb.ndim == 4:
            xb = xb.contiguous(memory_format=torch.channels_last)
        xb = xb.to(device, non_blocking=True)
        with torch.no_grad():
            preds = []
            for s in models_dict.keys():
                m = models_dict[s].eval().to(device)
                y, _ = m(xb)
                preds.append(y.unsqueeze(1))
            yavg = torch.mean(torch.stack(preds, dim=0), dim=0) 
        return yavg.detach().cpu().numpy()

    def _normalize_sv(values):
        sv_raw = values if not isinstance(values, list) else values[0]
        arr = np.asarray(sv_raw)
        if arr.ndim == 5:  #(N,H,W,C,O) -> (N,H,W,C*O)
            N,H,W,C,O = arr.shape
            arr = arr.reshape(N,H,W,C*O)
        if arr.ndim == 3:  #(N,H,W) -> (N,H,W,1)
            arr = arr[..., np.newaxis]
        while arr.ndim > 4 and arr.shape[-1] == 1:
            arr = arr[..., 0]
        if arr.ndim != 4:
            raise ValueError(f"SHAP: shape shape {arr.shape} (it espect 4D).")
        return arr

    def _save_or_show(fig, path_png):
        if SAVE: fig.savefig(path_png, dpi=180, bbox_inches='tight')
        if SHOW:
            if SAVE: ipy_display(IPImage(filename=path_png))
            else:    plt.show()
        plt.close(fig)

    def _decile_split_and_plot(sv_arr, X_ex_cl):
        preds = _nhwc_predictor(X_ex_cl).ravel()
        qL, qH = np.percentile(preds, [10,90])
        L = np.where(preds<=qL)[0]; Hidx = np.where(preds>=qH)[0]
        Lmap = sv_arr[L].mean(axis=0).sum(axis=-1) if len(L)>0 else np.zeros(sv_arr.shape[1:3])
        Hmap = sv_arr[Hidx].mean(axis=0).sum(axis=-1) if len(Hidx)>0 else np.zeros(sv_arr.shape[1:3])
        Dmap = Hmap - Lmap
        vmax = np.percentile(np.abs(Dmap), 99)+1e-12
        fig, axs = plt.subplots(1,3,figsize=(14,4))
        for ax, M, ttl in zip(axs, [Lmap,Hmap,Dmap], ["Signed SHAP – Low decile",
                                                      "Signed SHAP – High decile",
                                                      "High − Low"]):
            im = ax.imshow(M, aspect='auto', cmap='RdBu_r', vmin=-vmax, vmax=vmax, interpolation='nearest')
            ax.set_title(ttl); fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            ax.set_xlabel("maturity bins"); ax.set_ylabel("moneyness bins")
        _save_or_show(fig, os.path.join(outdir, "shap_decile_split.png"))

    def _faithfulness_occlusion_report(sv_arr, X_ex_cl):
      
        mass = np.abs(sv_arr).sum(axis=(0,3))     #(H,W)
        thr  = np.quantile(mass, 0.95)
        roi  = mass >= thr                         #(H,W) True = high-importance
        mask4 = (~roi)[None, :, :, None].astype(X_ex_cl.dtype)  # (1,H,W,1)
        X_mask = X_ex_cl * mask4                   

        p0 = _nhwc_predictor(X_ex_cl).ravel()
        p1 = _nhwc_predictor(X_mask).ravel()
        drop = float(np.mean(np.abs(p1 - p0)))
        print(f"Faithfulness: mean |Δprediction| after masking top-5% cells = {drop:.3e}")
        fig, ax = plt.subplots(1,1,figsize=(5,4))
        im = ax.imshow(roi.astype(float), aspect='auto', cmap='Greys')
        ax.set_title("Masked ROI (top-5% mass cells)")
        _save_or_show(fig, os.path.join(outdir, "shap_masked_roi.png"))

    def _roi_newey_west_share(sv_arr, ex_idx, rows=None, cols=None, lags=CFG.LAGS_NW):
        H, W = sv_arr.shape[1], sv_arr.shape[2]
        R = np.zeros((H,W), dtype=bool)
        if rows is None: rows = range(H)
        if cols is None: cols = range(W)
        R[np.ix_(list(rows), list(cols))] = True
        absv = np.abs(sv_arr)
        num = absv[:, R, :].sum(axis=(1,2))
        num = num.sum(axis=-1) if num.ndim==2 else num
        den = absv.sum(axis=(1,2,3)) + 1e-12
        share = (num/den)
        m = pd.DataFrame({"month": pd.to_datetime(meta.iloc[ex_idx]['month']).values, "s": share})
        r = m.groupby("month")["s"].mean().dropna()
        mu, t = newey_west_mean_t(r.values, lags=lags)
        print(f"ROI share mean={mu:.3%}  NW t={t:.2f}   (rows={list(rows)}, cols={list(cols)})")
        return mu, t

    def _hml_series_from_preds(meta_df, pred_vec, idx_vec, name="pred_tmp",
                              n_ports=CFG.PORTS_N, weighting='oiw'):
        d = meta_df.copy()
        d[name] = np.nan

        d.loc[idx_vec, name] = np.asarray(pred_vec, dtype=np.float32)
        d = d.dropna(subset=['month', name, CFG.TARGET_COL])
        d['month'] = pd.to_datetime(d['month'])
        d['port'] = d.groupby('month', observed=True)[name].transform(
            lambda s: _assign_ports_per_month_rank(s, n_ports))
        d = d.dropna(subset=['port']); d['port'] = d['port'].astype('Int64')
        weight_col = None if weighting.lower()=='ew' else 'doi'
        g = _monthly_port_avgs(d, weight_col, name, CFG.TARGET_COL)
        maxp = d.groupby('month', observed=True)['port'].max().rename('maxp')
        gj   = g.merge(maxp, on='month', how='left').set_index('month').sort_index()
        low_r  = gj[gj['port']==1]['real']
        high_r = gj[gj['port']==gj['maxp']]['real']
        return (high_r - low_r).dropna().astype(float)  # monthly H−L
    ins_idx = np.where(pd.to_datetime(meta['month']).lt(first_oos).values)[0]
    if ins_idx.size == 0:
        print("SHAP: no in-sample for background.")
        return
    bg_idx = np.random.choice(ins_idx, size=min(getattr(CFG, "SHAP_BG_SAMPLES", 64), ins_idx.size), replace=False)

    oos_mask = np.isfinite(meta['rb_cnn_node'].to_numpy(np.float32))
    ex_idx_all = np.where(oos_mask)[0]
    if ex_idx_all.size == 0:
        print("SHAP: no OOS forecasts for explanation.")
        return
    ex_idx = np.random.choice(ex_idx_all, size=min(getattr(CFG, "SHAP_EXPLAIN_SAMPLES", 64), ex_idx_all.size), replace=False)

    X_bg_cl = np.transpose(Xz[bg_idx].astype(np.float32), (0,2,3,1))  #(M,H,W,C)
    X_ex_cl = np.transpose(Xz[ex_idx].astype(np.float32), (0,2,3,1))
    H, W, C = X_ex_cl.shape[1], X_ex_cl.shape[2], X_ex_cl.shape[3]

    kr = max(1, H // 4); kc = max(1, W // 4)
    masker = shap.maskers.Image(f"blur({kr},{kc})", X_ex_cl[0].shape)
    explainer = shap.Explainer(_nhwc_predictor, masker, algorithm="partition")
    sv = explainer(X_ex_cl) if MAX_EVALS is None else explainer(X_ex_cl, max_evals=(4*H*W if MAX_EVALS=="auto" else int(MAX_EVALS)))
    sv_arr = _normalize_sv(sv.values)  # (N,H,W,CH)

    abs_vals   = np.abs(sv_arr)
    total_mass = abs_vals.sum() + 1e-12
    mass_map   = abs_vals.sum(axis=(0,3)) / total_mass   # (H,W)
    row_share  = mass_map.sum(axis=1)
    col_share  = mass_map.sum(axis=0)
    flat = mass_map.ravel()
    idxs = np.argsort(flat)[::-1][:TOPK]
    topk = [(int(i // W), int(i % W), float(flat[i])) for i in idxs]
    share_topk = float(flat[idxs].sum())
    global_mean_abs = float(abs_vals.mean())

    #3×3 coarse
    bins = 3
    r_edges = np.linspace(0, H, bins+1, dtype=int); c_edges = np.linspace(0, W, bins+1, dtype=int)
    coarse = np.zeros((bins, bins), dtype=float)
    for i in range(bins):
        for j in range(bins):
            coarse[i,j] = mass_map[r_edges[i]:r_edges[i+1], c_edges[j]:c_edges[j+1]].sum()
    coarse_df = pd.DataFrame(np.round(coarse*100, 2),
                            index=[f"R{i+1}" for i in range(bins)],
                            columns=[f"C{j+1}" for j in range(bins)])

    print("\nSHAP — summary stats")
    print(f"  global mean |SHAP|: {global_mean_abs:.4e}")
    print(f"  top-{TOPK} cells share of attribution mass: {share_topk*100:.1f}%")
    print(f"  most important row index: {int(np.argmax(row_share))}  (share={row_share.max()*100:.1f}%)")
    print(f"  most important col index: {int(np.argmax(col_share))}  (share={col_share.max()*100:.1f}%)")
    print("\n  3×3 coarse share (% of mass) by row×col bins:")
    print(coarse_df.to_string(index=True))

    if SAVE_STATS:
        pd.DataFrame({"row":np.arange(H), "share":row_share}).to_csv(os.path.join(outdir, "shap_row_share.csv"), index=False)
        pd.DataFrame({"col":np.arange(W), "share":col_share}).to_csv(os.path.join(outdir, "shap_col_share.csv"), index=False)
        pd.DataFrame(topk, columns=["row","col","share"]).to_csv(os.path.join(outdir, "shap_topk_cells.csv"), index=False)
        coarse_df.to_csv(os.path.join(outdir, "shap_coarse_3x3.csv"))

    #Plots (SHAP)
    X_mean2d = X_ex_cl.mean(axis=(0,3))                    # (H,W)
    signed_map = sv_arr.mean(axis=0).sum(axis=-1)          # (H,W)
    vmax = np.percentile(np.abs(signed_map), 99) + 1e-12

    fig, axs = plt.subplots(1, 2, figsize=(12, 4))
    im0 = axs[0].imshow(X_mean2d, aspect='auto', cmap='gray', interpolation='nearest')
    axs[0].set_title("Mean IV surface (OOS samples)")
    fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)

    im1 = axs[1].imshow(signed_map, aspect='auto', cmap='RdBu_r',
                        vmin=-vmax, vmax=vmax, interpolation='nearest')
    axs[1].set_title("Signed SHAP (mean across OOS)")
    fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)
    for ax in axs:
        ax.set_xlabel("maturity bins"); ax.set_ylabel("moneyness bins")
        ax.set_xticks(np.linspace(0, W-1, min(W, 6)).astype(int))
        ax.set_yticks(np.linspace(0, H-1, min(H, 6)).astype(int))
    _save_or_show(fig, os.path.join(outdir, "shap_signed_vs_surface.png"))

    fig2, ax2 = plt.subplots(1, 1, figsize=(6.5, 4))
    im2 = ax2.imshow(100*mass_map, aspect='auto', cmap='magma', interpolation='nearest')
    ax2.set_title("Attribution mass share (%)")
    fig2.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)
    ax2.set_xlabel("maturity bins"); ax2.set_ylabel("moneyness bins")
    ax2.set_xticks(np.linspace(0, W-1, min(W, 6)).astype(int))
    ax2.set_yticks(np.linspace(0, H-1, min(H, 6)).astype(int))
    _save_or_show(fig2, os.path.join(outdir, "shap_mass_share.png"))


    try:
        oos_idx_for_hist = ex_idx_all

        ens_vec_all = meta['rb_cnn_node'].to_numpy(np.float32)
        ens_vec_oos = ens_vec_all[oos_idx_for_hist]
        hml_ens = _hml_series_from_preds(meta, ens_vec_oos, oos_idx_for_hist,
                                        name='rb_cnn_node', weighting='oiw')
        sr_ens = sharpe_ann(hml_ens)

        indiv_srs = []
        for s, m in models_dict.items():
            yhat_oos = predict(m, Xz[oos_idx_for_hist], device=device,
                              bs=max(8192, getattr(CFG, "BATCH_SIZE", 8192)))
            hml_ind = _hml_series_from_preds(meta, yhat_oos, oos_idx_for_hist,
                                            name=f'pred_{s}', weighting='oiw')
            sr_ind = sharpe_ann(hml_ind)
            if np.isfinite(sr_ind): indiv_srs.append(float(sr_ind))

        fig3, ax3 = plt.subplots(1,1, figsize=(7.2, 4.0))
        if len(indiv_srs) > 0:
            nbins = min(12, max(5, len(indiv_srs))) 
            ax3.hist(indiv_srs, bins=nbins, color='#0b3c5d', alpha=0.95, label='Individual')
        if np.isfinite(sr_ens):
            ax3.axvline(sr_ens, color='limegreen', linewidth=10, label='Ensemble', alpha=0.9)
        ax3.set_xlabel("Annualized Sharpe ratio"); ax3.set_ylabel("Frequency")
        ax3.set_title("Sharpe ratios of long–short portfolios (OIW)")
        ax3.legend()
        xs = indiv_srs + ([sr_ens] if np.isfinite(sr_ens) else [])
        if len(xs) > 0:
            xmin, xmax = min(xs), max(xs)
            pad = 0.1*(xmax - xmin + 1e-9)
            ax3.set_xlim(xmin - pad, xmax + pad)
        _save_or_show(fig3, os.path.join(outdir, "sharpe_histogram_oiw.png"))

        print(f"Figure3-style: Ensemble Sharpe (OIW) = {sr_ens:.2f}; "
              f"Individual mean={np.mean(indiv_srs) if len(indiv_srs)>0 else float('nan'):.2f} (N={len(indiv_srs)})")
    except Exception as e_fig3:
        print(f"Sharpe histogram (Figure 3-style) failed: {e_fig3}")

    if DO_DECILES:
        _decile_split_and_plot(sv_arr, X_ex_cl)
    if DO_FAITHFUL:
        _faithfulness_occlusion_report(sv_arr, X_ex_cl)
    if DO_ROI_NW:
        r0, c0 = int(np.argmax(row_share)), int(np.argmax(col_share))
        r_low, r_high = max(0, r0-1), min(H, r0+2)
        c_low, c_high = max(0, c0-1), min(W, c0+2)
        _roi_newey_west_share(sv_arr, ex_idx, rows=range(r_low, r_high), cols=range(c_low, c_high), lags=getattr(CFG, "LAGS_NW", 6))

    return {
        "row_share": row_share,
        "col_share": col_share,
        "topk": topk,
        "coarse_3x3": coarse,
        "global_mean_abs": global_mean_abs
    }
LIME
def run_lime_analysis(Xz, meta, models_dict, first_oos, device=CFG.DEVICE):

    import os, numpy as np, pandas as pd, matplotlib.pyplot as plt

    SHOW = getattr(CFG, "LIME_SHOW_FIGS", True)
    SAVE = getattr(CFG, "LIME_SAVE_FIGS", False)

    def _save_or_show(fig, fname=None):
        if SAVE and fname:
            os.makedirs(outdir, exist_ok=True)
            fig.savefig(os.path.join(outdir, fname), dpi=180, bbox_inches='tight')
        if SHOW:
            import matplotlib.pyplot as plt
            plt.show()
        import matplotlib.pyplot as plt
        plt.close(fig)


    try:
        from lime.lime_image import LimeImageExplainer
        from skimage.segmentation import mark_boundaries
    except Exception as e:
        raise RuntimeError("missing lime / scikit-image: pip install lime scikit-image") from e

    outdir = getattr(CFG, "LIME_SAVE_DIR", "./lime_out")
    os.makedirs(outdir, exist_ok=True)

    # ---------------- Helpers ----------------
    def _predict_nhwc(x_nhwc: np.ndarray) -> np.ndarray:
        x = np.array(x_nhwc, dtype=np.float32)
        if x.ndim == 3: x = x[None, ...]
        bs = int(getattr(CFG, "LIME_BATCH_SIZE", 2048))
        N = x.shape[0]
        out = np.empty((N,1), dtype=np.float32)
        for i in range(0, N, bs):
            xb = x[i:i+bs]
            x_nchw = np.transpose(xb, (0,3,1,2))               # NHWC -> NCHW
            xt = torch.from_numpy(x_nchw).to(device, non_blocking=True)
            if getattr(CFG, "CHANNELS_LAST", True) and xt.ndim == 4:
                xt = xt.contiguous(memory_format=torch.channels_last)
            with torch.no_grad():
                preds = []
                for s in models_dict.keys():
                    m = models_dict[s].eval().to(device)
                    y, _ = m(xt)                                # (B,)
                    preds.append(y.unsqueeze(1))
                yavg = torch.mean(torch.stack(preds, dim=0), dim=0)  # (B,1)
            out[i:i+bs] = yavg.detach().cpu().numpy().astype(np.float32)
        return out

    def _pack_to_rgb_from_sample(X1_chw: np.ndarray) -> np.ndarray:
        C,H,W = X1_chw.shape
        x_hwc = np.transpose(X1_chw, (1,2,0))      # -> H,W,C
        rgb = np.zeros((H,W,3), dtype=np.float32)
        rgb[...,0] = x_hwc[...,0]
        if C >= 2:
            rgb[...,1] = x_hwc[...,1]
        return rgb

    def _unpack_from_rgb_to_nhwc(imgs_rgb: np.ndarray, C_in: int) -> np.ndarray:

        if C_in == 1:
            return imgs_rgb[..., :1].astype(np.float32)
        else:
            return imgs_rgb[..., :2].astype(np.float32)

    def _grid_segmentation(image_hw3: np.ndarray) -> np.ndarray:
        H, W = image_hw3.shape[0], image_hw3.shape[1]
        return np.arange(H*W, dtype=np.int32).reshape(H, W)

    def _classifier_for_lime(imgs_rgb: np.ndarray) -> np.ndarray:
        C_in = int(Xz.shape[1])
        x_nhwc = _unpack_from_rgb_to_nhwc(imgs_rgb, C_in=C_in)
        y = _predict_nhwc(x_nhwc).reshape(-1)  # (N,)

        center = getattr(_classifier_for_lime, "_center", float(y.mean()))
        scale  = float(np.std(y) + 1e-6)
        z = (y - center) / scale
        p1 = 1.0 / (1.0 + np.exp(-z))
        p0 = 1.0 - p1
        return np.stack([p0, p1], axis=1).astype(np.float32)

    oos_mask = np.isfinite(meta['rb_cnn_node'].to_numpy(np.float32))
    ex_idx_all = np.where(oos_mask)[0]
    if ex_idx_all.size == 0:
        print("LIME: not enough OOS forecasts.")
        return

    scores = meta.loc[ex_idx_all, 'rb_cnn_node'].astype(float).to_numpy()
    order  = np.argsort(scores)
    low_i  = ex_idx_all[order[0]]
    high_i = ex_idx_all[order[-1]]
    mid_i  = ex_idx_all[order[len(order)//2]]
    candidates = [('low', low_i), ('mid', mid_i), ('high', high_i)]

    H, W = Xz.shape[2], Xz.shape[3]
    mode = str(getattr(CFG, "LIME_SEGMENTATION_MODE", "grid")).lower()

    for tag, idx in candidates:
        X1 = Xz[idx]                                   # (C,H,W)
        x_rgb = _pack_to_rgb_from_sample(X1)           # (H,W,3)

        C_in = int(Xz.shape[1])
        x_nhwc_full = np.transpose(X1, (1,2,0)).astype(np.float32)  # H,W,C
        y0 = float(_predict_nhwc(x_nhwc_full[None, ...])[0,0])
        _classifier_for_lime._center = y0

        #Segmentation fn
        if mode == 'grid':
            segmentation_fn = _grid_segmentation
        else:
            from skimage.segmentation import slic
            def segmentation_fn(im):
                return slic(im, n_segments=int(getattr(CFG, "LIME_SLIC_SEGMENTS", 100)),
                            compactness=0.1, sigma=1, start_label=0)

        explainer = LimeImageExplainer(random_state=123)
        explanation = explainer.explain_instance(
            image=x_rgb,
            classifier_fn=_classifier_for_lime,
            top_labels=2,
            hide_color=0.0, 
            num_samples=int(getattr(CFG, "LIME_NUM_SAMPLES", 1500)),
            segmentation_fn=segmentation_fn
        )

        label = 1 
        num_feats = int(getattr(CFG, "LIME_NUM_FEATURES", 12))


        img_pos,  mask_pos  = explanation.get_image_and_mask(label, positive_only=True,
                                                            num_features=num_feats, hide_rest=False)
        img_both, mask_both = explanation.get_image_and_mask(label, positive_only=False,
                                                            num_features=num_feats, hide_rest=False)

        def _save_overlay(base_rgb, mask, fname, title):
            fig, ax = plt.subplots(1,1, figsize=(6,5))
            ax.imshow(mark_boundaries(base_rgb, mask))
            dt = pd.to_datetime(meta.iloc[idx]['month']).date() if 'month' in meta.columns else 'n/a'
            ax.set_title(title)
            ax.set_xlabel(f"index={idx} | month={dt} | y0={y0:.4f}")
            ax.axis('off')
            _save_or_show(fig, fname if SAVE else None)


        _save_overlay(x_rgb, mask_pos,  f"lime_{tag}_pos_{idx}.png",
                      f"LIME (positive-only) — {tag.upper()} decile")
        _save_overlay(x_rgb, mask_both, f"lime_{tag}_both_{idx}.png",
                      f"LIME (pos/neg) — {tag.upper()} decile")

        if mode == 'grid':
            sp_weights = dict(explanation.local_exp[label]) 
            Wgrid = np.zeros((H,W), dtype=np.float32)
            for r in range(H):
                for c in range(W):
                    sp = r*W + c
                    Wgrid[r,c] = sp_weights.get(sp, 0.0)

            vmax = np.percentile(np.abs(Wgrid), 99) + 1e-12
            fig, ax = plt.subplots(1,1, figsize=(6,5))
            im = ax.imshow(Wgrid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto')
            ax.set_title(f"LIME local weights per cell — {tag.upper()} (index={idx})")
            ax.set_xlabel("maturity bins"); ax.set_ylabel("moneyness bins")
            fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            _save_or_show(fig, f"lime_{tag}_weights_grid_{idx}.png")

            plt.close(fig)

    print(f"LIME: finished files saved at: {os.path.abspath(outdir)}")




def main():
    dev = CFG.DEVICE
    print(f"CUDA: {torch.cuda.get_device_name(0) if dev=='cuda' else 'CPU'} | AMP={CFG.AMP_DTYPE} | channels_last={CFG.CHANNELS_LAST} | compile={CFG.TORCH_COMPILE}")
    Xz = load_Xz(CFG.XZ_PATH)            #(N,1,H,W)
    meta = load_meta(CFG.META_PATH).copy()
    if 'doi' not in meta.columns: meta['doi'] = 1.0
    assert len(meta)==Xz.shape[0], "len(meta) != N of Xz"
    meta['month'] = pd.to_datetime(meta['month'])
    y = meta[CFG.TARGET_COL].to_numpy(np.float32)

    print(f"Xz finite ratio={float(np.isfinite(Xz).mean()):.4f} | y finite ratio={float(np.isfinite(y).mean()):.4f}")
    Xz = identity_stack(Xz, add_mask=CFG.ADD_MASK_CHANNEL)  # NaN->0 (no scaling)
    in_ch = Xz.shape[1]
    months = np.sort(meta['month'].unique())
    first_oos = pd.to_datetime(meta['month'].min()) + pd.DateOffset(years=7)
    years_oos = sorted(pd.Series(months[months>=first_oos]).dt.year.unique())
    if CFG.MODE.upper()=='FAST' and len(years_oos)>2:
        years_oos = years_oos[-2:]
    print(f"Mode={CFG.MODE} | OOS years: {len(years_oos)} -> {years_oos[0]}..{years_oos[-1]}")

    rb_hat = np.full(len(meta), np.nan, dtype=np.float32)

    models, opts = {}, {}
    for s in CFG.SEEDS:
        torch.manual_seed(s); np.random.seed(s)
        m = CNN_NODE(in_ch=in_ch, dz=CFG.DZ, p_drop=CFG.DROPOUT_P)
        if CFG.TORCH_COMPILE and hasattr(torch, 'compile'):
            m = torch.compile(m, fullgraph=False, dynamic=True)
        o = torch.optim.AdamW(m.parameters(), lr=CFG.LR, weight_decay=CFG.WEIGHT_DECAY)
        models[s], opts[s] = m, o

    for i, Y in enumerate(tqdm(years_oos, desc="OOS years")):
        start_Y = pd.Timestamp(year=Y, month=1, day=1)
        start_next = pd.Timestamp(year=Y+1, month=1, day=1)
        tr_idx = np.where(meta['month'].lt(start_Y).values)[0]
        te_idx = np.where((meta['month']>=start_Y) & (meta['month']<start_next) & (meta['month']>=first_oos))[0]
        if tr_idx.size==0 or te_idx.size==0: continue

        X_tr = Xz[tr_idx]
        y_tr = y[tr_idx]
        good = np.isfinite(y_tr)
        if (~good).sum()>0:
            print(f"Warning: {(~good).sum()} μη-περατά targets — αγνοούνται στο training.")
        X_tr, y_tr = X_tr[good], y_tr[good]
        loader = make_loader(X_tr, y_tr, bs=CFG.BATCH_SIZE, shuffle=True, device=dev)

        ne = CFG.EPOCHS_INIT if i==0 else CFG.EPOCHS_PER_YEAR
        print(f"\n=== Year {Y} | train_n={y_tr.size:,} | test_n={te_idx.size:,} | epochs={ne} ===")
        t0 = time.time()
        for s in models.keys():
            opts[s], models[s] = train_epochs(models[s], loader, nepochs=ne, opt=opts[s], device=dev)
        t1 = time.time()
        print(f"Train time (year {Y}): {t1-t0:.1f}s")

        ens = []
        for s in models.keys():
            ens.append(predict(models[s], Xz[te_idx], device=dev, bs=max(8192, CFG.BATCH_SIZE)))
        rb_hat[te_idx] = np.mean(np.vstack(ens), axis=0)
        del loader; gc.collect()
        if dev=='cuda': torch.cuda.empty_cache()

    meta_eval = meta.copy()
    meta_eval['rb_cnn_node'] = rb_hat
    nnz = int(np.isfinite(meta_eval['rb_cnn_node']).sum())
    print(f"\nOOS non-NaN preds: {nnz} / {len(meta_eval)}")

    print(f"\nBuilding decile tables (N={CFG.PORTS_N}) using rb_cnn_node ...")
    tbl_oiw = decile_table(meta_eval, 'rb_cnn_node', CFG.TARGET_COL, n_ports=CFG.PORTS_N, weighting='oiw')
    tbl_ew  = decile_table(meta_eval, 'rb_cnn_node', CFG.TARGET_COL, n_ports=CFG.PORTS_N, weighting='ew')

    print("\n=== (a) Dollar open interest weighted ===")
    print(tbl_oiw.round({'Pred.':CFG.HOF_ROUND,'Real.':CFG.HOF_ROUND,'t':CFG.HOF_ROUND,'SR':CFG.HOF_ROUND}).to_string())
    print("\n=== (b) Equal weighted ===")
    print(tbl_ew.round({'Pred.':CFG.HOF_ROUND,'Real.':CFG.HOF_ROUND,'t':CFG.HOF_ROUND,'SR':CFG.HOF_ROUND}).to_string())


    try_cols = [c for c in CFG.OPTION_CHAR_COLS if c in meta_eval.columns]
    if len(try_cols)==0:
        print("\n[Info] No option-characteristic columns at meta.")
    else:
        print("\nBuilding table for OPTION characteristics")
        tbl_opt_oiw, pretty_oiw = hofler_option_chars_table(
            meta_eval, pred_col='rb_cnn_node', char_cols=try_cols,
            n_ports=CFG.PORTS_N, weighting='oiw', lags=CFG.LAGS_NW,
            pct_cols=CFG.OPTION_CHAR_PCT, rnd=CFG.HOF_ROUND,
            scalers=CFG.OPTION_CHAR_SCALERS  
        )
        tbl_opt_ew,  pretty_ew  = hofler_option_chars_table(
            meta_eval, pred_col='rb_cnn_node', char_cols=try_cols,
            n_ports=CFG.PORTS_N, weighting='ew',  lags=CFG.LAGS_NW,
            pct_cols=CFG.OPTION_CHAR_PCT, rnd=CFG.HOF_ROUND,
            scalers=CFG.OPTION_CHAR_SCALERS  
        )
        print("\n=== (c) Option characteristics — OIW ===")
        print(pretty_oiw.to_string())
        print("\n=== (d) Option characteristics — EW ===")
        print(pretty_ew.to_string())

    #shap analysis
    if CFG.SHAP_ENABLE:
        print("\nRunning SHAP analysis (Deep/GradientExplainer) on ensemble...")
        run_shap_analysis(Xz, meta_eval, models, first_oos, device=dev)

    try:
        if getattr(CFG, "LIME_ENABLE", True):
            print("\nRunning LIME analysis on selected OOS examples...")
            run_lime_analysis(Xz, meta_eval, models, first_oos, device=dev)
    except Exception as e:
        print(f"LIME analysis skipped/failed: {e}")



    return meta_eval, tbl_oiw, tbl_ew



if __name__ == "__main__":
    meta_eval, tbl_oiw, tbl_ew = main()
