In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from moment.utils.config import Config
from moment.utils.utils import parse_config
from moment.utils.masking import Masking
from moment.utils.forecasting_metrics import get_forecasting_metrics
from moment.data.dataloader import get_timeseries_dataloader
from moment.data.forecasting_datasets import get_forecasting_datasets
from moment.models.base import BaseModel
from moment.models.moment import MOMENT

In [None]:
def get_dataloaders(args):
    args.dataset_names = args.full_file_path_and_name
    args.data_split = 'train'
    train_dataloader = get_timeseries_dataloader(args=args)
    args.data_split = 'test'
    test_dataloader = get_timeseries_dataloader(args=args)
    args.data_split = 'val'
    val_dataloader = get_timeseries_dataloader(args=args)
    return train_dataloader, test_dataloader, val_dataloader

def load_pretrained_moment(args,
                         pretraining_task_name: str = "pre-training"):
    args.task_name = pretraining_task_name
        
    checkpoint = BaseModel.load_pretrained_weights(
        run_name=args.pretraining_run_name, 
        opt_steps=args.pretraining_opt_steps)
    
    pretrained_model = MOMENT(configs=args)
    pretrained_model.load_state_dict(checkpoint["model_state_dict"])
    
    return pretrained_model

def statistical_interpolation(y):
    y = pd.DataFrame(y)
    
    linear_y = y.interpolate(method='linear', axis=1).values
    nearest_y = y.interpolate(method='nearest', axis=1).values
    cubic_y = y.interpolate(method='cubic', axis=1).values

    return linear_y, nearest_y, cubic_y