In [1]:
# default_exp data.ts_loader_general

In [2]:
#hide
%load_ext autoreload
%autoreload 2

In [3]:
#export
import numpy as np
import pandas as pd
import random
import torch as t
import copy
from fastcore.foundation import patch
from nixtla.data.ts_dataset import TimeSeriesDataset
from collections import defaultdict

In [None]:
#export
class TimeSeriesLoader(object):
    def __init__(self,
                 ts_dataset:TimeSeriesDataset,
                 model:str,
                 offset:int,
                 window_sampling_limit: int, 
                 input_size: int,
                 output_size: int,
                 idx_to_sample_freq: int, #TODO: not active yet
                 batch_size: int,
                 n_series_per_batch: int,
                 ts_outsample_mask: list=[]):
        """
        """
        self.model = model
        self.window_sampling_limit = window_sampling_limit
        self.input_size = input_size
        self.output_size = output_size
        self.batch_size = batch_size
        self.idx_to_sample_freq = idx_to_sample_freq
        self.offset = offset
        self.ts_dataset = copy.deepcopy(ts_dataset) #TODO: sacar deep_copy
        self.ts_outsample_mask = ts_outsample_mask
        self.t_cols = self.ts_dataset.t_cols
        self.n_series_per_batch = min(n_series_per_batch, self.ts_dataset.n_series)
        self.windows_per_serie = self.batch_size // self.n_series_per_batch

        assert self.batch_size % self.n_series_per_batch == 0, \
                        f'batch_size {self.batch_size} must be multiple of n_series_per_batch {self.n_series_per_batch}'

        # Overwrite mask if provided
        if len(self.ts_outsample_mask) > 0:
            self.ts_dataset.ts_tensor[:, self.t_cols.index('outsample_mask'), :] = t.as_tensor(ts_outsample_mask,dtype=t.float32)
       
        self._is_train = True

        #TODO: cambiar estos prints
        # print('X: time series features, of shape (#series,#times,#features): \t' + str(X.shape))
        # print('Y: target series (in X), of shape (#series,#times): \t \t' + str(Y.shape))
        # print('S: static features, of shape (#series,#features): \t \t' + str(S.shape))

    def _get_sampleable_windows_idxs(self, ts_windows):
        # Only sample during training windows with at least one active output mask
        sampling_idx = t.sum(ts_windows[:, self.t_cols.index('outsample_mask'), -self.output_size:], axis=1)
        sampling_idx = t.nonzero(sampling_idx > 0)
        return list(sampling_idx.flatten().numpy())

    def _create_windows_tensor(self, ts_idxs=None):
        """
        Comment here
        TODO: Cuando creemos el otro dataloader, si es compatible lo hacemos funcion transform en utils
        """
        tensor, right_padding = self.ts_dataset.get_filtered_tensor(offset=self.offset, output_size=self.output_size,
                                                                    window_sampling_limit=self.window_sampling_limit,
                                                                    ts_idxs=ts_idxs)
        _, n_channels, _ = tensor.size()

        padder = t.nn.ConstantPad1d(padding=(self.input_size-1, right_padding), value=0)
        tensor = padder(tensor)

        # Last output_size outsample_mask and y to 0
        tensor[:, self.t_cols.index('y'), -self.output_size:] = 0 # overkill to ensure no leakage
        tensor[:, self.t_cols.index('outsample_mask'), -self.output_size:] = 0

        # Creating rolling windows
        windows = tensor.unfold(dimension=-1, size=self.input_size + self.output_size, step=1)
        windows = windows.permute(2,0,1,3)
        windows = windows.reshape(-1, n_channels, self.input_size + self.output_size)
        return windows

    def __iter__(self):
        while True:
            if self._is_train:
                ts_idxs = np.random.choice(range(self.ts_dataset.n_series),
                                           size=self.n_series_per_batch, replace=False)
            else:
                # Get last n_series windows, dataset is ordered because of unfold
                assert 1<0, 'implementar'
                #ts_idxs = list(range(self.n_windows-self.ts_dataset.n_series, self.n_windows))

            batch = self.__get_item__(index=ts_idxs)

            yield batch

    def __get_item__(self, index):
        if self.model == 'nbeats':
            return self._nbeats_batch(index)
        elif self.model == 'esrnn':
            assert 1<0, 'hacer esrnn'
        else:
            assert 1<0, 'error'

    def _nbeats_batch(self, index):

        # Create windows for each sampled ts and sample random unmasked windows from each ts
        windows = self._create_windows_tensor(index)
        sampleable_windows = self._get_sampleable_windows_idxs(windows)
        windows_idxs = np.random.choice(sampleable_windows, self.batch_size, replace=True)
        windows = windows[windows_idxs]

        x_s = self.ts_dataset.x_s[index]
        x_s = x_s.repeat(self.windows_per_serie, 1)
        x_s = x_s[windows_idxs]

        insample_y = windows[:, self.t_cols.index('y'), :self.input_size]
        insample_x_t = windows[:, (self.t_cols.index('y')+1):self.t_cols.index('insample_mask'), :self.input_size]
        insample_mask = windows[:, self.t_cols.index('insample_mask'), :self.input_size]

        outsample_y = windows[:, self.t_cols.index('y'), self.input_size:]
        outsample_x_t = windows[:, (self.t_cols.index('y')+1):self.t_cols.index('insample_mask'), self.input_size:]
        outsample_mask = windows[:, self.t_cols.index('outsample_mask'), self.input_size:]

        batch = {'insample_y':insample_y, 'insample_x_t':insample_x_t, 'insample_mask':insample_mask,
                  'outsample_y':outsample_y, 'outsample_x_t':outsample_x_t, 'outsample_mask':outsample_mask,
                  'x_s':x_s}

        return batch

    def update_offset(self, offset):
        if offset == self.offset:
            return # Avoid extra computation
        self.offset = offset

    def get_meta_data_var(self, var):
        """
        """
        return self.ts_dataset.get_meta_data_var(var)

    def get_n_variables(self):
        return self.ts_dataset.n_x_t, self.ts_dataset.n_s_t

    def get_n_series(self):
        return self.ts_dataset.n_series

    def get_max_len(self):
        return self.ts_dataset.max_len

    def get_n_channels(self):
        return self.ts_dataset.n_channels

    def get_frequency(self):
        return self.ts_dataset.frequency

    def train(self):
        self._is_train = True

    def eval(self):
        self._is_train = False