In [1]:
%load_ext autoreload
%autoreload 2

import os

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from moment.common import PATHS
from moment.utils.config import Config
from moment.utils.utils import parse_config
from moment.utils.masking import Masking
from moment.utils.forecasting_metrics import get_forecasting_metrics
from moment.data.dataloader import get_timeseries_dataloader
from moment.data.forecasting_datasets import get_forecasting_datasets
from moment.models.base import BaseModel
from moment.models.moment import MOMENT
from moment.models.timesnet import TimesNet
from moment.models.gpt4ts import GPT4TS

In [None]:
NOTES = [
    "Supervised finetuning on imputation datasets", # MOMENT
    "Pre-training TimesNet for imputation", # TimesNet
    "Pre-training GPT4TS for imputation", # GPT4TS
    ]

runs_summary = pd.read_csv("../../assets/data/wandb_runs_summary.csv")
runs = runs_summary[runs_summary["notes"].isin(NOTES)]
runs = runs[['run_name', 'model_name', 'dataset_names', 'hostname']]

runs.dataset_names = runs.dataset_names.apply(lambda x: x.split("/")[-1][:-4])
runs.hostname = runs.hostname.apply(lambda x: x.split(".")[0])

In [None]:
print(runs)

In [None]:
config_file_map = {
    "MOMENTNTNTNT_LP": "../../configs/imputation/linear_probing.yaml",
    "MOMENTNTNT_0": "../../configs/imputation/zero_shot.yaml",
    "TimesNet": "../../configs/imputation/timesnet_train.yaml",
    "GPT4TS": "../../configs/imputation/gpt4ts_train.yaml",
}

def get_model(model_name, args):
    if model_name == "MOMENTNT_LP" or model_name ==MOMENTMENT_0":
        model = MOMENTNT(args)
    elif model_name == "TimesNet":
        model = TimesNet(args)
    elif model_name == "GPT4TS":
        model = GPT4TS(args)
    else:
        raise ValueError(f"Model {model_name} not found.")
    return model

def get_checkpoint_name(model_name):
    if model_name == "MOMENTNT_LP" or model_name ==MOMENTMENT_0":
        return 'MOMENTNT.pth'
    elif model_name == "TimesNet":
        return 'TimesNet.pth'
    elif model_name == "GPT4TS":
        return 'GPT4TS.pth'
    else:
        raise ValueError(f"Model {model_name} not found.")

dataset_file_and_path_names = get_forecasting_datasets(collection='autoformer')
DEFAULT_CONFIG_PATH = "../../configs/default.yaml"
GPU_ID = 0
MASK_RATIOS = [0.125, 0.25, 0.375, 0.5]

In [None]:
DATASET_NAME = 'ETTh1'
MODEL_NAME = 'MOMENTNT_LP'

dataset_name = [i for i in dataset_file_and_path_names if DATASET_NAME in i][0]

config = Config(
    config_file_path=config_file_map[MODEL_NAME], 
    default_config_file_path=DEFAULT_CONFIG_PATH).parse()
config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'
args = parse_config(config)

args.full_file_path_and_name = dataset_name
args.dataset_names = args.full_file_path_and_name
args.data_split = 'test'
args.output_type = 'multivariate'
args.seq_len = 512
args.shuffle = False
args.data_stride_len = 512
args.batch_size = 8 # args.val_batch_size
# args.n_channels = 7 # 7 21 321

test_dataloader = get_timeseries_dataloader(args=args)
print(f"Length of dataloader={len(test_dataloader)}, timeseries={len(test_dataloader.dataset.data)}")

In [None]:
# ### Load pre-trained MOMENTNT
# run_name = "fearless-planet-52" 
# checkpoint = BaseModel.load_pretrained_weights(run_name=run_name, opt_steps=None)
# config = Config(config_file_path=DEFAULT_CONFIG_PATH, default_config_file_path=DEFAULT_CONFIG_PATH).parse()
# config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'

# args = parse_config(config)
# model = MOMENTNT(configs=args)
# model.load_state_dict(checkpoint["model_state_dict"])
# model.eval()
# model.to(args.device)

In [None]:
selected_runs = runs.loc[(runs['model_name'] == MODEL_NAME.split('_')[0]) & (runs['dataset_names'] == DATASET_NAME)]
if selected_runs.shape[0] > 1: 
    print(f"More than one run found for the given source collection and frequency\n{selected_runs}")
    selected_runs.loc[:, 'run_num'] = selected_runs.run_name.apply(lambda x: int(x.split("-")[-1]))
    selected_runs = selected_runs.sort_values(by=['run_num'], ascending=False)
checkpoint_name = str(selected_runs.run_name.values[0])
print(f"Checkpoint name: {checkpoint_name}")

model = get_model(MODEL_NAME, args)

with open(os.path.join(PATHS.CHECKPOINTS_DIR, f'{checkpoint_name}/{get_checkpoint_name(MODEL_NAME)}'), 'rb') as f:
    checkpoint = torch.load(f, map_location='cpu')

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model.to(args.device)

In [None]:
results = []
for mask_ratio in tqdm(MASK_RATIOS, total=len(MASK_RATIOS)):
    
    mask_generator = Masking(mask_ratio=mask_ratio) 
    masks, preds, trues = [], [], []
    
    for batch in tqdm(test_dataloader, total=len(test_dataloader)):
        timeseries = batch.timeseries.float().to(args.device)
        n_samples, n_channels, _ = timeseries.shape
        input_mask = batch.input_mask.long().to(args.device)
        mask = mask_generator.generate_mask(
            x=timeseries, input_mask=input_mask).to(args.device)
        
        with torch.no_grad():
            outputs = model.pretraining(
                x_enc=timeseries, input_mask=input_mask)

        masks.append(mask.repeat_interleave(n_channels, dim=0).detach().cpu().numpy())
        trues.append(timeseries.detach().cpu().numpy())
        preds.append(outputs.reconstruction.detach().cpu().numpy())

    masks = np.concatenate(masks, axis=0).squeeze().flatten()
    trues = np.concatenate(trues, axis=0).squeeze().flatten()
    preds = np.concatenate(preds, axis=0).squeeze().flatten()

    metrics = get_forecasting_metrics(
                y=trues[masks==0], 
                y_hat=preds[masks==0], 
                reduction='mean')
    
    results.append([dataset_name.split('/')[-1][:-4], mask_ratio, metrics.mse, metrics.mae])
results = pd.DataFrame(results, columns=['dataset', 'mask_ratio', 'MSE', 'MAE'])
print(results)
print(results.iloc[:, 2:])

In [None]:
import time

import wandb
from torch import nn
from torch import optim
from tqdm import trange, tqdm

from moment.common import PATHS


def train(args, model, train_dataloader):
        n_train_epochs = args.max_epoch
        
        # Training loop
        tr_loss = 0
        
        optimizer = optim.AdamW(model.parameters(), 
                                lr=args.init_lr,
                                weight_decay=args.weight_decay)

        criterion = nn.MSELoss() 

        logger = wandb.init(
            project="Time-series Foundation Model",
            dir=PATHS.WANDB_DIR)
        
        for epoch in trange(n_train_epochs):
            for batch in tqdm(train_dataloader, total=len(train_dataloader)):
                timeseries = batch.timeseries.float().to(args.device)
                input_mask = batch.input_mask.long().to(args.device)

                model.train()
                # Training step
                outputs = model.pretraining(x_enc=timeseries, input_mask=input_mask)
                
                recon_loss = criterion(outputs.reconstruction, timeseries)
                observed_mask = input_mask * (1 - outputs.pretrain_mask)
                n_channels = outputs.reconstruction.shape[1]
                observed_mask = observed_mask.unsqueeze(1).repeat((1, n_channels, 1))
                masked_loss = observed_mask * recon_loss
                loss = masked_loss.nansum() / (observed_mask.nansum() + 1e-7)

                if not np.isnan(float(loss)):
                    loss.backward()
                
                logger.log({
                     'step_loss_train': loss.item(),
                     'lr': optimizer.param_groups[0]['lr']})
                
                nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
                    
                optimizer.step()
                optimizer.zero_grad()
                
                tr_loss += loss.detach().cpu().numpy()

        logger.finish()

        return model

In [None]:
from moment.models.timesnet import TimesNet
from moment.models.gpt4ts import GPT4TS

DEFAULT_CONFIG_PATH = "../../configs/default.yaml"
GPU_ID = 2

config = Config(config_file_path="../../configs/anomaly_detection/gpt4ts_train.yaml", 
                default_config_file_path=DEFAULT_CONFIG_PATH).parse()
config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'
args = parse_config(config)
args.task_name = 'pre-training'

In [None]:
model = GPT4TS(args)
model.to(args.device)

In [None]:
args.batch_size = args.train_batch_size
train_dataloader = get_timeseries_dataloader(args=args)

In [None]:
for batch_x in train_dataloader:
    break

In [None]:
model = train(args, model, train_dataloader)

In [None]:
args.batch_size = args.train_batch_size
args.data_split = 'test'
test_dataloader = get_timeseries_dataloader(args=args)

In [None]:
model.eval()
trues = []
preds = []
masks = []
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
    timeseries = batch.timeseries.float().to(args.device)
    input_mask = batch.input_mask.long().to(args.device)

    # Training step
    outputs = model.pretraining(x_enc=timeseries, input_mask=input_mask)

    trues.append(timeseries.detach().cpu().numpy())
    preds.append(outputs.reconstruction.detach().cpu().numpy())
    masks.append(outputs.pretrain_mask.detach().cpu().numpy())

trues = np.concatenate(trues, axis=0).squeeze()
preds = np.concatenate(preds, axis=0).squeeze()
masks = np.concatenate(masks, axis=0)

In [None]:
idx = np.random.choice(range(len(trues)))
fig, axs = plt.subplots(2, 1, figsize=(10, 5))

# Plot true and predicted values
idx = np.random.choice(range(len(trues)))
axs[0].set_title(f"Sample {idx}")
axs[0].plot(trues[idx], c='k', 
        linewidth=1, label="True")
axs[0].plot(preds[idx], linewidth=0.75, c='darkred', label="Preds")

axs[0].legend()

# Plot masked locations
axs[1].imshow(np.tile(masks[idx], reps=(8, 1)), cmap='binary')

# Turn off x and y ticks
axs[1].set_xticks([])
axs[1].set_yticks([])

plt.show()

In [None]:
config_path = '../../configs/imputation/zero_shot.yaml'
default_config_path = '../../configs/default.yaml'

mask_ratios = [0.125, 0.25, 0.375, 0.5]

def statistical_interpolation(y):
    y = pd.DataFrame(y)
    
    linear_y = y.interpolate(method='linear', axis=1).values
    nearest_y = y.interpolate(method='nearest', axis=1).values
    cubic_y = y.interpolate(method='cubic', axis=1).values

    return linear_y, nearest_y, cubic_y

def forward_backward_fill(y):
        return pd.DataFrame(y).ffill(axis=1).bfill(axis=1).values

In [None]:
dataset_name = '/XXXX-14/project/public/XXXX-9/TimeseriesDatasets/forecasting/autoformer/ETTm1.csv'

In [None]:
config = Config(config_file_path=config_path, 
                default_config_file_path=default_config_path).parse()
# config['device'] = gpu_id if torch.cuda.is_available() else 'cpu'
args = parse_config(config)

args.output_type = 'multivariate'
args.seq_len = 512
args.data_stride_len = 512
args.batch_size = args.val_batch_size

results = []

args.full_file_path_and_name = dataset_name
args.dataset_names = args.full_file_path_and_name
args.data_split = 'test'
test_dataloader = get_timeseries_dataloader(args=args)

trues = []
masks = {}        
mask_generators = {}
for mask_ratio in mask_ratios:
    mask_generators[mask_ratio] = Masking(mask_ratio=mask_ratio) 

for batch_x in tqdm(test_dataloader, total=len(test_dataloader)):
    timeseries = batch_x.timeseries.float()
    n_examples, n_channels, _ = timeseries.shape
    timeseries = timeseries.reshape((-1, 1, args.seq_len))
    
    input_mask = batch_x.input_mask.long()
    input_mask = input_mask.repeat_interleave(n_channels, axis=0)
    trues.append(timeseries.squeeze().numpy())

    for mask_ratio, mask_generator in mask_generators.items():
        if mask_ratio not in masks:
            masks[mask_ratio] = []
        m = mask_generator.generate_mask(
                x=timeseries, input_mask=input_mask)
        masks[mask_ratio].append(m)
    
trues = np.concatenate(trues, axis=0)

for mask_ratio in mask_ratios:
    masks[mask_ratio] = np.concatenate(masks[mask_ratio], axis=0)

for mask_ratio in tqdm(mask_ratios, total=len(mask_ratios)):
    preds = trues.copy()
    mask = masks[mask_ratio]
    preds[mask == 0] = torch.nan
    
    preds_fbfill = forward_backward_fill(preds.copy())
    preds_linear, preds_nearest, preds_cubic = statistical_interpolation(preds.copy())

    metrics_fbfill = get_forecasting_metrics(
        y=trues[mask==0], y_hat=preds_fbfill[mask==0], reduction='mean')
    metrics_linear = get_forecasting_metrics(
        y=trues[mask==0], y_hat=preds_linear[mask==0], reduction='mean')
    metrics_nearest = get_forecasting_metrics(
        y=trues[mask==0], y_hat=preds_nearest[mask==0], reduction='mean')
    metrics_cubic = get_forecasting_metrics(
        y=trues[mask==0], y_hat=preds_cubic[mask==0], reduction='mean')
    
    results.append([dataset_name.split('/')[-1][:-4], mask_ratio, 
                    metrics_fbfill.mse, metrics_fbfill.mae, 
                    metrics_linear.mse, metrics_linear.mae,
                    metrics_nearest.mse, metrics_nearest.mae,
                    metrics_cubic.mse, metrics_cubic.mae])


In [None]:
idx = np.random.choice(range(len(trues)))
fig, axs = plt.subplots(2, 1, figsize=(10, 5))

# Plot true and predicted values
idx = np.random.choice(range(len(trues)))
axs[0].set_title(f"Sample {idx}")
axs[0].plot(trues[idx], c='k', 
        linewidth=1, label="True")
axs[0].plot(preds[idx], linewidth=0.75, c='darkred', label="Preds")
axs[0].plot(preds_linear[idx], linewidth=0.75, c='darkblue', label="Preds (L)")
axs[0].plot(preds_cubic[idx], linewidth=0.75, c='darkgreen', label="Preds (C)")
axs[0].plot(preds_nearest[idx], linewidth=0.75, c='pink', label="Preds (N)")
axs[0].legend()

# Plot masked locations
axs[1].imshow(np.tile(mask[0.125][np.newaxis, idx], reps=(8, 1)), cmap='binary')

# Turn off x and y ticks
axs[1].set_xticks([])
axs[1].set_yticks([])

plt.show()

In [None]:
def get_dataloaders(args):
    args.dataset_names = args.full_file_path_and_name
    args.data_split = 'train'
    train_dataloader = get_timeseries_dataloader(args=args)
    args.data_split = 'test'
    test_dataloader = get_timeseries_dataloader(args=args)
    args.data_split = 'val'
    val_dataloader = get_timeseries_dataloader(args=args)
    return train_dataloader, test_dataloader, val_dataloader

def load_pretrained_moment(args,
                         pretraining_task_name: str = "pre-training"):
    args.task_name = pretraining_task_name
        
    checkpoint = BaseModel.load_pretrained_weights(
        run_name=args.pretraining_run_name, 
        opt_steps=args.pretraining_opt_steps)
    
    pretrained_model = MOMENTNT(configs=args)
    pretrained_model.load_state_dict(checkpoint["model_state_dict"])
    
    return pretrained_model

def repeated_fill(timeseries, mask, limit_area=None, limit=None):
    # NOTE: This will only work for univariate time-series
    # Set indices in timeseries where mask = 1 to NaN
    timeseries = timeseries.squeeze()
    timeseries[mask == 0] = torch.nan

    timeseries = pd.DataFrame(
        timeseries.detach().cpu().numpy()).ffill(
            axis=1, limit=limit).bfill(axis=1, limit=limit).values
    timeseries = torch.tensor(timeseries)
    return timeseries.unsqueeze(1)

def statistical_interpolation(y):
    y = pd.DataFrame(y)
    
    linear_y = y.interpolate(method='linear', axis=1).values
    nearest_y = y.interpolate(method='nearest', axis=1).values
    cubic_y = y.interpolate(method='cubic', axis=1).values

    return linear_y, nearest_y, cubic_y

In [None]:
DEFAULT_CONFIG_PATH = "../../configs/default.yaml"
GPU_ID = 1

config = Config(config_file_path="../../configs/imputation/zero_shot.yaml", 
                default_config_file_path=DEFAULT_CONFIG_PATH).parse()
config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'
args = parse_config(config)

args.full_file_path_and_name = '/XXXX-14/project/public/XXXX-9/TimeseriesDatasets/forecasting/autoformer/ETTh1.csv'
args.output_type = 'multivariate'
args.batch_size = 256
args.seq_len = 512

model = load_pretrained_momentntntntntntnt(args)
model.to(args.device)
model.eval()

In [None]:
# Setting 1: Patches missing at random
args.mask_ratio = 0.5
args.data_stride_len = args.seq_len

train_dataloader, test_dataloader, val_dataloader = get_dataloaders(args)
mask_generator = Masking(mask_ratio=args.mask_ratio)

trues = []
preds = []
input_masks = []
masks = []

with torch.no_grad():
    for batch_x in tqdm(test_dataloader, total=len(test_dataloader)):
        timeseries = batch_x.timeseries.float()
        n_examples, n_channels, _ = timeseries.shape
        timeseries = timeseries.reshape(
            (n_examples*n_channels, 1, args.seq_len))
        
        trues.append(timeseries.detach().cpu().numpy().copy())
        
        input_mask = batch_x.input_mask
        input_mask = input_mask.repeat_interleave(n_channels, dim=0)
        input_masks.append(input_mask)
        
        mask = mask_generator.generate_mask(
            x=timeseries, input_mask=input_mask)
        masks.append(mask)
        
        # Move to device
        timeseries = timeseries.to(args.device)
        input_mask = input_mask.to(args.device)
        mask = mask.to(args.device) 
        
        outputs = model.reconstruct(
            x_enc=timeseries, 
            input_mask=input_mask, 
            mask=mask)

        preds.append(outputs.reconstruction.detach().cpu().numpy())
        
    trues = np.concatenate(trues, axis=0).squeeze()
    preds = np.concatenate(preds, axis=0).squeeze()
    input_masks = np.concatenate(input_masks, axis=0).squeeze()
    masks = np.concatenate(masks, axis=0).squeeze()

statistical_preds = trues.copy()
statistical_preds[masks == 0] = torch.nan

preds_linear, preds_nearest, preds_cubic = statistical_interpolation(statistical_preds.copy())

# Careful -- Measure reconstruction of the masked patches
moment_metrics = get_forecasting_metrics(
    y=trues[masks == 0], y_hat=preds[masks == 0], reduction='mean')
print('   MOMENTNT:', momentntnt_metrics)
linear_metrics = get_forecasting_metrics(
    y=trues[masks == 0], y_hat=preds_linear[masks == 0], reduction='mean')
print(' Linear:', linear_metrics)
nearest_metrics = get_forecasting_metrics(
    y=trues[masks == 0], y_hat=preds_nearest[masks == 0], reduction='mean')
print('Nearest:', nearest_metrics)
cubic_metrics = get_forecasting_metrics(
    y=trues[masks == 0], y_hat=preds_cubic[masks == 0], reduction='mean')
print('  Cubic:', cubic_metrics)

In [None]:
idx = np.random.choice(range(len(trues)))
fig, axs = plt.subplots(2, 1, figsize=(10, 5))

# Plot true and predicted values
idx = np.random.choice(range(len(trues)))
axs[0].set_title(f"Sample {idx}")
axs[0].plot(trues[idx], c='k', 
        linewidth=1, label="True")
axs[0].plot(preds[idx], linewidth=0.75, c='darkred', label="Preds")
axs[0].plot(preds_linear[idx], linewidth=0.75, c='darkblue', label="Preds (L)")
axs[0].plot(preds_cubic[idx], linewidth=0.75, c='darkgreen', label="Preds (C)")
axs[0].plot(preds_nearest[idx], linewidth=0.75, c='pink', label="Preds (N)")
axs[0].legend()

# Plot masked locations
axs[1].imshow(np.tile(masks[np.newaxis, idx], reps=(8, 1)), cmap='binary')

# Turn off x and y ticks
axs[1].set_xticks([])
axs[1].set_yticks([])

plt.show()

In [None]:
# trues = []
# preds_with_attention = []
# preds_without_attention = []
# preds_with_repeated_fill = []
# input_masks = []
# masks = []

# missing_mode = 'timesteps_missing_at_random' # 'patches_missing_at_random' 'timesteps_missing_at_random'
# with torch.no_grad():
#     for batch_x in tqdm(test_dataloader, total=len(test_dataloader)):
#         timeseries = batch_x.timeseries.float()
#         n_examples, n_channels, _ = timeseries.shape
#         timeseries = timeseries.reshape(
#             (n_examples*n_channels, 1, args.seq_len))
        
#         trues.append(timeseries.detach().cpu().numpy().copy())
        
#         input_mask = batch_x.input_mask
#         input_mask = input_mask.repeat_interleave(n_channels, dim=0)
#         input_masks.append(input_mask)
        
#         if missing_mode == 'timesteps_missing_at_random':
#             mask = torch.rand_like(input_mask)
            
#             mask[mask <= args.mask_ratio] = 0 # masked
#             mask[mask > args.mask_ratio] = 1  # oberseved
#             masks.append(mask.detach().cpu().numpy())

#             timeseries = repeated_fill(timeseries, mask, limit_area=None, limit=1)

#             preds_with_repeated_fill.append(
#                 timeseries.detach().cpu().numpy())

#             mask = torch.isnan(timeseries).long().squeeze().to(args.device)
#             torch.nan_to_num(timeseries, nan=0, out=timeseries) 
            
#         elif missing_mode == 'patches_missing_at_random':
#             mask = mask_generator.generate_mask(
#                 x=timeseries, input_mask=input_mask)
#             masks.append(mask)
        
#         # Move to device
#         timeseries = timeseries.to(args.device)
#         input_mask = input_mask.to(args.device)
#         mask = mask.to(args.device) 
        
#         outputs_with_attention = model.reconstruct(
#             x_enc=timeseries, 
#             input_mask=input_mask, 
#             mask=mask)

#         # Turn off attention to masked inputs
#         input_mask = input_mask*mask

#         outputs_without_attention = model.reconstruct(
#             x_enc=timeseries, 
#             input_mask=input_mask)
        
#         preds_with_attention.append(
#             outputs_with_attention.reconstruction.detach().cpu().numpy())
#         preds_without_attention.append(
#             outputs_without_attention.reconstruction.detach().cpu().numpy())
        
#     trues = np.concatenate(trues, axis=0).squeeze()
#     preds_with_attention = np.concatenate(preds_with_attention, axis=0).squeeze()
#     preds_without_attention = np.concatenate(preds_without_attention, axis=0).squeeze()
#     input_masks = np.concatenate(input_masks, axis=0).squeeze()
#     masks = np.concatenate(masks, axis=0).squeeze()
    
#     if missing_mode == 'timesteps_missing_at_random':
#         preds_with_repeated_fill = np.concatenate(preds_with_repeated_fill, axis=0).squeeze()

In [None]:
# trues = []
# preds = []
# preds_with_repeated_fill = []
# input_masks = []
# masks = []

# missing_mode = 'timesteps_missing_at_random' # 'patches_missing_at_random' 'timesteps_missing_at_random'

# with torch.no_grad():
#     for batch_x in tqdm(test_dataloader, total=len(test_dataloader)):
#         timeseries = batch_x.timeseries.float()
#         n_examples, n_channels, _ = timeseries.shape
#         timeseries = timeseries.reshape(
#             (n_examples*n_channels, 1, args.seq_len))
        
#         trues.append(timeseries.detach().cpu().numpy().copy())
        
#         input_mask = batch_x.input_mask
#         input_mask = input_mask.repeat_interleave(n_channels, dim=0)
#         input_masks.append(input_mask)
        
#         if missing_mode == 'timesteps_missing_at_random':
#             mask = torch.rand_like(input_mask)
            
#             mask[mask <= args.mask_ratio] = 0 # masked
#             mask[mask > args.mask_ratio] = 1  # oberseved
#             masks.append(mask.detach().cpu().numpy())

#             pred_repeated_fill = repeated_fill(timeseries, mask, limit_area=None, limit=None)
#             timeseries = repeated_fill(timeseries, mask, limit_area=None, limit=1)

#             preds_with_repeated_fill.append(pred_repeated_fill.detach().cpu().numpy())

#             mask = torch.isnan(timeseries).long().squeeze().to(args.device)
#             timeseries = torch.nan_to_num(timeseries, nan=0) 
            
#         elif missing_mode == 'patches_missing_at_random':
#             mask = mask_generator.generate_mask(
#                 x=timeseries, input_mask=input_mask)
#             masks.append(mask)
        
#         # Move to device
#         timeseries = timeseries.to(args.device)
#         input_mask = input_mask.to(args.device)
#         mask = mask.to(args.device) 
        
#         outputs = model.reconstruct(
#             x_enc=timeseries, 
#             input_mask=input_mask, 
#             mask=mask)

#         preds.append(outputs.reconstruction.detach().cpu().numpy())
        
#     trues = np.concatenate(trues, axis=0).squeeze()
#     preds = np.concatenate(preds, axis=0).squeeze()
#     input_masks = np.concatenate(input_masks, axis=0).squeeze()
#     masks = np.concatenate(masks, axis=0).squeeze()
    
#     if missing_mode == 'timesteps_missing_at_random':
#         preds_with_repeated_fill = np.concatenate(preds_with_repeated_fill, axis=0).squeeze()

In [None]:
idx = np.random.choice(range(len(trues)))
fig, axs = plt.subplots(2, 1, figsize=(10, 5))

# Plot true and predicted values
idx = np.random.choice(range(len(trues)))
axs[0].set_title(f"Sample {idx}")
axs[0].plot(trues[idx], c='k', 
        linewidth=1, label="True")
axs[0].plot(preds[idx], linewidth=0.75, c='darkblue', label="Preds")
axs[0].plot(preds_with_repeated_fill[idx],
        linewidth=0.75, c='darkred', label="Pred (repeated fill)")
axs[0].legend()

# Plot masked locations
axs[1].imshow(np.tile(masks[np.newaxis, idx], reps=(8, 1)), cmap='binary')

# Turn off x and y ticks
axs[1].set_xticks([])
axs[1].set_yticks([])

plt.show()

In [None]:
trues[idx]

In [None]:
idx = np.random.choice(range(len(trues)))
fig, axs = plt.subplots(2, 1, figsize=(10, 5))

# Plot true and predicted values
idx = np.random.choice(range(len(trues)))
axs[0].set_title(f"Sample {idx}")
axs[0].plot(trues[idx], c='k', 
        linewidth=1, label="True")
axs[0].plot(preds_without_attention[idx], 
        linewidth=0.75, c='darkblue', label="Pred (w/o attention)")
axs[0].plot(preds_with_attention[idx], 
        linewidth=0.75, c='darkgreen', label="Pred (w/ attention")
# axs[0].plot(preds_with_repeated_fill[idx],
#         linewidth=0.75, c='darkred', label="Pred (repeated fill)")
axs[0].legend()

# Plot masked locations
axs[1].imshow(np.tile(masks[np.newaxis, idx], reps=(8, 1)), cmap='binary')

# Turn off x and y ticks
axs[1].set_xticks([])
axs[1].set_yticks([])

plt.show()