# Question 7

Use the dataloaders to load both the train and the test set into large tensors:
one for the instances, one for the labels. Split the training data into 50 000 training instances
and 10 000 validation instances. Then write a training loop that loops over batches of 16
instances at a time.

In [None]:
import torch
import torchvision
from torchvision.transforms import *
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [None]:
TRAIN_SIZE = 50000
VALIDATION_SIZE = 10000
ROOT = '../data' # .. to not have data folder within notebooks folder
BATCH_SIZE = 16
EPOCHS = 1

### Data transformer

See on PyTorch discussion forum [here](https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/2) on choice for 0.137 (mean) and 0.3081 (std.) for normalization. Official PyTorch example code for MNIST also uses these parameters. 

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

### Training / validation data

(Note: Using ```.dataset``` on **train_set** or **validation_set** will return the original 60000 instances, instead use the specified dataloaders further below)

In [None]:
dataset = torchvision.datasets.MNIST(root = ROOT,
                                 train = True,
                                 transform = transform,
                                 download = True)
train_set, validation_set = torch.utils.data.random_split(dataset, [TRAIN_SIZE, VALIDATION_SIZE])

train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size = BATCH_SIZE,
                                           shuffle = True,
                                           num_workers = 2)

validation_loader = torch.utils.data.DataLoader(validation_set,
                                                batch_size = BATCH_SIZE,
                                                shuffle = True,
                                                num_workers = 2)

### Test data

In [None]:
test_set = dataset = torchvision.datasets.MNIST(root = ROOT,
                                 train = False,
                                 transform = transform,
                                 download = True)

test_loader = torch.utils.data.DataLoader(test_set,
                                          batch_size = BATCH_SIZE,
                                          shuffle = True,
                                          num_workers = 2)

### Training loop

(Note: Training loop uses the dataloader **train_loader** and extracts _instances_ and _labels_. _Instances_ is shaped (batch_size,1,28,28), as is required for further operations)

In [None]:
for e in range(EPOCHS):
    print(f'# batches ({TRAIN_SIZE}/{BATCH_SIZE}): {TRAIN_SIZE // BATCH_SIZE}')
    for batch_i, batch_data in enumerate(train_loader, 0):
        #if batch_i >= 10: break (THIS ONE-LINER CAN BE USED TO CONTROL enumerate FOR BEBUGGING)
        instances, labels = batch_data
        
        # (...)