# WM391 PMA Assessment

## Convolutional Neural Network for the Exposure Correction of Poorly Exposed Images

This notebook implements a Generative Adversarial Network for the purpose of generating well exposed images using training data based on images which are over or over exposed.

### Import required libraries

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from tqdm import tqdm
import albumentations as a
from albumentations.pytorch import ToTensorV2
from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader

### Set Parameters for the use of the Model

This configuration is setup to use the WM391_PMA_dataset. Use with other datasets will require modification of the dataloader

In [3]:
# Chooses the most appropriate device given the machines constraints
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Set path of dataset. Please change as appripriate
TRAIN_DIR = "WM391_PMA_dataset\\training"
VAL_DIR = "WM391_PMA_dataset\\validation"
# Determines how quickly the gradient is travelled for the machine learning model
LEARNING_RATE = 2e-4
# Sets the number of images that are sent to the device per iteration
BATCH_SIZE = 8
# Number of cpu threads used
NUM_WORKERS = 2
# Size of the images used to train the model
IMAGE_SIZE = 256
# Specifies the number of channels in the images input to the model
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
# Number of times the model is trained with the entire training dataset
NUM_EPOCHS = 10
# Load model weights & parameters from checkpoint state
LOAD_MODEL = False
# Save model weights & parameters to checkpoint file
SAVE_MODEL = True
# Set file location for the discriminator and generator checkpoint files
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

In [4]:
print(torch.cuda.is_available())

True


### Create Image Transformations

Apply a crop to all images to train the model with a consistent image size.

In [5]:
both_transform = a.Compose(
    [a.Resize(width=256, height=256),], additional_targets={"image0": "image"},
)

transform_varied_exposure = a.Compose(
    [
        a.HorizontalFlip(p=0.5),
        a.ColorJitter(p=0.2),
        a.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_ground_truth = a.Compose(
    [
        a.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

### Create a Dataset and Dataloader

Return the poorly exposed image and its corresponding ground truth image. This gives the GAN an input image and a target image to work towards for every example in the training dataset. Using pytorch dataloaders requries overwriting of the datasets ```__len__``` and the ```__getitem__``` methods to return the correct size of the dataset and an example from the dataset respectivly.

In [6]:
class ExposedImageDataset(Dataset):
    def __init__(self, root_dir, transform_both=None, transform_varied_exposure=None, transform_ground_truth=None):
        
        # Set paths to image directories
        self.root_dir = root_dir
        self.variable_exposure_path = os.path.join(root_dir, "INPUT_IMAGES")
        self.ground_truth_path = os.path.join(root_dir, "GT_IMAGES")

        # Initialise transforms to class variables
        self.transform_both = transform_both
        self.transform_varied_exposure = transform_varied_exposure
        self.transform_ground_truth = transform_ground_truth

        # Get the list of file names from the directories
        self.variable_exposure_images = os.listdir(self.variable_exposure_path)
        self.ground_truth_images = os.listdir(self.ground_truth_path)
        
        # Get length of individual dataset classes
        self.variable_exposure_len = len(self.variable_exposure_images)
        self.ground_truth_len = len(self.ground_truth_images)
        
        # Use the variable exposure length since it holds all the training images
        self.length_dataset = self.variable_exposure_len

    def __len__(self):
        return self.length_dataset

    def get_file_name(self, index):
        # Modulo input index to prevent an index out of range of the dataset
        index = index % self.length_dataset
        variable_exposure_image = self.variable_exposure_images[index]
        # Floor the ground truth index by 5 since there are 5 exposures for every corresponding ground truth
        ground_truth_image = self.ground_truth_images[index // 5]

        return variable_exposure_image, ground_truth_image

    def __getitem__(self, index):
        variable_exposure_image, ground_truth_image = self.get_file_name(index)

        # Create full path to image
        variable_exposure_image_path = os.path.join(self.variable_exposure_path, variable_exposure_image)
        ground_truth_image_path = os.path.join(self.ground_truth_path, ground_truth_image)

        # Open the image as an RGB numpy array
        variable_exposure_image = np.array(Image.open(variable_exposure_image_path).convert("RGB"))
        ground_truth_image = np.array(Image.open(ground_truth_image_path).convert("RGB"))

        # If there's an image transform for both images, apply the transform
        if self.transform_both:
            augentations = self.transform_both(image=variable_exposure_image, image0=ground_truth_image)
            variable_exposure_image = augentations["image"]
            ground_truth_image = augentations["image0"]

        # If ther's an image transform for the varied exposure image, apply the transform
        if self.transform_varied_exposure:
            variable_exposure_image = self.transform_varied_exposure(image=variable_exposure_image)["image"]

        # IF there's an image transform for the ground truth image, apply the transform
        if self.transform_ground_truth:
            ground_truth_image = self.transform_varied_exposure(image=ground_truth_image)["image"]

        return variable_exposure_image, ground_truth_image

### Test the Exposed Image Dataset

This snippet should return 5 image pairs (1 varied exposure and 1 ground truth). If everything is working correctly, the varied exposure and ground truth should produce tensors of size 256x256.

In [7]:
dataset = ExposedImageDataset(
    TRAIN_DIR,
    transform_both=both_transform,
    transform_varied_exposure=transform_varied_exposure,
    transform_ground_truth=transform_ground_truth
    )
loader = DataLoader(dataset, batch_size=5)
count = 0
for x, y in loader:
    if(count < 5):
        print("Variable exposure: {}".format(x.shape))
        print("Ground truth: {}".format(y.shape))
    else:
        break
    count+=1

Variable exposure: torch.Size([5, 3, 256, 256])
Ground truth: torch.Size([5, 3, 256, 256])
Variable exposure: torch.Size([5, 3, 256, 256])
Ground truth: torch.Size([5, 3, 256, 256])
Variable exposure: torch.Size([5, 3, 256, 256])
Ground truth: torch.Size([5, 3, 256, 256])
Variable exposure: torch.Size([5, 3, 256, 256])
Ground truth: torch.Size([5, 3, 256, 256])
Variable exposure: torch.Size([5, 3, 256, 256])
Ground truth: torch.Size([5, 3, 256, 256])


### Utility Functions to make model easier to train

In [8]:
def save_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

### Create Model for Discriminator

The Discriminator's job is to decide whether an image is real or fake. The learning from the discriminator is used to feedback to the generator to improve its ability to make images more like the target image.

In [9]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels * 2,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

In [10]:
x = torch.randn((1, 3, 256, 256))
y = torch.randn((1, 3, 256, 256))
model = Discriminator(in_channels=3)
preds = model(x, y)
print(model)
print(preds.shape)

Discriminator(
  (initial): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (model): Sequential(
    (0): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=reflect)
        (1): Batc

### Create model for the Generator

In [11]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up3 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up4 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))

In [12]:
x = torch.randn((1, 3, 256, 256))
model = Generator(in_channels=3, features=64)
preds = model(x)
print(preds.shape)


torch.Size([1, 3, 256, 256])


In [13]:
torch.backends.cudnn.benchmark = True


def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,):
    
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )


def main():
    disc = Discriminator(in_channels=3).to(DEVICE)
    gen = Generator(in_channels=3, features=64).to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
        )

    train_dataset = ExposedImageDataset(root_dir=TRAIN_DIR)
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    val_dataset = ExposedImageDataset(root_dir=VAL_DIR)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(NUM_EPOCHS):
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

        save_examples(gen, val_loader, epoch, folder="evaluation")


In [14]:
main()

  0%|          | 0/2210 [00:00<?, ?it/s]