In [1]:
import pandas as pd
import numpy as np
import torch
from pathlib import Path

from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed

from tsfm_public import (
    TimeSeriesForecastingPipeline,
    TimeSeriesPreprocessor,
    TinyTimeMixerForPrediction,
    TrackingCallback,
    count_parameters,
    get_datasets,
)

# Transformers and TSFM library imports
from transformers import Trainer, TrainingArguments
from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor
from tsfm_public import get_datasets # Corrected import
from tsfm_public import TinyTimeMixerForPrediction
from tsfm_public import TinyTimeMixerConfig
from tsfm_public.toolkit.time_series_preprocessor import prepare_data_splits
from tsfm_public.toolkit.visualization import plot_predictions
from tsfm_public.toolkit.service_util import save_deployment_package
# --- Configuration ---
DATA_PATH = Path("./") # Assuming data is in the same directory
TRAIN_FILE_PATH =  "final_train.csv"
TEST_FILE_PATH =  "final_test.csv"
METADATA_FILE_PATH =  "metadata.csv"

timestamp_column = "timestamp"
context_length = 168
prediction_length = 24
SEED = 79
subset_fraction = 1

print("Cell 1 executed: Setup and configuration complete.")

Cell 1 executed: Setup and configuration complete.


In [2]:
# --- 1) Load data ---
print("Loading data...")
train_full_df = pd.read_csv(TRAIN_FILE_PATH, parse_dates=[timestamp_column])
test_full_df = pd.read_csv(TEST_FILE_PATH,  parse_dates=[timestamp_column])


# Keep building_id as string (e.g., 'H001'); trim any whitespace
for df in (train_full_df, test_full_df):
    df["building_id"] = df["building_id"].astype("string").str.strip()


# --- 4) Save row_ids for submission (from TARGET rows in competition test) ---
print("Saving row_ids for submission...")
if "row_id" in test_full_df.columns:
    submission_ids = test_full_df.loc[test_full_df["role"] == "target", ["row_id"]].copy()
else:
    submission_ids = None
    print("Note: 'row_id' not found in test; skipping submission_ids extraction.")

# Keep everything for now (we need 'building_id' and 'role' for clean splits)
train_df1 = train_full_df.copy()
test_df1  = test_full_df.copy()

# --- 5) Build chronological 80/10/10 window splits *per building* from TRAIN only ---
# Use only INPUT rows to define each window's start time
required = {"building_id", "window_id", "role", timestamp_column}
assert required.issubset(train_df1.columns), \
    f"Required columns missing for window split: {required - set(train_df1.columns)}"

winfo = (
    train_df1.loc[train_df1["role"] == "input", ["building_id", "window_id", timestamp_column]]
    .groupby(["building_id", "window_id"], as_index=False)
    .agg(window_start=(timestamp_column, "min"))
    .sort_values(["building_id", "window_start"])
)

# Quick visibility
n_buildings = winfo["building_id"].nunique()
n_windows   = len(winfo)
print(f"Found {n_windows} windows across {n_buildings} buildings in train.")

# ---- Integer 80/10/10 per building (chronological) ----
train_window_ids, valid_window_ids, test_window_ids = set(), set(), set()

for bld, g in winfo.groupby("building_id", sort=False):
    # g is already sorted by window_start from the earlier code
    ids = g["window_id"].tolist()
    n = len(ids)

    n_train = int(0.8 * n)   # integer count for train
    n_valid = int(0.1 * n)   # integer count for valid
    # remainder goes to test (ensures totals sum to n)
    n_test  = n - n_train - n_valid

    train_ids_b = ids[:n_train]
    valid_ids_b = ids[n_train:n_train + n_valid]
    test_ids_b  = ids[n_train + n_valid:]  # last chunk = remainder

    train_window_ids.update(train_ids_b)
    valid_window_ids.update(valid_ids_b)
    test_window_ids.update(test_ids_b)

print(f"[Split per building] train={len(train_window_ids)}, "
      f"valid={len(valid_window_ids)}, test={len(test_window_ids)}")

# -----------------------
# Sanity checks (BEFORE subsampling)
# -----------------------

# 1) Window integrity: every window must have 168 'input' + 24 'target'
_counts = train_df1.groupby(["window_id","role"]).size().unstack(fill_value=0)
assert "input" in _counts.columns and "target" in _counts.columns, \
    "Expected roles 'input' and 'target' not found."
assert _counts["input"].eq(168).all(), "Some windows do not have 168 input rows."
assert _counts["target"].eq(24).all(), "Some windows do not have 24 target rows."

# 2) Splits are disjoint (no overlap)
assert len(train_window_ids & valid_window_ids) == 0, "Train and Valid windows overlap."
assert len(train_window_ids & test_window_ids) == 0,  "Train and Test windows overlap."
assert len(valid_window_ids & test_window_ids) == 0,  "Valid and Test windows overlap."

# 3) Coverage BEFORE subsample: all windows in train_df1 are assigned to exactly one split
_all_window_ids = set(winfo["window_id"])
_union_before = train_window_ids | valid_window_ids | test_window_ids
assert _union_before == _all_window_ids, \
    "Split does not cover all windows before subsampling."

# 4) Chronology per building: last train <= first valid <= first test (where applicable)
_wstarts = (
    train_df1.loc[train_df1["role"] == "input", ["building_id","window_id","timestamp"]]
    .groupby(["building_id","window_id"], as_index=False)
    .agg(window_start=("timestamp","min"))
)

def _per_bld_min(ids):
    s = _wstarts[_wstarts.window_id.isin(ids)].groupby("building_id").window_start
    return s.min() if len(s) else pd.Series(dtype="datetime64[ns]")

def _per_bld_max(ids):
    s = _wstarts[_wstarts.window_id.isin(ids)].groupby("building_id").window_start
    return s.max() if len(s) else pd.Series(dtype="datetime64[ns]")

_max_train = _per_bld_max(train_window_ids)
_min_valid = _per_bld_min(valid_window_ids)
_min_test  = _per_bld_min(test_window_ids)

if not _max_train.empty and not _min_valid.empty:
    _idx = _max_train.index.intersection(_min_valid.index)
    assert (_max_train.loc[_idx] <= _min_valid.loc[_idx]).all(), \
        "For some buildings, valid windows start before the last train window."
if not _min_valid.empty and not _min_test.empty:
    _idx = _min_valid.index.intersection(_min_test.index)
    assert (_min_valid.loc[_idx] <= _min_test.loc[_idx]).all(), \
        "For some buildings, test windows start before the first valid window."

# Keep originals for post-subsample checks
orig_train_window_ids = train_window_ids.copy()
orig_valid_window_ids = valid_window_ids.copy()
orig_test_window_ids  = test_window_ids.copy()

# --- 6) (Optional) Subsample training windows to speed up experiments ---
if 0 < subset_fraction < 1.0:
    # Keep the earliest fraction within each building's *train* windows
    keep_train = []
    for bld, g in winfo[winfo["window_id"].isin(orig_train_window_ids)].groupby("building_id", sort=False):
        ids = g["window_id"].tolist()  # chronological
        k = max(1, int(len(ids) * subset_fraction))
        keep_train.extend(ids[:k])
    train_window_ids = set(keep_train)
    print(f"[Subsample] subset_fraction={subset_fraction} -> train={len(train_window_ids)}")

    # -----------------------
    # Sanity checks (AFTER subsampling)
    # -----------------------

    # 1) New train must be a subset of original train; still disjoint from valid/test
    assert train_window_ids.issubset(orig_train_window_ids), \
        "Subsampled train includes windows not in original train set."
    assert len(train_window_ids & orig_valid_window_ids) == 0, \
        "Train and Valid overlap after subsample."
    assert len(train_window_ids & orig_test_window_ids) == 0, \
        "Train and Test overlap after subsample."

    # 2) Coverage is now a subset of all windows (because we dropped some train windows on purpose)
    covered_after = train_window_ids | orig_valid_window_ids | orig_test_window_ids
    dropped = _all_window_ids - covered_after
    print(f"[Info] Dropped {len(dropped)} original train windows due to subset_fraction.")

print("\nCell 2 executed: Data loaded/merged and chronological splits created per building.")
train_df1.info()


Loading data...
Saving row_ids for submission...
Found 13453 windows across 91 buildings in train.
[Split per building] train=10720, valid=1310, test=1423

Cell 2 executed: Data loaded/merged and chronological splits created per building.
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2582976 entries, 0 to 2582975
Data columns (total 51 columns):
 #   Column                      Dtype         
---  ------                      -----         
 0   building_id                 string        
 1   window_id                   int64         
 2   timestamp                   datetime64[ns]
 3   meter_reading               float64       
 4   role                        object        
 5   region                      object        
 6   rooms                       float64       
 7   no_of_people                int64         
 8   area_in_sqft                float64       
 9   inverter                    float64       
 10  lights                      int64         
 11  ceiling_fans       

In [3]:
# --- Cell 3: Features → Splits → Preprocessor → Datasets ---
print("Creating time-based features...")

def create_time_features(df, timestamp_col):
    df = df.copy()
    df["hour"]       = df[timestamp_col].dt.hour
    df["dayofweek"]  = df[timestamp_col].dt.dayofweek
    df["month"]      = df[timestamp_col].dt.month
    df["hour_sin"]   = np.sin(2*np.pi*df["hour"]/24)
    df["hour_cos"]   = np.cos(2*np.pi*df["hour"]/24)
    return df

train_df1 = create_time_features(train_df1, timestamp_column)
test_df1  = create_time_features(test_df1,  timestamp_column)

# --- 3.1 Split rows by the window sets from Cell 2 ---
train_rows = train_df1["window_id"].isin(train_window_ids)
valid_rows = train_df1["window_id"].isin(valid_window_ids)
test_rows_comp = test_df1["window_id"].notna() if "window_id" in test_df1.columns else pd.Series(True, index=test_df1.index)

train_split = train_df1.loc[train_rows].copy()
valid_split = train_df1.loc[valid_rows].copy()
test_split  = test_df1.loc[test_rows_comp].copy()   # competition/test file

# --- 3.2 One-hot 'region' using TRAIN categories only (prevents leakage) ---
print("Encoding categorical features (region) with train-derived columns...")
train_region_dum = pd.get_dummies(train_split["region"], prefix="region")
region_cols = sorted(train_region_dum.columns)

def _apply_region_dummies(df, region_cols):
    Xr = pd.get_dummies(df["region"], prefix="region")
    return Xr.reindex(columns=region_cols, fill_value=0)

train_region = _apply_region_dummies(train_split, region_cols)
valid_region = _apply_region_dummies(valid_split, region_cols)
test_region  = _apply_region_dummies(test_split,  region_cols)

# --- 3.3 Assemble frames with numeric features + region dummies (exclude identifiers) ---
exclude = {"building_id","window_id","row_id","role","timestamp","meter_reading","region"}

feat_cols_no_region = [c for c in train_split.columns if c not in exclude]
feat_cols_no_region = sorted(feat_cols_no_region)   # keep consistent order

# Train
train_ready = pd.concat(
    [train_split[["window_id","timestamp","meter_reading","role"] + feat_cols_no_region], train_region],
    axis=1
)
# Valid
valid_ready = pd.concat(
    [valid_split[["window_id","timestamp","meter_reading","role"] + feat_cols_no_region], valid_region],
    axis=1
)
# Test (competition)
base_cols_test = ["window_id","timestamp","meter_reading"] + feat_cols_no_region
base_cols_test = [c for c in base_cols_test if c in test_split.columns]
test_ready  = pd.concat([test_split[base_cols_test], test_region], axis=1)

# --- 3.4 Optional numeric imputation using TRAIN+INPUT medians (no leakage) ---
from pandas.api.types import is_numeric_dtype

control_columns = feat_cols_no_region + region_cols
numeric_controls = [c for c in control_columns if c in train_ready.columns and is_numeric_dtype(train_ready[c])]

fit_mask = (train_ready["role"] == "input") if "role" in train_ready.columns else pd.Series(False, index=train_ready.index)
medians = train_ready.loc[fit_mask, numeric_controls].median(numeric_only=True)

def _fill_missing(df, cols, med):
    df = df.copy()
    if len(med):
        cols_present = [c for c in cols if c in df.columns]
        df[cols_present] = df[cols_present].fillna(med.reindex(cols_present))
    return df

train_ready = _fill_missing(train_ready, numeric_controls, medians)
valid_ready = _fill_missing(valid_ready, numeric_controls, medians)
test_ready  = _fill_missing(test_ready,  numeric_controls, medians)

# --- 3.4b Enforce numeric dtypes (prevents np.object_ later) ---
def _to_float32(df, cols):
    cols_present = [c for c in cols if c in df.columns]
    if cols_present:
        df[cols_present] = df[cols_present].apply(pd.to_numeric, errors="coerce").astype(np.float32)
    return df

for _df in (train_ready, valid_ready, test_ready):
    _df["meter_reading"] = pd.to_numeric(_df["meter_reading"], errors="coerce").astype(np.float32)

train_ready = _to_float32(train_ready, control_columns)
valid_ready = _to_float32(valid_ready, control_columns)
test_ready  = _to_float32(test_ready,  control_columns)

# --- 3.5 Fit preprocessor on INPUT rows only (no scaling here) ---
print("Fitting preprocessor on INPUT rows of train+valid+test windows...")

base_cols_test_with_role = ["window_id","timestamp","meter_reading","role"] + feat_cols_no_region
base_cols_test_with_role = [c for c in base_cols_test_with_role if c in test_split.columns]
test_ready_with_role = pd.concat([test_split[base_cols_test_with_role], test_region], axis=1)
test_ready_with_role = _fill_missing(test_ready_with_role, numeric_controls, medians)
test_ready_with_role["meter_reading"] = pd.to_numeric(test_ready_with_role["meter_reading"], errors="coerce").astype(np.float32)
test_ready_with_role = _to_float32(test_ready_with_role, control_columns)

tsp = TimeSeriesPreprocessor(
    id_columns=["window_id"],
    timestamp_column="timestamp",
    target_columns=["meter_reading"],
    control_columns=control_columns,
    context_length=context_length,
    prediction_length=prediction_length,
    scaling=False,                 # <<< CHANGED: disable global scaler
    freq="H",
)

fit_df = pd.concat(
    [
        train_ready.loc[train_ready["role"] == "input"],
        valid_ready.loc[valid_ready["role"] == "input"],
        test_ready_with_role.loc[test_ready_with_role["role"] == "input"],
    ],
    ignore_index=True
).drop(columns=["role"])

tsp.train(fit_df)

# --- 3.6 Drop 'role' and build datasets ---
train_ready_norole = train_ready.drop(columns=["role"])
valid_ready_norole = valid_ready.drop(columns=["role"])

print("Building datasets via ForecastDFDataset...")
from tsfm_public import ForecastDFDataset

dataset_params = dict(
    timestamp_column="timestamp",
    id_columns=["window_id"],
    target_columns=["meter_reading"],
    control_columns=control_columns,
    context_length=context_length,
    prediction_length=prediction_length,
)

train_proc = tsp.preprocess(train_ready_norole)
valid_proc = tsp.preprocess(valid_ready_norole)
test_proc  = tsp.preprocess(test_ready)

# sanity: ensure no object dtypes slipped through
def _assert_no_object(df, name):
    bad = [c for c in df.columns if df[c].dtype == "object"]
    if bad:
        raise TypeError(f"{name} has object dtypes: {bad[:10]}")

_assert_no_object(train_proc, "train_proc")
_assert_no_object(valid_proc, "valid_proc")
_assert_no_object(test_proc,  "test_proc")

train_dataset = ForecastDFDataset(train_proc, **dataset_params)
valid_dataset = ForecastDFDataset(valid_proc, **dataset_params)
test_dataset  = ForecastDFDataset(test_proc,  **dataset_params)

print("Datasets:",
      "train =", len(train_dataset),
      "| valid =", len(valid_dataset),
      "| test =", len(test_dataset))
print("Context/Pred:", tsp.context_length, tsp.prediction_length)

assert len(train_dataset) == len(train_window_ids)
assert len(valid_dataset) == len(valid_window_ids)

print("\nCell 3 done: global scaling OFF; model will do per-window instance norm.")


Creating time-based features...
Encoding categorical features (region) with train-derived columns...
Fitting preprocessor on INPUT rows of train+valid+test windows...
Building datasets via ForecastDFDataset...
Datasets: train = 10720 | valid = 1310 | test = 3525
Context/Pred: 168 24

Cell 3 done: global scaling OFF; model will do per-window instance norm.


In [4]:
# --- Cell 4 (additions marked) ---
from tsfm_public import TinyTimeMixerForPrediction, TinyTimeMixerConfig
import torch, torch.nn as nn

print("Loading TTM (keep checkpoint CL/patch; adapt head to FL=H)…")

CKPT = "ibm-granite/granite-timeseries-ttm-r2"

s0 = train_dataset[0]
L  = int(s0["past_values"].shape[0])
H  = int(s0["future_values"].shape[0])
C  = int(s0["past_values"].shape[1])
print(f"[data] L={L}, H={H}, C={C}")

cfg = TinyTimeMixerConfig.from_pretrained(CKPT)
cfg.head_dropout = 0.2  # add some head dropout
# (optional, if available) keep model-side instance norm ON
if hasattr(cfg, "scaling") and (cfg.scaling is None or cfg.scaling == "none"):
    cfg.scaling = "std"   # ensure built-in per-window std-scaler

ckpt_sl        = int(getattr(cfg, "context_length", 512))
ckpt_fl        = int(getattr(cfg, "prediction_length", 96))
ckpt_patch_len = int(getattr(cfg, "patch_length", 64))
print(f"[ckpt] sl={ckpt_sl}, fl={ckpt_fl}, patch_len={ckpt_patch_len}")

for k in ("num_input_channels", "nvars", "num_channels", "decoder_num_channels", "decoder_nvars"):
    if hasattr(cfg, k):
        setattr(cfg, k, C)

forecast_model = TinyTimeMixerForPrediction.from_pretrained(
    CKPT,
    config=cfg,
    ignore_mismatched_sizes=True
)

# prune head to H (unchanged) …
def _prune_linear_out_rows(linear: nn.Linear, keep_rows: int):
    new = nn.Linear(linear.in_features, keep_rows, bias=(linear.bias is not None))
    with torch.no_grad():
        new.weight.copy_(linear.weight[:keep_rows, :])
        if linear.bias is not None:
            new.bias.copy_(linear.bias[:keep_rows])
    return new

if hasattr(forecast_model.head, "base_forecast_block") and isinstance(forecast_model.head.base_forecast_block, nn.Linear):
    lin = forecast_model.head.base_forecast_block
    if lin.out_features != H and lin.out_features >= H:
        forecast_model.head.base_forecast_block = _prune_linear_out_rows(lin, H)
        print(f"[FLA] pruned head out_features: {lin.out_features} -> {H}")

# gating safety (unchanged) …
fixed = 0
for name, mod in forecast_model.named_modules():
    if name.endswith("channel_feature_mixer.gating_block.attn_layer") and isinstance(mod, nn.Linear):
        if mod.in_features != C or mod.out_features != C:
            parent = forecast_model
            for part in name.rsplit(".", 1)[0].split("."):
                parent = getattr(parent, part)
            setattr(parent, "attn_layer", nn.Linear(C, C, bias=True))
            fixed += 1
print(f"[fix] decoder gating attn layers reset: {fixed}")

forecast_model.config.prediction_length = H

# freeze encoder; train decoder+head
for p in forecast_model.parameters(): p.requires_grad = False
for n, p in forecast_model.named_parameters():
    if n.startswith("decoder.") or n.startswith("head."):
        p.requires_grad = True

# >>> Sanity: print active scaler in the backbone
try:
    print("Backbone scaler:", type(forecast_model.backbone.scaler).__name__)
except Exception:
    pass

# smoke test (unchanged) …
with torch.no_grad():
    pv  = s0["past_values"].float()
    pom = s0["past_observed_mask"].float()
    pad = ckpt_sl - pv.shape[0]
    assert pad >= 0
    filler = pv.mean(dim=0)
    pv  = torch.cat([filler.expand(pad, C), pv], dim=0).unsqueeze(0)
    pom = torch.cat([torch.zeros(pad, C),  pom], dim=0).unsqueeze(0)
    fv  = s0["future_values"].unsqueeze(0).float()
    fom = s0["future_observed_mask"].unsqueeze(0).float()
    _   = forecast_model(past_values=pv, future_values=fv,
                         past_observed_mask=pom, future_observed_mask=fom)

print("Cell 4 OK.")


Loading TTM (keep checkpoint CL/patch; adapt head to FL=H)…
[data] L=168, H=24, C=53
[ckpt] sl=512, fl=96, patch_len=64
[FLA] pruned head out_features: 96 -> 24
[fix] decoder gating attn layers reset: 0
Backbone scaler: TinyTimeMixerStdScaler
Cell 4 OK.


In [5]:
# --- Cell 5 (ES + Best Model on eval_loss): pad-to-ckpt-CL collator + RPT; 2-phase FT (head → head+decoder)
import math, torch, inspect
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, get_cosine_schedule_with_warmup

device = "cuda" if torch.cuda.is_available() else "cpu"
forecast_model.to(device)
set_seed(SEED)

# 1) Force MSE inside TTM (so Trainer's eval_loss == MSE)
if hasattr(forecast_model, "loss"):
    forecast_model.loss = "mse"
if hasattr(forecast_model.config, "loss"):
    forecast_model.config.loss = "mse"
print("Using loss =", getattr(forecast_model, "loss", None))

CKPT_SL = int(getattr(forecast_model.config, "context_length", 512))
FREQ_ID = 1  # set to your dataset’s resolution id (e.g., hourly=1)

# 2) Collator: left-pad 168→CKPT_SL, mask pad, add freq_token; only forward-allowed keys
_forward_allowed = set(inspect.signature(forecast_model.forward).parameters.keys())
def pad_to_ckpt_collate(batch):
    out = {}
    C = batch[0]["past_values"].shape[-1]
    pv_list, pom_list = [], []
    for s in batch:
        pv  = s["past_values"].float()
        pom = s["past_observed_mask"].float()
        pad = CKPT_SL - pv.shape[0]
        if pad < 0:
            pv, pom = pv[-CKPT_SL:], pom[-CKPT_SL:]
            pad = 0
        fill   = pv.mean(dim=0)
        pv_pad = torch.cat([fill.expand(pad, C), pv], dim=0)
        om_pad = torch.cat([torch.zeros(pad, C),  pom], dim=0)
        pv_list.append(pv_pad)
        pom_list.append(om_pad)

    out["past_values"]        = torch.stack(pv_list, 0)
    out["past_observed_mask"] = torch.stack(pom_list, 0)
    if "future_values" in batch[0]:
        out["future_values"] = torch.stack([s["future_values"].float() for s in batch], 0)
    if "future_observed_mask" in batch[0]:
        out["future_observed_mask"] = torch.stack([s["future_observed_mask"].float() for s in batch], 0)

    # Resolution Prefix Token (RPT)
    fts = []
    for s in batch:
        ft = s.get("freq_token", None)
        if ft is None:
            ft = torch.tensor(FREQ_ID, dtype=torch.long)
        elif not torch.is_tensor(ft):
            ft = torch.tensor(int(ft), dtype=torch.long)
        fts.append(ft)
    out["freq_token"] = torch.stack(fts, 0)

    # Make model compute & return loss during train/eval
    if "return_loss" in _forward_allowed:
        out["return_loss"] = True
    if "return_dict" in _forward_allowed:
        out["return_dict"] = True

    return {k: v for k, v in out.items() if k in _forward_allowed}

# 3) Build Trainer per phase with AdamW + cosine schedule + early stopping on eval_loss
def make_trainer(model, train_ds, eval_ds, lr_head, lr_dec, epochs, output_dir, bs=256, ga=2, wd=1e-2):
    head_named = [(n,p) for n,p in model.named_parameters() if p.requires_grad and n.startswith("head.")]
    dec_named  = [(n,p) for n,p in model.named_parameters() if p.requires_grad and n.startswith("decoder.")]

    def _split_decay(named_params):
        decay, no_decay = [], []
        for n, p in named_params:
            if n.endswith(".bias") or ("norm" in n.lower()) or ("layernorm" in n.lower()):
                no_decay.append(p)
            else:
                decay.append(p)
        return decay, no_decay

    head_decay, head_no_decay = _split_decay(head_named)
    dec_decay,  dec_no_decay  = _split_decay(dec_named)

    groups = []
    if head_decay:     groups.append({"params": head_decay,     "lr": lr_head, "weight_decay": wd})
    if head_no_decay:  groups.append({"params": head_no_decay,  "lr": lr_head, "weight_decay": 0.0})
    if dec_decay:      groups.append({"params": dec_decay,      "lr": lr_dec,  "weight_decay": wd})
    if dec_no_decay:   groups.append({"params": dec_no_decay,   "lr": lr_dec,  "weight_decay": 0.0})

    use_fused = hasattr(torch.optim, "AdamW") and "fused" in torch.optim.AdamW.__init__.__code__.co_varnames
    optimizer = torch.optim.AdamW(groups, fused=use_fused)

    steps_per_epoch = max(1, math.ceil(len(train_ds) / (bs * ga)))
    t_total = steps_per_epoch * epochs
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=int(0.1 * t_total), num_training_steps=t_total
    )

    sm_major = torch.cuda.get_device_capability()[0] if torch.cuda.is_available() else 0
    use_bf16 = torch.cuda.is_available() and sm_major >= 8

    args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        learning_rate=max(lr_head, lr_dec, 1e-8),  # real LRs set per param group
        num_train_epochs=epochs,
        per_device_train_batch_size=bs,
        per_device_eval_batch_size=bs,
        gradient_accumulation_steps=ga,

        # eval/log/save each epoch
        eval_strategy="epoch",
        logging_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,

        # >>> keep the best checkpoint by eval_loss (MSE) <<<
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,

        # keep our keys; collator already filters to forward-allowed
        remove_unused_columns=False,

        dataloader_num_workers=8,
        bf16=use_bf16,
        fp16=(torch.cuda.is_available() and not use_bf16),
        max_grad_norm=1.0,
        report_to=[],
        seed=SEED,
    )

    es = EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.0)

    print(
        f"[make_trainer] head params: {sum(p.numel() for _,p in head_named):,} | "
        f"decoder params: {sum(p.numel() for _,p in dec_named):,} | "
        f"LRs: head={lr_head}, dec={lr_dec} | epochs={epochs}, batch={bs}x{ga}"
    )

    return Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        data_collator=pad_to_ckpt_collate,
        optimizers=(optimizer, scheduler),
        callbacks=[es],
    )

# 4) Phase settings
bs, ga = 32, 2
epochs_phase1 = 8
epochs_phase2 = 28

# ----------------- Phase 1: HEAD ONLY -----------------
for n, p in forecast_model.named_parameters():
    if n.startswith("decoder."): p.requires_grad = False
    if n.startswith("head."):    p.requires_grad = True

print("Phase 1: head-only fine-tuning…")
trainer = make_trainer(
    forecast_model, train_dataset, valid_dataset,
    lr_head=1e-3, lr_dec=0.0, epochs=epochs_phase1,
    output_dir="./ttm_ft_phase1_head", bs=bs, ga=ga
)
trainer.train()
print(f"[Phase 1] best_checkpoint={trainer.state.best_model_checkpoint} | best_eval_loss={trainer.state.best_metric}")

# ----------------- Phase 2: HEAD + DECODER -----------------
for n, p in forecast_model.named_parameters():
    if n.startswith("decoder."): p.requires_grad = True

print("Phase 2: head+decoder fine-tuning…")
trainer = make_trainer(
    forecast_model, train_dataset, valid_dataset,
    lr_head=8e-4, lr_dec=2e-4, epochs=epochs_phase2,
    output_dir="./ttm_ft_phase2_dec", bs=bs, ga=ga
)
trainer.train()
print(f"[Phase 2] best_checkpoint={trainer.state.best_model_checkpoint} | best_eval_loss={trainer.state.best_metric}")

print("✅ Training complete (ckpt CL preserved via padding, FL=24, RPT on, EarlyStopping + best model by eval_loss).")


Using loss = mse
Phase 1: head-only fine-tuning…
[make_trainer] head params: 24,600 | decoder params: 0 | LRs: head=0.001, dec=0.0 | epochs=8, batch=32x2


Epoch,Training Loss,Validation Loss
1,0.2562,0.1083
2,0.1043,0.085419
3,0.0908,0.079416
4,0.0848,0.076066
5,0.0825,0.074934
6,0.0812,0.074268
7,0.0808,0.074029
8,0.0805,0.074072


[Phase 1] best_checkpoint=./ttm_ft_phase1_head/checkpoint-1176 | best_eval_loss=0.07402917742729187
Phase 2: head+decoder fine-tuning…
[make_trainer] head params: 24,600 | decoder params: 191,296 | LRs: head=0.0008, dec=0.0002 | epochs=28, batch=32x2


Epoch,Training Loss,Validation Loss
1,0.0737,0.069196
2,0.0634,0.066404
3,0.0591,0.065382
4,0.0565,0.06384
5,0.0549,0.063151
6,0.054,0.062671
7,0.0534,0.062039
8,0.053,0.06187
9,0.0526,0.061852
10,0.0524,0.062039


[Phase 2] best_checkpoint=./ttm_ft_phase2_dec/checkpoint-2856 | best_eval_loss=0.06106005236506462
✅ Training complete (ckpt CL preserved via padding, FL=24, RPT on, EarlyStopping + best model by eval_loss).


In [6]:
print("Best ckpt:", trainer.state.best_model_checkpoint)

Best ckpt: ./ttm_ft_phase2_dec/checkpoint-2856


In [7]:
# --- Cell A: write validation preds & GT with per-window de-normalization (paper-style) ---
import torch, numpy as np, pandas as pd, os

forecast_model.eval()
device = next(forecast_model.parameters()).device

SL = int(getattr(forecast_model.config, "context_length", 512))
H  = int(getattr(forecast_model.config, "prediction_length", 24))
CONTEXT_LEN = int(globals().get("context_length", 168))
tgt_idx = getattr(tsp, "prediction_channel_indices", [0])[0]  # target channel

# 0) ordered window_ids for valid set (1 sample == 1 window)
val_window_ids = valid_proc["window_id"].astype(int).drop_duplicates().tolist()
assert len(val_window_ids) == len(valid_dataset), "valid ids ↔ dataset length mismatch"

def _pad_or_crop_to_SL(pv, pom, SL):
    T, C = pv.shape
    if T == SL: return pv, pom
    if T > SL:  return pv[-SL:, :], pom[-SL:, :]
    pad = SL - T
    filler = pv.mean(dim=0) if T > 0 else torch.zeros(C, dtype=pv.dtype, device=pv.device)
    pv2  = torch.cat([filler.expand(pad, C), pv], dim=0)
    pom2 = torch.cat([torch.zeros(pad, C, dtype=pom.dtype, device=pom.device), pom], dim=0)
    return pv2, pom2

def _extract_forecast(outputs):
    # robust across TTM variants
    if isinstance(outputs, dict):
        for k in ("forecast","forecasts","prediction_outputs","predictions","mean","loc"):
            if k in outputs: 
                v = outputs[k]
                return v[0] if isinstance(v,(list,tuple)) and torch.is_tensor(v[0]) else v
    for k in ("forecast","forecasts","prediction_outputs","predictions","mean","loc"):
        if hasattr(outputs, k):
            v = getattr(outputs, k)
            return v[0] if isinstance(v,(list,tuple)) and torch.is_tensor(v[0]) else v
    if isinstance(outputs,(list,tuple)) and torch.is_tensor(outputs[0]): 
        return outputs[0]
    raise RuntimeError("forecast tensor not found in model output")

ids, yhat_all, ytrue_all = [], [], []
bs, N = 64, len(valid_dataset)

with torch.no_grad():
    for start in range(0, N, bs):
        end   = min(N, start + bs)
        batch = [valid_dataset[i] for i in range(start, end)]
        wids  = val_window_ids[start:end]

        pv_list, pom_list, fv_list, mean_list, std_list = [], [], [], [], []
        for s in batch:
            pv  = torch.as_tensor(s["past_values"],        dtype=torch.float32, device=device)  # [Lp,C]
            pom = torch.as_tensor(s["past_observed_mask"], dtype=torch.float32, device=device)  # [Lp,C]
            fv  = torch.as_tensor(s["future_values"],      dtype=torch.float32, device=device)  # [H ,C]
            fom = torch.as_tensor(s["future_observed_mask"], dtype=torch.float32, device=device)  # [H ,C]

            # instance stats from first CONTEXT_LEN past steps
            Lp, C = pv.shape
            ctx   = min(CONTEXT_LEN, Lp)
            pv_c, pom_c = pv[:ctx, :], pom[:ctx, :]

            denom = torch.clamp(pom_c.sum(dim=0, keepdim=True), min=1.0)       # (1,C)
            mean  = (pv_c * pom_c).sum(dim=0, keepdim=True) / denom            # (1,C)
            var   = ((pv_c - mean)**2 * pom_c).sum(dim=0, keepdim=True) / denom
            std   = torch.sqrt(torch.clamp(var, min=1e-6))                     # (1,C)

            # normalize past & future (paper)
            pv_n = (pv - mean) / std
            fv_n = (fv - mean) / std

            # left-pad to ckpt CL
            pv_n, pom = _pad_or_crop_to_SL(pv_n, pom, SL)

            pv_list.append(pv_n)
            pom_list.append(pom)
            fv_list.append(fv_n)
            mean_list.append(mean)
            std_list.append(std)

        model_inputs = {
            "past_values":        torch.stack(pv_list,  0),   # [B,SL,C]
            "past_observed_mask": torch.stack(pom_list, 0),   # [B,SL,C]
            "return_dict": True, "return_loss": False
        }
        outputs = forecast_model(**model_inputs)
        yhat = _extract_forecast(outputs)                     # [B,H] or [B,H,1]/[B,H,C]

        # ensure [B,H] on target channel
        if yhat.dim() == 3 and yhat.size(-1) == 1:
            yhat = yhat.squeeze(-1)
        elif yhat.dim() == 3 and yhat.size(-1) > 1:
            yhat = yhat[..., tgt_idx]
        assert yhat.dim() == 2 and yhat.size(1) == H

        yhat = yhat.detach().cpu()
        fv_n = torch.stack(fv_list, 0).detach().cpu()        # [B,H,C]
        mean = torch.stack(mean_list, 0).detach().cpu()      # [B,1,C]
        std  = torch.stack(std_list, 0).detach().cpu()       # [B,1,C]

        # grab target channel for GT
        ytrue_n = fv_n[..., tgt_idx]                         # [B,H]

        # de-normalize both pred & true: x = x_norm * std + mean  (channel-wise)
        mean_t = mean[..., tgt_idx]                          # [B,1]
        std_t  = std[..., tgt_idx]                           # [B,1]
        yhat_u  = (yhat * std_t)  + mean_t                   # [B,H]
        ytrue_u = (ytrue_n * std_t) + mean_t                 # [B,H]

        # flatten & attach ids
        for wid, seq_p, seq_t in zip(wids, yhat_u.numpy(), ytrue_u.numpy()):
            ids.extend([f"{int(wid)}_{h}" for h in range(1, H+1)])
            yhat_all.extend(seq_p.tolist())
            ytrue_all.extend(seq_t.tolist())

# write files
pred_path = f"val_preds_seed{SEED}_r2.csv"
gt_path   = "val_ground_truth.csv"

pd.DataFrame({"id": ids, "y_hat":  np.asarray(yhat_all,  dtype=np.float32)}).to_csv(pred_path, index=False)

# write GT (same across seeds; overwrite safely)
gt_df = pd.DataFrame({"id": ids, "y_true": np.asarray(ytrue_all, dtype=np.float32)})
if (not os.path.exists(gt_path)) or (pd.read_csv(gt_path).shape != gt_df.shape):
    gt_df.to_csv(gt_path, index=False)
print(f"✅ wrote {pred_path} and {gt_path} | rows={len(ids)}")


✅ wrote val_preds_seed79_r2.csv and val_ground_truth.csv | rows=31440


In [8]:
total_params     = sum(p.numel() for p in forecast_model.parameters())
trainable_params = sum(p.numel() for p in forecast_model.parameters() if p.requires_grad)
print("TOTAL:", total_params, "| TRAINABLE:", trainable_params)


TOTAL: 731480 | TRAINABLE: 215896


In [9]:
# --- Cell 6 (Path A): Inference & Submission (model-side instance norm only) ---

import torch, numpy as np, pandas as pd, inspect

forecast_model.eval()
device = next(forecast_model.parameters()).device

# 0) Model context length & horizon
SL = int(getattr(forecast_model.config, "context_length", 512))   # ckpt CL (e.g., 512)
H  = int(getattr(forecast_model.config, "prediction_length", 24)) # pruned head to this
print(f"[model] context_length={SL}, prediction_length={H}")

# 1) Build window_id -> ordered row_id list (TARGET rows from original test)
assert {"window_id","timestamp","row_id","role"}.issubset(test_df1.columns), \
    "test_df1 must contain window_id, timestamp, row_id, role"

target_rows = (test_df1.loc[test_df1["role"] == "target", ["window_id","timestamp","row_id"]]
               .copy()
               .sort_values(["window_id","timestamp"]))
target_rows["window_id"] = target_rows["window_id"].astype(int)
window_to_rowids = {wid: grp["row_id"].tolist()
                    for wid, grp in target_rows.groupby("window_id", sort=False)}

# 2) helpers
def _extract_forecast(outputs):
    # Try common fields
    if isinstance(outputs, dict):
        for k in ("forecast","forecasts","prediction_outputs","predictions","mean","loc"):
            if k in outputs:
                v = outputs[k]
                return v[0] if isinstance(v, (list, tuple)) and torch.is_tensor(v[0]) else v
    for k in ("forecast","forecasts","prediction_outputs","predictions","mean","loc"):
        if hasattr(outputs, k):
            v = getattr(outputs, k)
            return v[0] if isinstance(v, (list, tuple)) and torch.is_tensor(v[0]) else v
    if isinstance(outputs, (list, tuple)) and torch.is_tensor(outputs[0]):
        return outputs[0]
    raise RuntimeError("Could not find forecast tensor in model output.")

def _extract_wid(sample):
    if "window_id" in sample:
        v = sample["window_id"]
        if torch.is_tensor(v): return int(v.view(-1)[0].item())
        if isinstance(v, (np.integer, int, float)): return int(v)
        if isinstance(v, (list, tuple)) and len(v):
            vv = v[0]; return int(vv.item() if torch.is_tensor(vv) else vv)
    if "id" in sample:
        v = sample["id"]
        if torch.is_tensor(v): return int(v.view(-1)[0].item())
        if isinstance(v, (np.integer, int, float)): return int(v)
        if isinstance(v, (list, tuple)) and len(v):
            for vv in v:
                if torch.is_tensor(vv): return int(vv.view(-1)[0].item())
                if isinstance(vv, (np.integer, int, float)): return int(vv)
            try: return int(v[-1])
            except: pass
        if isinstance(v, dict):
            if "window_id" in v: return int(v["window_id"])
            for vv in v.values():
                if isinstance(vv, (np.integer, int, float)): return int(vv)
    raise KeyError("Cannot extract window_id from sample.")

def _pad_or_crop_to_SL(pv: torch.Tensor, pom: torch.Tensor, SL: int):
    # pv, pom are [T, C]
    T, C = pv.shape
    if T == SL:
        return pv, pom
    if T > SL:
        return pv[-SL:, :], pom[-SL:, :]
    pad = SL - T
    filler = pv.mean(dim=0) if T > 0 else torch.zeros(C, dtype=pv.dtype, device=pv.device)
    pv2  = torch.cat([filler.expand(pad, C), pv], dim=0)         # [SL, C]
    pom2 = torch.cat([torch.zeros(pad, C, dtype=pom.dtype, device=pom.device), pom], dim=0)
    return pv2, pom2

# 3) inference loop
bs = 64
N  = len(test_dataset)
preds_per_window = {}

forward_allowed = set(inspect.signature(forecast_model.forward).parameters.keys())
FREQ_ID = 1  # hourly id (only used if model actually accepts freq_token)

with torch.no_grad():
    for start in range(0, N, bs):
        end   = min(N, start + bs)
        batch = [test_dataset[i] for i in range(start, end)]
        wids  = [ _extract_wid(s) for s in batch ]

        # Prepare per-sample padded past_values / past_observed_mask
        pv_list, pom_list = [], []
        for s in batch:
            pv  = s["past_values"]
            pom = s.get("past_observed_mask", None)
            pv  = torch.as_tensor(pv, device=device, dtype=torch.float32)        # [T, C]
            if pom is None:
                pom = torch.ones_like(pv, device=device)                         # assume fully observed
            else:
                pom = torch.as_tensor(pom, device=device, dtype=torch.float32)

            pv, pom = _pad_or_crop_to_SL(pv, pom, SL)                            # mask zeros on pad
            pv_list.append(pv)
            pom_list.append(pom)

        model_inputs = {
            "past_values":        torch.stack(pv_list, 0),   # [B, SL, C]
            "past_observed_mask": torch.stack(pom_list, 0),  # [B, SL, C]
        }

        # Optional features if present in the dataset and accepted by the model
        for opt_key in ("metadata", "static_categorical_values"):
            if (opt_key in batch[0]) and (opt_key in forward_allowed):
                vals = [torch.as_tensor(s[opt_key]) if not torch.is_tensor(s[opt_key]) else s[opt_key] for s in batch]
                model_inputs[opt_key] = torch.stack(vals, 0).to(device)

        # If the model accepts freq_token, provide a constant hourly token
        if "freq_token" in forward_allowed:
            model_inputs["freq_token"] = torch.full((len(batch),), FREQ_ID, dtype=torch.long, device=device)

        # IMPORTANT: don't pass future_* at test-time (no loss, just predictions)
        outputs = forecast_model(**model_inputs, return_dict=True, return_loss=False)
        yhat = _extract_forecast(outputs)  # e.g., [B, H] or [B, H, 1] or [B, H, C?]

        # Normalize shapes to [B, H]
        if yhat.dim() == 3 and yhat.size(-1) == 1:
            yhat = yhat.squeeze(-1)
        elif yhat.dim() == 3 and yhat.size(-1) > 1:
            # pick target channel index 0 (or use tsp.prediction_channel_indices[0] if available)
            tgt_idx = getattr(tsp, "prediction_channel_indices", [0])[0]
            yhat = yhat[..., tgt_idx]
        assert yhat.dim() == 2 and yhat.size(1) == H, f"Expected [B, {H}], got {tuple(yhat.shape)}"
        yhat = yhat.detach().cpu().numpy()

        # Path A: DO NOT inverse-transform with tsp/global scalers.
        for wid, seq in zip(wids, yhat):
            seq_unscaled = seq                          # model did per-window norm internally
            seq_unscaled = np.clip(seq_unscaled, 0, None)
            preds_per_window[wid] = seq_unscaled

# 4) assemble submission
rows = []
missing_windows = []
count_mismatch  = 0

for wid, pred in preds_per_window.items():
    row_ids = window_to_rowids.get(wid)
    if not row_ids:
        missing_windows.append(wid)
        continue
    if len(row_ids) != len(pred):
        count_mismatch += 1
    m = min(len(row_ids), len(pred))
    rows.extend(zip(row_ids[:m], pred[:m]))

if missing_windows:
    print(f"[warn] {len(missing_windows)} predicted windows missing row_ids (e.g., {missing_windows[:5]})")
if count_mismatch:
    print(f"[warn] {count_mismatch} windows had horizon length mismatch; truncated.")

submission = pd.DataFrame(rows, columns=["row_id", "meter_reading"]).sort_values("row_id")
print("Submission shape:", submission.shape)
display(submission.head(60))




[model] context_length=512, prediction_length=24
Submission shape: (84600, 2)


Unnamed: 0,row_id,meter_reading
0,169,0.347524
1,170,0.364863
2,171,0.374867
3,172,0.371532
4,173,0.355526
5,174,0.390872
6,175,0.541342
7,176,0.753001
8,177,0.785012
9,178,0.568936


In [10]:
# Save
submission_path = f"submission_seed_{SEED}_r2.csv"
submission.to_csv(submission_path, index=False)
print(f"✅ Wrote {submission_path}")

✅ Wrote submission_seed_79_r2.csv


In [11]:
# # --- Cell B (MSE-only, no calibration): convex blend seeds and write final submission ---

# import numpy as np, pandas as pd

# # point to your per-seed validation preds and ground truth (from Cell A)
# VAL_FILES = [
#     "val_preds_seed40_r1.csv",  # id,y_hat
#     "val_preds_seed42_r1.csv",
#     "val_preds_seed50_r1.csv",
# ]
# VAL_GT = "val_ground_truth.csv"       # id,y_true

# # point to your per-seed test submissions
# TEST_FILES = [
#     "submission_seed_40_r1.csv",       # row_id,meter_reading
#     "submission_seed_42_r1.csv",
#     "submission_seed_50_r1.csv",
# ]

# def _project_to_simplex(v):
#     v = np.asarray(v, dtype=np.float64)
#     if np.isclose(v.sum(), 1.0) and np.all(v >= 0): return v
#     n = v.size; u = np.sort(v)[::-1]; cssv = np.cumsum(u)
#     rho = np.nonzero(u * np.arange(1, n+1) > (cssv - 1))[0][-1]
#     theta = (cssv[rho] - 1) / float(rho + 1)
#     return np.maximum(v - theta, 0.0)

# def fit_convex_blend_weights_mse(val_preds_list, y_true, l2=1e-8, max_iter=2000, tol=1e-9):
#     """
#     Solve: min_w  (1/N)||P w - y||^2 + l2||w||^2  s.t. w>=0, sum w = 1
#     """
#     P = np.stack([p.astype(np.float64) for p in val_preds_list], axis=1)  # [N, K]
#     y = y_true.astype(np.float64)
#     N, K = P.shape
#     # Lipschitz step for PGD
#     M = (P.T @ P) / N + l2 * np.eye(K)
#     L = float(np.linalg.eigvalsh(M).max())
#     step = 1.0 / (L + 1e-12)

#     w = np.ones(K) / K
#     for _ in range(max_iter):
#         grad = (P.T @ (P @ w - y)) / N + l2 * w
#         w_new = _project_to_simplex(w - step * grad)
#         if np.linalg.norm(w_new - w) < tol:
#             w = w_new
#             break
#         w = w_new
#     blended_val = P @ w
#     return w, blended_val

# def mse(a, b):
#     a = np.asarray(a, np.float64); b = np.asarray(b, np.float64)
#     return float(np.mean((a - b) ** 2))

# # 1) load validation ground truth and merge per-seed val preds
# gt = pd.read_csv(VAL_GT)              # columns: id, y_true
# base = gt.copy()
# val_cols = []
# for i, f in enumerate(VAL_FILES):
#     df = pd.read_csv(f)               # columns: id, y_hat
#     col = f"y_hat_{i}"
#     df = df.rename(columns={"y_hat": col}) if "y_hat" in df.columns else df
#     base = base.merge(df[["id", col]], on="id", how="inner")
#     val_cols.append(col)

# y_true = base["y_true"].to_numpy()
# val_preds_list = [base[c].to_numpy() for c in val_cols]

# # 2) fit convex MSE weights and print diagnostics
# w, blended_val = fit_convex_blend_weights_mse(val_preds_list, y_true, l2=1e-8)
# print("\n=== Convex blend weights (sum=1, MSE) ===")
# for f, wi in zip(VAL_FILES, w):
#     print(f"  {f}: {wi:.4f}")
# print(f"Val MSE (best single): {min(mse(p, y_true) for p in val_preds_list):.6f}")
# print(f"Val MSE (blended)    : {mse(blended_val, y_true):.6f}")

# # 3) apply the same weights to test submissions, save final
# subs = [pd.read_csv(f)[["row_id","meter_reading"]].sort_values("row_id").reset_index(drop=True)
#         for f in TEST_FILES]
# row_ids = subs[0]["row_id"].to_numpy()
# for i, df in enumerate(subs[1:], start=2):
#     if not np.array_equal(row_ids, df["row_id"].to_numpy()):
#         raise ValueError(f"Row IDs differ between test file 1 and test file {i}.")

# stack = np.stack([df["meter_reading"].to_numpy() for df in subs], axis=1)  # [N, K]
# y_blend_test = stack @ w
# y_blend_test = np.clip(y_blend_test, 0.0, None)

# final = pd.DataFrame({"row_id": row_ids, "meter_reading": y_blend_test.astype(np.float32)})
# final.to_csv("submission_blend_mse_only_r1.csv", index=False)
# print("✅ Wrote submission_blend_mse_only_r1.csv | rows =", len(final))


In [13]:
# --- Blend R1-ensemble with R2-ensemble (MSE-only, no calibration) ---
# It (1) fits convex weights per family (R1, R2) on validation,
#     (2) fits a scalar α to blend the two family ensembles on validation,
#     (3) applies those weights to test CSVs and writes the final submission.
#
# Also includes Option B: one-shot convex blend across ALL 6 seeds.

import numpy as np, pandas as pd

# =========== EDIT ONLY THESE PATHS ============

# Validation ground truth (from your Cell A)
VAL_GT = "val_ground_truth.csv"     # columns: id,y_true

# R1 validation prediction files (id,y_hat) and test submissions (row_id,meter_reading)
R1_VAL_FILES = [
    "val_preds_seed40_r1.csv",
    "val_preds_seed42_r1.csv",
    "val_preds_seed50_r1.csv",
    "val_preds_seed79_r1.csv",
    "val_preds_seed17_r1.csv"
]
R1_TEST_FILES = [
    "submission_seed_40_r1.csv",
    "submission_seed_42_r1.csv",
    "submission_seed_50_r1.csv",
    "submission_seed_79_r1.csv",
    "submission_seed_17_r1.csv"
]

# R2 validation prediction files (id,y_hat) and test submissions (row_id,meter_reading)
R2_VAL_FILES = [
    "val_preds_seed40.csv",
    "val_preds_seed42.csv",
    "val_preds_seed50.csv",
    "val_preds_seed79_r2.csv",
    "val_preds_seed19_r2.csv"
]
R2_TEST_FILES = [
    "submission_seed_40_0002.csv",
    "submission_seed_42_0002.csv",
    "submission_seed_50_0002.csv",
    "submission_seed_79_r2.csv",
    "submission_seed_19_r2.csv"
]

OUT_FAMILY_BLEND = "submission_blend_r1r2_family_then_scalar_5seeds.csv"
OUT_GLOBAL_BLEND = "submission_blend_r1r2_global10_5seeds.csv"

# ==============================================

def _project_to_simplex(v):
    v = np.asarray(v, dtype=np.float64)
    if np.isclose(v.sum(), 1.0) and np.all(v >= 0): return v
    n = v.size; u = np.sort(v)[::-1]; cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, n+1) > (cssv - 1))[0][-1]
    theta = (cssv[rho] - 1) / float(rho + 1)
    return np.maximum(v - theta, 0.0)

def fit_convex_blend_weights_mse(val_preds_list, y_true, l2=1e-8, max_iter=2000, tol=1e-9):
    """
    Solve: min_w (1/N)||P w - y||^2 + l2||w||^2  s.t. w>=0, sum w = 1
    Returns weights w and blended validation preds P@w
    """
    P = np.stack([p.astype(np.float64) for p in val_preds_list], axis=1)  # [N,K]
    y = y_true.astype(np.float64)
    N, K = P.shape
    # Lipschitz step for PGD
    M = (P.T @ P) / N + l2 * np.eye(K)
    L = float(np.linalg.eigvalsh(M).max())
    step = 1.0 / (L + 1e-12)

    w = np.ones(K) / K
    for _ in range(max_iter):
        grad = (P.T @ (P @ w - y)) / N + l2 * w
        w_new = _project_to_simplex(w - step * grad)
        if np.linalg.norm(w_new - w) < tol:
            w = w_new
            break
        w = w_new
    return w, (P @ w)

def mse(a, b):
    a = np.asarray(a, np.float64); b = np.asarray(b, np.float64)
    return float(np.mean((a - b) ** 2))

def _load_val_matrix(files, base_ids):
    """Merge multiple 'id,y_hat' CSVs onto base id order; return stacked preds [N,K]."""
    mats = []
    for f in files:
        df = pd.read_csv(f)
        # Normalize column names
        if "y_hat" not in df.columns and "meter_reading" in df.columns:
            df = df.rename(columns={"meter_reading":"y_hat"})
        df = df[["id","y_hat"]].drop_duplicates("id", keep="last")
        # align on base ids
        df = base_ids.merge(df, on="id", how="inner")
        mats.append(df["y_hat"].to_numpy())
    return np.stack(mats, axis=1)  # [N,K]

def _blend_test_from_weights(test_files, w):
    """Apply weights w to list of test CSVs (row_id,meter_reading)."""
    subs = [pd.read_csv(f)[["row_id","meter_reading"]].sort_values("row_id").reset_index(drop=True)
            for f in test_files]
    row_ids = subs[0]["row_id"].to_numpy()
    for i, df in enumerate(subs[1:], start=2):
        if not np.array_equal(row_ids, df["row_id"].to_numpy()):
            raise ValueError(f"Row IDs differ between test file 1 and test file {i}.")
    stack = np.stack([df["meter_reading"].to_numpy() for df in subs], axis=1)  # [N,K]
    y = stack @ w
    return row_ids, np.clip(y, 0.0, None)

# ---------- 1) Load validation ground truth ----------
gt = pd.read_csv(VAL_GT)[["id","y_true"]].drop_duplicates("id", keep="last").sort_values("id").reset_index(drop=True)
base_ids = gt[["id"]]  # keep order
y_true = gt["y_true"].to_numpy()

# ---------- 2) Family-level convex weights on validation ----------
# R1 family
P_r1 = _load_val_matrix(R1_VAL_FILES, base_ids)  # [N, K1]
w_r1, r1_val_blend = fit_convex_blend_weights_mse([P_r1[:,i] for i in range(P_r1.shape[1])], y_true, l2=1e-8)
print("\n[R1] per-seed weights:", " ".join(f"{wi:.4f}" for wi in w_r1),
      "| best single MSE:", f"{min(mse(P_r1[:,i], y_true) for i in range(P_r1.shape[1])):.6f}",
      "| blended MSE:", f"{mse(r1_val_blend, y_true):.6f}")

# R2 family
P_r2 = _load_val_matrix(R2_VAL_FILES, base_ids)  # [N, K2]
w_r2, r2_val_blend = fit_convex_blend_weights_mse([P_r2[:,i] for i in range(P_r2.shape[1])], y_true, l2=1e-8)
print("[R2] per-seed weights:", " ".join(f"{wi:.4f}" for wi in w_r2),
      "| best single MSE:", f"{min(mse(P_r2[:,i], y_true) for i in range(P_r2.shape[1])):.6f}",
      "| blended MSE:", f"{mse(r2_val_blend, y_true):.6f}")

# ---------- 3) Learn scalar α to blend R1 vs R2 on validation ----------
# We solve min_{α∈[0,1]} MSE( α r1 + (1-α) r2, y_true )
# This is a convex 1D problem; closed form under no box, then clamp to [0,1].
r1 = r1_val_blend.astype(np.float64)
r2 = r2_val_blend.astype(np.float64)
y  = y_true.astype(np.float64)

num = np.sum((r1 - r2) * (y - r2))
den = np.sum((r1 - r2) ** 2) + 1e-12
alpha = float(np.clip(num / den, 0.0, 1.0))  # α on R1; (1-α) on R2
val_mse_family_scalar = mse(alpha * r1 + (1 - alpha) * r2, y)
print(f"\n[Family→Scalar] alpha(R1)={alpha:.4f}, (1-alpha)(R2)={(1-alpha):.4f} | Val MSE={val_mse_family_scalar:.6f}")

# ---------- 4) Apply to TEST: family blends then scalar blend ----------
row_ids_r1, y_r1_test = _blend_test_from_weights(R1_TEST_FILES, w_r1)
row_ids_r2, y_r2_test = _blend_test_from_weights(R2_TEST_FILES, w_r2)
if not np.array_equal(row_ids_r1, row_ids_r2):
    raise ValueError("Row IDs differ between R1 and R2 test blends.")
y_final_family = alpha * y_r1_test + (1 - alpha) * y_r2_test
final_family = pd.DataFrame({"row_id": row_ids_r1, "meter_reading": y_final_family.astype(np.float32)})
final_family.to_csv(OUT_FAMILY_BLEND, index=False)
print(f"✅ Wrote {OUT_FAMILY_BLEND} | rows={len(final_family)}")

# ---------- (Optional) 5) One-shot global convex blend across ALL 6 seeds ----------
# Build global validation matrix [N, K1+K2] and learn weights; apply to all 6 test files.
ALL_VAL_FILES = R1_VAL_FILES + R2_VAL_FILES
ALL_TEST_FILES = R1_TEST_FILES + R2_TEST_FILES

P_all = _load_val_matrix(ALL_VAL_FILES, base_ids)  # [N, K]
w_all, val_blend_all = fit_convex_blend_weights_mse([P_all[:,i] for i in range(P_all.shape[1])], y_true, l2=1e-8)
print("\n[Global-6] weights:", " ".join(f"{wi:.4f}" for wi in w_all),
      "| best single MSE:", f"{min(mse(P_all[:,i], y_true) for i in range(P_all.shape[1])):.6f}",
      "| blended MSE:", f"{mse(val_blend_all, y_true):.6f}")

# Apply to test
subs_all = [pd.read_csv(f)[["row_id","meter_reading"]].sort_values("row_id").reset_index(drop=True)
            for f in ALL_TEST_FILES]
row_ids_all = subs_all[0]["row_id"].to_numpy()
for i, df in enumerate(subs_all[1:], start=2):
    if not np.array_equal(row_ids_all, df["row_id"].to_numpy()):
        raise ValueError(f"Row IDs differ between global test file 1 and {i}.")
stack_all = np.stack([df["meter_reading"].to_numpy() for df in subs_all], axis=1)  # [N,K]
y_global = stack_all @ w_all
y_global = np.clip(y_global, 0.0, None)
final_global = pd.DataFrame({"row_id": row_ids_all, "meter_reading": y_global.astype(np.float32)})
final_global.to_csv(OUT_GLOBAL_BLEND, index=False)
print(f"✅ Wrote {OUT_GLOBAL_BLEND} | rows={len(final_global)}")



[R1] per-seed weights: 0.2855 0.2724 0.1944 0.0711 0.1766 | best single MSE: 0.599243 | blended MSE: 0.599190
[R2] per-seed weights: 0.2572 0.0002 0.2248 0.1263 0.3914 | best single MSE: 0.605412 | blended MSE: 0.603942

[Family→Scalar] alpha(R1)=0.6731, (1-alpha)(R2)=0.3269 | Val MSE=0.597725
✅ Wrote submission_blend_r1r2_family_then_scalar_5seeds.csv | rows=84600

[Global-6] weights: 0.1748 0.1703 0.1256 0.0770 0.1312 0.0673 0.0000 0.0739 0.0000 0.1799 | best single MSE: 0.599243 | blended MSE: 0.597624
✅ Wrote submission_blend_r1r2_global10_5seeds.csv | rows=84600
