# ARIMA Forecast Example (Small Subset)

This notebook demonstrates building a simple next-month stock price forecast using an ARIMA model on a small subset of the large dataset (`data/ret_sample.csv`). We:

- Load a few tickers and dates to keep memory small
- Aggregate to monthly prices for one company
- Fit an ARIMA model
- Forecast next month and visualize


In [1]:
# Imports
import sys
import subprocess
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Ensure statsmodels is available
try:
    import statsmodels.api as sm
    from statsmodels.tsa.arima.model import ARIMA
except Exception:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "statsmodels==0.14.2"], stdout=subprocess.DEVNULL)
    import statsmodels.api as sm
    from statsmodels.tsa.arima.model import ARIMA

pd.set_option("display.max_rows", 10)
pd.set_option("display.width", 120)

project_root = Path(__file__).resolve().parents[2]
data_path = project_root / "ret_sample.csv"
print(f"Using data at: {data_path}")


Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


NameError: name '__file__' is not defined

In [None]:
# Chunked loader: detect schema and sample a few tickers
from typing import List, Optional, Tuple

def detect_columns(cols: List[str]) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
    lower = {c.lower(): c for c in cols}
    # Likely names
    date_candidates = ["date", "datadate", "pricedate", "time", "datetime"]
    id_candidates = ["asset_id", "ticker", "permno", "gvkey", "id", "asset", "symbol"]
    price_candidates = ["price", "prc", "adj_close", "adjclose", "close", "px_last"]
    ret_candidates = ["ret", "return", "stock_ret", "retx", "simple_return"]

    def pick(cands):
        for c in cands:
            if c in lower:
                return lower[c]
        return None

    return pick(date_candidates), pick(id_candidates), pick(price_candidates), pick(ret_candidates)


def read_small_subset(csv_path: Path, max_ids: int = 3, max_rows: int = 200_000) -> pd.DataFrame:
    # Peek to detect schema and sample ids
    head_df = pd.read_csv(csv_path, nrows=2000)
    date_col, id_col, price_col, ret_col = detect_columns(list(head_df.columns))
    if id_col is None or date_col is None:
        raise ValueError("Could not detect required date/id columns. Please rename or provide a smaller sample.")

    # Choose up to max_ids from the head sample
    candidate_ids = head_df[id_col].dropna().astype(str).unique().tolist()[:max_ids]
    print(f"Detected columns -> date: {date_col}, id: {id_col}, price: {price_col}, ret: {ret_col}")
    print(f"Selected {len(candidate_ids)} ids: {candidate_ids}")

    usecols = [date_col, id_col]
    if price_col: usecols.append(price_col)
    if ret_col: usecols.append(ret_col)

    chunks = []
    total = 0
    for chunk in pd.read_csv(csv_path, usecols=usecols, chunksize=200_000):
        mask = chunk[id_col].astype(str).isin(candidate_ids)
        filtered = chunk.loc[mask]
        chunks.append(filtered)
        total += len(filtered)
        if total >= max_rows:
            break

    df = pd.concat(chunks, ignore_index=True)
    # Parse dates and sort
    df[date_col] = pd.to_datetime(df[date_col], errors='coerce')
    df = df.dropna(subset=[date_col, id_col]).sort_values([id_col, date_col]).reset_index(drop=True)
    return df, date_col, id_col, price_col, ret_col

subset_df, DATE_COL, ID_COL, PRICE_COL, RET_COL = read_small_subset(data_path)
subset_df.head()


In [None]:
# Prepare one company monthly series

# Choose the first id with enough observations
counts = subset_df.groupby(ID_COL)[DATE_COL].count().sort_values(ascending=False)
chosen_id = counts.index[0]
print(f"Chosen id: {chosen_id} (n={counts.iloc[0]})")

company_df = subset_df[subset_df[ID_COL].astype(str) == str(chosen_id)].copy()

# Prefer price; if missing, reconstruct from returns assuming base 100
if PRICE_COL and PRICE_COL in company_df.columns:
    company_df = company_df[[DATE_COL, PRICE_COL]].dropna()
    company_df = company_df.rename(columns={PRICE_COL: "price"})
else:
    if RET_COL is None or RET_COL not in company_df.columns:
        raise ValueError("Neither price nor returns found for the selected id.")
    tmp = company_df[[DATE_COL, RET_COL]].dropna().rename(columns={RET_COL: "ret"})
    tmp = tmp.sort_values(DATE_COL)
    tmp["price"] = 100 * (1 + tmp["ret"]).cumprod()
    company_df = tmp[[DATE_COL, "price"]]

company_df = company_df.sort_values(DATE_COL)
company_df = company_df.set_index(DATE_COL)

# Aggregate to month-end price
monthly = company_df.resample("M").last().dropna()
print(monthly.head())
print(monthly.tail())


In [None]:
# Fit ARIMA(1,1,1) and forecast next month

# Check we have enough data
if len(monthly) < 24:
    print(f"Warning: only {len(monthly)} monthly points; ARIMA may be unstable.")

# Build and fit model
series = monthly["price"].asfreq("M")
model = ARIMA(series, order=(1,1,1))
results = model.fit()
print(results.summary())

# Forecast next 1 month
n_steps = 1
forecast_res = results.get_forecast(steps=n_steps)
forecast_mean = forecast_res.predicted_mean
forecast_ci = forecast_res.conf_int(alpha=0.05)

next_month = forecast_mean.index[-1]
print(f"Next month forecast date: {next_month.date()} value: {forecast_mean.iloc[-1]:.4f}")
print(f"95% CI: [{forecast_ci.iloc[-1,0]:.4f}, {forecast_ci.iloc[-1,1]:.4f}]")


In [None]:
# Plot historical and forecast
fig, ax = plt.subplots(figsize=(9, 4))
series.plot(ax=ax, label="History")
forecast_mean.plot(ax=ax, style="r--", label="Forecast")
ax.fill_between(forecast_ci.index, forecast_ci.iloc[:,0], forecast_ci.iloc[:,1], color="r", alpha=0.2, label="95% CI")
ax.set_title(f"ARIMA(1,1,1) forecast for {chosen_id}")
ax.set_xlabel("Date")
ax.set_ylabel("Price")
ax.legend()
plt.tight_layout()
plt.show()
