In [1]:

# 设置环境变量WANDB_MODE为disabled，禁止在微调模型的适合连接wandb记录微调过程，因为会发生SSL错误

import os
os.environ["WANDB_MODE"] = "disabled"


In [2]:

# 配置模型，下载到本地的文件作为模块导入需要指明可用于搜索模块的文件夹sys.path.append
# 为了避免下载模型连接不到网页的SSL错误，提前下载模型，传入路径代替下载模型
# TimesFm类本身就有path参数，传入了就不用下载了

from os import path
import torch
from huggingface_hub import snapshot_download

import sys
sys.path.append(r"C:\Users\传防科电脑\Desktop\timesfm-master\src")
from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams
from timesfm.pytorch_patched_decoder import PatchedTimeSeriesDecoder
from finetuning.finetuning_torch import FinetuningConfig, TimesFMFinetuner

def get_model(load_weights: bool = False):
  device = "cuda" if torch.cuda.is_available() else "cpu"
  repo_id = "google/timesfm-2.0-500m-pytorch"
  hparams = TimesFmHparams(
      backend=device,
      per_core_batch_size=32,
      horizon_len=128,
      num_layers=50,
      use_positional_embedding=False,
      context_len=192,  # Context length can be anything up to 2048 in multiples of 32
  )
  tfm = TimesFm(hparams=hparams, checkpoint=TimesFmCheckpoint(huggingface_repo_id=repo_id, path=r"C:\Users\传防科电脑\Desktop\timesfm\timesfm_model\torch_model.ckpt"))

  model = PatchedTimeSeriesDecoder(tfm._model_config)
  if load_weights:
    # checkpoint_path = path.join(snapshot_download(repo_id), "torch_model.ckpt")
    loaded_checkpoint = torch.load(r"C:\Users\传防科电脑\Desktop\timesfm\timesfm_model\torch_model.ckpt", weights_only=True)
    model.load_state_dict(loaded_checkpoint)
  return model, hparams, tfm._model_config


 See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded PyTorch TimesFM, likely because python version is 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:03:56) [MSC v.1929 64 bit (AMD64)].


In [3]:

model, hparams, tfm_config = get_model(load_weights=True)


In [4]:
model

PatchedTimeSeriesDecoder(
  (input_ff_layer): ResidualBlock(
    (hidden_layer): Sequential(
      (0): Linear(in_features=64, out_features=1280, bias=True)
      (1): SiLU()
    )
    (output_layer): Linear(in_features=1280, out_features=1280, bias=True)
    (residual_layer): Linear(in_features=64, out_features=1280, bias=True)
  )
  (freq_emb): Embedding(3, 1280)
  (horizon_ff_layer): ResidualBlock(
    (hidden_layer): Sequential(
      (0): Linear(in_features=1280, out_features=1280, bias=True)
      (1): SiLU()
    )
    (output_layer): Linear(in_features=1280, out_features=1280, bias=True)
    (residual_layer): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (stacked_transformer): StackedDecoder(
    (layers): ModuleList(
      (0-49): 50 x TimesFMDecoderLayer(
        (self_attn): TimesFMAttention(
          (qkv_proj): Linear(in_features=1280, out_features=3840, bias=True)
          (o_proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (

In [5]:
hparams

TimesFmHparams(context_len=192, horizon_len=128, input_patch_len=32, output_patch_len=128, num_layers=50, num_heads=16, model_dims=1280, per_core_batch_size=32, backend='cpu', quantiles=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9), use_positional_embedding=False, point_forecast_mode='median')

In [6]:
tfm_config

TimesFMConfig(num_layers=50, num_heads=16, num_kv_heads=16, hidden_size=1280, intermediate_size=1280, head_dim=80, rms_norm_eps=1e-06, patch_len=32, horizon_len=128, quantiles=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9), pad_val=1123581321.0, tolerance=1e-06, dtype='bfloat32', use_positional_embedding=False)

In [7]:

# 配置模型为题哦的config

sys.path.append(r"C:\Users\传防科电脑\Desktop\timesfm-master\src")
from finetuning.finetuning_torch import FinetuningConfig, TimesFMFinetuner

config = FinetuningConfig(batch_size=256,
                        num_epochs=5,
                        learning_rate=1e-4,
                        use_wandb=True,
                        freq_type=1,
                        log_every_n_steps=10,
                        val_check_interval=0.5,
                        use_quantile_loss=True)
config


FinetuningConfig(batch_size=256, num_epochs=5, learning_rate=0.0001, weight_decay=0.01, freq_type=1, use_quantile_loss=True, quantiles=None, device='cpu', distributed=False, gpu_ids=[0], master_port='12358', master_addr='localhost', use_wandb=True, wandb_project='timesfm-finetuning', log_every_n_steps=10, val_check_interval=0.5)

In [8]:

# 微调数据的预处理
# 数据不够，修改prepare_datasets函数的train_split，可增加训练、验证数据的样本数

from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from typing import Optional, Tuple

class TimeSeriesDataset(Dataset):
  """Dataset for time series data compatible with TimesFM."""

  def __init__(self,
               series: np.ndarray,
               context_length: int,
               horizon_length: int,
               freq_type: int = 0):
    """
        Initialize dataset.

        Args:
            series: Time series data
            context_length: Number of past timesteps to use as input
            horizon_length: Number of future timesteps to predict
            freq_type: Frequency type (0, 1, or 2)
        """
    if freq_type not in [0, 1, 2]:
      raise ValueError("freq_type must be 0, 1, or 2")

    self.series = series
    self.context_length = context_length
    self.horizon_length = horizon_length
    self.freq_type = freq_type
    self._prepare_samples()

  def _prepare_samples(self) -> None:
    """Prepare sliding window samples from the time series."""
    self.samples = []
    total_length = self.context_length + self.horizon_length

    for start_idx in range(0, len(self.series) - total_length + 1):
      end_idx = start_idx + self.context_length
      x_context = self.series[start_idx:end_idx]
      x_future = self.series[end_idx:end_idx + self.horizon_length]
      self.samples.append((x_context, x_future))

  def __len__(self) -> int:
    return len(self.samples)

  def __getitem__(
      self, index: int
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    x_context, x_future = self.samples[index]

    x_context = torch.tensor(x_context, dtype=torch.float32)
    x_future = torch.tensor(x_future, dtype=torch.float32)

    input_padding = torch.zeros_like(x_context)
    freq = torch.tensor([self.freq_type], dtype=torch.long)

    return x_context, input_padding, freq, x_future

def prepare_datasets(series: np.ndarray,
                     context_length: int,
                     horizon_length: int,
                     freq_type: int = 0,
                     train_split: float = 0.8) -> Tuple[Dataset, Dataset]:
  """
    Prepare training and validation datasets from time series data.

    Args:
        series: Input time series data
        context_length: Number of past timesteps to use
        horizon_length: Number of future timesteps to predict
        freq_type: Frequency type (0, 1, or 2)
        train_split: Fraction of data to use for training

    Returns:
        Tuple of (train_dataset, val_dataset)
    """
  train_size = int(len(series) * train_split)
  train_data = series[:train_size]
  val_data = series[train_size:]

  # Create datasets with specified frequency type
  train_dataset = TimeSeriesDataset(train_data,
                                    context_length=context_length,
                                    horizon_length=horizon_length,
                                    freq_type=freq_type)

  val_dataset = TimeSeriesDataset(val_data,
                                  context_length=context_length,
                                  horizon_length=horizon_length,
                                  freq_type=freq_type)

  return train_dataset, val_dataset


In [9]:

import akshare as ak

fund_etf_hist_em_df = ak.fund_etf_hist_em(symbol="513770", period="daily", start_date="20220218", end_date="20250723", adjust="")
print(fund_etf_hist_em_df)


  0%|          | 0/11 [00:00<?, ?it/s]

             日期     开盘     收盘     最高     最低      成交量          成交额    振幅   涨跌幅  \
0    2022-02-18  0.997  0.989  1.004  0.988  2917198  292174176.0  1.57 -2.94   
1    2022-02-21  0.965  0.948  0.973  0.944   846233   80811395.0  2.93 -4.15   
2    2022-02-22  0.930  0.922  0.940  0.913  1457760  134019602.0  2.85 -2.74   
3    2022-02-23  0.928  0.940  0.944  0.922  1404163  131535030.0  2.39  1.95   
4    2022-02-24  0.917  0.906  0.931  0.886  1353017  122475991.0  4.79 -3.62   
..          ...    ...    ...    ...    ...      ...          ...   ...   ...   
827  2025-07-17  1.110  1.116  1.121  1.108  4662467  519992175.0  1.17  0.09   
828  2025-07-18  1.132  1.134  1.139  1.128  4006023  454250379.0  0.99  1.61   
829  2025-07-21  1.151  1.144  1.151  1.134  4434565  506451738.0  1.50  0.88   
830  2025-07-22  1.141  1.130  1.145  1.127  3614534  409798410.0  1.57 -1.22   
831  2025-07-23  1.138  1.160  1.165  1.133  6481090  747140338.0  2.83  2.65   

       涨跌额    换手率  
0   -0.

In [10]:

# 获取并且预处理数据
# 通过fund_etf_hist_em_df接口获取国内公募基金股票数据，替换国外的
# 数据不够，修改get_data函数的训练数据长度，可以修改为32、64、96、128；验证数据的长度为模型配置参数，不可修改

import akshare as ak

def get_data(context_len: int,
             horizon_len: int,
             freq_type: int = 0) -> Tuple[Dataset, Dataset]:
  fund_etf_hist_em_df = ak.fund_etf_hist_em(symbol="513770", period="daily", start_date="20220218", end_date="20250723", adjust="")
  time_series = fund_etf_hist_em_df["收盘"].values

  train_dataset, val_dataset = prepare_datasets(
      series=time_series,
      context_length=context_len,
      horizon_length=horizon_len,
      freq_type=freq_type,
      train_split=0.7,
  )
  print(time_series)
  print(f"Created datasets:")
  print(f"- Training samples: {len(train_dataset)}")
  print(f"- Validation samples: {len(val_dataset)}")
  print(f"- Using frequency type: {freq_type}")
  return train_dataset, val_dataset

train_dataset, val_dataset = get_data(96,
                                    tfm_config.horizon_len,
                                    freq_type=config.freq_type)

print(train_dataset, val_dataset)


[0.989 0.948 0.922 0.94  0.906 0.907 0.896 0.918 0.897 0.874 0.84  0.811
 0.781 0.773 0.776 0.75  0.69  0.626 0.689 0.758 0.797 0.781 0.795 0.842
 0.818 0.772 0.791 0.817 0.827 0.806 0.798 0.83  0.802 0.79  0.747 0.769
 0.765 0.777 0.77  0.765 0.747 0.748 0.719 0.724 0.692 0.716 0.731 0.737
 0.806 0.784 0.75  0.742 0.734 0.759 0.748 0.784 0.786 0.825 0.814 0.793
 0.814 0.79  0.756 0.765 0.762 0.781 0.803 0.83  0.829 0.826 0.849 0.866
 0.904 0.89  0.914 0.885 0.89  0.9   0.874 0.887 0.891 0.91  0.885 0.896
 0.926 0.954 0.967 0.936 0.924 0.925 0.934 0.932 0.907 0.911 0.915 0.877
 0.869 0.876 0.872 0.85  0.876 0.87  0.887 0.879 0.88  0.865 0.876 0.863
 0.866 0.82  0.819 0.796 0.803 0.827 0.834 0.822 0.818 0.789 0.812 0.82
 0.83  0.8   0.804 0.798 0.807 0.805 0.801 0.777 0.808 0.812 0.817 0.803
 0.823 0.803 0.798 0.789 0.787 0.778 0.767 0.785 0.781 0.764 0.765 0.753
 0.734 0.743 0.73  0.715 0.708 0.72  0.726 0.713 0.69  0.672 0.664 0.653
 0.66  0.633 0.647 0.642 0.67  0.655 0.641 0.644 0.5

In [11]:

# 微调模型

finetuner = TimesFMFinetuner(model, config)

print("\nStarting finetuning...")
results = finetuner.finetune(train_dataset=train_dataset,
                           val_dataset=val_dataset)

print("\nFinetuning completed!")
print(f"Training history: {len(results['history']['train_loss'])} epochs")



Starting finetuning...

Finetuning completed!
Training history: 5 epochs


In [12]:

# 训练数据、验证数据、预测数据的图，可看出模型微调的好坏

def plot_predictions(
    model: TimesFm,
    val_dataset: Dataset,
    save_path: Optional[str] = "predictions.png",
) -> None:
  """
    Plot model predictions against ground truth for a batch of validation data.

    Args:
      model: Trained TimesFM model
      val_dataset: Validation dataset
      save_path: Path to save the plot
    """
  import matplotlib.pyplot as plt

  model.eval()

  x_context, x_padding, freq, x_future = val_dataset[0]
  x_context = x_context.unsqueeze(0)  # Add batch dimension
  x_padding = x_padding.unsqueeze(0)
  freq = freq.unsqueeze(0)
  x_future = x_future.unsqueeze(0)

  device = next(model.parameters()).device
  x_context = x_context.to(device)
  x_padding = x_padding.to(device)
  freq = freq.to(device)
  x_future = x_future.to(device)

  with torch.no_grad():
    predictions = model(x_context, x_padding.float(), freq)
    predictions_mean = predictions[..., 0]  # [B, N, horizon_len]
    last_patch_pred = predictions_mean[:, -1, :]  # [B, horizon_len]

  context_vals = x_context[0].cpu().numpy()
  future_vals = x_future[0].cpu().numpy()
  pred_vals = last_patch_pred[0].cpu().numpy()

  context_len = len(context_vals)
  horizon_len = len(future_vals)

  plt.figure(figsize=(12, 6))

  plt.plot(range(context_len),
           context_vals,
           label="Historical Data",
           color="blue",
           linewidth=2)

  plt.plot(
      range(context_len, context_len + horizon_len),
      future_vals,
      label="Ground Truth",
      color="green",
      linestyle="--",
      linewidth=2,
  )

  plt.plot(range(context_len, context_len + horizon_len),
           pred_vals,
           label="Prediction",
           color="red",
           linewidth=2)

  plt.xlabel("Time Step")
  plt.ylabel("Value")
  plt.title("TimesFM Predictions vs Ground Truth")
  plt.legend()
  plt.grid(True)

  if save_path:
    plt.savefig(save_path)
    print(f"Plot saved to {save_path}")

  plt.close()
    

In [13]:

# 绘图

plot_predictions(
      model=model,
      val_dataset=val_dataset,
      save_path="timesfm_predictions.png",
  )


Plot saved to timesfm_predictions.png
