In [None]:
# default_exp data.ts_loader_pinche

In [None]:
#hide
%load_ext autoreload
%autoreload 2

In [None]:
#export
import numpy as np
import pandas as pd
import random
import torch as t
import copy
from fastcore.foundation import patch
from nixtla.data.ts_dataset import TimeSeriesDataset
from collections import defaultdict

In [None]:
#export
class TimeSeriesLoader(object):
    def __init__(self,
                 ts_dataset:TimeSeriesDataset,
                 model:str,
                 offset:int,
                 window_sampling_limit: int, 
                 input_size: int,
                 output_size: int,
                 idx_to_sample_freq: int, #TODO: not active yet
                 batch_size: int,
                 ts_outsample_mask: list=[]):
        """
        """
        self.model = model
        self.window_sampling_limit = window_sampling_limit
        self.input_size = input_size
        self.output_size = output_size
        self.batch_size = batch_size
        self.idx_to_sample_freq = idx_to_sample_freq
        self.offset = offset
        self.ts_dataset = copy.deepcopy(ts_dataset) #TODO: sacar deep_copy
        self.ts_outsample_mask = ts_outsample_mask # TODO: no funciona por ahora

        # #TODO: use tcols.get_loc()
        # # mascara propia
        # if len(self.ts_outsample_mask) > 0:
        #     #self.ts_dataset.ts_tensor[:, -1, self.hide_ts_idx] = 0
        #     self.ts_dataset.ts_tensor[:, -1, :] = t.as_tensor(ts_outsample_mask,dtype=t.float32)

        # We sample from this tensor
        self.time_series = self.ts_dataset.ts_tensor
        self.x_s = self.ts_dataset.get_x_s()
       
        self._is_train = True

        #TODO: cambiar estos prints
        # print('X: time series features, of shape (#series,#times,#features): \t' + str(X.shape))
        # print('Y: target series (in X), of shape (#series,#times): \t \t' + str(Y.shape))
        # print('S: static features, of shape (#series,#features): \t \t' + str(S.shape))

    def __len__(self):
        return len(self.ts_dataset.len_series)

    def __iter__(self):
        while True:
            if self._is_train:
                sampled_ts_indices = np.random.randint(self.ts_dataset.n_series, size=self.batch_size)
            else:
                sampled_ts_indices = range(self.ts_dataset.n_series)

            batch_dict = defaultdict(list)
            for index in sampled_ts_indices:
                batch_i = self.__get_item__(index)
                for key in batch_i:
                    batch_dict[key].append(batch_i[key])

            batch = defaultdict(list)
            for key in batch_dict:
                batch[key] = np.stack(batch_dict[key])

            yield batch

    def __get_item__(self, index):
        if self.model == 'nbeats':
            return self._nbeats_batch(index)
        elif self.model == 'esrnn':
            assert 1<0, 'hacer esrnn'
        else:
            assert 1<0, 'error'

    def _nbeats_batch(self, index):
        insample = np.zeros((self.ts_dataset.n_channels, self.input_size), dtype=float)
        insample_mask = np.ones(self.input_size)  #TODO: si afecta en nbeats en residuales, cambiar!
        outsample = np.zeros((self.ts_dataset.n_channels, self.output_size), dtype=float)
        outsample_mask = np.zeros(self.output_size)
            
        ts = self.time_series[index]
        len_ts = self.ts_dataset.len_series[index]
        # Rolling window (like replay buffer)
        init_ts = max(self.ts_dataset.max_len-len_ts+1, self.ts_dataset.max_len-self.offset-self.window_sampling_limit)

        assert self.ts_dataset.max_len-self.offset > init_ts, f'Offset too big for serie {index}'
        if self._is_train:
            cut_point = np.random.randint(low=init_ts, high=self.ts_dataset.max_len-self.offset, size=1)[0]
        else:
            cut_point = max(self.ts_dataset.max_len-self.offset, self.input_size)
        
        insample_window = ts[:-2, max(0, cut_point - self.input_size):cut_point] #se saca mask channel del final
        insample[:, -insample_window.shape[1]:] = insample_window

        if self._is_train:
            #se saca mask channel del final
            outsample_window = ts[:-2, cut_point:min(self.ts_dataset.max_len - self.offset, cut_point + self.output_size)]
        else:
            #se saca mask channel del final
            outsample_window = ts[:-2, cut_point:min(self.ts_dataset.max_len, cut_point + self.output_size)]

        outsample[:, :outsample_window.shape[1]] = outsample_window 
        outsample_mask[:outsample_window.shape[1]] = 1.0

        insample_y = insample[0, :]
        insample_x_t = insample[1:, :]

        outsample_y = outsample[0, :]
        outsample_x_t = outsample[1:, :]

        x_s = self.x_s[index, :]

        sample = {'insample_y':insample_y, 'insample_x_t':insample_x_t, 'insample_mask':insample_mask,
                  'outsample_y':outsample_y, 'outsample_x_t':outsample_x_t, 'outsample_mask':outsample_mask,
                  'x_s':x_s}

        return sample

    def update_offset(self, offset):
        if offset == self.offset:
            return # Avoid extra computation
        self.offset = offset

    def get_meta_data_var(self, var):
        """
        """
        return self.ts_dataset.get_meta_data_var(var)

    def get_n_variables(self):
        return self.ts_dataset.n_x_t, self.ts_dataset.n_s_t

    def get_n_series(self):
        return self.ts_dataset.n_series

    def get_max_len(self):
        return self.ts_dataset.max_len

    def get_n_channels(self):
        return self.ts_dataset.n_channels

    def get_frequency(self):
        return self.ts_dataset.frequency

    def train(self):
        self._is_train = True

    def eval(self):
        self._is_train = False