In [None]:
# default_exp data.mixed

# Mixed data

> DataLoader than can take data from multiple dataloaders with different types of data

In [None]:
#export
from tsai.imports import *

In [None]:
#export
from packaging import version
from fastai.data.load import _FakeLoader
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

In [None]:
#export
# This implementation of a mixed dataloader is based on a great implementation created by Zach Mueller in this fastai thread:
# https://forums.fast.ai/t/combining-tabular-images-in-fastai2-and-should-work-with-almost-any-other-type/73197

class MixedDataLoaders():
    def __init__(self, *dls, device=None):
        "Accepts any number of `DataLoaders` and a device"
        device = ifnone(device, default_device())
        self.device = device
        self.c = []
        bs = min([dl.bs for dl in dls])
        for dl in dls: # ensure all dls have the same bs
            dl.bs = bs
            if bs == 0:  self.train_ds = dl.dataset
            dl.shuffle_fn = self.shuffle_fn
            if self.c == [] and hasattr(dl, "c"): self.c = dl.c
            dl.to(device=device)
        self.dls = dls
        self.count = 0
        self.fake_l = _FakeLoader(self, False, 0, 0, 0) if version.parse(fastai.__version__) >= version.parse("2.1") else _FakeLoader(self, False, 0, 0)
        self._get_idxs()
        
    def __len__(self): return len(self.dls[0])
    
    def _get_vals(self, x):
        "Checks for duplicates in batches"
        idxs, new_x = [], []
        for i, o in enumerate(x): x[i] = o.cpu().numpy().flatten()
        for idx, o in enumerate(x):
            if not self._arrayisin(o, new_x):
                idxs.append(idx)
                new_x.append(o)
        return idxs
    
    def _get_idxs(self):
        "Get `x` and `y` indices for batches of data"
        self.n_inps = [dl.n_inp for dl in self.dls]
        self.x_idxs = self._split_idxs(self.n_inps)
        
        # Identify duplicate targets
        dl_dict = dict(zip(range(0,len(self.dls)), self.n_inps))
        outs = L([])
        for key, n_inp in dl_dict.items():
            b = next(iter(self.dls[key]))
            outs += L(b[n_inp:])
        self.y_idxs = self._get_vals(outs)

    def __iter__(self):
        z = zip(*[_loaders[i.fake_l.num_workers==0](i.fake_l) for i in self.dls])
        for b in z:  
            inps = []
            outs = []
            if self.device is not None: b = to_device(b, self.device)
            for batch, dl in zip(b, self.dls):
                batch = dl.after_batch(batch)
                inps += batch[:dl.n_inp]
                outs += batch[dl.n_inp:]
            # Remove duplicates and split inputs and outputs
            inps = [L(inps)[idx] for idx in self.x_idxs] if len(self.x_idxs) > 1 else L(outs)[self.x_idxs][0]
            outs = L(outs)[self.y_idxs] if len(self.y_idxs) > 1 else L(outs)[self.y_idxs][0]
            yield (inps, outs)

    def one_batch(self):
        "Grab one batch of data"
        with self.fake_l.no_multiproc(): res = first(self)
        if hasattr(self, 'it'): delattr(self, 'it')
        return res

    def shuffle_fn(self, idxs):
        "Generate the same idxs for all dls in each batch"
        if self.count == 0: self.rng = self.dls[0].rng.sample(idxs, len(idxs))
        self.count += 1
        if self.count == len(self.dls): self.count = 0
        return self.rng

    def show_batch(self):
        "Show a batch of data"
        for dl in self.dls: dl.show_batch()
            
    def to(self, device): self.device = device

    def _arrayisin(self, arr, arr_list):
        "Checks if `arr` is in `arr_list`"
        for a in arr_list:
            if np.array_equal(arr, a): return True
        return False
    
    def _split_idxs(self, a):
        a_cum = np.array(a).cumsum().tolist()
        b = np.arange(sum(a)).tolist()
        start = 0
        b_ = []
        for i, idx in enumerate(range(len(a))):
            end = a_cum[i]
            b_.append(b[start:end] if end - start > 1 else b[start])
            start = end
        return b_

In [None]:
#export
def get_mixed_dls(*dls, device=None):
    device = ifnone(device, default_device())
    _mixed_train_dls = []
    _mixed_valid_dls = []
    for dl in dls:
        _mixed_train_dls.append(dl.train)
        _mixed_valid_dls.append(dl.valid)
    mixed_train_dl = MixedDataLoaders(*_mixed_train_dls)
    mixed_valid_dl = MixedDataLoaders(*_mixed_valid_dls)
    mixed_dls = DataLoaders(mixed_train_dl, mixed_valid_dl, device=device)
    return mixed_dls

In [None]:
from tsai.data.tabular import *

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
# df['salary'] = np.random.rand(len(df)) # uncomment to simulate a cont dependent variable
target = 'salary'
splits = RandomSplitter()(range_of(df))

cat_names = ['workclass', 'education', 'marital-status']
cont_names = ['age', 'fnlwgt']
dls1 = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, target=target, splits=splits, bs=512)
dls1.show_batch()

cat_names = None #['occupation', 'relationship', 'race']
cont_names = ['education-num']
dls2 = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, target=target, splits=splits, bs=128)
dls2.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,Private,Some-college,Divorced,30.0,173646.999321,<50k
1,Private,HS-grad,Separated,41.0,216116.001059,<50k
2,State-gov,11th,Widowed,61.000001,159908.000763,>=50k
3,Self-emp-not-inc,Assoc-acdm,Married-civ-spouse,54.000001,392286.005788,>=50k
4,Private,Bachelors,Separated,43.0,27765.999597,>=50k
5,Private,11th,Married-spouse-absent,25.0,210095.000186,<50k
6,Private,Some-college,Married-civ-spouse,35.0,44780.005201,>=50k
7,Private,9th,Married-civ-spouse,47.0,121124.001323,<50k
8,Self-emp-not-inc,Assoc-voc,Never-married,26.0,201137.999878,<50k
9,Private,Assoc-voc,Married-civ-spouse,46.0,237730.99891,<50k


Unnamed: 0,education-num_na,education-num,salary
0,False,13.0,>=50k
1,False,10.0,<50k
2,False,10.0,<50k
3,False,13.0,<50k
4,False,9.0,<50k
5,False,9.0,<50k
6,False,10.0,<50k
7,False,9.0,<50k
8,False,7.0,<50k
9,False,9.0,<50k


In [None]:
b = first(dls2.train)
b[0].shape, b[1].shape

(torch.Size([128, 1]), torch.Size([128, 1]))

In [None]:
dls = get_mixed_dls(dls1, dls2)
dls.train.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,Local-gov,Bachelors,Never-married,28.0,407671.996379,<50k
1,Private,Some-college,Never-married,24.000001,379417.996016,<50k
2,Private,HS-grad,Married-civ-spouse,50.0,95469.000195,>=50k
3,Private,HS-grad,Separated,39.0,174329.999327,<50k
4,?,HS-grad,Married-civ-spouse,60.0,174073.000497,<50k
5,Self-emp-not-inc,10th,Married-civ-spouse,55.0,278228.002656,<50k
6,Local-gov,11th,Divorced,39.0,189911.000001,<50k
7,Private,Masters,Married-civ-spouse,57.0,112840.001634,>=50k
8,Private,Bachelors,Never-married,31.0,111566.998586,>=50k
9,Private,HS-grad,Never-married,20.000001,102607.002072,<50k


Unnamed: 0,education-num_na,education-num,salary
0,False,13.0,<50k
1,False,10.0,<50k
2,False,9.0,>=50k
3,False,9.0,<50k
4,True,10.0,<50k
5,False,6.0,<50k
6,False,7.0,<50k
7,False,14.0,>=50k
8,False,13.0,>=50k
9,False,9.0,<50k


In [None]:
xs, ys = first(dls.train)
xs[0][0].shape, xs[0][1].shape, xs[1][0].shape, xs[1][1].shape, ys.shape

(torch.Size([128, 3]),
 torch.Size([128, 2]),
 torch.Size([128, 1]),
 torch.Size([128, 1]),
 torch.Size([128, 1]))

In [None]:
#hide
beep(create_scripts())

<IPython.core.display.Javascript object>

Converted 000_utils.ipynb.
Converted 000b_data.validation.ipynb.
Converted 001_data.external.ipynb.
Converted 002_data.core.ipynb.
Converted 003_data.preprocessing.ipynb.
Converted 003b_data.transforms.ipynb.
Converted 003c_data.image.ipynb.
Converted 005_data.tabular.ipynb.
Converted 006_data.mixed.ipynb.
Converted 007_metrics.ipynb.
Converted 008_learner.ipynb.
Converted 009_optimizer.ipynb.
Converted 010_callback.core.ipynb.
Converted 011_callback.semi_supervised.ipynb.
Converted 100_models.utils.ipynb.
Converted 100b_models.layers.ipynb.
Converted 101_models.ResNet.ipynb.
Converted 101b_models.ResNetPlus.ipynb.
Converted 102_models.InceptionTime.ipynb.
Converted 102b_models.InceptionTimePlus.ipynb.
Converted 103_models.MLP.ipynb.
Converted 103b_models.FCN.ipynb.
Converted 103c_models.FCNPlus.ipynb.
Converted 104_models.ResCNN.ipynb.
Converted 105_models.RNN.ipynb.
Converted 105_models.RNNPlus.ipynb.
Converted 106_models.XceptionTime.ipynb.
Converted 106b_models.XceptionTimePlus.ipy