In [None]:
%load_ext autoreload
%autoreload 2

from typing import Optional

import torch
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from tqdm import tqdm

from moment.utils.config import Config
from moment.utils.utils import parse_config
from moment.data.forecasting_datasets import get_forecasting_datasets, LongForecastingDataset
from moment.data.dataloader import get_timeseries_dataloader
from moment.models.base import BaseModel
from moment.models.moment import MOMENT

In [None]:
forecasting_datasets = get_forecasting_datasets(collection="autoformer")
forecasting_datasets

In [None]:
# Load the model
DEFAULT_CONFIG_PATH = "../../configs/default.yaml"
CONGIF_PATH = "../../configs/forecasting/linear_probing.yaml"
GPU_ID = 3
# run_name = "fast-pyramid-63" # "avid-moon-55" "proud-dust-41" "curious-blaze-53" "laced-firebrand-51" "prime-music-50" "fast-pyramid-63" "fearless-planet-52"
run_name = "fancy-music-127" # "peach-sponge-95"

In [None]:
config = Config(config_file_path=CONGIF_PATH, 
                default_config_file_path=DEFAULT_CONFIG_PATH,
                verbose=False).parse()
config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'

args = parse_config(config)

In [None]:
from moment.tasks.forecast_finetune import ForecastFinetuning
task_obj = ForecastFinetuning(args=args)

In [None]:
task_obj.model.encoder

In [None]:
task_obj.load_pretrained_moment()

In [None]:
task_obj.train_dataloader.dataset

In [None]:
task_obj.val_dataloader.dataset

In [None]:
task_obj.test_dataloader.dataset

In [None]:
print(task_obj.test_dataloader.dataset[0].timeseries.shape)
print(task_obj.train_dataloader.dataset[0].timeseries.shape)
print(task_obj.test_dataloader.dataset[0].forecast.shape)
print(task_obj.train_dataloader.dataset[0].forecast.shape)

In [None]:
plt.plot(task_obj.test_dataloader.dataset.data[:512, -1])

In [None]:
a = task_obj.test_dataloader.dataset.data[:512, [-2, -1]].reshape((2, 512))
a.shape

In [None]:
task_obj.test_dataloader.dataset.data.shape
task_obj.train_dataloader.dataset.data.shape

In [None]:
task_obj.val_dataloader.dataset.data.mean(axis=0)

In [None]:
for batch_x in task_obj.train_dataloader:
    print(batch_x.timeseries.shape)
    print(batch_x.forecast.shape)
    print(batch_x.input_mask.shape)
    break

In [None]:
task_obj.model.to(task_obj.device)
task_obj.model.eval()

In [None]:
args.pretraining_run_name = run_name
task_obj.load_pretrained_moment(pretraining_task_name="long-horizon-forecasting", do_not_copy_head=False)
task_obj.model.to(task_obj.device)

In [None]:
average_loss, losses, (trues, preds, histories) = task_obj.validation(data_loader=task_obj.train_dataloader, return_preds=True)

In [None]:
average_loss

In [None]:
print(f"History values: {histories.shape} | True values: {trues.shape} | False values: {preds.shape}")

In [None]:
plt.plot(task_obj.train_dataloader.dataset.data[:, -1][:608])

In [None]:
idx = 0 # np.random.randint(0, len(histories))
channel = -1
true = np.concatenate([histories[idx, channel].squeeze(), trues[idx, channel, :].squeeze()])
pred = np.concatenate([histories[idx, channel].squeeze(), preds[idx, channel, :].squeeze()])
plt.title(f"Forecasting: idx={idx}, horizon={trues[idx].shape[-1]}")
plt.plot(true, label="True", c='red')
plt.plot(pred, label="Pred", c='darkblue')
plt.legend()
plt.show()

In [None]:
checkpoint = BaseModel.load_pretrained_weights(run_name=run_name, opt_steps=None)

In [None]:
# for ((name_p, param_p), (name_f, param_f)) in zip(pretrained_model.named_parameters(), finetuned_model.named_parameters()):
#     if name_p == name_f and param_p.shape == param_f.shape:
#         # print(name_p, param_p.shape)
#         pass
#     else:
#         print("MISMATCH", name_p, name_f, param_p.shape, param_f.shape)

# # Make sure that all necessary parameters have been copied
# for ((name_p, param_p), (name_f, param_f)) in zip(pretrained_model.named_parameters(), finetuned_model.named_parameters()):
#     if (name_p == name_f) and (param_p.shape == param_f.shape) and name_p.startswith("head"):
#         if not torch.allclose(param_f.data, param_p.data):
#             print("MISMATCH", name_p, name_f, param_p.shape, param_f.shape)

In [None]:
dataloader = get_timeseries_dataloader(args)

In [None]:
dataloader.dataset

In [None]:
dataloader.dataset.plot(8032)