In [None]:
# default_exp data.tsloader

# TimeSeriesLoader
> Data Loader for Time Series data

In [None]:
#hide
from nbdev import *
%load_ext autoreload
%autoreload 2

In [None]:
#export
from typing import Dict

import torch as t
from torch.utils.data import DataLoader
from fastcore.foundation import patch

In [None]:
#export
class TimeSeriesLoader(DataLoader):

    def __init__(self, dataset, **kwargs) -> 'TimeSeriesLoader':
        """Wraps the pytorch `DataLoader` with a special collate function 
        for the `TimeSeriesDataset` ouputs.
        """
        if 'collate_fn' in kwargs.keys():
            raise Exeption(
                'This class wraps the pytorch `DataLoader` with a '
                'special collate function. If you want to use yours '
                'simply use `DataLoader`'
            )
        kwargs_ = {**kwargs, **dict(collate_fn=self.collate_fn)}
        DataLoader.__init__(self, dataset=dataset, **kwargs_)

In [None]:
#export
@patch
def collate_fn(self: TimeSeriesLoader, batch: Dict[str, t.Tensor]):
    """Special collate fn for the `TimeSeriesDataset`.
    
    Notes
    -----
    [1] Adapted from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py.
    """

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, t.Tensor):
        out = None
        if t.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return t.cat(batch, out=out)
    elif isinstance(elem, collections.abc.Mapping):
        return {key: self.collate_fn([d[key] for d in batch]) for key in elem}

    raise TypeError(f'Unknown {elem_type}')

In [None]:
from nixtla.data.tsdataset import TimeSeriesDataset
from nixtla.data.utils import create_synthetic_tsdata

Y_df, S_df, X_df = create_synthetic_tsdata(sort=True)
dataset = TimeSeriesDataset(Y_df=Y_df, skip_nonsamplable=True)
dataloader = TimeSeriesLoader(dataset=dataset)

In [None]:
for batch in dataloader:
    batch