<a href="https://colab.research.google.com/github/MihaiDogariu/CV3/blob/main/scripts/Unit_5_Datasets_and_Dataloaders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import math

In [None]:
# Download the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True)

In [None]:
# Split the train and test sets such that validation is the same size as test
# and it consists of elements extracted from the train set
val_size = len(test_dataset.targets)
train_size = len(train_dataset.targets) - val_size

train_data, val_data = torch.utils.data.random_split(
    train_dataset, [train_size, val_size]
)

train_dataset = train_data.dataset
val_dataset = val_data.dataset

In [None]:
# Every custom Dataset must inherit torch.utils.data.Dataset and implement 3 functions:
# __init__()
# __len_()
# __getitem__()

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [None]:
# Create custom datasets
train_dataset = CustomDataset(train_dataset.data, train_dataset.targets)
val_dataset = CustomDataset(val_dataset.data, val_dataset.targets)
test_dataset = CustomDataset(test_dataset.data, test_dataset.targets)

In [None]:
# Create dataloaders based on the custom datasets
batch_size = 67
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Iterate over all batches in the train dataloader and print some info about them
for batch_idx, (data, labels) in enumerate(train_loader):
    print('Batch idx:', batch_idx)
    print(data.shape)
    print(labels.shape)

In [None]:
# Plot a batch of images
data, labels = next(val_loader.__iter__())

num_rows = math.ceil(math.sqrt(batch_size))
num_cols = math.ceil(batch_size / num_rows)

fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
i = 0
for i in range(batch_size):
    # print(i)
    row_idx = i // num_cols
    col_idx = i % num_cols
    ax = axes[row_idx, col_idx]
    image = data.numpy().squeeze()  # Remove the channel dimension
    ax.imshow(image[i], cmap='gray')  # Assuming grayscale images
    ax.set_title(f'Label: {labels[i].item()}')
    ax.axis('off')
plt.show()