In [None]:
#|default_exp data.metadatasets

# Metadataset

>A dataset of datasets

This functionality will allow you to create a dataset from data stores in multiple, smaller datasets.

I'd like to thank both **Thomas Capelle** (https://github.com/tcapelle)  and **Xander Dunn** (https://github.com/xanderdunn) for their contributions to make this code possible. 

This functionality allows you to use multiple numpy arrays instead of a single one, which may be very useful in many practical settings. It's been tested it with 10k+ datasets and it works well. 

In [None]:
#|export
from tsai.imports import *
from tsai.utils import *
from tsai.data.validation import *
from tsai.data.core import *

In [None]:
#|export
class TSMetaDataset():
    _type = (TSTensor,)
    " A dataset capable of indexing mutiple datasets at the same time"
    def __init__(self, dataset_list, **kwargs):
        if not is_listy(dataset_list): dataset_list = [dataset_list]
        self.datasets = dataset_list
        self.split = kwargs['split'] if 'split' in kwargs else None            
        self.mapping = self._mapping()
        if hasattr(dataset_list[0], 'loss_func'): 
            self.loss_func =  dataset_list[0].loss_func
        else: 
            self.loss_func = None

    def __len__(self):
        if self.split is not None: 
            return len(self.split)
        else:
            return sum([len(ds) for ds in self.datasets])

    def __getitem__(self, idx):
        if self.datasets:
            if self.split is not None: idx = self.split[idx]
            idx = listify(idx)
            idxs = self.mapping[idx]
            idxs = idxs[idxs[:, 0].argsort()]
            self.mapping_idxs = idxs
            ds = np.unique(idxs[:, 0])
            b = [self.datasets[d][idxs[idxs[:, 0] == d, 1]] for d in ds]
            output = tuple(map(torch.cat, zip(*b)))
            output = self._type[0](output[0]), output[1]
            return output
        else:
            return

    def _mapping(self):
        lengths = [len(ds) for ds in self.datasets]
        idx_pairs = np.zeros((np.sum(lengths), 2)).astype(np.int32)
        start = 0
        for i,length in enumerate(lengths):
            if i > 0: 
                idx_pairs[start:start+length, 0] = i
            idx_pairs[start:start+length, 1] = np.arange(length)
            start += length
        return idx_pairs

    def new_empty(self): 
        new_dset = type(self)(self.datasets, split=self.split)
        new_dset.datasets = None
        return new_dset
    
    @property
    def vars(self):
        s = self.datasets[0][0][0] if not isinstance(self.datasets[0][0][0], tuple) else self.datasets[0][0][0][0]
        return s.shape[-2]
    @property
    def len(self): 
        s = self.datasets[0][0][0] if not isinstance(self.datasets[0][0][0], tuple) else self.datasets[0][0][0][0]
        return s.shape[-1]
    @property
    def vocab(self): 
        return self.datasets[0].vocab
    @property
    def cat(self): return hasattr(self, "vocab")


class TSMetaDatasets(FilteredBase):
    def __init__(self, metadataset, splits):
        store_attr()
        self.mapping = metadataset.mapping
        self.datasets = metadataset.datasets
    def subset(self, i):
        return type(self.metadataset)(self.metadataset.datasets, split=self.splits[i])
    @property
    def train(self): 
        return self.subset(0)
    @property
    def valid(self): 
        return self.subset(1)

Let's create 3 datasets. In this case they will have different sizes.

In [None]:
vocab = alphabet[:10]
dsets = []
for i in range(3):
    size = np.random.randint(50, 150)
    X = torch.rand(size, 5, 50)
    y = vocab[torch.randint(0, 10, (size,))]
    tfms = [None, TSClassification(vocab=vocab)]
    dset = TSDatasets(X, y, tfms=tfms)
    dsets.append(dset)



metadataset = TSMetaDataset(dsets)
splits = TimeSplitter(show_plot=False)(metadataset)
metadatasets = TSMetaDatasets(metadataset, splits=splits)
dls = TSDataLoaders.from_dsets(metadatasets.train, metadatasets.valid)
xb, yb = dls.train.one_batch()
xb, yb

(TSTensor(samples:64, vars:5, len:50, device=cpu, dtype=torch.float32),
 TensorCategory([1, 0, 3, 9, 7, 2, 8, 6, 1, 1, 1, 8, 1, 1, 9, 2, 6, 6, 1, 5, 5,
                 6, 9, 2, 7, 1, 6, 4, 9, 2, 5, 0, 4, 9, 1, 4, 4, 6, 0, 8, 8, 5,
                 8, 6, 9, 0, 8, 8, 6, 4, 8, 9, 7, 3, 4, 7, 7, 8, 6, 2, 3, 0, 7,
                 4]))

You can train metadatasets as you would train any other time series model in `tsai`:

```python
learn = ts_learner(dls, arch="TSTPlus")
learn.fit_one_cycle(1)
learn.export("test.pkl")
```

For inference, you should create the new metadatasets using the same method you used when you trained it. The you use fastai's learn.get_preds method to generate predictions: 

```python
vocab = alphabet[:10]
dsets = []
for i in range(3):
    size = np.random.randint(50, 150)
    X = torch.rand(size, 5, 50)
    y = vocab[torch.randint(0, 10, (size,))]
    tfms = [None, TSClassification(vocab=vocab)]
    dset = TSDatasets(X, y, tfms=tfms)
    dsets.append(dset)
metadataset = TSMetaDataset(dsets)
dl = TSDataLoader(metadataset)


learn = load_learner("test.pkl")
learn.get_preds(dl=dl)
```

There also en easy way to map any particular sample in a batch to the original dataset and id: 

In [None]:
dls = TSDataLoaders.from_dsets(metadatasets.train, metadatasets.valid)
xb, yb = first(dls.train)
mappings = dls.train.dataset.mapping_idxs
for i, (xbi, ybi) in enumerate(zip(xb, yb)):
    ds, idx = mappings[i]
    test_close(dsets[ds][idx][0].data.cpu(), xbi.cpu())
    test_close(dsets[ds][idx][1].data.cpu(), ybi.cpu())

For example the 3rd sample in this batch would be: 

In [None]:
dls.train.dataset.mapping_idxs[2]

array([  0, 112], dtype=int32)

In [None]:
#|eval: false
#|hide
from tsai.export import get_nb_name; nb_name = get_nb_name(locals())
from tsai.imports import create_scripts; create_scripts(nb_name)

<IPython.core.display.Javascript object>

/Users/nacho/notebooks/tsai/nbs/008_data.metadatasets.ipynb saved at 2023-03-24 11:30:57
Correct notebook to script conversion! 😃
Friday 24/03/23 11:31:00 CET
