# Exploration of MNIST data set

In [None]:
import os
import sys
sys.path.append("../")
from dotenv import find_dotenv, load_dotenv
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import torchvision.transforms as T
import torch.nn as nn

from src.utils import CustomMnistDataset, imshow

load_dotenv(find_dotenv())
from pathlib import Path
DATA_DIR= Path(os.getenv('DATA_DIR'))


In [None]:

train_dataset = CustomMnistDataset(img_dir = DATA_DIR, type='train')
test_dataset = CustomMnistDataset(img_dir = DATA_DIR, type='test')


In [None]:
BATCH_SIZE = 5
SHUFFLE = True
train_dataloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = SHUFFLE)
test_dataloader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = SHUFFLE)

## Training Examples

In [None]:
inputs, classes = next(iter(train_dataloader))
out = make_grid(inputs, nrow=5)
imshow(out, title = [int(x.numpy()) for x in classes])

## Test Examples

In [None]:
inputs = next(iter(test_dataloader))
out = make_grid(inputs[0], nrow=5)
imshow(out)

## Rotations and translations

In [None]:
augmentation = nn.Sequential(
    T.RandomAffine(degrees = 22.5, translate = (0.2, 0.2), scale = (0.5, 1))
)
aug_dataset = CustomMnistDataset(img_dir = DATA_DIR, type='train', transform=augmentation)
aug_dataloader = DataLoader(aug_dataset, batch_size = BATCH_SIZE, shuffle = SHUFFLE)

In [None]:
inputs, classes = next(iter(aug_dataloader))
out = make_grid(inputs, nrow=5)
imshow(out, title = [int(x.numpy()) for x in classes])