In [None]:
# pip install plotly

import torch
import numpy as np
from data_provider.data_factory import data_provider_subset as data_provider
from data_provider.ictsp_dataloader import ForecastingDatasetWrapper
from types import SimpleNamespace
from models.ICPretrain import ICPretrain
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

import json
with open('configs/pretrain_configs_sequential.json', 'r') as file:
    config_data = json.load(file)
icpretrain_configs = SimpleNamespace(**config_data)
icpretrain_configs.stage = "inference"
weight_path = './pt_model_2048_96_current.pth'
model_name = 'ICTSP_FT'

def nested_collate_fn(batch):
    elem = batch[0]
    elem_type = type(elem)
    
    if isinstance(elem, np.ndarray):
        tensor_batch = list(map(torch.Tensor, batch))
        return torch.nested.nested_tensor(tensor_batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float)
    elif isinstance(elem, int):
        return torch.LongTensor(batch)
    elif isinstance(elem, str):
        return batch
    elif isinstance(elem, tuple):
        transposed = zip(*batch)
        return [nested_collate_fn(samples) for samples in transposed]
    elif isinstance(elem, list):
        transposed = zip(*batch)
        return [nested_collate_fn(samples) for samples in transposed]
    else:
        print(batch)
        raise TypeError(f"batch must contain tensors, numpy arrays or numbers; found {elem_type}")

def get_dataset(seq_len=2048, pred_len=720, data_type='ETTh2', root_path='./dataset/', data_path='ETTh2.csv', train_ratio=0.7, test_ratio=0.2, 
                flag='test', do_forecasting=False, batch_size=8, force_lookback=52):
    data_args = SimpleNamespace(embed='timeF', 
                                batch_size=batch_size,
                                batch_size_test=batch_size,
                                freq='h',
                                data=data_type,
                                root_path=root_path,
                                data_path=data_path,
                                seq_len=seq_len,
                                label_len=0,
                                pred_len=pred_len,
                                features='M',
                                target='OT',
                                scale=1,
                                train_ratio=train_ratio,
                                test_ratio=test_ratio,
                                num_workers=0,
                                do_forecasting=do_forecasting
                                )
    dataset, dataloader = data_provider(data_args, flag=flag)
    ds = ForecastingDatasetWrapper(dataset, icpretrain_configs)
    ds.force_legacy_lookback_for_inference = force_lookback
    dl = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        pin_memory=False,
        collate_fn=nested_collate_fn)
    return ds, dl

def preprocess_data(data, device):
    # flipped both in the tokenizer and the preprocessing here, just to ensure the "float to the right" alignment format properly applied on the channel dimension
    task_id = data[8].int().to(device, non_blocking=True)
    token_x_part = torch.nested.to_padded_tensor(data[0].float().to(device, non_blocking=True), 0).flip(1)
    y_true = torch.nested.to_padded_tensor(data[1].float().to(device, non_blocking=True), 0).flip(1) if task_id[0] != 1 else torch.nested.to_padded_tensor(data[1].int().to(device, non_blocking=True), 0).flip(1)      # C L or C
    token_y_part = torch.nested.to_padded_tensor(data[2].float().to(device, non_blocking=True), 0).flip(1) if task_id[0] != 1 else torch.nested.to_padded_tensor(data[2].int().to(device, non_blocking=True), 0).flip(1)
    channel_label = torch.nested.to_padded_tensor(data[3].int().to(device, non_blocking=True), 0).flip(1)
    position_label = torch.nested.to_padded_tensor(data[4].int().to(device, non_blocking=True), 0).flip(1)
    source_label = torch.nested.to_padded_tensor(data[5].int().to(device, non_blocking=True), 0).flip(1)
    tag_multihot = torch.nested.to_padded_tensor(data[6].float().to(device, non_blocking=True), 0).flip(1)
    y_true_shape = torch.nested.to_padded_tensor(data[7].int().to(device, non_blocking=True), 0)
    return task_id, token_x_part, y_true, token_y_part, channel_label, position_label, source_label, tag_multihot, y_true_shape

from collections import OrderedDict

model = ICPretrain(icpretrain_configs)#.float()

#weight_path = './pt_model_2048_96_current.pth'

state_dict = torch.load(weight_path, map_location='cpu')

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith('module.') else k  # remove `module.` prefix
    new_state_dict[name] = v

model.load_state_dict(new_state_dict) #, strict=False

for attribute in dir(model):
    if isinstance(getattr(model, attribute), torch._dynamo.eval_frame.OptimizedModule):
        setattr(model, attribute, getattr(model, attribute)._orig_mod)

device = torch.device('cuda')
model = model.to(device)
model.process_output = True
model.eval()

In [None]:
# path to your inference data
data_path = 'main.csv'

# seq_len: L_I, force_lookback: L_b, pred_len: L_P
vali_data, vali_loader = get_dataset(seq_len=416, pred_len=4, data_type='custom', root_path='./dataset/', data_path=data_path, train_ratio=0.87, test_ratio=0.1, 
                                  flag='val', do_forecasting=True, batch_size=8, force_lookback=104)
test_data, test_loader = get_dataset(seq_len=416, pred_len=4, data_type='custom', root_path='./dataset/', data_path=data_path, train_ratio=0.87, test_ratio=0.1, 
                                  flag='test', do_forecasting=True, batch_size=8, force_lookback=104)
number_of_targets = 0  # All

def calculation(dataloader):
    preds = []
    trues = []
    number_of_targets = 53
    index = 0
    # x: (L_I, C), y: (L_P, C)
    for index, data in tqdm(enumerate(dataloader)):
        task_id, token_x_part, y_true, token_y_part, channel_label, position_label, source_label, tag_multihot, y_true_shape = preprocess_data(data, device=device)
        with torch.no_grad():
            res = model(token_x_part, token_y_part, channel_label, position_label, source_label, tag_multihot, y_true_shape, task_id)
            res = res.detach().cpu().numpy().transpose((0, 2, 1))[:, :, -number_of_targets:]
            y = y_true.detach().cpu().numpy().transpose((0, 2, 1))[:, :, -number_of_targets:]
        preds.append(res)
        trues.append(y)
        index += 1

    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)
    return preds, trues

preds_vali, trues_vali = calculation(vali_loader)
preds, trues = calculation(test_loader)
preds.shape