# Stocks Forecasting — Train + Predict + Plot

This notebook:
1. Loads a `stocks-forecasting` config
2. Trains a TFT model (Darts)
3. Loads the saved bundle
4. Produces log-return quantile forecasts for a selected symbol
5. Back-transforms to price-path quantiles and plots historical close + forecast band

Prereqs:
- The `stocks` harness Postgres is running on `localhost:5432`
- You are using the `repos/stocks-forecasting/.venv` kernel (or any env with `stocks-forecasting` installed)


In [None]:
import os
import json
from pathlib import Path

import numpy as np
import pandas as pd


def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / "pyproject.toml").exists():
            return p
    raise RuntimeError("Could not find repo root (missing pyproject.toml)")


ROOT = find_repo_root(Path.cwd())
print("Repo root:", ROOT)

# Local cache dirs (avoid writing to ~)
os.environ.setdefault("XDG_CACHE_HOME", str(ROOT / ".cache"))
os.environ.setdefault("MPLCONFIGDIR", str(ROOT / ".mplconfig"))
os.environ.setdefault("TORCH_HOME", str(ROOT / ".torch"))

Path(os.environ["XDG_CACHE_HOME"]).mkdir(parents=True, exist_ok=True)
Path(os.environ["MPLCONFIGDIR"]).mkdir(parents=True, exist_ok=True)
Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)

# DB password for local docker-compose (override in your environment if different)
os.environ.setdefault("POSTGRES_PASSWORD", "postgres")

import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (12, 6)


In [None]:
from stocks_forecasting.config.load import load_config
from stocks_forecasting.config.models import RunMode

CONFIG_PATH = ROOT / "configs" / "config.smoke.yaml"
config = load_config(CONFIG_PATH)

# Recommended for "serving-like" training runs:
# config.project.mode = RunMode.production

config.model_dump()

In [None]:
from stocks_forecasting.artifacts import create_bundle_paths
from stocks_forecasting.training.train_tft import train_tft

RUN_TRAINING = True

if RUN_TRAINING:
    bundle_paths = train_tft(config, artifacts_root=ROOT / config.artifacts.root_dir)
else:
    bundle_root = ROOT / config.artifacts.root_dir / config.artifacts.bundle_name
    versions = sorted([p.name for p in bundle_root.iterdir() if p.is_dir()])
    if not versions:
        raise RuntimeError(f"No bundles found under {bundle_root}")
    bundle_paths = create_bundle_paths(
        root_dir=ROOT / config.artifacts.root_dir,
        bundle_name=config.artifacts.bundle_name,
        version=versions[-1],
    )

print("Bundle root:", bundle_paths.root)
bundle_paths

In [None]:
manifest = json.loads(bundle_paths.manifest_path.read_text())
manifest

In [None]:
from darts.models import TFTModel

model_path = bundle_paths.model_dir / "tft_model.pt"
model = TFTModel.load(str(model_path))

print("Loaded model:", model.__class__.__name__)
print("Model path:", model_path)


In [None]:
from darts import TimeSeries

from stocks_forecasting.calendars import build_trading_calendar
from stocks_forecasting.dataset.metadata import add_market_cap_bucket
from stocks_forecasting.dataset.prepare_symbol import build_symbol_feature_frame
from stocks_forecasting.db import PostgresClient

SYMBOL = "AAPL"
PLOT_HISTORY_POINTS = 300
NUM_SAMPLES = 500

client = PostgresClient(config.database)

meta = client.fetch_stock_metadata(symbols=[SYMBOL])
if meta.empty:
    raise ValueError(f"Unknown symbol or missing metadata: {SYMBOL}")
meta = add_market_cap_bucket(meta)

exchange_mic = None
if pd.notna(meta.loc[0, "exchange_mic"]):
    exchange_mic = str(meta.loc[0, "exchange_mic"])

calendar = build_trading_calendar(config.features.calendar, exchange_mic=exchange_mic)

summary = client.fetch_price_summary(SYMBOL, price_type=config.data.price_type)
as_of_date = pd.to_datetime(summary.end_time, utc=True).normalize()
print("as_of_date:", as_of_date.date(), "rows:", summary.rows)

prices = client.fetch_daily_prices(SYMBOL, price_type=config.data.price_type, end_time=as_of_date.to_pydatetime())
prices["time"] = pd.to_datetime(prices["time"], utc=True).dt.normalize()
prices = (
    prices.dropna(subset=["time"])  # defensive
    .sort_values("time")
    .drop_duplicates("time", keep="last")
    .reset_index(drop=True)
)

origin_time = prices["time"].iloc[-1]
origin_close = float(prices["close"].iloc[-1])

built = build_symbol_feature_frame(prices, config, calendar=calendar)
frame = built.frame.sort_values("time").reset_index(drop=True)
frame["step"] = np.arange(len(frame), dtype="int32")

observed = frame[frame["is_future"] == 0].copy()
required = ["log_return", *built.past_covariate_columns]
observed = observed.dropna(subset=required)

target_df = observed[["step", "log_return"]]
past_cov_df = observed[["step", *built.past_covariate_columns]] if built.past_covariate_columns else None
future_cov_df = frame[["step", *built.future_covariate_columns]] if built.future_covariate_columns else None

target_ts = TimeSeries.from_dataframe(target_df, time_col="step", value_cols=["log_return"]).astype(np.float32)
past_cov_ts = (
    TimeSeries.from_dataframe(past_cov_df, time_col="step", value_cols=built.past_covariate_columns).astype(np.float32)
    if past_cov_df is not None
    else None
)
future_cov_ts = (
    TimeSeries.from_dataframe(future_cov_df, time_col="step", value_cols=built.future_covariate_columns).astype(np.float32)
    if future_cov_df is not None
    else None
)

# Attach static covariates aligned to training dummy columns (if present)
static_cols = ["exchange_mic", "sector", "industry", "country_code", "currency", "market_cap_bucket"]
row = meta.loc[0, static_cols].fillna("unknown").astype(str)
static_row_df = pd.DataFrame([{c: row[c] for c in static_cols}])
static_dum = pd.get_dummies(static_row_df[static_cols], prefix=static_cols, dtype="int8")

trained_static_cols = manifest.get("features", {}).get("static_covariates", []) or []
if trained_static_cols:
    static_dum = static_dum.reindex(columns=trained_static_cols, fill_value=0)
    target_ts = target_ts.with_static_covariates(static_dum)
    if past_cov_ts is not None:
        past_cov_ts = past_cov_ts.with_static_covariates(static_dum)
    if future_cov_ts is not None:
        future_cov_ts = future_cov_ts.with_static_covariates(static_dum)

print("Prepared observed points:", len(target_ts))
print("Past covariates:", past_cov_ts is not None, "Future covariates:", future_cov_ts is not None)
print("Origin:", origin_time.date(), "close=", origin_close)


In [None]:
def timeseries_to_frame(ts, index_name: str = "step") -> pd.DataFrame:
    df = ts.to_dataframe()
    out = df.reset_index()
    out = out.rename(columns={out.columns[0]: index_name})
    return out


def forecast_to_quantile_frame(forecast, quantiles: list[float], index_name: str = "step") -> pd.DataFrame:
    out = None
    for q in quantiles:
        q_ts = forecast.quantile(q)
        q_df = timeseries_to_frame(q_ts, index_name=index_name)
        value_cols = [c for c in q_df.columns if c != index_name]
        if len(value_cols) != 1:
            raise ValueError("Expected a single component when extracting quantiles")
        q_df = q_df.rename(columns={value_cols[0]: f"q{q:g}"})
        out = q_df if out is None else out.merge(q_df, on=index_name, how="inner")

    if out is None or out.empty:
        raise ValueError("No quantiles extracted")
    out = out.sort_values(index_name).reset_index(drop=True)
    out["horizon_step"] = np.arange(1, len(out) + 1, dtype="int32")
    return out


horizon = int(config.model.horizon_days)
quantiles = [float(q) for q in config.model.quantiles]

forecast = model.predict(
    n=horizon,
    series=target_ts,
    past_covariates=past_cov_ts,
    future_covariates=future_cov_ts,
    num_samples=NUM_SAMPLES,
)

pred_lr = forecast_to_quantile_frame(forecast, quantiles, index_name="step")
pred_lr = pred_lr.merge(frame[["step", "time", "is_future"]], on="step", how="left")
pred_lr

In [None]:
# Back-transform log-return quantiles -> price-path quantiles
pred_prices = pd.DataFrame({"time": pd.to_datetime(pred_lr["time"], utc=True)})
for q in quantiles:
    lr = pred_lr[f"q{q:g}"].to_numpy(dtype="float64")
    pred_prices[f"q{q:g}"] = origin_close * np.exp(np.cumsum(lr))

pred_prices

In [None]:
# Plot historical close + forecast band
hist = prices.tail(PLOT_HISTORY_POINTS).copy()

q_low = min(quantiles)
q_high = max(quantiles)
q_mid = sorted(quantiles)[len(quantiles) // 2]

fig, ax = plt.subplots()
ax.plot(hist["time"], hist["close"], label="Historical close", color="black", linewidth=1.5)
ax.axvline(origin_time, color="gray", linestyle="--", linewidth=1, label="Forecast origin")

ax.plot(pred_prices["time"], pred_prices[f"q{q_mid:g}"], label=f"Forecast q{q_mid:g}", color="tab:blue")
ax.fill_between(
    pred_prices["time"],
    pred_prices[f"q{q_low:g}"],
    pred_prices[f"q{q_high:g}"],
    color="tab:blue",
    alpha=0.2,
    label=f"Interval q{q_low:g}–q{q_high:g}",
)

ax.set_title(f"{SYMBOL} close price forecast ({horizon} sessions)")
ax.set_ylabel("Close")
ax.grid(True, alpha=0.3)
ax.legend()
plt.show()
