In [None]:
import os
os.chdir('../')

In [None]:
import os
import torch
import torch.nn as nn

import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import copy
import random
import numpy as np
from opt.option_hrlr_cycle_scheduler import args
from data.LQGT_dataset_hrlr_BICUBIC_scale2 import LQGTDataset, Testdataset
from util.utils import calculate_psnr, _ssim, Logger, RandCrop, RandHorizontalFlip, RandRotate, ToTensor, VGG19PerceptualLoss
from util.utils_common import calculate_psnr as PSNR
from util.utils_common import calculate_ssim as SSIM
from util.utils_common import calc_metrics as CALC

from model import encoder_x2 as encoder
from model import decoder_scale2 as decoder
from model import discriminator_scale2 as discriminator
from memory.storage_s_feature15_rede2 import Storage
from tqdm import tqdm
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
import torch.nn.functional as F

## Setting hyper-parameter

In [None]:
name_file = 'Input the name with path'
snap_path = 'input the path for saving weight file'
checkpoint = 'weight file For Transfer Learning'
args.gpu_id = '1,2,3'
device='cuda'

args.lr_G = 5e-5
args.lr_D = 5e-5
args.cycle_mode=True
args.shuffle_mode=True
args.n_gen=2
args.scale = 2
args.lambda_align = 0.01
args.epochs = 200
args.sub_file_name = f'hrlr_lrisorigin_encodreX{args.scale}_alignlambdais{args.lambda_align}_storage_softmax_bicubic'

### Logging & Set Seed & Set device

In [None]:
log = Logger()
data_name = 'DIV2K'
print(f'data name is {data_name}')

#Logging
out_dir = f'./results/{name_file}_{args.sub_file_name}'
os.makedirs(out_dir, exist_ok=True)
log.open(os.path.join(out_dir,'[log_train]'+f'{name_file}_{args.sub_file_name}.txt'))
log.write('\n')

#set_seed
seed = 2020
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False  # for faster training, but not deterministic
# device setting
if args.gpu_id is not None:
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    log.write('using GPU %s' % args.gpu_id)
else:
    print('use --gpu_id to specify GPU ID to use')
    exit()


# make directory for saving weights
if not os.path.exists(snap_path):
    os.mkdir(snap_path)

### Load dataset 

In [None]:
# load training dataset
train_dataset = LQGTDataset(
    args.dir_lr, args.dir_gt,
    transform=transforms.Compose([ToTensor()]),
    patch_size = args.patch_size,
    scale=2,
    shuffle_mode = args.shuffle_mode
)
train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    drop_last=True,
    shuffle=True
)
# load val dataset
test_dataset = Testdataset(
    args.dir_gt,args.dir_lr,
    transform=transforms.Compose([ToTensor()]),
    patch_size = args.patch_size,
    scale=2,
    crop_mode = False,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    num_workers=args.num_workers,
    drop_last=False,
    shuffle=False
)
print(len(train_dataset))

In [None]:
# define model (generator)
model_Enc = encoder.Encoder_RRDB(num_feat = args.n_hidden_feats).cuda()
model_Dec_Id = decoder.Decoder_Id_RRDB(num_in_ch=args.n_hidden_feats).cuda() #Gt
model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats, num_block=args.n_sr_block).cuda() #Gsr
model_Disc_feat = discriminator.DiscriminatorVGG(in_ch=args.n_hidden_feats, image_size=args.patch_size).cuda()
model_Disc_img_LR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.patch_size).cuda()
model_Disc_img_HR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.patch_size*args.scale).cuda()

model_Enc = nn.DataParallel(model_Enc).eval().to(device)
model_Dec_Id= nn.DataParallel(model_Dec_Id).eval().to(device)
model_Dec_SR= nn.DataParallel(model_Dec_SR).eval().to(device)
model_Disc_feat= nn.DataParallel(model_Disc_feat).eval().to(device)
model_Disc_img_LR= nn.DataParallel(model_Disc_img_LR).eval().to(device)
model_Disc_img_HR= nn.DataParallel(model_Disc_img_HR).eval().to(device)

# Declaration or Load Pre-models

In [None]:
# loss
loss_L1 = nn.L1Loss().cuda()
loss_MSE = nn.MSELoss().cuda()
loss_adversarial = nn.BCEWithLogitsLoss().cuda()
loss_percept = VGG19PerceptualLoss().cuda()


# optimizer 
params_G = list(model_Enc.parameters()) + list(model_Dec_Id.parameters()) + list(model_Dec_SR.parameters())
optimizer_G = optim.Adam(
    params_G,
    lr=args.lr_G,
    betas=(args.beta1, args.beta2),
    weight_decay=args.weight_decay,
    amsgrad=True
)
params_D = list(model_Disc_feat.parameters()) + list(model_Disc_img_LR.parameters()) + list(model_Disc_img_HR.parameters())
optimizer_D = optim.Adam(
    params_D,
    lr=args.lr_D,
    betas=(args.beta1, args.beta2),
    weight_decay=args.weight_decay,
    amsgrad=True
)

# Scheduler
iter_indices = [args.interval1, args.interval2, args.interval3]
scheduler_G = optim.lr_scheduler.MultiStepLR(
    optimizer=optimizer_G,
    milestones=iter_indices,
    gamma=0.5
)
scheduler_D = optim.lr_scheduler.MultiStepLR(
    optimizer=optimizer_D,
    milestones=iter_indices,
    gamma=0.5
)

scaler = GradScaler()

args.checkpoint = checkpoint
# load model weights & optimzer % scheduler
if args.checkpoint is not None:
    checkpoint = torch.load(args.checkpoint)

    model_Enc.load_state_dict(checkpoint['model_Enc'])
    model_Dec_Id.load_state_dict(checkpoint['model_Dec_Id'])
    model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR'])
    model_Disc_feat.load_state_dict(checkpoint['model_Disc_feat'])
    model_Disc_img_LR.load_state_dict(checkpoint['model_Disc_img_LR'])
    model_Disc_img_HR.load_state_dict(checkpoint['model_Disc_img_HR'])

    optimizer_D.load_state_dict(checkpoint['optimizer_D'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G'])

    scheduler_D.load_state_dict(checkpoint['scheduler_D'])
    scheduler_G.load_state_dict(checkpoint['scheduler_G'])

    start_epoch = checkpoint['epoch']
    _Storage_disc = checkpoint['storage_disc']
    _Storage_gene = checkpoint['storage_gene']
else:
    start_epoch = 0
    _Storage_disc = Storage(2, 'euclid')
    _Storage_gene = Storage(2, 'euclid')

### To check arguments of logger

In [None]:

log.write(' '.join(f'{k}={v} \n' for k, v in vars(args).items()))

In [None]:
def Test_v2(loader,
         save_path,
         log,
         models,
         curr_epoch = 0,
         show_mode = False,
         n_interval = 100,
         save_mode = False):
    model_Enc.eval()
    model_Dec_SR.eval()
    list_psnr, list_ssim = [],[]
    list_psnr_origin, list_ssim_origin, list_psnr_v2, list_ssim_v2 = [],[],[],[]
    list_lqhq_psnr=[]
    print(f'save path : {save_path}')
    
    with torch.no_grad():
        for idx, (data, filename) in enumerate(loader):
            X_t = data['img_LQ'].cuda(non_blocking=True)
            _X_origin = copy.deepcopy(X_t)
            Y = data['img_GT'].cuda(non_blocking=True)#.squeeze()
            # real label and fake label
            batch_size = X_t.size(0)
            # inference output
            feat = models[0](X_t) #encoder
            out = models[1](feat) #decoder(generator)

            assert(Y.shape== out.shape)
            min_max = (0, 1)
            out = out.detach().float().cpu().clamp_(*min_max)
            _X_origin = F.interpolate(_X_origin, scale_factor=(4,4))
            out, Y = out.squeeze(), Y.squeeze()
            if (show_mode or save_mode) and idx % n_interval == 0 :
                img_lq = transforms.ToPILImage()(X_t.squeeze())
                img_output = transforms.ToPILImage()(out)
                img_y = transforms.ToPILImage()(Y)
                if show_mode:
                    img_lq.show(); img_output.show(); img_y.show()
                if save_mode:
                    img_lq.save(f'{save_path}/LQ_e{curr_epoch}_{idx}.png')
                    img_output.save(f'{save_path}/OUTPUT_e{curr_epoch}_{idx}.png')
                    img_y.save(f'{save_path}/HQ_e{curr_epoch}_{idx}.png')
            _X_origin = (_X_origin.squeeze()*255.0).round().squeeze()
            out = (out*255.0).round().squeeze()
            Y = (Y*255.0).round().squeeze()
            out = out.permute(1,2,0)
            Y = Y.permute(1,2,0)
            _X_origin = _X_origin.permute(1,2,0)
            
            out = (out - min_max[0]) / (min_max[1] - min_max[0])
            
            psnr = calculate_psnr(out, Y, crop_border=0, input_order='HWC')
            psnr_v2 = PSNR(np.array(out.detach().cpu()), np.array(Y.detach().cpu()), border=0)
            ssim_v2 = SSIM(np.array(out.detach().cpu()), np.array(Y.detach().cpu()), border=0)
            
            if curr_epoch in [0,1]:
                psnr_v2_origin = PSNR(np.array(_X_origin.detach().cpu()), np.array(Y.detach().cpu()), border=0)
                ssim_v2_origin = SSIM(np.array(_X_origin.detach().cpu()), np.array(Y.detach().cpu()), border=0)
                list_psnr_origin.append(psnr_v2_origin)
                list_ssim_origin.append(ssim_v2_origin)
            list_psnr.append(psnr)
            list_psnr_v2.append(psnr_v2)
            list_ssim_v2.append(ssim_v2)

            final_ssim=0
            
    final_psnr_temp = sum(list_psnr)/len(list_psnr)
    final_psnr_v2 = sum(list_psnr_v2)/len(list_psnr_v2)
    final_ssim_v2 = sum(list_ssim_v2)/len(list_ssim_v2)
    if curr_epoch ==0 :
        final_psnr_v2_origin = sum(list_psnr_origin)/len(list_psnr_origin)
        final_ssim_v2_origin = sum(list_ssim_origin)/len(list_ssim_origin)
        log.write(f'===> final_psnr_v2_oigin: {final_psnr_v2_origin}, final_ssim_v2_origin: {final_ssim_v2_origin} \n')
            
    log.write(f'===> final_psnr_temp(origin matric) {final_psnr_temp} \n')
    log.write(f'===> psnr_v2: {final_psnr_v2}, ssim_v2: {final_ssim_v2} \n')
    return final_psnr_v2
    

# Training 

In [None]:
print(f"weight file would be saved at {snap_path}")

# training
best_psnr = 0
start_epoch = start_epoch+1
args.epochs = 840
print(f'start epoch is {start_epoch} | final epoch is {args.epochs}')
for epoch in range(start_epoch, args.epochs):
    
    model_Enc.train()
    model_Dec_SR.train()
    psnr_all=[]
    # generator
    model_Enc.train()
    model_Dec_Id.train()
    model_Dec_SR.train()

    # discriminator
    model_Disc_feat.train()
    model_Disc_img_LR.train()
    model_Disc_img_HR.train()
    
    running_loss_D_total = 0.0
    running_loss_G_total = 0.0

    running_loss_align = 0.0
    running_loss_rec = 0.0
    running_loss_res = 0.0
    running_loss_sty = 0.0
    running_loss_idt = 0.0
    running_loss_cyc = 0.0

    iter = 0
    for data,path in (train_loader):
        iter += 1
        X_t, Y_s = data['img_LQ'], data['img_GT']
        us4 = nn.Upsample(scale_factor=args.scale, mode='bicubic')
        ds4 = nn.Upsample(scale_factor=1/args.scale, mode='bicubic')
        X_s = ds4(Y_s)

        X_t = X_t.cuda(non_blocking=True)
        X_s = X_s.cuda(non_blocking=True)
        Y_s = Y_s.cuda(non_blocking=True)
        X_s_store=None
        
        with autocast(enabled=True):
        # real label and fake label
            batch_size = X_t.size(0)
            real_label = torch.full((batch_size, 1), 1, dtype=X_t.dtype).cuda(non_blocking=True)
            fake_label = torch.full((batch_size, 1), 0, dtype=X_t.dtype).cuda(non_blocking=True)
            model_Disc_feat.zero_grad()
            model_Disc_img_LR.zero_grad()
            model_Disc_img_HR.zero_grad()
            for i in range(args.n_disc):
            # generator output (feature domain)
                a=[]

                F_t = model_Enc(X_t)
                F_s = model_Enc(X_s)
                list_output_mixed_X_s=[]
                for idx, (_xt, _xs) in enumerate(zip(F_t, F_s)):
                    _path, val = _Storage_disc.get_minimum_scalar(_xs.detach())
                    if _path:
                        if i==0: # 저장해놓고 두번째 iter돌때 쓰기위해서
                            _, X_s_store = _Storage_disc.get_img(_path)
                        mixed_Xt = _Storage_disc.get_mixed_img(_xt,copy.deepcopy(X_s_store),alpha=0.1)
                    else:
                        mixed_Xt = _xt
                        X_s_store = copy.deepcopy(mixed_Xt.detach())
                    dist = 0.5-val if val else 0.5
                    list_output_mixed_X_s.append(np.array(mixed_Xt.detach().cpu()))
                F_t = torch.tensor(np.array(list_output_mixed_X_s,dtype=np.float32), requires_grad=True, dtype=torch.float32).cuda(non_blocking=True)

                output_Disc_F_t = model_Disc_feat(F_t.detach())
                output_Disc_F_s = model_Disc_feat(F_s.detach())
                loss_Disc_F_t = loss_MSE(output_Disc_F_t, fake_label)
                loss_Disc_F_s = loss_MSE(output_Disc_F_s, real_label)
                loss_Disc_feat_align = (loss_Disc_F_t + loss_Disc_F_s) / 2

                Y_s_s = model_Dec_SR(F_s)

                output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s.detach())
                output_Disc_Y_s = model_Disc_img_HR(Y_s)
                loss_Disc_Y_s_s = loss_MSE(output_Disc_Y_s_s, fake_label)
                loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label)
                loss_Disc_img_rec = (loss_Disc_Y_s_s + loss_Disc_Y_s) / 2

                X_s_t = model_Dec_Id(F_s)
                output_Disc_X_s_t = model_Disc_img_LR(X_s_t.detach())
                output_Disc_X_t = model_Disc_img_LR(X_t)
                loss_Disc_X_s_t = loss_MSE(output_Disc_X_s_t, fake_label)
                loss_Disc_X_t = loss_MSE(output_Disc_X_t, real_label)
                loss_Disc_img_sty = (loss_Disc_X_s_t + loss_Disc_X_t) / 2

                if args.cycle_mode:
                    Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s)))
                    output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s.detach())
                    output_Disc_Y_s = model_Disc_img_HR(Y_s)
                    loss_Disc_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, fake_label)
                    loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label)
                    loss_Disc_img_cyc = (loss_Disc_Y_s_t_s + loss_Disc_Y_s) / 2

                loss_D_total = loss_Disc_feat_align + loss_Disc_img_rec + loss_Disc_img_sty + loss_Disc_img_cyc
                if args.cycle_mode: loss_D_total += loss_Disc_img_cyc

                scaler.scale(loss_D_total).backward()
                scaler.step(optimizer_D)
                scaler.update()
            scheduler_D.step()


        with autocast(enabled=True):
            model_Enc.zero_grad()
            model_Dec_Id.zero_grad()
            model_Dec_SR.zero_grad()
            for i in range(args.n_gen):
                list_output_mixed_X_s=[]
                list_dist_softmax=[]
                F_t = model_Enc(X_t)
                F_s = model_Enc(X_s)
                for idx, (_xt, _xs) in enumerate(zip(F_t, F_s)):
                    _path, val = _Storage_disc.get_minimum_scalar(_xs.detach())
                    if _path and i==0:
                        mixed_Xt = _Storage_disc.get_mixed_img(_xt,copy.deepcopy(X_s_store),alpha=0.1)
                    else:
                        mixed_Xt = _xt
                    dist = 0.5-val if val else 0.5
                    list_dist_softmax.append(dist)
                    list_output_mixed_X_s.append(np.array(mixed_Xt.detach().cpu()))   
                    _Storage_disc.update_storage_by_representation_redesign(path['img_GT'][idx], feature=_xt.detach(),mode='maximum')
                list_dist_softmax = torch.reshape(torch.tensor(list_dist_softmax),[16,1]).cuda(non_blocking=True)
                F_t = torch.tensor(np.array(list_output_mixed_X_s,dtype=np.float32), requires_grad=True, dtype=torch.float32).cuda(non_blocking=True)

                output_Disc_F_t = model_Disc_feat(F_t)
                output_Disc_F_s = model_Disc_feat(F_s)
                loss_G_F_t = loss_MSE(output_Disc_F_t, list_dist_softmax) #mhkim
                loss_G_F_s = loss_MSE(output_Disc_F_s, abs(1-list_dist_softmax))
                L_align_E = loss_G_F_t + loss_G_F_s;
                Y_s_s = model_Dec_SR(F_s)

                output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s)
                loss_L1_rec = loss_L1(Y_s.detach(), Y_s_s)
                loss_percept_rec = loss_percept(Y_s.detach(), Y_s_s)
                loss_G_Y_s_s = loss_MSE(output_Disc_Y_s_s, real_label)
                L_rec_G_SR = loss_L1_rec + args.lambda_percept*loss_percept_rec + args.lambda_adv*loss_G_Y_s_s

                X_t_t = model_Dec_Id(F_t)
                L_res_G_t = loss_L1(X_t, X_t_t)

                X_s_t = model_Dec_Id(F_s)
                output_Disc_X_s_t = model_Disc_img_LR(X_s_t)
                loss_G_X_s_t = loss_MSE(output_Disc_X_s_t, real_label)
                L_sty_G_t = loss_G_X_s_t

                F_s_tilda = model_Enc(model_Dec_Id(F_s))
                L_idt_G_t = loss_L1(F_s, F_s_tilda)

                if args.cycle_mode:
                    Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s)))
                    output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s)
                    loss_L1_cyc = loss_L1(Y_s.detach(), Y_s_t_s)
                    loss_percept_cyc = loss_percept(Y_s.detach(), Y_s_t_s)
                    loss_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, real_label)
                    L_cyc_G_t_G_SR = loss_L1_cyc + args.lambda_percept*loss_percept_cyc + args.lambda_adv*loss_Y_s_t_s
            loss_G_total = args.lambda_align*L_align_E + args.lambda_rec*L_rec_G_SR + args.lambda_res*L_res_G_t + args.lambda_sty*L_sty_G_t + args.lambda_idt*L_idt_G_t
            if args.cycle_mode: loss_G_total += args.lambda_cyc*L_cyc_G_t_G_SR
            scaler.scale(loss_G_total).backward()
            scaler.step(optimizer_G)
            scaler.update()
        scheduler_G.step()
    
        running_loss_D_total += loss_D_total.item()
        running_loss_G_total += loss_G_total.item()

        running_loss_align += L_align_E.item()
        running_loss_rec += L_rec_G_SR.item()
        running_loss_res += L_res_G_t.item()
        running_loss_sty += L_sty_G_t.item()
        running_loss_idt += L_idt_G_t.item()
        running_loss_cyc += L_cyc_G_t_G_SR.item() if args.cycle_mode else 0
        psnr = calculate_psnr(Y_s * 255, Y_s_s * 255, crop_border=0, input_order='HWC')
        psnr_all.append(psnr)
        
    if (epoch+1) % args.save_freq == 0:
        weights_file_name = 'epoch_%d.pth' % (epoch+1)
        weights_file = os.path.join(snap_path,  weights_file_name)
        torch.save({
            'epoch': epoch+1,

            'model_Enc': model_Enc.state_dict(),
            'model_Dec_Id': model_Dec_Id.state_dict(),
            'model_Dec_SR': model_Dec_SR.state_dict(),
            'model_Disc_feat': model_Disc_feat.state_dict(),
            'model_Disc_img_LR': model_Disc_img_LR.state_dict(),
            'model_Disc_img_HR': model_Disc_img_HR.state_dict(),

            'optimizer_D': optimizer_D.state_dict(),
            'optimizer_G': optimizer_G.state_dict(),

            'scheduler_D': scheduler_D.state_dict(),
            'scheduler_G': scheduler_G.state_dict(),
            
            'storage_disc': _Storage_disc,
            'storage_gene': _Storage_gene,
        }, weights_file)
        log.write('save weights of epoch %d' % (epoch+1))
    
    ### validation
    if epoch > 460:
        mean_psnr = Test_v2(test_loader, snap_path.replace('weights','results'), log, \
                            [model_Enc, model_Dec_SR], curr_epoch = epoch+1, show_mode=False, n_interval=10, save_mode = True)

        if mean_psnr>best_psnr:
            weights_file_name = 'epoch_%d_BEST_PSNR.pth' % (epoch+1)
            weights_file = os.path.join(snap_path,  weights_file_name)
            best_psnr = mean_psnr
            torch.save({
                'epoch': epoch+1,

                'model_Enc': model_Enc.state_dict(),
                'model_Dec_Id': model_Dec_Id.state_dict(),
                'model_Dec_SR': model_Dec_SR.state_dict(),
                'model_Disc_feat': model_Disc_feat.state_dict(),
                'model_Disc_img_LR': model_Disc_img_LR.state_dict(),
                'model_Disc_img_HR': model_Disc_img_HR.state_dict(),

                'optimizer_D': optimizer_D.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),

                'scheduler_D': scheduler_D.state_dict(),
                'scheduler_G': scheduler_G.state_dict(),

                'storage_disc': _Storage_disc,
                'storage_gene': _Storage_gene,
            }, weights_file)
            log.write('save weights of epoch %d \n ' % (epoch+1))
        log.write('===> TRAIN epoch:%d, psnr: %f, lr:%f, loss_D_total:%f, loss_G_total:%f, loss_align:%f, loss_rec:%f, loss_res:%f, loss_sty:%f, loss_idt:%f, loss_cyc:%f \n ' %\
            (epoch, sum(psnr_all)/len(psnr_all), optimizer_G.param_groups[0]['lr'], running_loss_D_total/iter, running_loss_G_total/iter, running_loss_align/iter, running_loss_rec/iter, running_loss_res/iter, running_loss_sty/iter, running_loss_idt/iter, running_loss_cyc/iter))


