# **Finetuning TimesFM 2.0 on Stock Data**
**IMPORTANT:** Run this notebook using Vertex AI Workbench 

This notebook demonstrates how to finetune a [TimesFM](https://github.com/google-research/google-research/tree/master/times_fm) model (or a variant) on a custom time series dataset. We'll specifically show how to:

1. Set up dependencies and prerequisites.
2. Define a flexible framework for finetuning TimesFM models on your own time-series data.
3. Download stock data (e.g., AAPL) with [yfinance](https://pypi.org/project/yfinance/).
4. Optionally log training progress with [Weights & Biases](https://wandb.ai/site) (W&B).
5. Visualize model predictions vs. ground truth.

We'll use a simplified example for a single GPU, but the framework includes the capacity for distributed training if desired.

## **1. Prerequisites**

Before running this notebook, ensure you have the following libraries installed:

- `torch` (PyTorch)
- `timesfm` (version 1.2.6 or later)
- `yfinance` (for data fetching)
- `wandb` (optional, for logging)

Below are example commands to install these packages. Uncomment and run if you haven't installed them in your environment.

> **Note:** If you already have these packages installed or prefer alternative versions, skip these cells or adjust as needed.


In [1]:
%%capture 

!pip install torch
!pip install timesfm[torch]
!pip install timesfm==1.2.6
!pip install yfinance
!pip install wandb

### **Weights & Biases Setup**
If you intend to log metrics to W&B, you can store your access token in an environment variable or specify it directly. In practice, you might run something like:

```
import os
os.environ['WANDB_API_KEY'] = 'YOUR_WANDB_ACCESS_TOKEN'
```

You can also enter your W&B credentials when prompted.


## **2. Imports**

In [2]:
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset

import yfinance as yf
import numpy as np
import pandas as pd
import wandb

from huggingface_hub import snapshot_download
from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams
from timesfm.pytorch_patched_decoder import PatchedTimeSeriesDecoder


TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.
Loaded PyTorch TimesFM.


In [3]:
import os 

os.environ['WANDB_API_KEY'] = 'xxxxxxxxxx'
os.environ["WANDB_NOTEBOOK_NAME"] = "./TimesFM-Examples/src/tune/02-finetune.ipynb"

wandb.init(project="TimesFM-Examples")

[34m[1mwandb[0m: Currently logged in as: [33mshankar-arunp[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


## **3. Finetuning Framework**

Below is a flexible framework for training or finetuning a TimesFM model on custom time-series data. The code supports single- and multi-GPU (distributed) training.


In [4]:
class MetricsLogger(ABC):
    """Abstract base class for logging metrics during training."""

    @abstractmethod
    def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
        """Log metrics to the specified backend."""
        pass

    @abstractmethod
    def close(self) -> None:
        """Clean up any resources used by the logger."""
        pass

class WandBLogger(MetricsLogger):
    """Weights & Biases implementation of metrics logging."""

    def __init__(self, project: str, config: Dict[str, Any], rank: int = 0):
        self.rank = rank
        if rank == 0:
            wandb.init(project=project, config=config)

    def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
        if self.rank == 0:
            wandb.log(metrics, step=step)

    def close(self) -> None:
        if self.rank == 0:
            wandb.finish()

class DistributedManager:
    """Manages distributed training setup and cleanup."""

    def __init__(
        self,
        world_size: int,
        rank: int,
        master_addr: str = "localhost",
        master_port: str = "12358",
        backend: str = "nccl",
    ):
        self.world_size = world_size
        self.rank = rank
        self.master_addr = master_addr
        self.master_port = master_port
        self.backend = backend

    def setup(self) -> None:
        os.environ["MASTER_ADDR"] = self.master_addr
        os.environ["MASTER_PORT"] = self.master_port

        if not dist.is_initialized():
            dist.init_process_group(backend=self.backend, world_size=self.world_size, rank=self.rank)

    def cleanup(self) -> None:
        if dist.is_initialized():
            dist.destroy_process_group()

@dataclass
class FinetuningConfig:
    """Configuration for model training."""
    batch_size: int = 32
    num_epochs: int = 20
    learning_rate: float = 1e-4
    weight_decay: float = 0.01
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    distributed: bool = False
    gpu_ids: List[int] = field(default_factory=lambda: [0])
    master_port: str = "12358"
    master_addr: str = "localhost"
    use_wandb: bool = False
    wandb_project: str = "timesfm-finetuning"

class TimesFMFinetuner:
    """Handles model training and validation for TimesFM."""

    def __init__(
        self,
        model: nn.Module,
        config: FinetuningConfig,
        rank: int = 0,
        loss_fn: Optional[Callable] = None,
        logger: Optional[logging.Logger] = None,
    ):
        self.model = model
        self.config = config
        self.rank = rank
        self.logger = logger or logging.getLogger(__name__)
        self.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
        self.loss_fn = loss_fn or (lambda x, y: torch.mean((x - y.squeeze(-1)) ** 2))

        if config.use_wandb:
            self.metrics_logger = WandBLogger(config.wandb_project, config.__dict__, rank)

        if config.distributed:
            self.dist_manager = DistributedManager(
                world_size=len(config.gpu_ids),
                rank=rank,
                master_addr=config.master_addr,
                master_port=config.master_port,
            )
            self.dist_manager.setup()
            self.model = self._setup_distributed_model()

    def _setup_distributed_model(self) -> nn.Module:
        self.model = self.model.to(self.device)
        return DDP(
            self.model,
            device_ids=[self.config.gpu_ids[self.rank]],
            output_device=self.config.gpu_ids[self.rank]
        )

    def _create_dataloader(self, dataset: Dataset, is_train: bool) -> DataLoader:
        if self.config.distributed:
            sampler = torch.utils.data.distributed.DistributedSampler(
                dataset,
                num_replicas=len(self.config.gpu_ids),
                rank=dist.get_rank(),
                shuffle=is_train
            )
        else:
            sampler = None

        return DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            shuffle=(is_train and not self.config.distributed),
            sampler=sampler,
        )

    def _process_batch(self, batch: List[torch.Tensor]) -> tuple:
        x_context, x_padding, freq, x_future = [t.to(self.device, non_blocking=True) for t in batch]

        predictions = self.model(x_context, x_padding.float(), freq)
        predictions_mean = predictions[..., 0]
        last_patch_pred = predictions_mean[:, -1, :]

        loss = self.loss_fn(last_patch_pred, x_future.squeeze(-1))
        return loss, predictions

    def _train_epoch(self, train_loader: DataLoader, optimizer: torch.optim.Optimizer) -> float:
        self.model.train()
        total_loss = 0.0

        for batch in train_loader:
            loss, _ = self._process_batch(batch)

            if self.config.distributed:
                losses = [torch.zeros_like(loss) for _ in range(dist.get_world_size())]
                dist.all_gather(losses, loss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        return total_loss / len(train_loader)

    def _validate(self, val_loader: DataLoader) -> float:
        self.model.eval()
        total_loss = 0.0

        with torch.no_grad():
            for batch in val_loader:
                loss, _ = self._process_batch(batch)

                if self.config.distributed:
                    losses = [torch.zeros_like(loss) for _ in range(dist.get_world_size())]
                    dist.all_gather(losses, loss)

                total_loss += loss.item()

        return total_loss / len(val_loader)

    def finetune(self, train_dataset: Dataset, val_dataset: Dataset) -> Dict[str, Any]:
        self.model = self.model.to(self.device)
        train_loader = self._create_dataloader(train_dataset, is_train=True)
        val_loader = self._create_dataloader(val_dataset, is_train=False)

        optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )

        history = {"train_loss": [], "val_loss": [], "learning_rate": []}

        self.logger.info(f"Starting training for {self.config.num_epochs} epochs...")
        self.logger.info(f"Training samples: {len(train_dataset)}")
        self.logger.info(f"Validation samples: {len(val_dataset)}")

        try:
            for epoch in range(self.config.num_epochs):
                train_loss = self._train_epoch(train_loader, optimizer)
                val_loss = self._validate(val_loader)
                current_lr = optimizer.param_groups[0]["lr"]

                metrics = {
                    "train_loss": train_loss,
                    "val_loss": val_loss,
                    "learning_rate": current_lr,
                    "epoch": epoch + 1,
                }

                if self.config.use_wandb:
                    self.metrics_logger.log_metrics(metrics)

                history["train_loss"].append(train_loss)
                history["val_loss"].append(val_loss)
                history["learning_rate"].append(current_lr)

                if self.rank == 0:
                    self.logger.info(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        except KeyboardInterrupt:
            self.logger.info("Training interrupted by user")

        if self.config.distributed:
            self.dist_manager.cleanup()

        if self.config.use_wandb:
            self.metrics_logger.close()

        return {"history": history}


## **4. Dataset Preparation**

We define a simple `TimeSeriesDataset` that creates sliding-window samples from a single time series. This structure can be adapted for more complex multi-variate datasets.


In [5]:
class TimeSeriesDataset(Dataset):
    """Dataset for time series data compatible with TimesFM."""

    def __init__(self, series: np.ndarray, context_length: int, horizon_length: int):
        """
        Args:
            series: Time series data.
            context_length: Number of past timesteps to use as input.
            horizon_length: Number of future timesteps to predict.
        """
        self.series = series
        self.context_length = context_length
        self.horizon_length = horizon_length
        self._prepare_samples()

    def _prepare_samples(self) -> None:
        self.samples = []
        total_length = self.context_length + self.horizon_length

        for start_idx in range(0, len(self.series) - total_length + 1):
            end_idx = start_idx + self.context_length
            x_context = self.series[start_idx:end_idx]
            x_future = self.series[end_idx : end_idx + self.horizon_length]
            self.samples.append((x_context, x_future))

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        x_context, x_future = self.samples[index]
        x_context = torch.tensor(x_context, dtype=torch.float32)
        x_future = torch.tensor(x_future, dtype=torch.float32)

        # TimesFM expects a certain input format with freq, etc.
        input_padding = torch.zeros_like(x_context)
        freq = torch.zeros(1, dtype=torch.long)

        return x_context, input_padding, freq, x_future

def prepare_datasets(
    series: np.ndarray,
    context_length: int,
    horizon_length: int,
    train_split: float = 0.8
) -> Tuple[Dataset, Dataset]:
    """
    Prepare training and validation datasets from time series data.

    Args:
        series: Input time series data.
        context_length: Number of past timesteps to use.
        horizon_length: Number of future timesteps to predict.
        train_split: Fraction of data to use for training.

    Returns:
        Tuple of (train_dataset, val_dataset).
    """
    train_size = int(len(series) * train_split)
    train_data = series[:train_size]
    val_data = series[train_size:]

    train_dataset = TimeSeriesDataset(train_data, context_length=context_length, horizon_length=horizon_length)
    val_dataset = TimeSeriesDataset(val_data, context_length=context_length, horizon_length=horizon_length)
    return train_dataset, val_dataset


## **5. Model Creation**

We'll define a helper function that downloads the official `google/timesfm-2.0-500m-pytorch` checkpoint (using [Hugging Face Hub](https://huggingface.co/)) and constructs a [PatchedTimeSeriesDecoder](https://github.com/google-research/google-research/tree/master/times_fm) model. 

This function allows you to toggle whether to load weights from the checkpoint. By default, it loads them.

In [6]:
from os import path

def get_model(load_weights: bool = True):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    repo_id = "google/timesfm-2.0-500m-pytorch"
    hparams = TimesFmHparams(
        backend=device,
        per_core_batch_size=32,
        horizon_len=128,
        num_layers=50,
        use_positional_embedding=False,
        context_len=192,
    )

    tfm = TimesFm(hparams=hparams, checkpoint=TimesFmCheckpoint(huggingface_repo_id=repo_id))
    model = PatchedTimeSeriesDecoder(tfm._model_config)

    if load_weights:
        checkpoint_path = path.join(snapshot_download(repo_id), "torch_model.ckpt")
        loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)
        model.load_state_dict(loaded_checkpoint)

    return model, hparams, tfm._model_config


## **6. Utility Functions for Plotting & Data**

Below are:
1. A utility function to plot predictions vs. ground truth.
2. A function `get_data` which downloads Apple (AAPL) stock prices from `yfinance`, then prepares training and validation sets.

In [7]:
import matplotlib.pyplot as plt

def plot_predictions(
    model: nn.Module,
    val_dataset: Dataset,
    save_path: Optional[str] = "predictions.png",
) -> None:
    """
    Plot model predictions against ground truth for a single batch of validation data.
    """
    model.eval()

    # Take the first sample in the validation dataset
    x_context, x_padding, freq, x_future = val_dataset[0]
    x_context = x_context.unsqueeze(0)
    x_padding = x_padding.unsqueeze(0)
    freq = freq.unsqueeze(0)
    x_future = x_future.unsqueeze(0)

    device = next(model.parameters()).device
    x_context = x_context.to(device)
    x_padding = x_padding.to(device)
    freq = freq.to(device)
    x_future = x_future.to(device)

    with torch.no_grad():
        predictions = model(x_context, x_padding.float(), freq)
        predictions_mean = predictions[..., 0]  # shape [B, N, horizon_len]
        last_patch_pred = predictions_mean[:, -1, :]  # shape [B, horizon_len]

    context_vals = x_context[0].cpu().numpy()
    future_vals = x_future[0].cpu().numpy()
    pred_vals = last_patch_pred[0].cpu().numpy()

    context_len = len(context_vals)
    horizon_len = len(future_vals)

    plt.figure(figsize=(12, 6))
    plt.plot(range(context_len), context_vals, label="Historical Data", color="blue", linewidth=2)
    plt.plot(
        range(context_len, context_len + horizon_len),
        future_vals,
        label="Ground Truth",
        color="green",
        linestyle="--",
        linewidth=2,
    )
    plt.plot(
        range(context_len, context_len + horizon_len),
        pred_vals,
        label="Prediction",
        color="red",
        linewidth=2,
    )

    plt.xlabel("Time Step")
    plt.ylabel("Value")
    plt.title("TimesFM Predictions vs Ground Truth")
    plt.legend()
    plt.grid(True)

    if save_path:
        plt.savefig(save_path)
        print(f"Plot saved to {save_path}")
    plt.close()

def get_data(context_len: int, horizon_len: int) -> Tuple[Dataset, Dataset]:
    """
    Download AAPL stock data from yfinance, then split into train/val sets.
    """
    df = yf.download("AAPL", start="2010-01-01", end="2019-01-01")
    time_series = df["Close"].values

    train_dataset, val_dataset = prepare_datasets(
        series=time_series,
        context_length=context_len,
        horizon_length=horizon_len,
        train_split=0.8,
    )

    print("Created datasets:")
    print(f"- Training samples: {len(train_dataset)}")
    print(f"- Validation samples: {len(val_dataset)}")
    return train_dataset, val_dataset


## **7. Single-GPU Finetuning Example**

Below is a convenience function that:
1. Builds the model.
2. Creates the training and validation datasets.
3. Initializes the finetuner.
4. Trains for a few epochs.
5. Logs results (optional via W&B).
6. Plots predictions vs. ground truth.


In [8]:
def single_gpu_example():
    # 1. Create model & load checkpoint
    model, hparams, tfm_config = get_model(load_weights=True)

    # 2. Define finetuning config (feel free to adjust epochs, batch_size, etc.)
    config = FinetuningConfig(
        batch_size=256,
        num_epochs=5,
        learning_rate=1e-4,
        use_wandb=True,  # set to False if you don't want to log to W&B
        wandb_project="timesfm-finetuning",
    )

    # 3. Prepare data
    train_dataset, val_dataset = get_data(128, tfm_config.horizon_len)

    # 4. Finetuner
    finetuner = TimesFMFinetuner(model, config)

    # 5. Train
    print("\nStarting finetuning...")
    results = finetuner.finetune(train_dataset=train_dataset, val_dataset=val_dataset)

    print("\nFinetuning completed!")
    print(f"Training history: {len(results['history']['train_loss'])} epochs")

    # 6. Plot predictions
    plot_predictions(
        model=model,
        val_dataset=val_dataset,
        save_path="timesfm_predictions.png",
    )


## **8. Run the Example**

Execute the cell below to run the entire pipeline on a single GPU. If you haven't configured W&B or you don't have an account, simply set `use_wandb=False` in the `FinetuningConfig`.

> **Note**: This may download a large checkpoint (2GB) from the Hugging Face Hub on first run.

In [9]:
# Uncomment to run the single-GPU finetuning example.
single_gpu_example()

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

[*********************100%***********************]  1 of 1 completed


Created datasets:
- Training samples: 1556
- Validation samples: 198

Starting finetuning...


0,1
epoch,▁▃▅▆█
learning_rate,▁▁▁▁▁
train_loss,██▃▂▁
val_loss,█▅▁▂▃

0,1
epoch,5.0
learning_rate,0.0001
train_loss,0.71242
val_loss,13.67925



Finetuning completed!
Training history: 5 epochs
Plot saved to timesfm_predictions.png


In [10]:
wandb.finish()

# **Conclusion**
This notebook has shown how to:
- Install and prepare a TimesFM model.
- Build a training and validation pipeline with logging.
- Finetune on custom time series data.
- Visualize the model’s predictions.

Adapt or extend this workflow for your own dataset and tasks!