In [None]:
import os
os.chdir('../')

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import cv2
import torchvision.transforms as transforms
import PIL
import matplotlib.pyplot as plt
import modules.dist_model
from utils import util
import glob
import torch.nn.functional as F
from modules import * 
import modules
import copy
import random
import numpy as np

from util.utils import calculate_psnr, _ssim, Logger, RandCrop, RandHorizontalFlip, RandRotate, ToTensor, VGG19PerceptualLoss
from util.utils_common import calculate_psnr as PSNR
from util.utils_common import calculate_ssim as SSIM
from util.utils_common import calc_metrics as CALC
from memory.storage import Storage

from tqdm import tqdm
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
import torch.nn.functional as F


## When X4

In [None]:
from opt.option_hrlr_cycle_scheduler import args
from data.LQGT_dataset_hrlr_BICUBIC_scale4 import LQGTDataset, Testdataset
from model import encoder_x4 as encoder

from model import decoder_scale4 as decoder
from model import discriminator_scale4 as discriminator

args.ratio_x = 'x4'
args.scale = 4

## When X2

In [None]:
from opt.option_hrlr_cycle import args
from data.LQGT_dataset_hrlr_BICUBIC_scale2 import LQGTDataset, Testdataset
from model import encoder_x2 as encoder

from model import decoder_scale2 as decoder
from model import discriminator_scale2 as discriminator

args.ratio_x = 'x2'
args.scale = 2


## Setting hyper-parameter and Seed

In [None]:
args.gpu_id='0'
args.lr_G = 5e-5
args.lr_D = 5e-5
device = 'cuda'
args.cycle_mode=True
args.shuffle_mode=True
args.n_gen=2
args.lambda_align = 0.01
args.eval_mode = True
args.sub_file_name = f'TEMP'

data_name_test = ['Deepfake','Face2Face','FaceSwap','NeuralTextures','DFDC_trans'] ### please input the name you will evaluate

#set_seed
seed = 2020
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False  # for faster training, but not deterministic
# device setting
if args.gpu_id is not None:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
else:
    print('use --gpu_id to specify GPU ID to use')
    exit()


In [None]:
model_Enc = encoder.Encoder_RRDB(num_feat = args.n_hidden_feats).cuda()
model_Dec_Id = decoder.Decoder_Id_RRDB(num_in_ch=args.n_hidden_feats).cuda() #Gt
model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats, num_block=args.n_sr_block).cuda() #Gsr
model_Disc_feat = discriminator.DiscriminatorVGG(in_ch=args.n_hidden_feats, image_size=args.patch_size).cuda()
model_Disc_img_LR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.patch_size).cuda()
model_Disc_img_HR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.patch_size*args.scale).cuda()

model_Enc = nn.DataParallel(model_Enc).eval().to(device)
model_Dec_Id= nn.DataParallel(model_Dec_Id).eval().to(device)
model_Dec_SR= nn.DataParallel(model_Dec_SR).eval().to(device)
model_Disc_feat= nn.DataParallel(model_Disc_feat).eval().to(device)
model_Disc_img_LR= nn.DataParallel(model_Disc_img_LR).eval().to(device)
model_Disc_img_HR= nn.DataParallel(model_Disc_img_HR).eval().to(device)

In [None]:
# loss
loss_L1 = nn.L1Loss().cuda()
loss_MSE = nn.MSELoss().cuda()
loss_adversarial = nn.BCEWithLogitsLoss().cuda()
loss_percept = VGG19PerceptualLoss().cuda()


# optimizer 
params_G = list(model_Enc.parameters()) + list(model_Dec_Id.parameters()) + list(model_Dec_SR.parameters())
optimizer_G = optim.Adam(
    params_G,
    lr=args.lr_G,
    betas=(args.beta1, args.beta2),
    weight_decay=args.weight_decay,
    amsgrad=True
)
params_D = list(model_Disc_feat.parameters()) + list(model_Disc_img_LR.parameters()) + list(model_Disc_img_HR.parameters())
optimizer_D = optim.Adam(
    params_D,
    lr=args.lr_D,
    betas=(args.beta1, args.beta2),
    weight_decay=args.weight_decay,
    amsgrad=True
)

# Scheduler
iter_indices = [args.interval1, args.interval2, args.interval3]
scheduler_G = optim.lr_scheduler.MultiStepLR(
    optimizer=optimizer_G,
    milestones=iter_indices,
    gamma=0.5
)
scheduler_D = optim.lr_scheduler.MultiStepLR(
    optimizer=optimizer_D,
    milestones=iter_indices,
    gamma=0.5
)

scaler = GradScaler()
args.checkpoint = './task1_storage.pth' ### please input the path of weight you will evaluate

args.snap_path = os.path.dirname(args.checkpoint).replace('weights','results')
# load model weights & optimzer % scheduler
if args.checkpoint is not None:
    checkpoint = torch.load(args.checkpoint)

    model_Enc.load_state_dict(checkpoint['model_Enc'])
    model_Dec_Id.load_state_dict(checkpoint['model_Dec_Id'])
    model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR'])
    model_Disc_feat.load_state_dict(checkpoint['model_Disc_feat'])
    model_Disc_img_LR.load_state_dict(checkpoint['model_Disc_img_LR'])
    model_Disc_img_HR.load_state_dict(checkpoint['model_Disc_img_HR'])

    optimizer_D.load_state_dict(checkpoint['optimizer_D'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G'])

    scheduler_D.load_state_dict(checkpoint['scheduler_D'])
    scheduler_G.load_state_dict(checkpoint['scheduler_G'])

    start_epoch = checkpoint['epoch']
else:
    start_epoch = 0

In [None]:
def make_eval_log(path_folder):
    os.makedirs(path_folder, exist_ok=True)
    cnt=0
    name_file = f'eval_result_v{cnt}'
    while os.path.exists(os.path.join(path_folder,f'{name_file}.txt')):
        cnt += 1
        name_file = f'eval_result_v{cnt}'
    print(f'name_file : {name_file}')
    log = Logger()
    log.open(os.path.join(path_folder,f'{name_file}.txt'))
    log.write('\n')
    log.write(f'checkpoint path: {args.checkpoint} \n')
    return log

class PerceptualLoss(torch.nn.Module):
    def __init__(self, model='net-lin', net='vgg', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
    # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
        super(PerceptualLoss, self).__init__()
        print('Setting up Perceptual loss...')
        print(f" LPIPS net is {net}")
        self.use_gpu = use_gpu
        self.spatial = spatial
        self.gpu_ids = gpu_ids
        self.model = dist_model.DistModel()
        self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
        print('...[%s] initialized'%self.model.name())
        print('...Done')

    def forward(self, pred, target, normalize=False):
        """
        Pred and target are Variables.
        If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
        If normalize is False, assumes the images are already between [-1,+1]

        Inputs pred and target are Nx3xHxW
        Output pytorch Variable N long
        """

        if normalize:
            target = 2 * target  - 1
            pred = 2 * pred  - 1

        return self.model.forward(target, pred)
    
lpips = modules.PerceptualLoss(model='net-lin',net='vgg',use_gpu=True).eval() # alex, squeeze, vgg
lpips_alex = modules.PerceptualLoss(model='net-lin',net='alex',use_gpu=True).eval() # alex, squeeze, vgg


In [None]:
list_path_test = []
path_imgs = args.checkpoint[:-4].replace('weights','results')
log = make_eval_log(path_imgs)

for _data_name in data_name_test:
    full_path = os.path.join(path_imgs,_data_name)
    _path_test = f'/media/data1/DS/{args.ratio_x}/{_data_name}/test' # Change as needed
    list_path_test.append((_path_test,full_path))
    os.makedirs(full_path, exist_ok=True)
    

In [None]:
list_test_loader=[]
for _path, _save_path in list_path_test:
    test_dataset = Testdataset(
        os.path.join(_path,'HR'),
        os.path.join(_path,'LR'),
        transform=transforms.Compose([ToTensor()]),
        patch_size = args.patch_size,
        scale=2,
        crop_mode = False,
    )
    test_loader = DataLoader(
        copy.deepcopy(test_dataset),
        batch_size=1,
        num_workers=args.num_workers,
        drop_last=False,
        shuffle=False
    )
    list_test_loader.append(copy.deepcopy(test_loader))
    print(f'path is : {_path}')
    print(f'len of dataset is : {len(test_dataset)}')

# Evaluating 

In [None]:
def Test_v2(loader,
         save_path,
         log,
         models,
         curr_epoch = 0,
         show_mode = False,
         n_interval = 100,
         save_mode = False):
    model_Enc.eval()
    model_Dec_SR.eval()
    list_psnr, list_ssim = [],[]
    list_psnr_origin, list_ssim_origin, list_psnr_v2, list_ssim_v2 = [],[],[],[]
    list_lpips_origin, list_lpips_v2 = [],[]
    list_lpips_a_origin, list_lpips_a_v2 = [],[]
    
    list_lqhq_psnr=[]
    n_scale = 4 if args.ratio_x == 'x4' else 2
    os.makedirs(save_path, exist_ok=True)
    
    with torch.no_grad():
        for idx, (data, filename) in enumerate(loader):
            X_t = data['img_LQ'].cuda(non_blocking=True)
            _X_origin = copy.deepcopy(X_t)
            Y = data['img_GT'].cuda(non_blocking=True)#.squeeze()
            # real label and fake label
            batch_size = X_t.size(0)
            # inference output
            feat = models[0](X_t) #encoder
            out = models[1](feat) #decoder(generator)
            if not (Y.shape== out.shape):
                out = F.interpolate(out,size=(Y.shape[2],Y.shape[3]))
            min_max = (0, 1)
            out = out.detach().float().cpu().clamp_(*min_max)
            _X_origin = F.interpolate(_X_origin,size=(Y.shape[2],Y.shape[3]))

            X_t = F.interpolate(X_t.detach().cpu(),scale_factor = n_scale)
            if show_mode or save_mode :
                img_lq = transforms.ToPILImage()(X_t.squeeze())
                img_output = transforms.ToPILImage()(out.squeeze())
                img_y = transforms.ToPILImage()(Y.squeeze())
                if show_mode and idx % n_interval == 0:
                    img_lq.show(); img_output.show(); img_y.show()
                if save_mode:
                    img_lq.save(f'{save_path}/LQ_e{curr_epoch}_{idx}.png')
                    img_output.save(f'{save_path}/OUTPUT_e{curr_epoch}_{idx}.png')
                    img_y.save(f'{save_path}/HQ_e{curr_epoch}_{idx}.png')
                    
            __X_origin = np.array(_X_origin.squeeze().permute(1,2,0).detach().cpu()*255.0).round()
            _out = np.array(out.squeeze().permute(1,2,0).detach().cpu()*255.0).round()#.squeeze()
            _Y = np.array(Y.squeeze().permute(1,2,0).detach().cpu()*255.0).round()#.squeeze()
            _out = (_out - min_max[0]) / (min_max[1] - min_max[0])
            
            psnr_v2 = PSNR(_out, _Y, border=0)
            psnr_v2_origin = PSNR(__X_origin ,_Y, border=0)
            ssim_v2 = SSIM(_out, _Y, border=0)
            ssim_v2_origin = SSIM(__X_origin ,_Y, border=0)
            list_psnr_v2.append(psnr_v2)
            list_psnr_origin.append(psnr_v2_origin)
            list_ssim_v2.append(ssim_v2)
            list_ssim_origin.append(ssim_v2_origin)
            
            #lpips
            img_x = modules.normalize_tensor(Y)
            pred_xhat = modules.normalize_tensor(_X_origin)
            lpips_dist_origin = lpips.forward(img_x,pred_xhat).item()
            lpips_dist_a_origin = lpips_alex.forward(img_x,pred_xhat).item()
            list_lpips_origin.append(lpips_dist_origin)
            list_lpips_a_origin.append(lpips_dist_a_origin)
            
            pred_xhat = modules.normalize_tensor(out)
            lpips_dist_v2 = lpips.forward(img_x,pred_xhat).item()
            lpips_dist_a_v2 = lpips_alex.forward(img_x,pred_xhat).item()
            list_lpips_v2.append(lpips_dist_v2)
            list_lpips_a_v2.append(lpips_dist_a_v2)
            
    final_psnr_v2_origin, final_psnr_v2, final_ssim_v2, final_ssim_v2_origin, final_lpips_v2_origin, final_lpips_v2 = 0,0,0,0,0,0
    
    final_psnr_v2_origin = sum(list_psnr_origin)/len(list_psnr_origin)
    final_psnr_v2 = sum(list_psnr_v2)/len(list_psnr_v2)
    final_ssim_v2_origin = sum(list_ssim_origin)/len(list_ssim_origin)
    final_ssim_v2 = sum(list_ssim_v2)/len(list_ssim_v2)
    final_lpips_v2_origin = sum(list_lpips_origin)/len(list_lpips_origin)
    final_lpips_v2 = sum(list_lpips_v2)/len(list_lpips_v2)
    final_lpips_a_origin = sum(list_lpips_a_origin)/len(list_lpips_a_origin)
    final_lpips_a = sum(list_lpips_a_v2)/len(list_lpips_a_v2)
    log.write(f'===> psnr_origin: {final_psnr_v2_origin}, ssim_origin: {final_ssim_v2_origin}, lpips_origin: {final_lpips_v2_origin}, lpips_alex: {final_lpips_a_origin} \n')
    log.write(f'===> psnr: {final_psnr_v2}, ssim_v2: {final_ssim_v2}, lpips_v2: {final_lpips_v2}, lpips_alex: {final_lpips_a} \n')


best_psnr = 0
if not args.eval_mode : args.epochs = start_epoch+1
print('path img : ', path_imgs)
iter = 0
for idx, (_loader, _save_path) in enumerate(zip(list_test_loader, list_path_test)):
    log.write(f'{_save_path[0]} \n')
    Test_v2(_loader, _save_path[1], log, [model_Enc, model_Dec_SR], curr_epoch = start_epoch, show_mode=False, n_interval=20, save_mode = True)
