# Training a ResNet classifier to classify mushrooms
In this notebook, we will train and save a working ResNet model based on the ResNet9 architechture (LINK) to classify images of mushrooms in the norwegian flora. To this end, we first need to create a pipeline to load in our data, preprocess it and feed it to a training loop in mini-batches. We must further design the residual convolutional blocks used in the ResNet, as well as the final model. 


## Preparing dataset and preprocessing of data
Before declaring residual convolutional blocks and the ResNet model, we should make sure all data can be loaded, preprocessed and iterated over in a consistent, precise manner. To this end, we will declare a PyTorch dataset and a PyTorch preprocessing step, all present and pre-loaded into a dataset instance `data`. 
### Defining a PyTorch dataset for image data
The image data and subsequent labels will be accessed and loaded into memory using a custom `MushroomDataset`-class, inheriting from PyTorch standard dataset-class in `torch.utils.Dataset`. Not only will creating a separate class streamline the retrieval and preprocessing of data, the inherited functionality allows for the seamless division into mini-batches using PyTorch Dataloaders, which should allow for better, less resource intensive training down the line: 

In [88]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from pathlib import Path

# Create a custom dataset to simplify the use and indexing of the custom mushroom dataset
class MushroomDataset(Dataset):
    # Overload the init funtion to capture the image directory, transform and load in the labels
    def __init__(self, path_imgs: str, path_labels: str, transform = None) -> None:
        self.img_dir = path_imgs
        self.labels = pd.read_csv(path_labels)
        self.transform = transform
    
    # Overload the len(..) operator to give the length of all labels
    def __len__(self) -> int:
        return self.labels.shape[0]
    
    # Overload the [index] indexator to yield an image (s.t transforms) and it's corresponding label
    def __getitem__(self, index):
        # Find the image path and load the image
        img_path = Path(f"{self.img_dir}/{self.labels.iloc[index, 0]}.jpg")
        img = plt.imread(img_path)

        # Load the corresponding image label
        label = torch.tensor(self.labels.iloc[index, 1], dtype=torch.int16)

        # If a transform is specified, apply it to the image
        if self.transform:
            img = self.transform(img)

        return (img, label)

### Defining a preprocessing pipeline
The `MushroomDataset`-class contains a `transform` parameter, which will be used to apply a set of simple, yet important, preprocessing steps to the image data. Essentially, we wish to normalize all color channels of the image data for the better convergence of the employed nonlinear optimization scheme during training, as well as transform the data into PyTorch tensors.

Before defining the preprocessing pipeline however, we need to note the average mean- and standard deviation of all color channels across our dataset. This will serve as the backbone for our normalization scheme, and the values should be found experimentally. Below, the mean and standard deviations of the separate color channels of all images are accumulated into `mean_liet` and `std_list`, before being averaged and returned. This yields the necessary data for our normalization pipeline:

In [89]:
BASE_DIR = Path("01_Training_RestNet_Classifier.ipynb").parent.resolve()

# Find a list of all .jpg image-files in the dataset
img_paths = Path(f"{BASE_DIR}/data/mushroom_imgs").rglob('*.jpg')

# Find the mean of all color channels by accumulating each value over all available images
mean_list, std_list = np.array([0, 0, 0]), np.array([0, 0, 0])

for count, path in enumerate(img_paths):
    # Load in the image, convert it to a torch.Tensor and permute for correct dimensions
    img = torch.Tensor(plt.imread(str(path))).permute((2, 0, 1))
    
    # Perform elementwise addition using np.add
    mean_list = np.add(mean_list, img.mean([1, 2]))
    std_list = np.add(std_list, img.std([1, 2]))

# Perform elementwise division with the counter to get the average mean/std of all color channels across all images 
mean_list, std_list = mean_list / (count+1), std_list / (count + 1)

With the mean and standard deviations, we can define a simple preprocessing pipeline using a composite transformation from `torchvision.transforms.Compose`. An image fed to the composite transformation will first be converted into a PyTorch tensor, before being normalized accross all available color channels.

In [90]:
from torchvision import transforms

# Define a composite transform to preprocess the data
preprocessing_pipeline = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(list(mean_list), list(std_list))
])  

### Create training/test sets w. Dataloaders
An instance `data` of the `MushroomDataset`-class, preprocessed using `preprocessing_pipeline`, can be split into a training and test set using `torch.utils.data.random_split()`. Here, this will be done using 25% of the data for validation:  

In [91]:
# Define the absolute paths to the image data and subsequent labels
IMAGE_DIR = Path(f"{BASE_DIR}/data/mushroom_imgs")
LABEL_DIR = Path(f"{BASE_DIR}/data/mushroom_imgs/img_labels.csv")

# Instantiate the dataset
data = MushroomDataset(IMAGE_DIR, LABEL_DIR, preprocessing_pipeline)

# Divide the dataset into training and test sets using pytorch's 'random_split' method:
train_data, test_data = torch.utils.data.random_split(data, [0.75, 0.25])

For each subset of data, we can now create a dataloader from `torch.utils.data.Dataloader`, allowing us to iterate through the dataset in shuffled mini-batches: 

In [94]:
import os
from torch.utils.data import DataLoader

# Define the BATCH_SIZE hyperparameter deciding the amount of images in each mini-batch during training
# NOTE: This should be tuned as a hyperparameter
BATCH_SIZE = 32

# Declare the dataloaders
train_dataloader = DataLoader(dataset = train_data,
                              batch_size = BATCH_SIZE,
                              shuffle = True)

test_dataloader = DataLoader(dataset = test_data,
                            batch_size = BATCH_SIZE,
                            shuffle = True)


## Defining the ResNet Model