# WM391 PMA Assessment

## Convolutional Neural Network for the Exposure Correction of Poorly Exposed Images

This notebook implements a Generative Adversarial Network for the purpose of generating well exposed images using training data based on images which are over or over exposed.

### Import required libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader

### Set Parameters for the use of the Model

This configuration is setup to use the WM391_PMA_dataset. Use with other datasets will require modification of the dataloader

In [2]:
# Chooses the most appropriate device given the machines constraints
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Set path of dataset. Please change as appripriate
TRAIN_DIR = "WM391_PMA_dataset\\training"
VAL_DIR = "WM391_PMA_dataset\\validation"
# Determines how quickly the gradient is travelled for the machine learning model
LEARNING_RATE = 2e-4
# Sets the number of images that are sent to the device per iteration
BATCH_SIZE = 64
# Number of cpu threads used
NUM_WORKERS = 2
# Size of the images used to train the model
IMAGE_SIZE = 256
# Specifies the number of channels in the images input to the model
CHANNELS_IMG = 3
#L1_LAMBDA = 100
#LAMBDA_GP = 10
# Number of times the model is trained with the entire training dataset
NUM_EPOCHS = 500
# Load model weights & parameters from checkpoint state
LOAD_MODEL = False
# Save model weights & parameters to checkpoint file
SAVE_MODEL = False
# Set file location for the discriminator and generator checkpoint files
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

### Create a Dataset and Dataloader

Return the poorly exposed image and its corresponding ground truth image. This gives the GAN an input image and a target image to work towards for every example in the training dataset. Using pytorch dataloaders requries overwriting of the datasets ```__len__``` and the ```__getitem__``` methods to return the correct size of the dataset and an example from the dataset respectivly.

In [3]:
class ExposedImageDataset(Dataset):
    def __init__(self, root_dir, transform_both=None, transform_varied_exposure=None, transform_ground_truth=None):
        
        # Set paths to image directories
        self.root_dir = root_dir
        self.variable_exposure_path = os.path.join(root_dir, "INPUT_IMAGES")
        self.ground_truth_path = os.path.join(root_dir, "GT_IMAGES")

        # Initialise transforms to class variables
        self.transform_both = transform_both
        self.transform_varied_exposure = transform_varied_exposure
        self.transform_ground_truth = transform_ground_truth

        # Get the list of file names from the directories
        self.variable_exposure_images = os.listdir(self.variable_exposure_path)
        self.ground_truth_images = os.listdir(self.ground_truth_path)
        
        # Get length of individual dataset classes
        self.variable_exposure_len = len(self.variable_exposure_images)
        self.ground_truth_len = len(self.ground_truth_images)
        
        # Use the variable exposure length since it holds all the training images
        self.length_dataset = self.variable_exposure_len

    def __len__(self):
        return self.length_dataset

    def get_file_name(self, index):
        # Modulo input index to prevent an index out of range of the dataset
        index = index % self.length_dataset
        variable_exposure_image = self.variable_exposure_images[index]
        # Floor the ground truth index by 5 since there are 5 exposures for every corresponding ground truth
        ground_truth_image = self.ground_truth_images[index // 5]

        return variable_exposure_image, ground_truth_image

    def __getitem__(self, index):
        variable_exposure_image, ground_truth_image = self.get_file_name(index)

        # Create full path to image
        variable_exposure_image_path = os.path.join(self.variable_exposure_path, variable_exposure_image)
        ground_truth_image_path = os.path.join(self.ground_truth_path, ground_truth_image)

        # Open the image as an RGB numpy array
        variable_exposure_image = np.array(Image.open(variable_exposure_image_path).convert("RGB"))
        ground_truth_image = np.array(Image.open(ground_truth_image_path).convert("RGB"))

        # If there's an image transform for both images, apply the transform
        if self.transform_both:
            augentations = self.transform_both(image=variable_exposure_image, image0=ground_truth_image)
            variable_exposure_image = augentations["image"]
            ground_truth_image = augentations["image0"]

        # If ther's an image transform for the varied exposure image, apply the transform
        if self.transform_varied_exposure:
            variable_exposure_image = self.transform_varied_exposure(image=variable_exposure_image)["image"]

        # IF there's an image transform for the ground truth image, apply the transform
        if self.transform_ground_truth:
            ground_truth_image = self.transform_varied_exposure(image=ground_truth_image)["image"]

        return variable_exposure_image, ground_truth_image

### Test the Exposed Image Dataset

In [10]:
dataset = ExposedImageDataset(TRAIN_DIR)
loader = DataLoader(dataset, batch_size=5)
count = 0
for x, y in loader:
    if(count < 5):
        print("Variable exposure: {}".format(x.shape))
        print("Ground truth: {}".format(y.shape))
    else:
        break
    count+=1

Variable exposure: torch.Size([5, 600, 903, 3])
Ground truth: torch.Size([5, 600, 903, 3])
Variable exposure: torch.Size([5, 1277, 850, 3])
Ground truth: torch.Size([5, 1277, 850, 3])
Variable exposure: torch.Size([5, 1167, 778, 3])
Ground truth: torch.Size([5, 1167, 778, 3])
Variable exposure: torch.Size([5, 874, 1311, 3])
Ground truth: torch.Size([5, 874, 1311, 3])
Variable exposure: torch.Size([5, 1052, 701, 3])
Ground truth: torch.Size([5, 1052, 701, 3])
