# **Augmentations**


## Overview
This notebook provides a comprehensive guide on how to use data augmentation to improve the performance of deep learning models, with practical examples and a detailed workflow for training and evaluating models.

## Learning Objectives
+ ``Understand Data Augmentation``:
    - Learn why data augmentation is important for improving model generalization.
+ ``Explore Image Augmentation Techniques``:
    - Familiarize with transformations like rotation, translation, cropping, and color jitter.
+ ``Hands-on Application``:

    - Apply various augmentation techniques using torchvision.transforms.
+ ``Work with Medical Imaging Datasets``:
    - Use the BreastMNIST dataset for binary classification tasks.
+ ``Train and Evaluate Models``:
    - Train a ResNet-18 model on plain, augmented, and mixed datasets.
    - Compare model performance across different training strategies.
+ ``Use Key Python Libraries``:
    - Gain proficiency in PyTorch, torchvision, medmnist, tqdm, and torchshow.
+ ``Analyze Results``:
    - Evaluate and interpret model performance using appropriate metrics.
+ ``Best Practices``:
    - Implement structured workflows and document work for reproducibility.

## Prerequisites
**Data**

The dataset we are using is a subset of the MedMNIST dataset called BreastMNIST. It consists of 780 ultrasound images at 28x28 resolution. It has two classes for binary classification, (normal + benign) are positive and (malignant) is negative. [[1,2]](#1and2)

**Libraries**

* ``pytorch`` and ``torchvision``: these libraries focus on designing machine learning models.
* ``medmnist``: this library is specifically designed for reading and processing the MedMNIST dataset. It includes functions for data preparation and formatting.
* ``tqdm``: a library used to display the progress of code loops.
* ``torchshow``: a library used for visualization

## Get Started

In previous submodule we looked at classification as a whole. Now we will look at a method of improving the datasets which will consequently improve classification accuracy.

### Why augment the data
---
In deep learning, the objective is to train a model using training data in such a way that the model generalizes well to unseen (test) data. This can be difficult if our training set is small and/or doesn't have much variation as the data in the training set may not truly represent the underlying data distribution. To understand this, consider how you would study for a math exam. Would you keep working the same example of a problem(s) over and over? Or would you work many different examples of the problem(s)? Which method would result in better performance on the exam? The answer would be to work a variety of examples to build a better intuition on how the problem can be solved. This same concept applies to the process of deep learning. Unfortunately, in some cases a larger dataset is not always available. So what do we do when we aren't able to acquire more appropriate examples for the training data? Well, this is where data augmentation comes in. We can artificially expand our training set by applying augmentations to the existing data to increase its variation in some aspects.

![Figure 1: aug](aug0.png)

### Image Augmentation
---
We can create an augmented image by taking the original image and applying some transformation to it. Some common transformations are rotation, translation, distortion, and cropping. The augmentations you would select are dataset and problem dependent as you wouldn't want to perform an irrelevant augmentation. For example, if your dataset consisted of brain scans and you are trying to identify if an abnormality exists in the left or right hemisphere, you wouldn't want to perform a horizontal flip because now the label corresponds to the wrong hemisphere. In this submodule we will be using [BreastMNIST](https://medmnist.com/) dataset.

### Notebook workflow:
---
- <a href="#0">Examples of augmentations</a></br>
- <a href="#1">Multiple transformations at once</a></br>
- <a href="#2">Apply the augmentation sequence and retreve plain and augmented datasets</a></br>
- <a href="#A">Training ``ResNet`` from scratch on plain dataset (no augmentation used)</a></br>
    1. <a href="#A1">Create ResNet-18 model to be learned from scratch.</a></br>
    2. <a href="#A2">Train on plain training dataset.</a></br>
    3. <a href="#A3">Evaluate on testing dataset.</a></br>
- <a href="#B">Training ``ResNet`` from scratch on augmented dataset (without original dataset)</a></br>
    1. <a href="#B1">Create ResNet-18 model to be learned from scratch.</a></br>
    2. <a href="#B2">Train on augmented training dataset.</a></br>
    3. <a href="#B3">Evaluate on testing dataset.</a></br>
- <a href="#C">Training ``ResNet`` from scratch on a random mix of both plain and augmented datasets</a></br>
    1. <a href="#C1">Create ResNet-18 model to be learned from scratch.</a></br>
    2. <a href="#C2">Train on mixed training dataset.</a></br>
    3. <a href="#C3">Evaluate on testing dataset.</a></br>
- <a href="#3">Conclusion</a></br>
- <a href="#4">References</a></br>

In [None]:
!pip install tqdm torch git+https://github.com/xwying/torchshow.git@master
from torchvision import transforms
import torchvision.transforms.functional as tf
import torch.utils.data as data
import torchshow as ts
import medmnist
from medmnist import INFO, Evaluator

In [None]:
from PrepareDataset import Augment_Data, Get_DataSet_Information

In [None]:
from Loops import train_loop,test_loop,aug_train_loop

In [None]:
from Model import Create_Model_Optimizer_Criterion

## <a name="0">Augmentation examples</a> 

We will use the Augment_Data function defined in PrepareDataset.py to augment our dataset. We can then define a transformation with **torchvision transforms**'s compose function as follows:

``General Note``: uncomment the "transforms.Normalize(mean=[.5], std=[.5])" if you are using colored images otherwise it is not necessary for grayscale image and could give a warning.
### No Transform

In [None]:
A = Augment_Data(data_flag = 'chestmnist', download = True, batch_size = 4, data_transform =None)

### Horizontal Flip

In [None]:

augmentation_transform_1 = transforms.Compose([transforms.ToTensor(), 
                                               #transforms.Normalize(mean=[.5], std=[.5]), 
                                               lambda x: tf.hflip(x)]
                                             )
A = Augment_Data(data_flag = 'chestmnist', download = True, batch_size = 4, data_transform =augmentation_transform_1)

### Random Horizontal Flip

In [None]:
augmentation_transform_2 = transforms.Compose([transforms.ToTensor(), 
                                               #transforms.Normalize(mean=[.5], std=[.5]), 
                                               transforms.RandomHorizontalFlip()]
                                             )
A = Augment_Data(data_flag = 'chestmnist', download = True, batch_size = 4, data_transform =augmentation_transform_2)

### Random Brightness, Contrast, Saturation, and Hue

In [None]:
#random brightness contrast saturation hue
augmentation_transform_3 = transforms.Compose([transforms.ToTensor(), 
                                               #transforms.Normalize(mean=[.5], std=[.5]), 
                                               transforms.ColorJitter(brightness=(0,1), contrast=(0,1), saturation=(0,1), hue=(-0.5,0.5))]
                                             )
A = Augment_Data(data_flag = 'chestmnist', download = True, batch_size = 4, data_transform =augmentation_transform_3)

### Random Rotation

In [None]:
augmentation_transform_4 = transforms.Compose([transforms.ToTensor(), 
                                               #transforms.Normalize(mean=[.5], std=[.5]), 
                                               transforms.RandomRotation(degrees=(-180,180))]
                                             )
A = Augment_Data(data_flag = 'chestmnist', download = True, batch_size = 4, data_transform =augmentation_transform_4)

### Random Resize Crop

In [None]:
augmentation_transform_5 = transforms.Compose([transforms.ToTensor(), 
                                               #transforms.Normalize(mean=[.5], std=[.5]), 
                                               transforms.transforms.RandomResizedCrop(224)]
                                             ) #nearest pixel interpolation  
A = Augment_Data(data_flag = 'chestmnist', download = True, batch_size = 4, data_transform =augmentation_transform_5)

<div class="alert alert-block alert-info"> <b>Knowledge Check</b> </div>

In [None]:
!pip install jupyterquiz==2.0.7 --quiet
from jupyterquiz import display_quiz
display_quiz('../quiz_files/submodule_02/kc1.json')

## <a name="1">Multiple transformations at once</a> 
A very common sequence of transformations is applying normalization with mean and std of 0.5 followed by random resized crop and random horizontal flip.

In [None]:
Multiple_Augmentation_Transforms = transforms.Compose([transforms.ToTensor(),
                                                       #transforms.Normalize(mean=[.5], std=[.5]),
                                                       transforms.RandomResizedCrop(224),
                                                       transforms.RandomHorizontalFlip(),]
                                                      )

## <a name="2">Apply the augmentation sequence and retrieve plain and augmented datasets</a> 
We will use the Augment_Data function to prepare the data for training.

In [None]:
plain_train_loader,aug_train_loader,test_loader,train_evaluator,test_evaluator = Augment_Data(data_flag = 'breastmnist', download = True, batch_size = 16, data_transform = Multiple_Augmentation_Transforms,train_shuffle = True)
_, Num_Classes = Get_DataSet_Information(data_flag = 'breastmnist')

## <a name="A">Training ``ResNet`` from scratch on plain dataset (no augmentation used)</a> 

####    1. <a name="A1">Create ``ResNet-18`` model to be learned from scratch.</a> 

In [None]:
Model, Optimizer, Criterion = Create_Model_Optimizer_Criterion(n_classes = Num_Classes, feature_extract = False, use_pretrained = False, bw = True)

####    2. <a name="A2">Train on plain training dataset.</a> 

In [None]:
Model = train_loop(Model, plain_train_loader, None, Criterion, Optimizer, train_evaluator, num_epochs=10)

####    3. <a name="A3">Evaluate on testing dataset.</a> 

In [None]:
plain_metrics = test_loop(Model,test_loader,test_evaluator)

## <a name="B">Training ``ResNet`` from scratch on augmented dataset (without original dataset)</a> 

####    1. <a name="B1">Create ``ResNet-18`` model to be learned from scratch.</a> 

In [None]:
Model, Optimizer, Criterion = Create_Model_Optimizer_Criterion(n_classes = Num_Classes, feature_extract = False, use_pretrained = False, bw = True)

####    2. <a name="B2">Train on training dataset.</a> 

In [None]:
Model = train_loop(Model, aug_train_loader, None, Criterion, Optimizer, train_evaluator, num_epochs=10)

####    3. <a name="B3">Evaluate on testing dataset.</a> 

In [None]:
aug_metrics = test_loop(Model,test_loader,test_evaluator)

## <a name="C">Training ``ResNet`` from scratch on a random mix of both plain and augmented datasets</a> 

####    1. <a name="C1">Create ``ResNet-18`` model to be learned from scratch.</a> 

In [None]:
Model, Optimizer, Criterion = Create_Model_Optimizer_Criterion(n_classes = Num_Classes, feature_extract = False, use_pretrained = False,bw = True)

####    2. <a name="C2">Train on mixed training dataset.</a> 

In [None]:
Model = aug_train_loop(Model, plain_train_loader, aug_train_loader, Criterion, Optimizer, train_evaluator, num_epochs=10)

####    3. <a name="C3">Evaluate on testing dataset.</a> 

In [None]:
random_mix_metrics = test_loop(Model,test_loader,test_evaluator)

In [None]:
# test comparison
import matplotlib.pyplot as plt
import numpy as np
data = {}
data['Plain Dataset']=plain_metrics[1]
data['Augmented Dataset']=aug_metrics[1]
data['Mixed Dataset']=random_mix_metrics[1]
networks = list(data.keys())
accuracies = list(data.values())
  
fig = plt.figure(figsize = (10, 5))
 
# creating the bar plot
plt.bar(networks[0], accuracies[0], color ='blue',
        width = 0.3)
plt.bar(networks[1], accuracies[1], color ='red',
        width = 0.3)
plt.bar(networks[2], accuracies[2], color ='green',
        width = 0.3)
plt.ylim((0,1.))

plt.xlabel("ResNet-18 Networks")
plt.ylabel("Accuracies")
plt.title("Test Set Results")
plt.show()

<div class="alert alert-block alert-info"> <b>Knowledge Check</b> </div>

In [None]:
display_quiz('../quiz_files/submodule_02/kc2.json')

## Conclusion
We see that the model trained using mixed dataset (plain + augmented) outperforms the accuracies of using the plain dataset or the augmented dataset seperately.  

## Conclusion

We can see that the model trained using mixed dataset (plain + augmented) outperforms the accuracies of using the plain dataset or the augmented dataset seperately.  



## Clean up
To keep your workspaced organized remember to: 

1. Save your work.
2. Shut down any notebooks and active sessions to avoid extra charges.
