# dataloaders

> Bundle trn/val/test datasets together, and add helpful functionality for viewing data

In [None]:
#| hide
#|default_exp dataloaders

In [None]:
#|hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from isaacai.utils import *
import pandas as pd, numpy as np, fastcore.all as fc
import matplotlib.pyplot as plt,matplotlib as mpl
import random
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch.utils.data import default_collate
from operator import itemgetter

In [None]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
mpl.rcParams['image.cmap'] = 'gray'

import logging
logging.disable(logging.WARNING)

set_seed(42)

In [None]:
#|export
@fc.delegates(DataLoader)
def get_dataloaders(train_dataset, valid_dataset, batch_size, **kwargs):
    return (DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs),
            DataLoader(valid_dataset, batch_size=batch_size*2, shuffle=False, **kwargs))

In [None]:
#| export
def collate_dataset_dict(dataset):
    get = itemgetter(*dataset.features)
    def _f(b): return get(default_collate(b))
    return _f

In [None]:
#| export
class DataLoaders():
    def __init__(self, train, valid): fc.store_attr()
    
    @classmethod
    def from_dataset_dict(cls, dataset_dict, batch_size, **kwargs):
        f = collate_dataset_dict(dataset_dict['train'])
        return cls(*get_dataloaders(*dataset_dict.values(), batch_size=batch_size, collate_fn=f))

    @fc.delegates(get_grid)
    def show_batch(self, n=9, train_dataset=True, **kwargs):
        _dataset = getattr(self, 'train').dataset if train_dataset else getattr(self, 'valid').dataset
        batch = list(zip(*_dataset[random.sample(range(len(_dataset)),n)].values()))
        fig,axs = get_grid(n=n,**kwargs)
        for (image,label),ax in zip(batch,axs.flat):
            show_image(image,ax=ax,title=_dataset.features['label'].names[label])

In [None]:
#| export
@inplace
def sample_dataset_dict(dataset, sample_sizes=(2000,2000)):
    for sample_size,name in zip(sample_sizes,dataset):
        sample_idxs = random.sample(range(len(dataset[name])),sample_size)
        dataset[name] = dataset[name].select(sample_idxs)

In [None]:
xmean,xstd = 0.28, 0.35

@inplace
def transformi(b): b['image'] = [(TF.to_tensor(o)-xmean)/xstd for o in b['image']]

_dataset = load_dataset('fashion_mnist').with_transform(transformi)
_dataset = sample_dataset_dict(_dataset)
dls = DataLoaders.from_dataset_dict(_dataset, 64, num_workers=4)

  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
dls.show_batch()

In [None]:
class SimpleNet(nn.Module):
    ## simplified from Pytorch Tutorial
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
    
model = SimpleNet()
fc.test_eq(model(fc.first(dls.train)[0]).shape,(64,10))

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()