# ESRGAN

## Imports

## Create LR images

In [None]:
'''
import cv2

scale = 1/4

inputdir = './data/DIV2K/ESRGAN/valid'
outputdir = './data/DIV2K/ESRGAN/validLR'
try:
  os.makedirs(outputdir)
except:
  print('lr dir exists')
# Get all image paths
image_file_names = os.listdir(inputdir)

for image_file_name in image_file_names:
  image = cv2.imread(f"{inputdir}/{image_file_name}", cv2.IMREAD_UNCHANGED)
  img_resized = cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
  cv2.imwrite(f"{outputdir}/{image_file_name.split('.')[-2]}.{image_file_name.split('.')[-1]}", img_resized)
'''

## Divide into train valid test

In [None]:
import numpy as np
import random
import cv2

percent_train = 0.8
percent_valid = 0.6

path = "./data/Bubbles"
path_full = path+"/bubblesHR"

os.makedirs(path+"/train")
os.makedirs(path+"/valid")
os.makedirs(path+"/test")

image_size = 480
'''
full_dataset = TrainValidImageDataset(path_full,
                                            image_size,
                                            esrgan_config.upscale_factor,
                                            "Valid")


train_size = int(percent_train * len(full_dataset))
valid_size = int((len(full_dataset) - train_size)*percent_valid)
test_size = int((len(full_dataset) - train_size)*(1-percent_valid))
if train_size+valid_size+test_size < len(full_dataset):
  train_size+=1
elif train_size+valid_size+test_size > len(full_dataset):
  train_size-=1
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, valid_size, test_size])

torch.save(data[0], '/content/train_loader/img/train_transformed_img{}'.format(i))
'''

# Get all image paths
image_file_names = os.listdir(path_full)

for image_file_name in image_file_names:
  image = cv2.imread(f"{path_full}/{image_file_name}", cv2.IMREAD_UNCHANGED)

  save_dir = ""
  if percent_train > random.random(): #train
    save_dir = path+"/train"
  elif percent_valid > random.random(): # valid
    save_dir = path+"/valid"
  else:
    save_dir = path+"/test"
  
  # Save image
  
  cv2.imwrite(f"{save_dir}/{image_file_name}", image)


## Train

In [1]:
# Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import time

import torch
from torch import nn
from torch import optim
from torch.cuda import amp
from torch.optim import lr_scheduler
from torch.optim.swa_utils import AveragedModel
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import mlflow

import esrgan_config


import model
from dataset import CUDAPrefetcher, TrainValidImageDataset, TestImageDataset
from image_quality_assessment import PSNR, SSIM, NIQE
from lpips import LPIPS
from utils import load_state_dict, make_directory, save_checkpoint, AverageMeter, ProgressMeter

model_names = sorted(
    name for name in model.__dict__ if
    name.islower() and not name.startswith("__") and callable(model.__dict__[name]))


def main():
    # Initialize the number of training epochs
    start_epoch = 0

    # Initialize training to generate network evaluation indicators
    best_psnr = 0.0
    best_ssim = 0.0

    train_prefetcher, valid_prefetcher = load_dataset()
    print("Load all datasets successfully.")

    d_model, g_model, ema_g_model = build_model()
    print(f"Build `{esrgan_config.g_arch_name}` model successfully.")

    pixel_criterion, content_criterion, adversarial_criterion = define_loss()
    print("Define all loss functions successfully.")

    # Start MLFlow Tracking
    try:
        mlflow.set_experiment(esrgan_config.experience_name)
    except:
        experiment_id= mlflow.create_experiment(esrgan_config.experience_name)
        print("New Experiment created with name: " + esrgan_config.experience_name + " and ID: " + str(experiment_id))
    

    
    print("Check whether to load pretrained d model weights...")
    if esrgan_config.pretrained_d_model_weights_path:
        #d_model = mlflow.pytorch.load_model(esrgan_config.pretrained_d_model_weights_path)
        d_model = load_state_dict(d_model, esrgan_config.pretrained_d_model_weights_path)
        print(f"Loaded `{esrgan_config.pretrained_d_model_weights_path}` pretrained model weights successfully.")
    else:
        print("Pretrained d model weights not found.")

    print("Check whether to load pretrained g model weights...")
    if esrgan_config.pretrained_g_model_weights_path:
        #g_model = mlflow.pytorch.load_model(esrgan_config.pretrained_g_model_weights_path)
        g_model = load_state_dict(g_model, esrgan_config.pretrained_g_model_weights_path)
        print(f"Loaded `{esrgan_config.pretrained_g_model_weights_path}` pretrained model weights successfully.")
    else:
        print("Pretrained g model weights not found.")

    # Define optimizers
    d_optimizer, g_optimizer = define_optimizer(d_model, g_model)
    print("Define all optimizer functions successfully.")

    d_scheduler, g_scheduler = define_scheduler(d_optimizer, g_optimizer)
    print("Define all optimizer scheduler functions successfully.")

    '''
    print("Check whether the pretrained d model is restored...")
    if esrgan_config.resume_d_model_weights_path:
        d_model, _, start_epoch, best_psnr, best_ssim, optimizer, scheduler = load_state_dict(
            d_model,
            esrgan_config.resume_d_model_weights_path,
            optimizer=d_optimizer,
            scheduler=d_scheduler,
            load_mode="resume")
        print("Loaded pretrained model weights.")
    else:
        print("Resume training d model not found. Start training from scratch.")

    print("Check whether the pretrained g model is restored...")
    if esrgan_config.resume_g_model_weights_path:
        lsrresnet_model, ema_lsrresnet_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler = load_state_dict(
            g_model,
            esrgan_config.resume_g_model_weights_path,
            ema_model=ema_g_model,
            optimizer=g_optimizer,
            scheduler=g_scheduler,
            load_mode="resume")
        print("Loaded pretrained model weights.")
    else:
        print("Resume training g model not found. Start training from scratch.")
    '''

    # Create a experiment results
    samples_dir = os.path.join("samples", esrgan_config.exp_name)
    results_dir = os.path.join("results", esrgan_config.exp_name)
    make_directory(samples_dir)
    make_directory(results_dir)

    # Create training process log file
    writer = SummaryWriter(os.path.join("samples", "logs", esrgan_config.exp_name))


    # Initialize the gradient scaler
    scaler = amp.GradScaler()

    # Create an IQA evaluation model
    psnr_model = PSNR(esrgan_config.upscale_factor, esrgan_config.only_test_y_channel)
    ssim_model = SSIM(esrgan_config.upscale_factor, esrgan_config.only_test_y_channel)
    niqe_model = NIQE(esrgan_config.upscale_factor, esrgan_config.niqe_model_path)
    lpips_model = LPIPS(net=esrgan_config.lpips_net)

    # Transfer the IQA model to the specified device
    psnr_model = psnr_model.to(device=esrgan_config.device)
    ssim_model = ssim_model.to(device=esrgan_config.device)
    niqe_model = niqe_model.to(device=esrgan_config.device, non_blocking=True)
    lpips_model = lpips_model.to(device=esrgan_config.device, non_blocking=True)


    best_lpips_metrics = 1.0

    # Start MLflow run & log parameters 
    try:
      mlflow.start_run(run_name=esrgan_config.run_name, tags=esrgan_config.tags, description=esrgan_config.description)
    except: # If last session was not ended
      mlflow.end_run()
      mlflow.start_run(run_name=esrgan_config.run_name, tags=esrgan_config.tags, description=esrgan_config.description)
    
    run = mlflow.active_run()
    print("Active run_id: {}".format(run.info.run_id))

    #mlflow.log_params({'exp_name':esrgan_config.exp_name,'d_arch_name':esrgan_config.d_arch_name,'g_arch_name':esrgan_config.g_arch_name,'in_channels':esrgan_config.in_channels,'out_channels':esrgan_config.out_channels,'channels':esrgan_config.channels,'growth_channels':esrgan_config.growth_channels,'num_blocks':esrgan_config.num_blocks,'upscale_factor':esrgan_config.upscale_factor,'gt_image_size':esrgan_config.gt_image_size,'batch_size':esrgan_config.batch_size,'train_gt_images_dir':esrgan_config.train_gt_images_dir,'test_gt_images_dir':esrgan_config.test_gt_images_dir,'test_lr_images_dir':esrgan_config.test_lr_images_dir,'pretrained_d_model_weights_path':esrgan_config.pretrained_d_model_weights_path,'pretrained_g_model_weights_path':esrgan_config.pretrained_g_model_weights_path,'resume_d_model_weights_path':esrgan_config.resume_d_model_weights_path,'resume_g_model_weights_path':esrgan_config.resume_g_model_weights_path,'epochs':esrgan_config.epochs,'pixel_weight':esrgan_config.pixel_weight,'content_weight':esrgan_config.content_weight,'adversarial_weight':esrgan_config.adversarial_weight,'feature_model_extractor_node':esrgan_config.feature_model_extractor_node,'feature_model_normalize_mean':esrgan_config.feature_model_normalize_mean,'feature_model_normalize_std':esrgan_config.feature_model_normalize_std,'model_lr':esrgan_config.model_lr,'model_betas':esrgan_config.model_betas,'model_eps':esrgan_config.model_eps,'model_weight_decay':esrgan_config.model_weight_decay,'model_ema_decay':esrgan_config.model_ema_decay,'lr_scheduler_milestones':esrgan_config.lr_scheduler_milestones,'lr_scheduler_gamma':esrgan_config.lr_scheduler_gamma,'lpips_net':esrgan_config.lpips_net,'niqe_model_path':esrgan_config.niqe_model_path})
    mlflow.log_params({'exp_name':esrgan_config.exp_name,'d_arch_name':esrgan_config.d_arch_name,'g_arch_name':esrgan_config.g_arch_name,'in_channels':esrgan_config.in_channels,'out_channels':esrgan_config.out_channels,'channels':esrgan_config.channels,'growth_channels':esrgan_config.growth_channels,'num_blocks':esrgan_config.num_blocks,'upscale_factor':esrgan_config.upscale_factor,'gt_image_size':esrgan_config.gt_image_size,'batch_size':esrgan_config.batch_size,'train_gt_images_dir':esrgan_config.train_gt_images_dir,'valid_gt_images_dir':esrgan_config.valid_gt_images_dir,
                       'pretrained_d_model_weights_path':esrgan_config.pretrained_d_model_weights_path,'pretrained_g_model_weights_path':esrgan_config.pretrained_g_model_weights_path,'resume_d_model_weights_path':esrgan_config.resume_d_model_weights_path,'resume_g_model_weights_path':esrgan_config.resume_g_model_weights_path,'epochs':esrgan_config.epochs,'pixel_weight':esrgan_config.pixel_weight,'content_weight':esrgan_config.content_weight,'adversarial_weight':esrgan_config.adversarial_weight,'feature_model_extractor_node':esrgan_config.feature_model_extractor_node,'feature_model_normalize_mean':esrgan_config.feature_model_normalize_mean,'feature_model_normalize_std':esrgan_config.feature_model_normalize_std,'model_lr':esrgan_config.model_lr,'model_betas':esrgan_config.model_betas,'model_eps':esrgan_config.model_eps,'model_weight_decay':esrgan_config.model_weight_decay,'model_ema_decay':esrgan_config.model_ema_decay,'lr_scheduler_milestones':esrgan_config.lr_scheduler_milestones,'lr_scheduler_gamma':esrgan_config.lr_scheduler_gamma,'lpips_net':esrgan_config.lpips_net,'niqe_model_path':esrgan_config.niqe_model_path})


    for epoch in range(start_epoch, esrgan_config.epochs):
        pixel_loss, content_loss, adversarial_loss, d_gt_probabilities, d_sr_probabilities= train(d_model,
              g_model,
              ema_g_model,
              train_prefetcher,
              pixel_criterion,
              content_criterion,
              adversarial_criterion,
              d_optimizer,
              g_optimizer,
              epoch,
              scaler,
              writer)

        psnr_val, ssim_val, niqe_val, lpips_val = validate(g_model,
                              valid_prefetcher,
                              epoch,
                              writer,
                              psnr_model,
                              ssim_model,
                              niqe_model,
                              lpips_model,
                              "Valid")

        print("\n")

        log_epoch(g_model, d_model, pixel_loss, content_loss, adversarial_loss, d_gt_probabilities, d_sr_probabilities, psnr_val, ssim_val, niqe_val, lpips_val, epoch)

        # Update LR
        d_scheduler.step()
        g_scheduler.step()

        # Save the best model with the highest LPIPS score in validation dataset
        is_best = lpips_val < best_lpips_metrics
        best_lpips_metrics = min(lpips_val, best_lpips_metrics)

        if is_best:
          print("Saving best model...")
          mlflow.pytorch.log_model(g_model, "g_model")
          mlflow.pytorch.log_model(d_model, "d_model")
          print("Finished Saving")
        else:
          print("Was not the best")

    # End logging
    mlflow.end_run()




        
def log_epoch(g_model, d_model, g_pixel_loss, g_content_loss, g_adversarial_loss, d_gt_probabilities, d_sr_probabilities, psnr_val, ssim_val, niqe_val, lpips_val, epoch):
    '''
    g_pixel_loss, g_content_loss, g_adversarial_loss: train generator loss
    d_gt_probabilities, d_sr_probabilities: descriminator probabilities
    psnr, ssim, niqe, lpips: validation metrics
    '''

    print('\nLogging epoch data...')

    g_train_loss = g_pixel_loss + g_content_loss + g_adversarial_loss

    mlflow.log_metrics({'g_train_loss':g_train_loss, 'g_pixel_loss':g_pixel_loss, 'g_content_loss':g_content_loss, 'g_adversarial_loss':g_adversarial_loss, 'd_gt_probabilities':d_gt_probabilities, 'd_sr_probabilities':d_sr_probabilities, 'psnr_val':psnr_val, 'ssim_val':ssim_val, 'niqe_val':niqe_val, 'lpips_val':lpips_val}, step=epoch)

    print('Finished Logging\n')


def load_dataset() -> [CUDAPrefetcher, CUDAPrefetcher]:
    # Load train, test and valid datasets
    train_datasets = TrainValidImageDataset(esrgan_config.train_gt_images_dir,
                                            esrgan_config.gt_image_size,
                                            esrgan_config.upscale_factor,
                                            "Train")
    '''valid_datasets = TestImageDataset(esrgan_config.valid_gt_images_dir, esrgan_config.valid_lr_images_dir)
    test_datasets = TestImageDataset(esrgan_config.test_gt_images_dir, esrgan_config.test_lr_images_dir)'''

    valid_datasets = TrainValidImageDataset(esrgan_config.valid_gt_images_dir,
                                        esrgan_config.gt_image_size,
                                        esrgan_config.upscale_factor,
                                        "Valid")

    # Generator all dataloader
    train_dataloader = DataLoader(train_datasets,
                                  batch_size=esrgan_config.batch_size,
                                  shuffle=True,
                                  num_workers=esrgan_config.num_workers,
                                  pin_memory=True,
                                  drop_last=True,
                                  persistent_workers=True)
    
    valid_dataloader = DataLoader(valid_datasets,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1,
                                 pin_memory=True,
                                 drop_last=False,
                                 persistent_workers=True)

    # Place all data on the preprocessing data loader
    train_prefetcher = CUDAPrefetcher(train_dataloader, esrgan_config.device)
    valid_prefetcher = CUDAPrefetcher(valid_dataloader, esrgan_config.device)

    return train_prefetcher, valid_prefetcher


def build_model() -> [nn.Module, nn.Module, nn.Module]:
    d_model = model.__dict__[esrgan_config.d_arch_name]()
    g_model = model.__dict__[esrgan_config.g_arch_name](in_channels=esrgan_config.in_channels,
                                                        out_channels=esrgan_config.out_channels,
                                                        channels=esrgan_config.channels,
                                                        growth_channels=esrgan_config.growth_channels,
                                                        num_blocks=esrgan_config.num_blocks)
    d_model = d_model.to(device=esrgan_config.device)
    g_model = g_model.to(device=esrgan_config.device)

    # Create an Exponential Moving Average Model
    ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged: (1 - esrgan_config.model_ema_decay) * averaged_model_parameter + esrgan_config.model_ema_decay * model_parameter
    ema_g_model = AveragedModel(g_model, avg_fn=ema_avg)

    return d_model, g_model, ema_g_model


def define_loss() -> [nn.L1Loss, model.content_loss, nn.BCEWithLogitsLoss]:
    pixel_criterion = nn.L1Loss()
    content_criterion = model.content_loss(esrgan_config.feature_model_extractor_node,
                                           esrgan_config.feature_model_normalize_mean,
                                           esrgan_config.feature_model_normalize_std)
    adversarial_criterion = nn.BCEWithLogitsLoss()

    # Transfer to CUDA
    pixel_criterion = pixel_criterion.to(device=esrgan_config.device)
    content_criterion = content_criterion.to(device=esrgan_config.device)
    adversarial_criterion = adversarial_criterion.to(device=esrgan_config.device)

    return pixel_criterion, content_criterion, adversarial_criterion


def define_optimizer(d_model, g_model) -> [optim.Adam, optim.Adam]:
    d_optimizer = optim.Adam(d_model.parameters(),
                             esrgan_config.model_lr,
                             esrgan_config.model_betas,
                             esrgan_config.model_eps,
                             esrgan_config.model_weight_decay)
    g_optimizer = optim.Adam(g_model.parameters(),
                             esrgan_config.model_lr,
                             esrgan_config.model_betas,
                             esrgan_config.model_eps,
                             esrgan_config.model_weight_decay)

    return d_optimizer, g_optimizer


def define_scheduler(
        d_optimizer: optim.Adam,
        g_optimizer: optim.Adam
) -> [lr_scheduler.MultiStepLR, lr_scheduler.MultiStepLR]:
    d_scheduler = lr_scheduler.MultiStepLR(d_optimizer,
                                           esrgan_config.lr_scheduler_milestones,
                                           esrgan_config.lr_scheduler_gamma)
    g_scheduler = lr_scheduler.MultiStepLR(g_optimizer,
                                           esrgan_config.lr_scheduler_milestones,
                                           esrgan_config.lr_scheduler_gamma)
    return d_scheduler, g_scheduler


def train(
        d_model: nn.Module,
        g_model: nn.Module,
        ema_g_model: nn.Module,
        train_prefetcher: CUDAPrefetcher,
        pixel_criterion: nn.L1Loss,
        content_criterion: model.content_loss,
        adversarial_criterion: nn.BCEWithLogitsLoss,
        d_optimizer: optim.Adam,
        g_optimizer: optim.Adam,
        epoch: int,
        scaler: amp.GradScaler,
        writer: SummaryWriter
):
    '''
    Returns average of key metrics (all in progress meter)
    '''
    print("Training")
    # Calculate how many batches of data are in each Epoch
    batches = len(train_prefetcher)
    # Print information of progress bar during training
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    pixel_losses = AverageMeter("Pixel loss", ":6.6f")
    content_losses = AverageMeter("Content loss", ":6.6f")
    adversarial_losses = AverageMeter("Adversarial loss", ":6.6f")
    d_gt_probabilities = AverageMeter("D(GT)", ":6.3f")
    d_sr_probabilities = AverageMeter("D(SR)", ":6.3f")
    progress = ProgressMeter(batches,
                             [batch_time, data_time,
                              pixel_losses, content_losses, adversarial_losses,
                              d_gt_probabilities, d_sr_probabilities],
                             prefix=f"Epoch: [{epoch + 1}]")

    # Put the generative network model in training mode
    d_model.train()
    g_model.train()

    # Initialize the number of data batches to print logs on the terminal
    batch_index = 0

    # Initialize the data loader and load the first batch of data
    train_prefetcher.reset()
    batch_data = train_prefetcher.next()

    # Get the initialization training time
    end = time.time()

    #limit = 12

    while batch_data is not None:
        # Calculate the time it takes to load a batch of data
        data_time.update(time.time() - end)

        # Transfer in-memory data to CUDA devices to speed up training
        gt = batch_data["gt"].to(device=esrgan_config.device, non_blocking=True)
        lr = batch_data["lr"].to(device=esrgan_config.device, non_blocking=True)

        # Set the real sample label to 1, and the false sample label to 0
        batch_size, _, _, _ = gt.shape
        real_label = torch.full([batch_size, 1], 1.0, dtype=gt.dtype, device=esrgan_config.device)
        fake_label = torch.full([batch_size, 1], 0.0, dtype=gt.dtype, device=esrgan_config.device)

        # Start training the generator model
        # During generator training, turn off discriminator backpropagation
        for d_parameters in d_model.parameters():
            d_parameters.requires_grad = False

        # Initialize generator model gradients
        g_model.zero_grad(set_to_none=True)

        # Calculate the perceptual loss of the generator, mainly including pixel loss, feature loss and adversarial loss
        with amp.autocast():
            # Use the generator model to generate fake samples
            sr = g_model(lr)
            # Output discriminator to discriminate object probability
            gt_output = d_model(gt.detach().clone())
            sr_output = d_model(sr)
            pixel_loss = esrgan_config.pixel_weight * pixel_criterion(sr, gt)
            content_loss = esrgan_config.content_weight * content_criterion(sr, gt)
            # Computational adversarial network loss
            d_loss_gt = adversarial_criterion(gt_output - torch.mean(sr_output), fake_label) * 0.5
            d_loss_sr = adversarial_criterion(sr_output - torch.mean(gt_output), real_label) * 0.5
            adversarial_loss = esrgan_config.adversarial_weight * (d_loss_gt + d_loss_sr)
            # Calculate the generator total loss value
            g_loss = pixel_loss + content_loss + adversarial_loss
        # Call the gradient scaling function in the mixed precision API to
        # back-propagate the gradient information of the fake samples
        scaler.scale(g_loss).backward()
        # Encourage the generator to generate higher quality fake samples, making it easier to fool the discriminator
        scaler.step(g_optimizer)
        scaler.update()

        # Update EMA
        ema_g_model.update_parameters(g_model)
        # Finish training the generator model

        # Start training the discriminator model
        # During discriminator model training, enable discriminator model backpropagation
        for d_parameters in d_model.parameters():
            d_parameters.requires_grad = True

        # Initialize the discriminator model gradients
        d_model.zero_grad(set_to_none=True)

        # Calculate the classification score of the discriminator model for real samples
        with amp.autocast():
            gt_output = d_model(gt)
            sr_output = d_model(sr.detach().clone())
            d_loss_gt = adversarial_criterion(gt_output - torch.mean(sr_output), real_label) * 0.5
        # Call the gradient scaling function in the mixed precision API to
        # back-propagate the gradient information of the fake samples
        scaler.scale(d_loss_gt).backward(retain_graph=True)

        # Calculate the classification score of the discriminator model for fake samples
        with amp.autocast():
            sr_output = d_model(sr.detach().clone())
            d_loss_sr = adversarial_criterion(sr_output - torch.mean(gt_output), fake_label) * 0.5
        # Call the gradient scaling function in the mixed precision API to
        # back-propagate the gradient information of the fake samples
        scaler.scale(d_loss_sr).backward()

        # Calculate the total discriminator loss value
        d_loss = d_loss_gt + d_loss_sr

        # Improve the discriminator model's ability to classify real and fake samples
        scaler.step(d_optimizer)
        scaler.update()
        # Finish training the discriminator model

        # Calculate the score of the discriminator on real samples and fake samples,
        # the score of real samples is close to 1, and the score of fake samples is close to 0
        d_gt_probability = torch.sigmoid_(torch.mean(gt_output.detach()))
        d_sr_probability = torch.sigmoid_(torch.mean(sr_output.detach()))

        # Statistical accuracy and loss value for terminal data output
        pixel_losses.update(pixel_loss.item(), lr.size(0))
        content_losses.update(content_loss.item(), lr.size(0))
        adversarial_losses.update(adversarial_loss.item(), lr.size(0))
        d_gt_probabilities.update(d_gt_probability.item(), lr.size(0))
        d_sr_probabilities.update(d_sr_probability.item(), lr.size(0))

        # Calculate the time it takes to fully train a batch of data
        batch_time.update(time.time() - end)
        end = time.time()

        # Write the data during training to the training log file
        if batch_index % esrgan_config.train_print_frequency == 0:
            iters = batch_index + epoch * batches + 1
            writer.add_scalar("Train/D_Loss", d_loss.item(), iters)
            writer.add_scalar("Train/G_Loss", g_loss.item(), iters)
            writer.add_scalar("Train/Pixel_Loss", pixel_loss.item(), iters)
            writer.add_scalar("Train/Content_Loss", content_loss.item(), iters)
            writer.add_scalar("Train/Adversarial_Loss", adversarial_loss.item(), iters)
            writer.add_scalar("Train/D(GT)_Probability", d_gt_probability.item(), iters)
            writer.add_scalar("Train/D(SR)_Probability", d_sr_probability.item(), iters)
            progress.display(batch_index + 1)

        # Preload the next batch of data
        batch_data = train_prefetcher.next()

        # After training a batch of data, add 1 to the number of data batches to ensure that the
        # terminal print data normally
        batch_index += 1

        #if batch_index>limit:
        #  print('Batch limit reached')
        #  return pixel_losses.avg, content_losses.avg, adversarial_losses.avg, d_gt_probabilities.avg, d_sr_probabilities.avg

    return pixel_losses.avg, content_losses.avg, adversarial_losses.avg, d_gt_probabilities.avg, d_sr_probabilities.avg


def validate(
        g_model: nn.Module,
        data_prefetcher: CUDAPrefetcher,
        epoch: int,
        writer: SummaryWriter,
        psnr_model: nn.Module,
        ssim_model: nn.Module,
        niqe_model: nn.Module,
        lpips_model: nn.Module,
        mode: str
) -> [float, float, float]:
    # Calculate how many batches of data are in each Epoch
    batch_time = AverageMeter("Time", ":6.3f")
    psnres = AverageMeter("PSNR", ":4.2f")
    ssimes = AverageMeter("SSIM", ":4.4f")
    niqees = AverageMeter("NIQE", ":4.2f")
    lpipses = AverageMeter("LPIPS", ":4.4f")
    progress = ProgressMeter(len(data_prefetcher), [batch_time, psnres, ssimes, niqees, lpipses], prefix=f"{mode}: ")

    print_freq = 1
    if mode == "Valid":
      print_freq = esrgan_config.valid_print_frequency
    else:
      print_freq = esrgan_config.test_print_frequency

    # Put the adversarial network model in validation mode
    g_model.eval()

    # Initialize the number of data batches to print logs on the terminal
    batch_index = 0

    #limit = 20

    # Initialize the data loader and load the first batch of data
    data_prefetcher.reset()
    batch_data = data_prefetcher.next()

    # Get the initialization test time
    end = time.time()

    with torch.no_grad():
        while batch_data is not None:
            # Transfer the in-memory data to the CUDA device to speed up the test
            gt = batch_data["gt"].to(device=esrgan_config.device, non_blocking=True)
            lr = batch_data["lr"].to(device=esrgan_config.device, non_blocking=True)

            # Use the generator model to generate a fake sample
            with amp.autocast():
                sr = g_model(lr)

            # Statistical loss value for terminal data output
            psnr = psnr_model(sr, gt)
            ssim = ssim_model(sr, gt)
            niqe = niqe_model(sr)
            sr_tensor = 2*sr - 1 # Normalize from [0,1] to [-1,1]
            gt_tensor = 2*gt - 1
            lpips = lpips_model(sr, gt)

            psnres.update(psnr.item(), lr.size(0))
            ssimes.update(ssim.item(), lr.size(0))
            niqees.update(niqe.item(), lr.size(0))
            lpipses.update(lpips.item(), lr.size(0))

            # Calculate the time it takes to fully test a batch of data
            batch_time.update(time.time() - end)
            end = time.time()

            # Record training log information
            if batch_index % print_freq == 0:
                progress.display(batch_index + 1)

            # Preload the next batch of data
            batch_data = data_prefetcher.next()

            # After training a batch of data, add 1 to the number of data batches to ensure that the
            # terminal print data normally
            batch_index += 1

            #if batch_index > limit:
            #  print("Limit reached")
            #  break

    # print metrics
    progress.display_summary()

    if mode == "Valid" or mode == "Test":
        writer.add_scalar(f"{mode}/PSNR", psnres.avg, epoch + 1)
        writer.add_scalar(f"{mode}/SSIM", ssimes.avg, epoch + 1)
        writer.add_scalar(f"{mode}/NIQE", niqees.avg, epoch + 1)
        writer.add_scalar(f"{mode}/LPIPS", lpipses.avg, epoch + 1)
    else:
        raise ValueError("Unsupported mode, please use `Valid` or `Test`.")

    return psnres.avg, ssimes.avg, niqees.avg, lpipses.avg


if __name__ == "__main__":
    main()

Train
Total Epochs -> 10
Load all datasets successfully.
Build `rrdbnet_x4` model successfully.
Define all loss functions successfully.
Check whether to load pretrained d model weights...
Loaded `./results/Discriminator/Discrminator_x4-DFO2K-e74d7ca1.pth.tar` pretrained model weights successfully.
Check whether to load pretrained g model weights...
Loaded `./results/RRDBNet_x4/RRDBNet_x4-DFO2K-2e2a91f4.pth.tar` pretrained model weights successfully.
Define all optimizer functions successfully.
Define all optimizer scheduler functions successfully.
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


The git executable must be specified in one of the following ways:
    - be included in your $PATH
    - be set via $GIT_PYTHON_GIT_EXECUTABLE
    - explicitly set via git.refresh()

All git commands will error until this is rectified.

$GIT_PYTHON_REFRESH environment variable. Use one of the following values:
    - error|e|raise|r|2: for a raised exception

Example:
    export GIT_PYTHON_REFRESH=quiet



Loading model from: /home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
Active run_id: 2f4d7eef4e784292892a39599f91f03c
Training
Epoch: [1][  1/211]	Time  2.267 ( 2.267)	Data  0.000 ( 0.000)	Pixel loss 0.000234 (0.000234)	Content loss 1.230294 (1.230294)	Adversarial loss 0.331025 (0.331025)	D(GT)  0.000 ( 0.000)	D(SR)  0.000 ( 0.000)
Epoch: [1][101/211]	Time  0.236 ( 0.263)	Data  0.000 ( 0.000)	Pixel loss 0.000418 (0.000293)	Content loss 1.087008 (1.012662)	Adversarial loss 0.098711 (0.063264)	D(GT)  0.000 ( 0.000)	D(SR)  0.000 ( 0.000)
Epoch: [1][201/211]	Time  0.238 ( 0.253)	Data  0.000 ( 0.000)	Pixel loss 0.000229 (0.000276)	Content loss 0.873789 (0.942434)	Adversarial loss 0.023458 (0.045047)	D(GT)  1.000 ( 0.379)	D(SR)  1.000 ( 0.366)
Valid: [  1/519]	Time  1.520 ( 1.520)	PSNR 30.56 (30.56)	SSIM 0.9144 (0.9144)	NIQE 6.42 (6.42)	LPIPS 0.0610 (0.0610)
Valid: [101/519]	Time  0.074 ( 0.087)	PSNR 23.97 (28.98)	SSIM 0.8270 (0.8734)	NIQE 5.7



Finished Saving
Training
Epoch: [2][  1/211]	Time  0.308 ( 0.308)	Data  0.000 ( 0.000)	Pixel loss 0.000246 (0.000246)	Content loss 0.919975 (0.919975)	Adversarial loss 0.010529 (0.010529)	D(GT)  1.000 ( 1.000)	D(SR)  0.999 ( 0.999)
Epoch: [2][101/211]	Time  0.243 ( 0.244)	Data  0.000 ( 0.000)	Pixel loss 0.000326 (0.000251)	Content loss 0.879672 (0.816349)	Adversarial loss 0.048472 (0.023195)	D(GT)  0.000 ( 0.408)	D(SR)  0.000 ( 0.296)
Epoch: [2][201/211]	Time  0.237 ( 0.245)	Data  0.000 ( 0.000)	Pixel loss 0.000204 (0.000243)	Content loss 0.747880 (0.813557)	Adversarial loss 0.025175 (0.028038)	D(GT)  0.000 ( 0.247)	D(SR)  0.000 ( 0.156)
Valid: [  1/519]	Time  0.088 ( 0.088)	PSNR 29.55 (29.55)	SSIM 0.8838 (0.8838)	NIQE 6.19 (6.19)	LPIPS 0.0797 (0.0797)
Valid: [101/519]	Time  0.073 ( 0.075)	PSNR 22.70 (28.06)	SSIM 0.8018 (0.8516)	NIQE 6.95 (9.70)	LPIPS 0.1093 (0.0917)
Valid: [201/519]	Time  0.073 ( 0.075)	PSNR 30.11 (28.49)	SSIM 0.8845 (0.8590)	NIQE 7.79 (9.93)	LPIPS 0.0537 (0.0888)
Val

Valid: [301/519]	Time  0.076 ( 0.075)	PSNR 30.65 (29.43)	SSIM 0.9206 (0.8788)	NIQE 7.65 (9.75)	LPIPS 0.0324 (0.0570)
Valid: [401/519]	Time  0.076 ( 0.075)	PSNR 24.29 (29.24)	SSIM 0.7663 (0.8754)	NIQE 6.31 (9.70)	LPIPS 0.0701 (0.0582)
Valid: [501/519]	Time  0.074 ( 0.075)	PSNR 23.41 (29.16)	SSIM 0.8221 (0.8739)	NIQE 5.74 (9.67)	LPIPS 0.0738 (0.0586)
 * Time 0.08 PSNR 29.12 SSIM 0.87 NIQE 9.67 LPIPS 0.06



Logging epoch data...
Finished Logging

Saving best model...
Finished Saving
Training
Epoch: [8][  1/211]	Time  0.279 ( 0.279)	Data  0.000 ( 0.000)	Pixel loss 0.000144 (0.000144)	Content loss 0.622071 (0.622071)	Adversarial loss 0.033756 (0.033756)	D(GT)  0.000 ( 0.000)	D(SR)  0.000 ( 0.000)
Epoch: [8][101/211]	Time  0.249 ( 0.247)	Data  0.000 ( 0.000)	Pixel loss 0.000173 (0.000227)	Content loss 0.661598 (0.763354)	Adversarial loss 0.045723 (0.024454)	D(GT)  0.000 ( 0.000)	D(SR)  0.000 ( 0.000)
Epoch: [8][201/211]	Time  0.251 ( 0.248)	Data  0.000 ( 0.000)	Pixel loss 0.000273 (0.000230

In [2]:
 !mlflow ui

[2023-03-21 16:14:55 +0000] [3880] [INFO] Starting gunicorn 20.1.0
[2023-03-21 16:14:55 +0000] [3880] [INFO] Listening at: http://127.0.0.1:5000 (3880)
[2023-03-21 16:14:55 +0000] [3880] [INFO] Using worker: sync
[2023-03-21 16:14:55 +0000] [3881] [INFO] Booting worker with pid: 3881
[2023-03-21 16:14:55 +0000] [3882] [INFO] Booting worker with pid: 3882
[2023-03-21 16:14:55 +0000] [3883] [INFO] Booting worker with pid: 3883
[2023-03-21 16:14:55 +0000] [3884] [INFO] Booting worker with pid: 3884
Traceback (most recent call last):
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 857, in _list_run_infos
    run_info = self._get_run_info_from_dir(r_dir)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 668, in _get_run_info_from_dir
    meta = FileStore._read_yaml(run_dir, FileStore.META_DATA_FILE_NAME)
  File "/home/miguelneves/anaconda3/envs/su

Traceback (most recent call last):
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 857, in _list_run_infos
    run_info = self._get_run_info_from_dir(r_dir)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 668, in _get_run_info_from_dir
    meta = FileStore._read_yaml(run_dir, FileStore.META_DATA_FILE_NAME)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 1083, in _read_yaml
    return _read_helper(root, file_name, attempts_remaining=retries)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 1076, in _read_helper
    result = read_yaml(root, file_name)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/utils/file_utils.py", line 214, in read_yaml
    raise MissingC

Traceback (most recent call last):
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 857, in _list_run_infos
    run_info = self._get_run_info_from_dir(r_dir)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 668, in _get_run_info_from_dir
    meta = FileStore._read_yaml(run_dir, FileStore.META_DATA_FILE_NAME)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 1083, in _read_yaml
    return _read_helper(root, file_name, attempts_remaining=retries)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 1076, in _read_helper
    result = read_yaml(root, file_name)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/utils/file_utils.py", line 214, in read_yaml
    raise MissingC

Traceback (most recent call last):
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 857, in _list_run_infos
    run_info = self._get_run_info_from_dir(r_dir)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 668, in _get_run_info_from_dir
    meta = FileStore._read_yaml(run_dir, FileStore.META_DATA_FILE_NAME)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 1083, in _read_yaml
    return _read_helper(root, file_name, attempts_remaining=retries)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 1076, in _read_helper
    result = read_yaml(root, file_name)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/utils/file_utils.py", line 214, in read_yaml
    raise MissingC

Traceback (most recent call last):
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 857, in _list_run_infos
    run_info = self._get_run_info_from_dir(r_dir)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 668, in _get_run_info_from_dir
    meta = FileStore._read_yaml(run_dir, FileStore.META_DATA_FILE_NAME)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 1083, in _read_yaml
    return _read_helper(root, file_name, attempts_remaining=retries)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 1076, in _read_helper
    result = read_yaml(root, file_name)
  File "/home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/mlflow/utils/file_utils.py", line 214, in read_yaml
    raise MissingC

## Test

In [1]:
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os

import cv2
import torch
from natsort import natsorted

import esrgan_config
import mlflow
import imgproc
import model
from image_quality_assessment import PSNR, SSIM, NIQE
from lpips import LPIPS
from utils import make_directory
from dataset import CUDAPrefetcher, TrainValidImageDataset
from torch.utils.data import DataLoader

model_names = sorted(
    name for name in model.__dict__ if
    name.islower() and not name.startswith("__") and callable(model.__dict__[name]))

save_images = True


def main() -> None:

    # Set MLflow experiment & run
    mlflow.set_experiment(esrgan_config.experience_name)
    try:
      mlflow.start_run(run_id=esrgan_config.run_id, tags=esrgan_config.tags, description=esrgan_config.description)
    except: # If last session was not ended
      mlflow.end_run()
      mlflow.start_run(run_id=esrgan_config.run_id, tags=esrgan_config.tags, description=esrgan_config.description)


    # Load Test Dataset
    test_datasets = TrainValidImageDataset(esrgan_config.gt_dir,
                                        0,
                                        esrgan_config.upscale_factor,
                                        "Valid")
    
    test_dataloader = DataLoader(test_datasets,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1,
                                 pin_memory=True,
                                 drop_last=False,
                                 persistent_workers=True)
    
    test_prefetcher = CUDAPrefetcher(test_dataloader, esrgan_config.device)
    # Initialize the data loader and load the first batch of data
    test_prefetcher.reset()



    # Initialize the super-resolution bsrgan_model
    esrgan_model = model.__dict__[esrgan_config.g_arch_name](in_channels=esrgan_config.in_channels,
                                                             out_channels=esrgan_config.out_channels,
                                                             channels=esrgan_config.channels,
                                                             growth_channels=esrgan_config.growth_channels,
                                                             num_blocks=esrgan_config.num_blocks)
    '''
    esrgan_model = esrgan_model.to(device=esrgan_config.device)
    print(f"Build `{esrgan_config.g_arch_name}` model successfully.")

    # Load the super-resolution bsrgan_model weights
    checkpoint = torch.load(esrgan_config.g_model_weights_path, map_location=lambda storage, loc: storage)
    esrgan_model.load_state_dict(checkpoint["state_dict"])
    print(f"Load `{esrgan_config.g_arch_name}` model weights "
          f"`{os.path.abspath(esrgan_config.g_model_weights_path)}` successfully.")
    '''

    
    # Load Generator Model
    g_model = mlflow.pytorch.load_model(esrgan_config.g_model_weights_path)
    esrgan_model = g_model.to(device=esrgan_config.device)

    # Create a folder of super-resolution experiment results
    #make_directory(esrgan_config.sr_dir)

    # Start the verification mode of the bsrgan_model.
    esrgan_model.eval()

    # Initialize the sharpness evaluation function
    psnr = PSNR(esrgan_config.upscale_factor, esrgan_config.only_test_y_channel)
    ssim = SSIM(esrgan_config.upscale_factor, esrgan_config.only_test_y_channel)
    niqe = NIQE(esrgan_config.upscale_factor, esrgan_config.niqe_model_path)
    lpips = LPIPS(net='alex')

    # Set the sharpness evaluation function calculation device to the specified model
    psnr = psnr.to(device=esrgan_config.device, non_blocking=True)
    ssim = ssim.to(device=esrgan_config.device, non_blocking=True)
    niqe = niqe.to(device=esrgan_config.device, non_blocking=True)
    lpips = lpips.to(device=esrgan_config.device, non_blocking=True)

    # Initialize IQA metrics
    psnr_metrics = 0.0
    ssim_metrics = 0.0
    niqe_metrics = 0.0
    lpips_metrics = 0.0

    # Get a list of test image file names.
    file_names = os.listdir(esrgan_config.gt_dir)
    # Get the number of test image files.
    total_files = int(len(file_names))

    pathLR = "testImagesLR/"
    pathTest = "testImages/"

    for index in range(total_files):
        '''
        lr_image_path = os.path.join(esrgan_config.lr_dir, file_names[index])
        sr_image_path = os.path.join(esrgan_config.sr_dir, file_names[index])
        gt_image_path = os.path.join(esrgan_config.gt_dir, file_names[index])

        print(f"Processing `{os.path.abspath(lr_image_path)}`...")
        lr_tensor = imgproc.preprocess_one_image(lr_image_path, esrgan_config.device)
        gt_tensor = imgproc.preprocess_one_image(gt_image_path, esrgan_config.device)
        '''

        batch_data = test_prefetcher.next()
        gt_tensor = batch_data["gt"].to(device=esrgan_config.device, non_blocking=True)
        lr_tensor = batch_data["lr"].to(device=esrgan_config.device, non_blocking=True)

        lr_image = imgproc.tensor_to_image(lr_tensor, False, False)
        lr_image = cv2.cvtColor(lr_image, cv2.COLOR_RGB2BGR)
        mlflow.log_image(lr_image, pathLR+file_names[index])

        # Only reconstruct the Y channel image data.
        with torch.no_grad():
            sr_tensor = esrgan_model(lr_tensor)

        # Save image
        if save_images:
          sr_image = imgproc.tensor_to_image(sr_tensor, False, False)
          sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)
          mlflow.log_image(sr_image, pathTest+file_names[index])
          #cv2.imwrite(sr_image_path, sr_image)

        # Cal IQA metrics
        psnr_metrics += psnr(sr_tensor, gt_tensor).item()
        ssim_metrics += ssim(sr_tensor, gt_tensor).item()
        niqe_metrics += niqe(sr_tensor).item()
        
        sr_tensor = 2*sr_tensor - 1 # Normalize from [0,1] to [-1,1]
        gt_tensor = 2*gt_tensor - 1
        lpips_metrics += lpips(sr_tensor, gt_tensor).item()

    # Calculate the average value of the sharpness evaluation index,
    # and all index range values are cut according to the following values
    # PSNR range value is 0~100
    # SSIM range value is 0~1
    # NIQE range value is 0~100 although it can go to infinite. Typically a score higher than 10 is bad and lower than 2 is excelent
    avg_psnr = 100 if psnr_metrics / total_files > 100 else psnr_metrics / total_files
    avg_ssim = 1 if ssim_metrics / total_files > 1 else ssim_metrics / total_files
    avg_niqe = 100 if niqe_metrics / total_files > 100 else niqe_metrics / total_files
    avg_lpips = 100 if lpips_metrics / total_files > 100 else lpips_metrics / total_files
 
    print(f"PSNR: {avg_psnr:4.2f} [dB]\n"
          f"SSIM: {avg_ssim:4.4f} [u]\n"
          f"NIQE: {avg_niqe:4.2f} [100u]\n"
          f"LPIPS: {avg_lpips:4.2f} [100u]")
    
    mlflow.end_run()


if __name__ == "__main__":
    main()


Test
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/miguelneves/anaconda3/envs/superres/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
PSNR: 29.40 [dB]
SSIM: 0.8759 [u]
NIQE: 9.19 [100u]
LPIPS: 0.07 [100u]
