In [1]:
import math
import os
import tempfile

import numpy as np
import pandas as pd
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
from transformers.integrations import INTEGRATION_TO_CALLBACK

from tsfm_public import TimeSeriesPreprocessor, TrackingCallback, count_parameters, get_datasets
from tsfm_public.toolkit.get_model import get_model
from tsfm_public.toolkit.lr_finder import optimal_lr_finder
from tsfm_public.toolkit.visualization import plot_predictions

In [2]:
# Set seed for reproducibility
SEED = 42
set_seed(SEED)

# TTM Model path. The default model path is Granite-R2. Below, you can choose other TTM releases.
TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
# TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r1"
# TTM_MODEL_PATH = "ibm-research/ttm-research-r2"

# Context length, Or Length of the history.
# Currently supported values are: 512/1024/1536 for Granite-TTM-R2 and Research-Use-TTM-R2, and 512/1024 for Granite-TTM-R1
CONTEXT_LENGTH = 52 # shortest context length

# Granite-TTM-R2 supports forecast length upto 720 and Granite-TTM-R1 supports forecast length upto 96
PREDICTION_LENGTH = 7

#TARGET_DATASET = "etth1"
#dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv"


# Results dir
OUT_DIR = "ttm_finetuned_models/"

In [3]:
# Dataset
timestamp_column = "date"
id_columns = []  # mention the ids that uniquely identify a time-series.

# Understanding the split config -- slides
split_config = {
    "train": [0, 24],
    "valid": [24, 32],
    "test": [
        32,
        40,
    ],
}


def generate_time_series_data(start_time: str = "2021-01-01 00:00:00", 
                              freq: str = "H", 
                              periods: int = 40, 
                              n_variables: int = 10, 
                              univariate: bool = False) -> pd.DataFrame:
    """
    Generate a univariate or multivariate time series DataFrame.
    
    Args:
        start_time (str): Start datetime.
        freq (str): Pandas frequency string (e.g., 'H' for hourly).
        periods (int): Number of time steps.
        n_variables (int): Number of variables (ignored if univariate=True).
        univariate (bool): Whether to generate univariate (single-column) data.

    Returns:
        pd.DataFrame: Time series data with datetime index.
    """
    index = pd.date_range(start=start_time, periods=periods, freq=freq)

    if univariate:
        data = np.random.randn(periods)
        df = pd.DataFrame(data, index=index, columns=["A"])
    else:
        columns = list("ABCDEFGHIJ")[:n_variables]
        data = np.random.randn(periods, n_variables)
        df = pd.DataFrame(data, index=index, columns=columns)
    
    return df

df_multivariate = generate_time_series_data()
df_multivariate.index.rename("date", inplace=True)
df_multivariate.reset_index(inplace=True)
print(df_multivariate.head())
df_univariate = generate_time_series_data(univariate=True)
df_univariate.index.rename("date", inplace=True)
df_univariate.reset_index(inplace=True)
print(df_univariate)

                 date         A         B         C         D         E  \
0 2021-01-01 00:00:00  0.496714 -0.138264  0.647689  1.523030 -0.234153   
1 2021-01-01 01:00:00 -0.463418 -0.465730  0.241962 -1.913280 -1.724918   
2 2021-01-01 02:00:00  1.465649 -0.225776  0.067528 -1.424748 -0.544383   
3 2021-01-01 03:00:00 -0.601707  1.852278 -0.013497 -1.057711  0.822545   
4 2021-01-01 04:00:00  0.738467  0.171368 -0.115648 -0.301104 -1.478522   

          F         G         H         I         J  
0 -0.234137  1.579213  0.767435 -0.469474  0.542560  
1 -0.562288 -1.012831  0.314247 -0.908024 -1.412304  
2  0.110923 -1.150994  0.375698 -0.600639 -0.291694  
3 -1.220844  0.208864 -1.959670 -1.328186  0.196861  
4 -0.719844 -0.460639  1.057122  0.343618 -1.763040  
                  date         A
0  2021-01-01 00:00:00 -1.594428
1  2021-01-01 01:00:00 -0.599375
2  2021-01-01 02:00:00  0.005244
3  2021-01-01 03:00:00  0.046981
4  2021-01-01 04:00:00 -0.450065
5  2021-01-01 05:00:00  0.6

  index = pd.date_range(start=start_time, periods=periods, freq=freq)
  index = pd.date_range(start=start_time, periods=periods, freq=freq)


In [4]:
def zeroshot_eval(data, batch_size, context_length=512, forecast_length=96, column_specifiers={}):
    # Get data

    tsp = TimeSeriesPreprocessor(
        **column_specifiers,
        context_length=context_length,
        prediction_length=forecast_length,
        scaling=True,
        encode_categorical=False,
        scaler_type="standard",
    )

    # Load model
    zeroshot_model = get_model(
        TTM_MODEL_PATH,
        context_length=context_length,
        prediction_length=forecast_length,
        freq_prefix_tuning=False,
        freq=None,
        prefer_l1_loss=False,
        prefer_longer_context=True,
    )

    dset_train, dset_valid, dset_test = get_datasets(
        tsp, data, split_config, use_frequency_token=zeroshot_model.config.resolution_prefix_tuning
    )

    temp_dir = tempfile.mkdtemp()
    # zeroshot_trainer
    zeroshot_trainer = Trainer(
        model=zeroshot_model,
        args=TrainingArguments(
            output_dir=temp_dir,
            per_device_eval_batch_size=batch_size,
            seed=SEED,
            report_to="none",
        ),
    )
    # evaluate = zero-shot performance
    print("+" * 20, "Test MSE zero-shot", "+" * 20)
    zeroshot_output = zeroshot_trainer.evaluate(dset_test)
    print(zeroshot_output)

    # get predictions

    predictions_dict = zeroshot_trainer.predict(dset_test)

    predictions_np = predictions_dict.predictions[0]

    print(np.transpose(predictions_np[0]))
    print(np.transpose(predictions_np[0]).shape)

    # get backbone embeddings (if needed for further analysis)

    backbone_embedding = predictions_dict.predictions[1]

    print(backbone_embedding[0])
    print(backbone_embedding[0].shape)

    # plot
    """
    plot_predictions(
        model=zeroshot_trainer.model,
        dset=dset_test,
        plot_dir=os.path.join(OUT_DIR, dataset_name),
        plot_prefix="test_zeroshot",
        indices=[685, 118, 902, 1984, 894, 967, 304, 57, 265, 1015],
        channel=0,
    )"""

In [5]:
# Multivariate Tiny Time Mixer
column_specifiers = {
    "timestamp_column": timestamp_column,
    "id_columns": id_columns,
    "target_columns": ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"],
    "control_columns": [],
}
zeroshot_eval(
    data=df_multivariate, context_length=CONTEXT_LENGTH, forecast_length=PREDICTION_LENGTH, batch_size=8, column_specifiers=column_specifiers
)

INFO:p-3435:t-8022728384:get_model.py:get_model:Loading model from: ibm-granite/granite-timeseries-ttm-r2
INFO:p-3435:t-8022728384:get_model.py:get_model:Model loaded successfully from ibm-granite/granite-timeseries-ttm-r2, revision = 52-16-ft-r2.1.
INFO:p-3435:t-8022728384:get_model.py:get_model:[TTM] context_length = 52, prediction_length = 16


++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++




{'eval_loss': 1.2032523155212402, 'eval_model_preparation_time': 0.0078, 'eval_runtime': 1.1825, 'eval_samples_per_second': 0.846, 'eval_steps_per_second': 0.846}
[[ 0.04402132  0.03331154  0.02133745  0.00305077 -0.00776591 -0.00895865
   0.00161481]
 [ 0.10501055  0.01142201 -0.01504199 -0.03931901 -0.03481245 -0.02803831
  -0.01688247]
 [-0.33105555 -0.22063375 -0.14527488 -0.14149031 -0.09604315 -0.05538443
  -0.04476943]
 [ 0.8807009   0.50128347  0.33325475  0.25177714  0.19909978  0.17655475
   0.16773659]
 [-0.26294658 -0.18029238 -0.12163518 -0.08381483 -0.07242465 -0.05933443
  -0.05498382]
 [-0.95935035 -0.75746226 -0.67643684 -0.5775811  -0.5128267  -0.50890464
  -0.50662965]
 [ 0.65836704  0.4228226   0.3083427   0.25424948  0.23301637  0.22384226
   0.22057135]
 [ 0.12895179  0.00613056 -0.016215   -0.03495287 -0.03996248 -0.03250063
  -0.01976608]
 [ 0.6290895   0.4039339   0.29543215  0.23577464  0.21004614  0.21258509
   0.19800416]
 [ 0.23581614  0.16492547  0.1374643

In [6]:
# Univariate Tiny Time Mixer
column_specifiers = {
    "timestamp_column": timestamp_column,
    "id_columns": id_columns,
    "target_columns": ["A"],
    "control_columns": [],
}
zeroshot_eval(
    data=df_univariate, context_length=CONTEXT_LENGTH, forecast_length=PREDICTION_LENGTH, batch_size=8, column_specifiers=column_specifiers
)

INFO:p-3435:t-8022728384:get_model.py:get_model:Loading model from: ibm-granite/granite-timeseries-ttm-r2
INFO:p-3435:t-8022728384:get_model.py:get_model:Model loaded successfully from ibm-granite/granite-timeseries-ttm-r2, revision = 52-16-ft-r2.1.
INFO:p-3435:t-8022728384:get_model.py:get_model:[TTM] context_length = 52, prediction_length = 16


++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++


{'eval_loss': 1.6534693241119385, 'eval_model_preparation_time': 0.0038, 'eval_runtime': 0.2715, 'eval_samples_per_second': 3.684, 'eval_steps_per_second': 3.684}
[[-0.31559467 -0.26337448 -0.2292342  -0.22169566 -0.17298758 -0.13871092
  -0.11940026]]
(1, 7)
[[[ 2.58540541e-01  1.89691231e-01 -9.90554392e-01 -1.80186749e+00
    1.37157008e-01  2.34319940e-01 -1.92468357e+00  1.77439868e-01
    3.82961571e-01  9.24913809e-02 -7.71513462e-01 -6.02124631e-01
   -9.41056848e-01 -8.00155044e-01 -8.31349969e-01 -1.80217361e+00
   -1.32850003e+00 -1.07328176e+00 -1.23185384e+00  6.71325848e-02]
  [ 1.04538046e-01 -5.37322089e-03 -2.67058402e-01 -4.80283797e-01
    4.71760929e-02  4.77122851e-02 -4.43057716e-01  8.26564506e-02
   -4.81275469e-02  3.16931754e-02  1.04031339e-02 -2.94233203e-01
   -1.69283167e-01 -5.34338057e-01  4.30340230e-01 -4.40655589e-01
   -1.79318916e-02 -3.66196424e-01 -1.23684146e-01 -2.55382862e-02]
  [ 9.88186970e-02  1.78145003e-02 -1.88574865e-01 -4.13320422e-01
 