In [None]:
%load_ext autoreload
%autoreload 2

import os
from copy import deepcopy

import torch
import numpy as np
import pandas as pd
import numpy.typing as npt
import matplotlib.pyplot as plt
from statsmodels.tsa.seasonal import STL
from tqdm import tqdm

from moment.common import PATHS
from moment.utils.config import Config
from moment.utils.utils import parse_config, dtype_map
from moment.utils.forecasting_metrics import get_forecasting_metrics, sMAPELoss
from moment.data.dataloader import get_timeseries_dataloader
from moment.data.forecasting_datasets import get_forecasting_datasets, ShortForecastingDataset
from moment.models.base import BaseModel
from moment.models.moment import MOMENT
from moment.models.nbeats import NBEATS
from moment.models.timesnet import TimesNet
from moment.models.gpt4ts import GPT4TS

### Find relevant runs

In [None]:
MODEL_NAME = "MOMENT"

NOTES = {
    "MOMENT": "Fine-tune MOMENT on source short-horizon forecasting datasets",
    "NBEATS": "Train N-BEATS on source short-horizon forecasting datasets",
    "TimesNet": "Train TimesNet on source short-horizon forecasting datasets",
    "GPT4TS": "Train GPT4TS on source short-horizon forecasting datasets"
}

runs_summary = pd.read_csv("../../assets/data/wandb_runs_summary.csv")
runs = runs_summary[runs_summary["notes"] == NOTES[MODEL_NAME]]
runs = runs[['model_name', 'run_name', 'hostname', 'dataset_names']]

runs.dataset_names = runs.dataset_names.apply(lambda x: x.split("/")[-1])
runs.hostname = runs.hostname.apply(lambda x: x.split(".")[0])
runs['source_frequency'] = runs.dataset_names.apply(lambda x: x.split(".")[0].split('_')[1])
runs['source_collection'] = runs.dataset_names.apply(lambda x: x.split(".")[0].split('_')[0])
runs = runs[['model_name', 'run_name', 'source_frequency', 'source_collection']]
print(runs)

### Load configs

In [None]:
short_forecasting_datasets = get_forecasting_datasets(collection="monash")
fred_forecasting_datasets = get_forecasting_datasets(collection="fred/preprocessed")

print("M3 datasets:")
m_datasets_base_path = '/'.join(short_forecasting_datasets[0].split('/')[:-1])
# print(f"--- M3 & M4 datasets (base path): {m_datasets_base_path}")
print("--- M3 splits:", [i.split('/')[-1] for i in short_forecasting_datasets if "m3" in i])
print("--- M4 splits:", [i.split('/')[-1] for i in short_forecasting_datasets if "m4" in i])

print("Fred datasets:")
fred_datasets_base_path = '/'.join(fred_forecasting_datasets[0].split('/')[:-1])
# print(f"--- FRED datasets (base path): {fred_datasets_base_path}")
print('--- Splits:', [i.split('/')[-1] for i in fred_forecasting_datasets if "fred" in i])

In [None]:
def get_dataloaders(args):
    args.dataset_names = args.full_file_path_and_name
    args.data_split = 'train'
    args.batch_size = args.train_batch_size
    train_dataloader = get_timeseries_dataloader(args=args)
    args.data_split = 'test'
    args.batch_size = args.val_batch_size
    test_dataloader = get_timeseries_dataloader(args=args)
    args.data_split = 'val'
    args.batch_size = args.val_batch_size
    val_dataloader = get_timeseries_dataloader(args=args)
    return train_dataloader, test_dataloader, val_dataloader


HORIZON_MAPPING = {'hourly': 48, 'daily': 14, 'weekly': 13, 'monthly': 18, 'quarterly': 8, 'other': 8, 'yearly': 6}

def validation(args, model, data_loader, return_preds):
    trues, preds, histories, losses = [], [], [], []

    # criterion = nn.MSELoss(reduction='mean')
    criterion = sMAPELoss(reduction='mean')
    
    model.eval()
    with torch.no_grad():
        for batch_x in tqdm(data_loader, total=len(data_loader)):
            timeseries = batch_x.timeseries.float().to(args.device)
            input_mask = batch_x.input_mask.long().to(args.device)
            forecast = batch_x.forecast.float().to(args.device)
            forecast_horizon = forecast.shape[-1]

            # scaler = torch.max(timeseries, dim=-1, keepdim=True)[0]
            # timeseries = timeseries / scaler

            with torch.autocast(device_type='cuda', 
                                dtype=dtype_map(args.torch_dtype), 
                                enabled=args.use_amp):
                # outputs = model.long_forecast(x_enc=timeseries, 
                #                         input_mask=input_mask, 
                #                         mask=None)
                outputs = model(x_enc=timeseries, input_mask=input_mask, mask=None)
                # outputs.forecast = outputs.forecast * scaler

            if outputs.forecast.shape != forecast:
                outputs.forecast = outputs.forecast[:, :, :forecast_horizon]
                
            loss = criterion(outputs.forecast, forecast)                
            losses.append(loss.item())

            if return_preds:
                trues.append(forecast.detach().cpu().numpy())
                preds.append(outputs.forecast.detach().cpu().numpy())
                histories.append(timeseries.detach().cpu().numpy())
    
    losses = np.array(losses)
    average_loss = np.average(losses)
    model.train()

    if return_preds:
        trues = np.concatenate(trues, axis=0)
        preds = np.concatenate(preds, axis=0)
        histories = np.concatenate(histories, axis=0)
        return average_loss, losses, (trues, preds, histories)
    else:
        return average_loss

config_file_map = {
    "MOMENT": "../../configs/forecasting/linear_probing_short_horizon.yaml",
    "N-BEATS": "../../configs/forecasting/nbeats.yaml",
    "TimesNet": "../../configs/forecasting/timesnet.yaml",
    "GPT4TS": "../../configs/forecasting/gpt4ts.yaml",
}

def get_model(model_name, args):
    if model_name == "MOMENT":
        model = MOMENT(args)
    elif model_name == "TimesNet":
        model = TimesNet(args)
    elif model_name == "GPT4TS":
        model = GPT4TS(args)
    elif model_name == "N-BEATS":
        model = NBEATS(args)
    else:
        raise ValueError(f"Model {model_name} not found.")
    return model

def get_checkpoint_name(model_name):
    if model_name == "MOMENT":
        return 'MOMENT.pth'
    elif model_name == "TimesNet":
        return 'TimesNet.pth'
    elif model_name == "GPT4TS":
        return 'GPT4TS.pth'
    elif model_name == "N-BEATS":
        return 'NBEATS.pth'
    else:
        raise ValueError(f"Model {model_name} not found.")

In [None]:
DEFAULT_CONFIG_PATH = "../../configs/default.yaml"
GPU_ID = 0

DATASET = "m3" # "m3" | "m4" | "fred"
FREQUENCY = "quarterly" # "monthly" | "quarterly" | "yearly" | "daily" | "hourly" | "weekly" | "other"

SOURCE_COLLECTION = "m4" # m4 fred
SOURCE_FREQUENCY = "quarterly" # monthly quarterly yearly daily hourly weekly

BASE_PATH = m_datasets_base_path if DATASET in ['m3', 'm4'] else fred_datasets_base_path

# linear_probing_short_horizon.yaml
config = Config(config_file_path=config_file_map[MODEL_NAME], 
                default_config_file_path=DEFAULT_CONFIG_PATH).parse()
config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'

args = parse_config(config)

file_format = 'tsf' if DATASET in ['m3', 'm4'] else 'npy'
args.full_file_path_and_name = os.path.join(BASE_PATH, f"{DATASET}_{FREQUENCY}_dataset.{file_format}")    
args.dataset_names = args.full_file_path_and_name
args.forecast_horizon = HORIZON_MAPPING[SOURCE_FREQUENCY] # HORIZON_MAPPING[FREQUENCY]
args.val_batch_size = 128

train_dataloader, test_dataloader, val_dataloader = get_dataloaders(args)
print(f"Source forecast horizon: {train_dataloader.dataset.forecast_horizon}")
print(f"Lengths: Train: {train_dataloader.dataset.length_dataset} | Test: {test_dataloader.dataset.length_dataset} | Val: {val_dataloader.dataset.length_dataset}")

### Load the source model

In [None]:
checkpoint_name = 'vibrant-glitter-1232'
# selected_runs = runs.loc[(runs['source_collection'] == SOURCE_COLLECTION) & (runs['source_frequency'] == SOURCE_FREQUENCY)]
# if selected_runs.shape[0] > 1: 
#     print(f"More than one run found for the given source collection and frequency\n{selected_runs}")
#     selected_runs.loc[:, 'run_num'] = selected_runs.run_name.apply(lambda x: int(x.split("-")[-1]))
#     selected_runs = selected_runs.sort_values(by=['run_num'], ascending=False)
# checkpoint_name = str(selected_runs.run_name.values[0])
print(f"Checkpoint name: {checkpoint_name}")

model = get_model(MODEL_NAME, args)

with open(os.path.join(PATHS.CHECKPOINTS_DIR, f'{checkpoint_name}/{get_checkpoint_name(MODEL_NAME)}'), 'rb') as f:
    checkpoint = torch.load(f, map_location='cpu')

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model.to(args.device)

In [None]:
_, _, (trues_val, preds_val, _) = validation(args, model, val_dataloader, return_preds=True)
_, _, (trues_test, preds_test, _) = validation(args, model, test_dataloader, return_preds=True)
trues = np.concatenate([trues_val, trues_test], axis=0)
preds = np.concatenate([preds_val, preds_test], axis=0)

get_forecasting_metrics(y=trues, y_hat=preds)

## Zero-shot forecasting

In [None]:
def load_pretrained_moment(args,
                         pretraining_task_name: str = "pre-training"):
    args.task_name = pretraining_task_name
        
    checkpoint = BaseModel.load_pretrained_weights(
        run_name=args.pretraining_run_name, 
        opt_steps=args.pretraining_opt_steps)
    
    pretrained_model = MOMENT(configs=args)
    pretrained_model.load_state_dict(checkpoint["model_state_dict"])
    
    return pretrained_model

# def validation(args, model, data_loader, return_preds):
#     trues, preds, histories, losses = [], [], [], []

#     # criterion = nn.MSELoss(reduction='mean')
#     criterion = sMAPELoss(reduction='mean')
    
#     model.eval()
#     with torch.no_grad():
#         for batch_x in tqdm(data_loader, total=len(data_loader)):
#             timeseries = batch_x.timeseries.float().to(args.device)
#             input_mask = batch_x.input_mask.long().to(args.device)
#             forecast = batch_x.forecast.float().to(args.device)
#             forecast_horizon = forecast.shape[-1]

#             # scaler = torch.max(timeseries, dim=-1, keepdim=True)[0]
#             # timeseries = timeseries / scaler

#             with torch.autocast(device_type='cuda', 
#                                 dtype=dtype_map(args.torch_dtype), 
#                                 enabled=args.use_amp): 
#                 outputs = model.short_forecast(
#                     x_enc=timeseries, input_mask=input_mask, forecast_horizon=forecast_horizon)
                
#                 # outputs.forecast = outputs.forecast * scaler
    
#             loss = criterion(outputs.forecast, forecast)                
#             losses.append(loss.item())

#             if return_preds:
#                 trues.append(forecast.detach().cpu().numpy())
#                 preds.append(outputs.forecast.detach().cpu().numpy())
#                 histories.append(timeseries.detach().cpu().numpy())
    
#     losses = np.array(losses)
#     average_loss = np.average(losses)
#     model.train()

#     if return_preds:
#         trues = np.concatenate(trues, axis=0)
#         preds = np.concatenate(preds, axis=0)
#         histories = np.concatenate(histories, axis=0)
#         return average_loss, losses, (trues, preds, histories)
#     else:
#         return average_loss

def validation(args, model, data_loader, return_preds, season, period):
    trues, preds, histories, losses = [], [], [], []

    # criterion = nn.MSELoss(reduction='mean')
    criterion = sMAPELoss(reduction='mean')
    
    model.eval()
    with torch.no_grad():
        for batch_x in tqdm(data_loader, total=len(data_loader)):
            timeseries = batch_x.timeseries.float().to(args.device)
            input_mask = batch_x.input_mask.long().to(args.device)
            forecast = batch_x.forecast.float().to(args.device)
            forecast_horizon = forecast.shape[-1]

            timeseries = batch_x.timeseries.squeeze().numpy()
            decomposition = STL(timeseries, seasonal=season, period=period).fit()
            # timeseries = np.concatenate(
            #     [decomposition.trend[np.newaxis, :], 
            #      decomposition.seasonal[np.newaxis, :], 
            #      decomposition.resid[np.newaxis, :],], axis=0)
            timeseries = np.concatenate(
                [decomposition.trend[np.newaxis, :], 
                 decomposition.seasonal[np.newaxis, :],], axis=0)
            timeseries = torch.from_numpy(timeseries).unsqueeze(1).float().to(args.device)

            with torch.autocast(device_type='cuda', 
                                dtype=dtype_map(args.torch_dtype), 
                                enabled=args.use_amp): 
                outputs = model.short_forecast(
                    x_enc=timeseries, input_mask=input_mask, forecast_horizon=forecast_horizon)
                
            outputs.forecast = outputs.forecast.sum(dim=0, keepdim=True)
            loss = criterion(outputs.forecast, forecast)                
            losses.append(loss.item())

            if return_preds:
                trues.append(forecast.detach().cpu().numpy())
                preds.append(outputs.forecast.detach().cpu().numpy())
                histories.append(timeseries.detach().cpu().numpy())
    
    losses = np.array(losses)
    average_loss = np.average(losses)
    model.train()

    if return_preds:
        trues = np.concatenate(trues, axis=0)
        preds = np.concatenate(preds, axis=0)
        histories = np.concatenate(histories, axis=0)
        return average_loss, losses, (trues, preds, histories)
    else:
        return average_loss

In [None]:
DEFAULT_CONFIG_PATH = "../../configs/default.yaml"
GPU_ID = 0

DATASET = "m4" # "m3" | "m4" | "fred"
FREQUENCY = "yearly" # "monthly" | "quarterly" | "yearly" | "daily" | "hourly" | "weekly" | "other"
BASE_PATH = m_datasets_base_path if DATASET in ['m3', 'm4'] else fred_datasets_base_path
SEASON = 25
PERIOD = 2

# M3 yearly:    Season 25 Period 2
# M3 quarterly: Season 3 Period 4
# M3 monthly:   Season 11 Period 12
# M3 other:     Season 7 Period 7

config = Config(config_file_path="../../configs/forecasting/zero_shot.yaml", 
                default_config_file_path=DEFAULT_CONFIG_PATH).parse()
config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'

args = parse_config(config)

file_format = 'tsf' if DATASET in ['m3', 'm4'] else 'npy'
args.full_file_path_and_name = os.path.join(BASE_PATH, f"{DATASET}_{FREQUENCY}_dataset.{file_format}")    
args.dataset_names = args.full_file_path_and_name
args.forecast_horizon = HORIZON_MAPPING[FREQUENCY] 
args.val_batch_size = 1

train_dataloader, test_dataloader, val_dataloader = get_dataloaders(args)
print(f"Source forecast horizon: {train_dataloader.dataset.forecast_horizon}")
print(f"Lengths: Train: {train_dataloader.dataset.length_dataset} | Test: {test_dataloader.dataset.length_dataset} | Val: {val_dataloader.dataset.length_dataset}")

In [None]:
model = load_pretrained_moment(args)
model.to(args.device)
model.eval()

In [None]:
# _, _, (trues_val, preds_val, _) = validation(args, model, val_dataloader, return_preds=True, season=SEASON, period=PERIOD)
# _, _, (trues_test, preds_test, _) = validation(args, model, test_dataloader, return_preds=True, season=SEASON, period=PERIOD)
# trues = np.concatenate([trues_val, trues_test], axis=0)
# preds = np.concatenate([preds_val, preds_test], axis=0)

# get_forecasting_metrics(y=trues, y_hat=preds)