# Understanding the impact of varying the batch size

We will compare the following:
- The loss and accuracy values of the training and validation data when the training batch size is 32
- the loss and accuravy values of the training and validation data when the training batch size is 10000

## Batch size of 32

1. Download and import the training set:

In [7]:
from torchvision import datasets
import torch
data_folder = "./data/FMNIST"
training_fmnist = datasets.FashionMNIST(data_folder, download=True, train=True)

training_images = training_fmnist.data
training_labels = training_fmnist.targets

2. In a similar manner to training images, we must download and import the validation dataset by specifying `train = False` while calling the `FashionMNIST` method in our datasets:

In [8]:
validation_fmnist = datasets.FashionMNIST(data_folder, download=True, train=False)

validation_images = validation_fmnist.data
validation_labels = validation_fmnist.targets

3. Import relevant packages and define `device`:

In [9]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

total_epochs = 5

Device: cuda


4. Define dataset class, and helper functions

In [10]:
class FMNISTDataset(Dataset):
    def __init__(self, x, y):
        x = x.float()/255
        x = x.view(-1, 28*28)
        
    def __getitem__(self, idx):
        x, y = self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)


from torch.optim import SGD, Adam
def get_model(use_SGD=True, lr=1e-2):
    model = nn.Sequential(
        nn.Linear(28 * 28, 1000),
        nn.ReLu(),
        nn.Linear(1000, 10)
    ).to(device)

    loss_func = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=lr) if use_SGD else Adam(model.parameters(), lr=lr)

    return model, loss_func, optimizer


def train_batch(x, y, model, loss_func, optimizer):
    model.train()
    prediction = model(x)
    batch_loss = loss_func(prediction, y)
    batch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return batch_loss.item()

def accuracy(x, y, model):
    model.eval()
    with torch.no_grad():
        prediction = model(x)
    max_values, argmaxes = prediction.max(-1)
    is_correct = argmaxes == y
    return is_correct.cpu().numpy().tolist()

5. Define a function for getting data with `batch_size=` as a parameter:

In [11]:
def get_data(batch_size):
    training = FMNISTDataset(training_images, training_labels)
    training_dataloader = DataLoader(training, batch_size=batch_size, shuffle=True)
    validation = FMNISTDataset(validation_images, validation_labels)
    validation_dataloader= DataLoader(validation,
                                      batch_size=len(validation_images), shuffle=False)

    return training_dataloader, validation_dataloader

6. Define a function that calculates the loss of the validation data.

In [12]:
@torch.no_grad()
def validation_loss(x, y, model, loss_func):
    model.eval()
    prediction = model(x)
    loss = loss_func(prediction, y)
    return loss.item()

7. Fetch the training and validation DataLoaders and initialize model, loss function and optimizer

In [None]:
training_dataloader, validation_dataloader = get_data()
model, loss_func, optimizer = get_model()

8. Train the model

In [None]:
training_losses, training_accuracies = [], []
validation_losses, validation_accuracies = [], []

for epoch in range(total_epochs):
    training_epoch_losses, training_epoch_accuracies = [], []
    
    for idx, batch, in enumerate(iter(training_dataloader)):
        x, y = batch
        batch_loss = train_batch(x, y, model, loss_func, optimizer)
        training_epoch_losses.append(batch_loss)
    training_epoch_loss = np.array(training_epoch_losses).mean()

    for idx, batch in enumerate(iter(training_dataloader)):
        x, y = batch
        is_correct = accuracy(x, y, model)
        training_epoch_accuracies.extend(is_correct)
    training_epoch_accuracy = np.mean(training_epoch_accuracies) 

    for idx, batch in enumerate(iter(validation_dataloader)):
        x, y = batch
