### Only Run on Google Colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/github/FYP_low_light_image_enhancement/

# Low Light Image Enhancement

### Import libraries

In [1]:
from data.custom_image_dataset import CustomImageDataset
from models.enlighten import EnlightenGAN
from configs.option import Option
from torch.utils.data import DataLoader

import torch
import time

  device: torch.device = torch.device("cpu"),


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Hyperparameters

In [3]:
img_dir = "./datasets/light_enhancement"
checkpoint_dir = "./checkpoints/enlightenGAN/"
batch_size = 32
batch_shuffle = True

lr = 0.0001

n_epochs = 100
print_freq = 1000
save_freq = 15000

### Load Dataset

In [None]:
# Load dataset
dataset = CustomImageDataset(
    img_dir=img_dir,
    opt=Option(phase="train")
)

In [None]:
# Load into dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=batch_shuffle)

In [None]:
dataloader_size = len(dataloader)

print("The number of training images = %d" % dataloader_size)

### Load Model

In [4]:
model = EnlightenGAN(use_src=True, lr=lr, device=device)

In [7]:
print(model.G)

Unet_resize_conv(
  (conv1_1): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (downsample_1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (downsample_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (downsample_3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (downsample_4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (LReLU1_1): LeakyReLU(negative_slope=0.2, inplace=True)
  (bn1_1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1_2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (LReLU1_2): LeakyReLU(negative_slope=0.2, inplace=True)
  (bn1_2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2_1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), pa

In [12]:
from torchvision import models

print(models.vgg16().features[4])

MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)


### Start Training

In [None]:
total_iterations = 0
train_start_time = time.time()

n_print = 1
n_save = 1

for epoch in range(n_epochs):
    start_time = time.time()

    epoch_iter = 0

    for i, data in enumerate(dataloader):
        model.set_input(data)
        model.optimize_parameters()

        total_iterations += len(data['img_A'])
        epoch_iter += len(data['img_A'])

        if total_iterations > (print_freq * n_print):
            time_taken = time.time() - train_start_time

            print("--------------------E%d-----------------------" % (epoch+1))
            print("Current Iteration: %05d | Epoch Iteration: %05d" % (print_freq * n_print, epoch_iter))
            print("Current Time Taken: %07ds | Current Epoch Running Time: %07ds" % (time_taken, time.time() - start_time))
            print("SPA Loss: %.7f | Color Loss: %.7f" % (model.loss_spa, model.loss_color))
            print("RAGAN Loss for Global D: %.7f | Local D: %.7f" % (model.loss_D, model.loss_patch_D))
            print("RAGAN Loss for Global G: %.7f | Local G: %.7f" % (model.loss_G, model.loss_G_patch))
            print("SFP Loss for Global G  : %.7f | Local G: %.7f" % (model.loss_G_SFP, model.loss_G_SFP_patch))
            print(f"Total generator loss: {model.total_loss_G}")
            n_print += 1

        if total_iterations > (save_freq * n_save):
            print("Saving models...")
            model.save_model(checkpoint_dir, save_freq * n_save)
            n_save += 1
            

print(f"Total time taken: {time.time() - train_start_time}")
print("Saving trained model ...")
model.save_model(checkpoint_dir, epoch="trained")
