[Reference](https://medium.com/@kylejones_47003/timesfm-for-time-series-forecasting-in-python-using-oil-production-data-b0a59b89d3ff)

In [1]:
import timesfm
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import TimeSeriesSplit

# Load real oil production data
df = pd.read_csv("north_dakota_production.csv")
# Select top two wells with nonzero oil production
top_wells = df[df["Oil"] > 0].nlargest(2, "Oil")["API_WELLNO"]
df = df[df["API_WELLNO"].isin(top_wells)].rename(columns={"API_WELLNO": "unique_id", "Date": "ds", "Oil": "y"})
df["ds"] = pd.to_datetime(df["ds"])

# Train-test split using TimeSeriesSplit
tscv = TimeSeriesSplit(n_splits=5, test_size=int(0.2 * len(df)))
train_idx, test_idx = list(tscv.split(df))[-1]
train_df, test_df = df.iloc[train_idx], df.iloc[test_idx]

# Initialize TimesFM Model
tfm = timesfm.TimesFm(
    hparams=timesfm.TimesFmHparams(
        per_core_batch_size=32, horizon_len=128, input_patch_len=32, output_patch_len=128,
        num_layers=50, model_dims=1280, use_positional_embedding=False
    ),
    checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-2.0-500m-pytorch"),
)

# Generate forecast
forecast_df = tfm.forecast_on_df(inputs=train_df, freq="M", value_name="y", num_jobs=-1)
forecast_df["ds"] = pd.to_datetime(forecast_df["ds"])
# Aggregate forecast
forecast_df = forecast_df.groupby("ds")["timesfm"].mean().reset_index()

# Restrict forecast to match test period
forecast_df = forecast_df[forecast_df["ds"].between(test_df["ds"].min(), test_df["ds"].max())]

In [2]:
# Plot results
plt.figure(figsize=(12, 6))
plt.plot(df["ds"], df["y"], label="Monthly Oil Production",  color="black", alpha=0.3)
plt.plot(test_df["ds"], test_df["y"], label="Test Data", color="blue")
plt.plot(forecast_df["ds"], forecast_df["timesfm"], label="Forecast",  color="red")

# Save and show
plt.savefig("timesfm_test_forecast_tufte.png", dpi=300)
plt.show()