In [1]:
# default_exp data.ts_loader

In [2]:
#hide
%load_ext autoreload
%autoreload 2

In [3]:
#export
import numpy as np
import pandas as pd
import random
import torch as t
import copy
from fastcore.foundation import patch
from nixtla.data.ts_dataset import TimeSeriesDataset
from collections import defaultdict

In [7]:
#export
class TimeSeriesLoader(object):
    def __init__(self,
                 ts_dataset:TimeSeriesDataset,
                 model:str,
                 offset:int,
                 window_sampling_limit: int, 
                 input_size: int,
                 output_size: int,
                 idx_to_sample_freq:int, #TODO: not active yet
                 batch_size: int,
                 hide_ts_idx=None):
        """
        """
        self.model = model
        self.window_sampling_limit = window_sampling_limit
        self.input_size = input_size
        self.output_size = output_size
        self.batch_size = batch_size
        self.idx_to_sample_freq = idx_to_sample_freq
        self.offset = offset
        self.ts_dataset = copy.deepcopy(ts_dataset) #TODO: sacar deep_copy
        self.hide_ts_idx = hide_ts_idx

        # mascara propia
        if self.hide_ts_idx is not None:
            self.ts_dataset.ts_tensor[:, -1, self.hide_ts_idx] = 0
       
        # Windows
        print('Creating windows matrix ...')
        self.ts_windows = self._create_windows_tensor()
        self.n_windows = len(self.ts_windows)
        self.static_data = self.ts_dataset.get_static_data().repeat(int(self.n_windows/self.ts_dataset.n_series), 1)
        self.sampling_idx = self._update_sampling_idx()
        self._is_train = True
        #random.seed(1)

    def _update_sampling_idx(self):
        # Only sample during training windows with at least one active output mask
        sampling_idx = t.sum(self.ts_windows[:, -1, -self.output_size:], axis=1)
        sampling_idx = t.nonzero(sampling_idx > 0)
        return list(sampling_idx.flatten().numpy())

    def _create_windows_tensor(self):
        """
        Comment here
        """
        tensor, right_padding = self.ts_dataset.get_filtered_tensor(self.offset, self.output_size, self.window_sampling_limit)
        _, c, _ = tensor.size()

        padder = t.nn.ConstantPad1d(padding=(self.input_size-1, right_padding), value=0)
        tensor = padder(tensor)

        tensor[:, 0, -self.output_size:] = 0
        tensor[:, -1, -self.output_size:] = 0

        windows = tensor.unfold(dimension=-1, size=self.input_size + self.output_size, step=1)
        windows = windows.permute(2,0,1,3)
        windows = windows.reshape(-1, c, self.input_size + self.output_size)
        return windows

    def __len__(self):
        return len(self.len_series)

    def __iter__(self):
        while True:
            if self._is_train:
                if self.batch_size > 0:
                    sampled_ts_indices = np.random.choice(self.sampling_idx, size=self.batch_size, replace=True)
                else:
                    sampled_ts_indices = self.sampling_idx
            else:
                sampled_ts_indices = list(range(self.n_windows-self.ts_dataset.n_series, self.n_windows))

            batch = self.__get_item__(sampled_ts_indices)

            #print(batch)

            yield batch

    def __get_item__(self, index):
        if self.model == 'nbeats':
            return self._nbeats_batch(index)
        elif self.model == 'esrnn':
            assert 1<0, 'hacer esrnn'
        else:
            assert 1<0, 'error'

    def _nbeats_batch(self, index):

        windows = self.ts_windows[index]
        static_data = self.static_data[index]

        insample_y = windows[:, 0, :self.input_size]
        insample_x_t = windows[:, 1:-1, :self.input_size]
        insample_mask = t.ones((len(insample_y), self.input_size)) #TODO: si afecta en nbeats en residuales, cambiar!
        #insample_mask = windows[:, -1, :self.input_size]

        outsample_y = windows[:, 0, self.input_size:]
        outsample_x_t = windows[:, 1:-1, self.input_size:]
        outsample_mask = windows[:, -1, self.input_size:]

        batch = {'insample_y':insample_y, 'insample_x_t':insample_x_t, 'insample_mask':insample_mask,
                  'outsample_y':outsample_y, 'outsample_x_t':outsample_x_t, 'outsample_mask':outsample_mask,
                  'static_data':static_data}

        return batch

    def update_offset(self, offset):
        if offset == self.offset:
            return # Avoid extra computation
        self.offset = offset
        self.ts_windows = self._create_windows_tensor()
        self.n_windows = len(self.ts_windows)
        self.static_data = self.ts_dataset.get_static_data().repeat(int(self.n_windows/self.ts_dataset.n_series), 1) #n_windows can change with offset
        self.sampling_idx = self._update_sampling_idx()

    def get_meta_data_var(self, var):
        """
        """
        return self.ts_dataset.get_meta_data_var(var)

    def get_n_variables(self):
        return self.ts_dataset.n_x_t, self.ts_dataset.n_s_t

    def get_n_series(self):
        return self.ts_dataset.n_series

    def get_max_len(self):
        return self.ts_dataset.max_len

    def get_n_channels(self):
        return self.ts_dataset.n_channels

    def get_frequency(self):
        return self.ts_dataset.frequency

    def train(self):
        self._is_train = True

    def eval(self):
        self._is_train = False