## How to load mini batches?

The examples that I've shown so far are just toy examples. Nobody in real life will train neural networks with the entire dataset in one go or iterate through each data instance in the dataset. You use mini batches. 

Jax and the ecosystem of libaries were primarily built with doing computation (or using fancy words, doing machine learning research) in mind. As a result, Jax has nothing similar to Tensorflow Datasets or Pytorch data utils for structuring and loading data. Jax instead encourages to use PyTorch and TF for data loading.

As outlined in [one of the official guides](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html), you can just use a pytorch dataset, and then override the collate_fn in the pytorch dataloader to use it with jax arrays. This may sound easy but for complex data cases writing a new collate_fn can be daunting and prone to errors. You can also use TF datasets but, personally, I'm not a big fan of TF and setting it up has been akin to offering peace [to some god of the pantheon](https://youtu.be/FxoFiu-M0mg?si=XFn7ZlB-RWNhFZmM) these days. So sticking with Pytorch, there must be an easier way right?

Well turns out, someone has written a package named [jax-dataloader](https://github.com/birkhoffg/jax-dataloader), which can take care of the dataloading functionality for you. All you need is to define a torch dataset. (If you're visiting the jax-dataloader repo, give a star, and also contribute, we need more work like this).

So in this notebook, I'll define a torch dataset and show you how to use jax-dataloader with it. The dataset will come from Huggingface Datasets package. 

In [2]:
from datasets import load_dataset

dataset_name = "mnist"
dataset = load_dataset(dataset_name, split="train")
dataset

Downloading builder script:   0%|          | 0.00/3.98k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.21k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.83k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/9.91M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/28.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.65M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.54k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset({
    features: ['image', 'label'],
    num_rows: 60000
})

In [3]:
# let's inspect a bit

sample = dataset[0]
print(type(sample["image"]))
print(type(sample["label"]))

<class 'PIL.PngImagePlugin.PngImageFile'>
<class 'int'>


In [5]:
import numpy as np

print(np.array(sample["image"]).shape)

(28, 28)


Let's define the torch dataset.

In [7]:
from torch.utils.data import Dataset

class MNISTDataset(Dataset):
    def __init__(self, split) -> None:
        self.dataset_name = "mnist"
        self.split = split
        
        self.data = load_dataset(self.dataset_name, split=self.split)
        
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, index):
        data_instance = self.data[index]
        
        image = data_instance["image"]
        label = data_instance["label"]
        
        # the image is a pillow image so we'll have to convert it to
        # a numpy array
        image = np.array(image)
        
        return image, label

# create a training dataset
train_set = MNISTDataset("train")
print(type(train_set[0][0]))

<class 'numpy.ndarray'>


Now for the dataloader. 

In [8]:
import jax_dataloader as jdl

train_loader = jdl.DataLoader(dataset=train_set, backend="pytorch", batch_size=64, shuffle=True)

So there are a few important things to notice here. The API of `jdl.Dataloader` is quite similar to that of Pytorch Dataloader. However, there is a backend parameter. This is to tell the `jdl.Dataloader` that the dataset will come in a pytorch format. 

In [9]:
# iterate through the loader
for batch in train_loader:
    print(batch)
    break # let's stop here!

[array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ...,

## What if I don't want to use torch dataset?

Good question. For simpler dataset structures you may not need something as expressive as a Dataset class. You can instead use the ArrayDataset type from jax-dataloader. 

In [12]:
images = np.array(
    [np.array(d["image"]) for d in dataset]
)

labels = np.array([d["label"] for d in dataset])

In [14]:
arr_dataset = jdl.ArrayDataset(images, labels)
# dataloader
arr_loader = jdl.DataLoader(dataset=arr_dataset, backend="jax", batch_size=64, shuffle=True)
# iterate
for batch in arr_loader:
    print(batch)
    break

(array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ...,

So what has changed here? The backend param. It's no longer a pytorch dataset so you'll mention jax as a backend. There's another thing you can do with the ArrayDataset. Say, your dataset was originally given as pytorch tensors. How can you use them here? Just change the backend param. 

In [15]:
import torch

images_tensor = torch.from_numpy(images)
labels_tensor = torch.from_numpy(labels)

ds = jdl.ArrayDataset(images_tensor, labels_tensor)
loader = jdl.DataLoader(dataset=ds, backend="pytorch", batch_size=64, shuffle=True)
for batch in loader:
    print(batch)
    break

[array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ...,