## To be used in google colab

In [None]:
# ========================================================
# Chronos-Bolt-small Target-only Block (leakage-safe, h=1)
# - Modell: amazon/chronos-bolt-small  (Zero/Few-Shot Quantile-API)
# - Warm-up = 60
# - Kontext: CPU-Tensor (vermeidet device-mismatch)
# - Features je t: (mean, p50, std≈(p90-p10)/(2*Phi^-1(0.9)), p10, p90)
#   -> Index = t; Vorhersage gilt für t+1
# ========================================================

!pip -q install -U chronos-forecasting pandas pyarrow

# --- Params ---
TARGET_PATH     = "target.csv"
INDEX_COL       = "date"
Y_COL           = "IP_change"
OUTPUT_PATH     = "chronos_bolt.parquet"

MODEL_ID        = "amazon/chronos-bolt-small"
USE_GPU         = True
INITIAL_WINDOW  = 60
SEED            = 42

# --- Imports & Seeds ---
import numpy as np, pandas as pd, torch
from pathlib import Path
torch.manual_seed(SEED); np.random.seed(SEED)

# --- Data utils ---
def load_y(path, y_col, index_col="date"):
    path = str(path)
    if path.endswith(".csv"):
        obj = pd.read_csv(path, index_col=index_col, parse_dates=True)
    elif path.endswith(".parquet"):
        obj = pd.read_parquet(path)
        if index_col in obj.columns:
            obj[index_col] = pd.to_datetime(obj[index_col])
            obj = obj.set_index(index_col)
    else:
        raise ValueError("Use .csv or .parquet")
    s = obj[y_col] if isinstance(obj, pd.DataFrame) and y_col in obj.columns else obj.iloc[:, 0]
    if not isinstance(s.index, pd.DatetimeIndex):
        raise ValueError("Target needs DatetimeIndex")
    return s.astype(float)

# --- Load target ---
y   = load_y(TARGET_PATH, Y_COL, INDEX_COL)
idx = y.index
m   = len(y)

# --- Chronos-Bolt Pipeline (Quantile-API) ---
from chronos import BaseChronosPipeline
DEVICE = "cuda" if (USE_GPU and torch.cuda.is_available()) else "cpu"
print(f"[Chronos-Bolt] model device: {DEVICE}")

pipe = BaseChronosPipeline.from_pretrained(
    MODEL_ID,
    device_map=DEVICE,          # Modell auf GPU/CPU
    torch_dtype=torch.float32   # offizieller Paramname laut README
)

# --- Rolling 1-step quantiles ---
q_levels = [0.10, 0.50, 0.90]
ZP = 2.0 * 1.2815515655446004   # 2*Phi^-1(0.9) für std-Approx aus (p90-p10)

rows = []
with torch.inference_mode():
    for t in range(INITIAL_WINDOW - 1, m - 1):
        # Kontext nur bis inkl. t (kausal); bewusst CPU-Tensor
        ctx_np = y.iloc[:t+1].values.astype("float32", copy=False)
        ctx    = torch.tensor(ctx_np, dtype=torch.float32)

        # Laut offizieller Doku liefert predict_quantiles: (quantiles, mean)
        quantiles, mean = pipe.predict_quantiles(
            context=ctx,
            prediction_length=1,
            quantile_levels=q_levels
        )

        # -> Shapes: quantiles [B, H, Q], mean [B, H]; hier B=1, H=1
        q = quantiles.detach().cpu().numpy().squeeze()
        mu_arr = mean.detach().cpu().numpy().squeeze()

        # robuste Entzerrung
        if q.ndim == 1:
            p10, p50, p90 = map(float, q[:3])
        elif q.ndim == 2:
            # (H, Q) oder (Q, H) - bevorzugt [H,Q]; wir greifen defensiv zu
            if q.shape[-1] >= 3:
                p10, p50, p90 = map(float, q[0, :3])
            else:
                p10, p50, p90 = map(float, q[:3, 0])
        else:
            # generischer Fallback
            flat = q.reshape(-1)
            p10, p50, p90 = map(float, flat[:3])

        mu  = float(mu_arr.item() if np.ndim(mu_arr) else mu_arr)
        std = float(max(p90 - p10, 0.0) / ZP)

        ts = idx[t]  # Feature-Zeitstempel t (Forecast für t+1)
        rows.append((ts, mu, p50, std, p10, p90))

        if len(rows) % 50 == 0:
            print(f"[Chronos-Bolt] step {len(rows)} @ t={t} -> {ts}")

# --- Frame & Save ---
DF = pd.DataFrame(
    rows,
    columns=["date","chronos_mean","chronos_p50","chronos_std","chronos_p10","chronos_p90"]
).set_index("date")

DF = DF.reindex(idx).astype("float32")  # NaNs bis Warm-up erwartbar
Path(OUTPUT_PATH).parent.mkdir(parents=True, exist_ok=True)
DF.to_parquet(OUTPUT_PATH)
print("[Chronos-Bolt] wrote:", OUTPUT_PATH, DF.shape)