In [None]:
import typing

import torch


class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, l: typing.List[typing.Any]):
        self.l = l

    # Not dependent on index
    def __getitem__(self, index):
        return self.l[index]

    def __len__(self):
        return len(self.l)


dataset = DummyDataset([0, 1, 2, 3, 4, 5])
len(dataset), dataset[3]

In [None]:
dataset = Dataset(torch.randn(300, 10), torch.randint(0, 5, size=(300,)))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64)

for (X, y) in dataloader:
    print(X.shape, y.shape)

In [None]:
import tempfile

import matplotlib.pyplot as plt
import torchvision

with tempfile.TemporaryDirectory() as tmp_dir:
    # Training dataset as torch.utils.data.Dataset instance
    train_dataset = torchvision.datasets.MNIST(
        root=tmp_dir,  # where data is stored
        transform=torchvision.transforms.ToTensor(),  # how each sample will be transformed
        train=True,  # we want training data
        download=True,  # should i download it if it's not already here?
    )

    # Test dataset as torch.utils.data.Dataset instance
    test_dataset = torchvision.datasets.MNIST(
        root=tmp_dir,
        transform=torchvision.transforms.ToTensor(),
        train=False,
    )
    
len(train_dataset), len(test_dataset)

In [None]:
x = train_dataset[0][0]
print(x.shape)
plt.imshow(x.squeeze().numpy(), cmap="gray")
plt.show()

In [None]:
train_dataset, validation_dataset = torch.utils.data.random_split(
    train_dataset, [50000, 10000]
)  # split into 50K training & 10K validation

In [None]:
BATCH_SIZE = 64

dataloaders = {
    "train": torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=torch.cuda.is_available(),
    ),
    "validation": torch.utils.data.DataLoader(
        validation_dataset, batch_size=BATCH_SIZE, pin_memory=torch.cuda.is_available()
    ),
    "test": torch.utils.data.DataLoader(
        test_dataset, batch_size=BATCH_SIZE, pin_memory=torch.cuda.is_available()
    ),
}