# ChronoTick 2: Granite TTM Fine-Tuning — Larger Context (ctx=1024)

Fine-tune IBM Granite TTM at larger context windows to test whether more
context improves fine-tuned predictions (especially for ares-comp-10).

## Motivation
- Notebook 03a trained at (512, 96), but ZS results show ctx=1024+ is dramatically better
- ZS@512 is broken for TTM (no zero-shot branch), so 03a's comparison was misleading
- ares-comp-10 FT@512 MAE=0.171 vs ZS@1800 MAE=0.054 — needs more context

## Experiments
**Group A — (1024, 96): Larger context, same horizon**
- E4: Univariate FT (common_channel, baseline)
- E5: Channel-mix FT with top-10 SHAP features

**Group B — (1024, 192): Larger context + longer horizon**
- E6: Univariate FT (common_channel, baseline)
- E7: Channel-mix FT with top-10 SHAP features

## Training Mode
Set `TRAINING_MODE` in Cell 3 to control:
- `"combined"`: train on all 4 machines (default)
- `"per_machine"`: train separately per machine

In [None]:
# === Environment Setup ===
import os
import subprocess
import sys

IN_COLAB = "COLAB_GPU" in os.environ or os.path.exists("/content")

if IN_COLAB:
    REPO_DIR = "/content/sensor-collector"
    REPO_URL = "https://github.com/JaimeCernuda/sensor-collector.git"

    # Read GitHub token from Colab secrets (set via sidebar key icon).
    # Required for git push; without it the notebook runs but cannot push.
    GITHUB_TOKEN = None
    try:
        from google.colab import userdata

        GITHUB_TOKEN = userdata.get("GITHUB_TOKEN")
    except Exception:
        print("WARNING: GITHUB_TOKEN secret not available. Git push will be skipped.")
        print("  To enable: run from Colab UI with Secrets > GITHUB_TOKEN set.")

    # Build authenticated URL if token available
    if GITHUB_TOKEN:
        auth_url = (
            f"https://{GITHUB_TOKEN}@github.com/JaimeCernuda/sensor-collector.git"
        )
    else:
        auth_url = REPO_URL

    # Clone or pull latest tick2 code from GitHub
    if os.path.exists(REPO_DIR):
        # Update remote URL in case token was added after initial clone
        subprocess.run(
            ["git", "-C", REPO_DIR, "remote", "set-url", "origin", auth_url], check=True
        )
        # Reset to remote HEAD to avoid divergence from previous Colab commits.
        # Colab outputs are regenerated each run, so local commits are disposable.
        subprocess.run(["git", "-C", REPO_DIR, "fetch", "-q", "origin"], check=True)
        subprocess.run(
            ["git", "-C", REPO_DIR, "reset", "--hard", "origin/main"], check=True
        )
    else:
        subprocess.run(["git", "clone", "-q", auth_url, REPO_DIR], check=True)

    # Configure git identity (Colab has no global config)
    subprocess.run(
        ["git", "-C", REPO_DIR, "config", "user.name", "Colab Runner"], check=True
    )
    subprocess.run(
        ["git", "-C", REPO_DIR, "config", "user.email", "colab@chronotick.dev"],
        check=True,
    )

    # Install tick2 package in editable mode
    subprocess.run(["pip", "install", "-q", "-e", f"{REPO_DIR}/tick2/"], check=True)

    # Ensure tick2 is importable (pip install via subprocess doesn't always
    # update sys.path in the running kernel)
    tick2_src = f"{REPO_DIR}/tick2/src"
    if tick2_src not in sys.path:
        sys.path.insert(0, tick2_src)

    # Always mount Drive — needed for checkpoint persistence (models too large for git)
    from google.colab import drive

    drive.mount("/content/drive")

    # Data: prefer repo copy, fall back to Drive
    DATA_DIR = f"{REPO_DIR}/sensors/data"
    if not os.path.isdir(f"{DATA_DIR}/24h_snapshot"):
        DATA_DIR = "/content/drive/MyDrive/chronotick2/data"

    # Output directory inside the repo (will be git-pushed)
    RESULTS_DIR = f"{REPO_DIR}/tick2/notebooks/output/03"
else:
    GITHUB_TOKEN = None
    DATA_DIR = None
    RESULTS_DIR = os.path.join(
        os.path.dirname(__file__) if "__file__" in dir() else ".", "output", "03"
    )

print(f"Environment: {'Colab' if IN_COLAB else 'Local'}")
print(f"Data dir:    {DATA_DIR or '(default)'}")
print(f"Results dir: {RESULTS_DIR}")

In [None]:
# === Granite TTM Dependencies ===
# granite-tsfm pins torch<2.9; install with --no-deps to keep CUDA torch
import subprocess

if IN_COLAB:
    subprocess.run(
        ["pip", "install", "-q", "granite-tsfm>=0.3.3", "--no-deps"], check=True
    )
    subprocess.run(
        ["pip", "install", "-q", "transformers>=4.56,<5", "datasets", "deprecated"],
        check=True,
    )

# Deep verify
print("granite-tsfm ready")

In [None]:
# === Imports, Config & TRAINING_MODE ===
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch

from tick2.finetuning.base import FineTuneConfig
from tick2.finetuning.data_prep import prepare_datasets
from tick2.finetuning.evaluate import (
    compare_ft_vs_zero_shot,
    evaluate_finetuned,
    load_zero_shot_baselines,
)
from tick2.finetuning.granite_ft import finetune_granite
from tick2.utils.gpu import clear_gpu_memory

sns.set_theme(style="whitegrid", font_scale=1.1)

TRAINING_MODE = "combined"  # "combined" or "per_machine"
DEVICE_OVERRIDE = None
FORCE_RETRAIN = False

device = DEVICE_OVERRIDE or ("cuda" if torch.cuda.is_available() else "cpu")
DEVICE_DIR_MAP = {"cuda": "gpu", "cpu": "cpu"}
device_label = DEVICE_DIR_MAP.get(device, device)

# Data preparation uses max_covariates=30 to load all features; experiments
# select subsets.  Context/prediction lengths are set per experiment group.
config_base = FineTuneConfig(
    context_length=1024,
    prediction_length=96,
    max_covariates=30,
    seed=42,
)

print(f"Device: {device}")
print(f"Training mode: {TRAINING_MODE}")

In [None]:
# === Load Data + Temporal Split ===
data_dir = Path(DATA_DIR) if DATA_DIR else None
prepared = prepare_datasets(config_base, data_dir=data_dir)
for name, p in prepared.items():
    n_train = len(p.split.train)
    n_val = len(p.split.val)
    n_test = len(p.split.test)
    n_feat = len(p.feature_cols)
    print(
        f"  {name:16s}: train={n_train}, val={n_val}, test={n_test}, features={n_feat}"
    )

In [None]:
# === Fine-Tune Granite TTM — ctx=1024 Experiments ===
import os
import subprocess

from tick2.utils.colab import (
    load_checkpoint_from_drive,
    save_checkpoint_to_drive,
    setup_training_log,
)

output_base = Path(RESULTS_DIR)
ft_output_dir = output_base / "granite_ttm_ft" / TRAINING_MODE
device_results_dir = output_base / device_label
device_results_dir.mkdir(parents=True, exist_ok=True)

# Persist training logs to disk (epoch losses, early stopping, errors)
log_path = setup_training_log(ft_output_dir)
print(f"Training log: {log_path}")


def checkpoint_push(label: str) -> None:
    """Git add, commit, and push results after a step completes."""
    if not IN_COLAB:
        return
    try:
        subprocess.run(
            ["git", "-C", REPO_DIR, "add", "tick2/notebooks/output/03/"],
            check=True,
            capture_output=True,
        )
        status = subprocess.run(
            [
                "git",
                "-C",
                REPO_DIR,
                "status",
                "--porcelain",
                "tick2/notebooks/output/03/",
            ],
            capture_output=True,
            text=True,
        )
        if not status.stdout.strip():
            return  # nothing new to commit
        msg = f"results: notebook 03a2 granite-ttm-ft-ctx1024 {label} ({device_label})"
        subprocess.run(
            ["git", "-C", REPO_DIR, "commit", "-m", msg],
            check=True,
            capture_output=True,
        )
        if GITHUB_TOKEN:
            subprocess.run(
                ["git", "-C", REPO_DIR, "fetch", "-q", "origin"],
                capture_output=True,
                timeout=30,
            )
            subprocess.run(
                ["git", "-C", REPO_DIR, "rebase", "origin/main"],
                capture_output=True,
                timeout=30,
            )
            subprocess.run(
                ["git", "-C", REPO_DIR, "push"],
                check=True,
                capture_output=True,
                timeout=60,
            )
            print(f"  [CHECKPOINT] Pushed {label} results")
        else:
            print("  [CHECKPOINT] Committed (no token for push)")
    except Exception as e:
        print(f"  [CHECKPOINT WARNING] {e}")


# --- Experiment definitions ---
# Group A: (1024, 96) — larger context, same horizon as 03a
# Group B: (1024, 192) — larger context + longer horizon (unlocks hz=120)
EXPERIMENTS = [
    # Group A: (1024, 96)
    {
        "name": "E4_ctx1024_uni",
        "context_length": 1024,
        "prediction_length": 96,
        "decoder_mode": "common_channel",
        "max_covariates": 0,
    },
    {
        "name": "E5_ctx1024_mix10",
        "context_length": 1024,
        "prediction_length": 96,
        "decoder_mode": "mix_channel",
        "max_covariates": 10,
    },
    # Group B: (1024, 192)
    {
        "name": "E6_ctx1024_h192_uni",
        "context_length": 1024,
        "prediction_length": 192,
        "decoder_mode": "common_channel",
        "max_covariates": 0,
    },
    {
        "name": "E7_ctx1024_h192_mix10",
        "context_length": 1024,
        "prediction_length": 192,
        "decoder_mode": "mix_channel",
        "max_covariates": 10,
    },
]

all_ft_results = []  # FineTuneResult objects across experiments
experiment_labels = {}  # experiment name -> list of FineTuneResult

for exp in EXPERIMENTS:
    exp_name = exp["name"]
    ctx_len = exp["context_length"]
    pred_len = exp["prediction_length"]
    dec = exp["decoder_mode"]
    mcov = exp["max_covariates"]
    print(f"\n{'=' * 60}")
    print(
        f"  {exp_name}  (ctx={ctx_len}, pred={pred_len}, decoder={dec}, max_cov={mcov})"
    )
    print(f"{'=' * 60}")

    # Check for cached checkpoint (local first, then Drive)
    exp_checkpoint_dir = ft_output_dir / exp_name
    cached_flags = list(exp_checkpoint_dir.glob("*/best/config.json"))

    if not cached_flags and not FORCE_RETRAIN:
        # Try restoring from Drive
        drive_model_name = f"granite_ttm_ft/{TRAINING_MODE}/{exp_name}"
        resumed = load_checkpoint_from_drive(
            model_name=drive_model_name,
            local_path=str(exp_checkpoint_dir),
        )
        if resumed:
            print(f"  [RESUMED] Loaded from Drive: {resumed}")
            cached_flags = list(exp_checkpoint_dir.glob("*/best/config.json"))

    if cached_flags and not FORCE_RETRAIN:
        best_path = cached_flags[0].parent
        print(f"  [CACHED] Checkpoint exists at {best_path}")
        from tick2.finetuning.base import FineTuneResult

        stub = FineTuneResult(
            model_name=f"granite-ttm-ft-{exp_name}",
            machine=TRAINING_MODE,
            checkpoint_path=str(best_path),
            config=exp,
        )
        all_ft_results.append(stub)
        experiment_labels[exp_name] = [stub]
        continue

    # Prepare data with experiment-specific config
    max_cov = (
        exp["max_covariates"]
        if exp["max_covariates"] > 0
        else config_base.max_covariates
    )
    exp_config = FineTuneConfig(
        context_length=ctx_len,
        prediction_length=pred_len,
        max_covariates=max_cov,
        seed=config_base.seed,
    )

    clear_gpu_memory()

    try:
        ft_results = finetune_granite(
            prepared=prepared,
            config=exp_config,
            output_dir=str(exp_checkpoint_dir),
            training_mode=TRAINING_MODE,
            decoder_mode=exp["decoder_mode"],
            freeze_backbone=True,
            learning_rate=0.001,
            num_epochs=150,
            batch_size=64,
            early_stopping_patience=10,
        )

        for r in ft_results:
            r.model_name = f"granite-ttm-ft-{exp_name}"
            print(f"  {r.machine}: {r.training_time_s:.1f}s, best_epoch={r.best_epoch}")
            if r.val_loss:
                print(f"    val_loss: {r.val_loss[r.best_epoch]:.6f}")

        all_ft_results.extend(ft_results)
        experiment_labels[exp_name] = ft_results

        # Save checkpoint to Drive for persistence
        save_checkpoint_to_drive(
            local_path=exp_checkpoint_dir,
            model_name=(f"granite_ttm_ft/{TRAINING_MODE}/{exp_name}"),
        )

        # Checkpoint push after each experiment
        checkpoint_push(exp_name)

    except Exception as e:
        print(f"  [FAIL] {exp_name}: {e}")
        import traceback

        traceback.print_exc()
    finally:
        clear_gpu_memory()

print(f"\n{'=' * 60}")
print(f"  Completed: {list(experiment_labels.keys())}")
print(f"{'=' * 60}")

In [None]:
# === Evaluate FT Models on Test Set ===
from tick2.finetuning.data_prep import combine_training_data
from tick2.finetuning.granite_ft import load_finetuned_granite
from tick2.models.granite import GraniteTTMWrapper

# Compute shared feature intersection once (same as training used)
_, shared_features_all = combine_training_data(prepared)
print(f"Shared features across all machines: {len(shared_features_all)}")

# Evaluation grid per experiment group:
# Group A (pred_len=96):  ctx=[1024], hz=[60, 96]
# Group B (pred_len=192): ctx=[1024], hz=[60, 96, 120, 192]
EVAL_GRIDS = {
    96: {"context_lengths": [1024], "horizons": [60, 96]},
    192: {"context_lengths": [1024], "horizons": [60, 96, 120, 192]},
}

eval_dfs = []

for exp in EXPERIMENTS:
    exp_name = exp["name"]
    ctx_len = exp["context_length"]
    pred_len = exp["prediction_length"]
    print(f"\n--- Evaluating {exp_name} (ctx={ctx_len}, pred={pred_len}) ---")

    # Check for cached evaluation results
    cached_eval_path = (
        device_results_dir / f"granite-ttm-ft-{exp_name}_{TRAINING_MODE}.csv"
    )
    if cached_eval_path.exists() and not FORCE_RETRAIN:
        print(f"  [CACHED] Loading from {cached_eval_path}")
        eval_dfs.append(pd.read_csv(cached_eval_path))
        continue

    # Find the checkpoint
    results_for_exp = experiment_labels.get(exp_name, [])
    if not results_for_exp:
        print(f"  [SKIP] No training results for {exp_name}")
        continue

    checkpoint_path = results_for_exp[0].checkpoint_path
    if not checkpoint_path or not Path(checkpoint_path).exists():
        checkpoint_path = str(ft_output_dir / exp_name / TRAINING_MODE / "best")
        if not Path(checkpoint_path).exists():
            checkpoint_path = str(ft_output_dir / exp_name / "best")

    if not Path(checkpoint_path).exists():
        print(f"  [SKIP] Checkpoint not found for {exp_name}")
        continue

    clear_gpu_memory()

    try:
        # Load fine-tuned model
        ft_model_raw = load_finetuned_granite(
            checkpoint_path,
            context_length=ctx_len,
            prediction_length=pred_len,
        )

        # Create wrapper conforming to ModelWrapper protocol
        wrapper = GraniteTTMWrapper(
            model_name=f"granite-ttm-ft-{exp_name}",
            context_length=ctx_len,
            prediction_length=pred_len,
        )
        # Replace internal model with our fine-tuned one
        wrapper._model = ft_model_raw
        wrapper._device = device
        wrapper._model.to(device)

        # Auto-detect channel count for mix_channel models
        n_channels = getattr(ft_model_raw.config, "num_input_channels", 1)
        wrapper._n_input_channels = n_channels

        # For mix_channel models, use the shared feature
        # intersection (same columns and order as training)
        if n_channels > 1:
            n_cov = n_channels - 1
            eval_features = shared_features_all[:n_cov]
            print(
                f"  Mix-channel: {n_channels} channels (1 target + {n_cov} covariates)"
            )
        else:
            eval_features = None
            print(f"  Univariate: {n_channels} channel")

        # Extract training metadata
        ft_epochs = results_for_exp[0].best_epoch
        ft_time = results_for_exp[0].training_time_s
        ft_machines = results_for_exp[0].machine

        # Use the correct eval grid for prediction_length
        eval_grid = EVAL_GRIDS[pred_len]

        # Config matching this experiment's ctx/pred lengths
        eval_config = FineTuneConfig(
            context_length=ctx_len,
            prediction_length=pred_len,
            max_covariates=config_base.max_covariates,
            seed=config_base.seed,
        )

        eval_df = evaluate_finetuned(
            model=wrapper,
            prepared=prepared,
            config=eval_config,
            training_mode=f"ft_{TRAINING_MODE}",
            ft_epochs=ft_epochs,
            ft_time_s=ft_time,
            ft_train_machines=ft_machines,
            context_lengths=eval_grid["context_lengths"],
            horizons=eval_grid["horizons"],
            n_samples=25,
            progress=True,
            shared_feature_cols=eval_features,
        )

        if not eval_df.empty:
            eval_df["experiment"] = exp_name
            eval_df.to_csv(cached_eval_path, index=False)
            eval_dfs.append(eval_df)
            print(f"  Mean MAE: {eval_df['mae'].mean():.4f}")
            print(f"  Saved: {cached_eval_path}")
        else:
            print(f"  [WARN] No eval results for {exp_name}")

        checkpoint_push(f"eval-{exp_name}")

    except Exception as e:
        print(f"  [FAIL] Evaluation {exp_name}: {e}")
        import traceback

        traceback.print_exc()
    finally:
        del wrapper, ft_model_raw
        clear_gpu_memory()

# Combine all evaluation results
if eval_dfs:
    ft_results_df = pd.concat(eval_dfs, ignore_index=True)
    print(f"\nTotal FT eval rows: {len(ft_results_df)}")
    display(ft_results_df)
else:
    ft_results_df = pd.DataFrame()
    print("No evaluation results collected.")

In [None]:
# === Load Baselines: ZS from Notebook 02 + FT@512 from Notebook 03a ===
zs_dir = Path(RESULTS_DIR).parent / "02"
zs_results_raw = load_zero_shot_baselines(zs_dir, model_name="granite-ttm")

# Filter to valid ZS contexts (no ZS branch at ctx=512 for 96-step horizon)
if not zs_results_raw.empty:
    zs_results = zs_results_raw[zs_results_raw["context_length"] >= 1024].copy()
    print(f"Zero-shot baselines: {len(zs_results)} rows (ctx>=1024)")
    print(f"  Contexts: {sorted(zs_results['context_length'].unique().tolist())}")
else:
    zs_results = pd.DataFrame()
    print("No zero-shot baselines found.")

# Load FT@512 results from notebook 03a for cross-context comparison
ft512_dfs = []
ft512_dir = Path(RESULTS_DIR) / device_label
for exp_name in ["E1_univariate", "E2_mix10", "E3_mix30"]:
    csv_path = ft512_dir / f"granite-ttm-ft-{exp_name}_{TRAINING_MODE}.csv"
    if csv_path.exists():
        df = pd.read_csv(csv_path)
        df["experiment"] = exp_name
        ft512_dfs.append(df)

if ft512_dfs:
    ft512_results = pd.concat(ft512_dfs, ignore_index=True)
    print(f"\nFT@512 baselines (from 03a): {len(ft512_results)} rows")
    print(f"  Experiments: {ft512_results['experiment'].unique().tolist()}")
else:
    ft512_results = pd.DataFrame()
    print("\nNo FT@512 baselines found (run notebook 03a first).")

In [None]:
# === Comparison Tables ===

# --- Table 1: FT@1024 vs ZS@1024 (apples-to-apples) ---
if not ft_results_df.empty and not zs_results.empty:
    zs_at_1024 = zs_results[zs_results["context_length"] == 1024]

    if not zs_at_1024.empty:
        print("=" * 60)
        print("Table 1: FT@1024 vs ZS@1024 (apples-to-apples)")
        print("=" * 60)

        summary_rows = []
        for machine in ft_results_df["machine"].unique():
            zs_m = zs_at_1024[zs_at_1024["machine"] == machine]
            if zs_m.empty:
                continue
            zs_mae = zs_m["mae"].mean()

            for exp in EXPERIMENTS:
                exp_name = exp["name"]
                ft_mask = (
                    ft_results_df["model"].str.contains(exp_name, na=False)
                    & (ft_results_df["machine"] == machine)
                    & (ft_results_df["context_length"] == 1024)
                )
                if not ft_mask.any():
                    continue
                ft_mae = ft_results_df.loc[ft_mask, "mae"].mean()

                if zs_mae > 0:
                    imp = (zs_mae - ft_mae) / zs_mae * 100
                    summary_rows.append(
                        {
                            "machine": machine,
                            "experiment": exp_name,
                            "ft_mae": ft_mae,
                            "zs_1024_mae": zs_mae,
                            "vs_zs1024_pct": imp,
                        }
                    )

        if summary_rows:
            display(pd.DataFrame(summary_rows).round(4))
        else:
            print("  No overlapping machines.")
    else:
        print("No ZS@1024 baselines available.")

# --- Table 2: FT@1024 vs Best ZS (context gap analysis) ---
if not ft_results_df.empty and not zs_results.empty:
    best_zs_per_machine = (
        zs_results.groupby("machine")["mae"]
        .agg(["min", "idxmin"])
        .rename(columns={"min": "best_zs_mae"})
    )
    best_zs_per_machine["best_zs_ctx"] = zs_results.loc[
        best_zs_per_machine["idxmin"], "context_length"
    ].values
    best_zs_per_machine = best_zs_per_machine.drop(columns=["idxmin"])

    print(f"\n{'=' * 60}")
    print("Table 2: FT@1024 vs Best ZS (context gap)")
    print(f"{'=' * 60}")

    summary_rows = []
    for machine in ft_results_df["machine"].unique():
        if machine not in best_zs_per_machine.index:
            continue
        bzs_mae = best_zs_per_machine.loc[machine, "best_zs_mae"]
        bzs_ctx = int(best_zs_per_machine.loc[machine, "best_zs_ctx"])

        for exp in EXPERIMENTS:
            exp_name = exp["name"]
            ft_mask = ft_results_df["model"].str.contains(exp_name, na=False) & (
                ft_results_df["machine"] == machine
            )
            if not ft_mask.any():
                continue
            ft_mae = ft_results_df.loc[ft_mask, "mae"].mean()

            if bzs_mae > 0:
                imp = (bzs_mae - ft_mae) / bzs_mae * 100
                summary_rows.append(
                    {
                        "machine": machine,
                        "experiment": exp_name,
                        "ft_ctx": 1024,
                        "ft_mae": ft_mae,
                        "best_zs_ctx": bzs_ctx,
                        "best_zs_mae": bzs_mae,
                        "vs_best_zs_pct": imp,
                    }
                )

    if summary_rows:
        display(pd.DataFrame(summary_rows).round(4))
    else:
        print("  No overlapping machines.")

# --- Table 3: FT@512 vs FT@1024 — more context? ---
if not ft_results_df.empty and not ft512_results.empty:
    print(f"\n{'=' * 60}")
    print("Table 3: FT@512 vs FT@1024 — more context?")
    print(f"{'=' * 60}")

    ft512_to_ft1024 = {
        "E1_univariate": "E4_ctx1024_uni",
        "E2_mix10": "E5_ctx1024_mix10",
    }

    summary_rows = []
    for machine in ft_results_df["machine"].unique():
        for ft512_exp, ft1024_exp in ft512_to_ft1024.items():
            ft512_mask = ft512_results["model"].str.contains(ft512_exp, na=False) & (
                ft512_results["machine"] == machine
            )
            ft1024_mask = (
                ft_results_df["model"].str.contains(ft1024_exp, na=False)
                & (ft_results_df["machine"] == machine)
                & (ft_results_df["context_length"] == 1024)
            )

            if not ft512_mask.any() or not ft1024_mask.any():
                continue
            ft512_mae = ft512_results.loc[ft512_mask, "mae"].mean()
            ft1024_mae = ft_results_df.loc[ft1024_mask, "mae"].mean()

            if ft512_mae > 0:
                imp = (ft512_mae - ft1024_mae) / ft512_mae * 100
                dec = "uni" if "uni" in ft1024_exp else "mix10"
                summary_rows.append(
                    {
                        "machine": machine,
                        "decoder": dec,
                        "ft512_exp": ft512_exp,
                        "ft512_mae": ft512_mae,
                        "ft1024_exp": ft1024_exp,
                        "ft1024_mae": ft1024_mae,
                        "ctx_improvement_pct": imp,
                    }
                )

    if summary_rows:
        ctx_df = pd.DataFrame(summary_rows)
        display(ctx_df.round(4))

        print("\nPer-decoder summary:")
        for dec in ctx_df["decoder"].unique():
            dec_data = ctx_df[ctx_df["decoder"] == dec]
            mean_imp = dec_data["ctx_improvement_pct"].mean()
            print(f"  {dec}: mean improvement = {mean_imp:+.1f}%")
    else:
        print("  No comparable experiments.")

if ft_results_df.empty:
    print("No FT@1024 results to compare.")

In [None]:
# === Visualizations ===
results_dir = Path(RESULTS_DIR)
results_dir.mkdir(parents=True, exist_ok=True)

# --- 1. MAE Comparison Bar Chart ---
if not ft_results_df.empty:
    fig, ax = plt.subplots(figsize=(14, 5))

    plot_rows = []

    # Add ZS@1024 baseline
    if not zs_results.empty:
        zs_1024 = zs_results[zs_results["context_length"] == 1024]
        for machine in zs_1024["machine"].unique():
            m_zs = zs_1024[zs_1024["machine"] == machine]
            plot_rows.append(
                {
                    "machine": machine,
                    "variant": "ZS@1024",
                    "mae": m_zs["mae"].mean(),
                }
            )

    # Add FT@512 baselines (from 03a)
    if not ft512_results.empty:
        for exp_name in ["E2_mix10"]:
            exp_data = ft512_results[ft512_results["experiment"] == exp_name]
            for machine in exp_data["machine"].unique():
                m_ft = exp_data[exp_data["machine"] == machine]
                plot_rows.append(
                    {
                        "machine": machine,
                        "variant": "FT@512 (E2)",
                        "mae": m_ft["mae"].mean(),
                    }
                )

    # Add FT@1024 experiments
    for exp in EXPERIMENTS:
        exp_name = exp["name"]
        exp_data = ft_results_df[
            ft_results_df["model"].str.contains(exp_name, na=False)
            & (ft_results_df["context_length"] == 1024)
        ]
        for machine in exp_data["machine"].unique():
            m_ft = exp_data[exp_data["machine"] == machine]
            plot_rows.append(
                {
                    "machine": machine,
                    "variant": exp_name,
                    "mae": m_ft["mae"].mean(),
                }
            )

    if plot_rows:
        plot_df = pd.DataFrame(plot_rows)
        sns.barplot(
            data=plot_df,
            x="machine",
            y="mae",
            hue="variant",
            ax=ax,
        )
        ax.set_ylabel("MAE (ppm)")
        ax.set_title("Granite TTM ctx=1024: FT vs ZS MAE by Machine")
        ax.legend(
            title="Variant",
            bbox_to_anchor=(1.05, 1),
            loc="upper left",
        )
        plt.tight_layout()
        fig.savefig(
            results_dir / "ft_ctx1024_vs_zs_mae_comparison.png",
            dpi=150,
            bbox_inches="tight",
        )
        plt.show()
    else:
        plt.close(fig)
        print("No data for MAE comparison plot.")

# --- 2. Training Loss Curves ---
if all_ft_results:
    n_exp = len(EXPERIMENTS)
    fig, axes = plt.subplots(
        1,
        n_exp,
        figsize=(5 * n_exp, 4),
        squeeze=False,
    )

    for idx, exp in enumerate(EXPERIMENTS):
        ax = axes[0, idx]
        exp_name = exp["name"]
        results_for_exp = experiment_labels.get(exp_name, [])

        has_data = False
        for r in results_for_exp:
            if r.train_loss:
                ax.plot(r.train_loss, label="Train", alpha=0.7)
                has_data = True
            if r.val_loss:
                ax.plot(
                    r.val_loss,
                    label="Validation",
                    alpha=0.7,
                )
                ax.axvline(
                    r.best_epoch,
                    color="red",
                    linestyle="--",
                    alpha=0.5,
                    label=f"Best (epoch {r.best_epoch})",
                )
                has_data = True

        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.set_title(exp_name)
        if has_data:
            ax.legend(fontsize=8)
        else:
            ax.text(
                0.5,
                0.5,
                "(cached, no loss history)",
                transform=ax.transAxes,
                ha="center",
            )

    plt.suptitle(
        "Granite TTM ctx=1024 FT Loss Curves",
        fontsize=14,
    )
    plt.tight_layout()
    fig.savefig(
        results_dir / "ft_ctx1024_training_loss_curves.png",
        dpi=150,
        bbox_inches="tight",
    )
    plt.show()

print(f"Saved figures to: {results_dir}")

In [None]:
# === Export Results CSV + LaTeX ===
from tick2.benchmark.reporting import results_to_latex, save_results

results_dir = Path(RESULTS_DIR)
device_results_dir = results_dir / device_label
device_results_dir.mkdir(parents=True, exist_ok=True)

# Save FT@1024 results
if not ft_results_df.empty:
    ft_csv = device_results_dir / f"granite-ttm-ft-ctx1024_{TRAINING_MODE}.csv"
    ft_results_df.to_csv(ft_csv, index=False)
    print(f"FT@1024 results CSV: {ft_csv}")

# Save combined comparison with ZS
if not ft_results_df.empty and not zs_results.empty:
    comparison = compare_ft_vs_zero_shot(
        ft_results_df,
        zs_results,
    )
    csv_path, latex_path = save_results(
        comparison,
        results_dir,
        prefix="granite_ttm_ft_ctx1024_comparison",
    )
    print(f"\nComparison CSV:   {csv_path}")
    print(f"Comparison LaTeX: {latex_path}")
    latex = results_to_latex(
        comparison,
        caption="Granite TTM FT (ctx=1024) vs zero-shot",
        label="tab:granite-ft-ctx1024",
    )
    print(f"\n{latex}")
elif not ft_results_df.empty:
    csv_path, latex_path = save_results(
        ft_results_df,
        results_dir,
        prefix="granite_ttm_ft_ctx1024",
    )
    print(f"FT-only CSV:   {csv_path}")
    print(f"FT-only LaTeX: {latex_path}")
    latex = results_to_latex(
        ft_results_df,
        caption="Granite TTM FT (ctx=1024) results",
        label="tab:granite-ft-ctx1024",
    )
    print(f"\n{latex}")
else:
    print("No results to export.")

# Save training metadata
if all_ft_results:
    meta_rows = []
    for r in all_ft_results:
        meta_rows.append(
            {
                "model": r.model_name,
                "machine": r.machine,
                "training_time_s": r.training_time_s,
                "best_epoch": r.best_epoch,
                "checkpoint_path": r.checkpoint_path,
                **r.config,
            }
        )
    meta_df = pd.DataFrame(meta_rows)
    meta_path = results_dir / f"granite_ttm_ft_ctx1024_meta_{TRAINING_MODE}.csv"
    meta_df.to_csv(meta_path, index=False)
    print(f"\nTraining metadata: {meta_path}")
    display(meta_df)

In [None]:
# === Save & Push Results ===
if IN_COLAB:
    os.chdir(REPO_DIR)

    subprocess.run(
        ["git", "add", "tick2/notebooks/output/03/"],
        check=True,
    )

    status = subprocess.run(
        ["git", "status", "--porcelain", "tick2/notebooks/output/03/"],
        capture_output=True,
        text=True,
    )
    if status.stdout.strip():
        msg = (
            "results: notebook 03a2 granite-ttm-ft-ctx1024"
            f" figures and combined results ({device_label})"
        )
        subprocess.run(
            ["git", "commit", "-m", msg],
            check=True,
        )
        if GITHUB_TOKEN:
            subprocess.run(
                ["git", "fetch", "-q", "origin"],
                capture_output=True,
                timeout=30,
            )
            subprocess.run(
                ["git", "rebase", "origin/main"],
                capture_output=True,
                timeout=30,
            )
            subprocess.run(["git", "push"], check=True)
            print("Pushed final outputs to GitHub.")
        else:
            print("Committed locally (no token for push).")
            print("Set Colab sidebar > Secrets > GITHUB_TOKEN")
    else:
        print("No new outputs to commit.")
else:
    print(f"Local run. Outputs saved to: {results_dir}")
    print(
        "Run 'git add tick2/notebooks/output/03/ && git commit && git push' to share."
    )