### Only Run on Google Colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/github/FYP_low_light_image_enhancement/

In [None]:
%ls

# Low Light Image Enhancement

### Import libraries

In [None]:
from data.custom_image_dataset import CustomImageDataset
from models.cycleGAN import CycleGANModel
from configs.option import Option
from torch.utils.data import DataLoader

import torch
import time

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Hyperparameters

In [None]:
img_dir = "./datasets/summer2winter_yosemite"
batch_size = 2
batch_shuffle = True

lr = 0.0001
lambda_A = 10
lambda_B = 10

n_epochs = 10000
print_freq = 10
save_freq = 10

### Load Dataset

In [None]:
# Load dataset
dataset = CustomImageDataset(
    img_dir=img_dir,
    opt=Option(phase="train")
)

In [None]:
# Load into dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=batch_shuffle)

In [None]:
dataloader_size = len(dataloader)

print("The number of training images = %d" % dataloader_size)

### Load Model

In [None]:
model = CycleGANModel(lr=lr, lamda_A=lambda_A, lamda_B=lambda_B, device=device)

### Start Training

In [None]:
total_iterations = 0

for epoch in range(n_epochs):
    start_time = time.time()
    iters_time = time.time()

    epoch_iter = 0

    for i, data in enumerate(dataloader):
        model.optimize_parameters(data['img_A'], data['img_B'])

        total_iterations += 1
        epoch_iter += len(data['img_A'])

        iter_start_time = time.time()

        if total_iterations % print_freq == 0:
            time_taken = iter_start_time - iters_time

            print("--------------------E%d-----------------------" % epoch)
            print("Current Iteration: %05d | Epoch Iteration: %05d" % (total_iterations, epoch_iter))
            print("Current Iteration Time Taken: %07ds | Current Epoch Running Time: %07ds" % (time_taken, iter_start_time - start_time))
            print("G(x) BCE Loss:", model.loss_G_X)
            print("D_Y  BCE Loss:", model.loss_D_Y)
            print("F(y) BCE Loss:", model.loss_F_Y)
            print("D_X  BCE Loss:", model.loss_D_X)
            print("X Collection Cycle L1 Loss:", model.loss_cycle_X)
            print("Y Collection Cycle L1 Loss:", model.loss_cycle_Y)
            print("Total Generators Loss:", model.loss_G)

        if total_iterations % save_freq == 0:
            print("Saving models...")
            model.save_model("./checkpoints", total_iterations)
            

        iters_time = time.time()

print("Saving trained model ...")
model.save_model("./checkpoints", epoch="trained")
