In [3]:
from mmedit.models.restorers import basicvsr
from mmedit.models.restorers import real_basicvsr
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import os
import mmcv
from mmedit.core import tensor2img 
from mmcv import Config


models = [
    basicvsr.BasicVSR(
    generator=dict(
        type='BasicVSRPlusPlus',
        mid_channels=64,
        num_blocks=10,
        max_residue_magnitude = 10,
        is_low_res_input=True,
        spynet_pretrained='https://download.openmmlab.com/mmediting/restorers/'
        'basicvsr/spynet_20210409-c6c1bd09.pth',
        cpu_cache_length = 100),
    pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean')
    ),
    basicvsr.BasicVSR(
    generator=dict(
        type='BasicVSRPlusPlus',
        mid_channels=128,
        num_blocks=25,
        max_residue_magnitude = 10,
        is_low_res_input=True,
        spynet_pretrained='https://download.openmmlab.com/mmediting/restorers/'
        'basicvsr/spynet_20210409-c6c1bd09.pth',
        cpu_cache_length = 100),
    pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean')
    ),
    basicvsr.BasicVSR(
    generator=dict(
        type='BasicVSRPlusPlus',
        mid_channels=64,
        num_blocks=7,
        is_low_res_input=True,
        spynet_pretrained='https://download.openmmlab.com/mmediting/restorers/'
        'basicvsr/spynet_20210409-c6c1bd09.pth'),
    pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='mean'))

]

PTHLIST = [
    # (r"D:\mmediting\Trained\basicvsr_plusplus_c64n7_8x1_600k_reds4_LLNoNoise_BlueCrystal\iter_440000.pth","basicvsr_plusplus_c64n7_8x1_440k_reds4_VemioResume_LLNN",models[2]),
    # (r"D:\mmediting\Trained\basicvsr_plusplus_c64n7_8x1_600k_reds4_NLNN\iter_50000.pth","basicvsr_plusplus_c64n7_8x1_600k_reds4_FromScratch_NLNN",models[2]),
    # (r"D:\mmediting\Trained\BasicVSRPP_lowres_Resume_NLNN\iter_601800.pth","basicvsr_plusplus_c64n7_8x1_600k_reds4_REDSResume_NLNN",models[2]),
    # (r"D:\result\LLNN\iter_25000.pth","basicvsr_plusplus_c64n7_8x1_600k_reds4_FromScratch_LLNN",models[2]),
    # (r"D:\mmediting\Trained\basicvsr_plusplus_c64n7_8x1_600k_reds4_FromScratch_LLWN\iter_25000.pth","basicvsr_plusplus_c64n7_8x1_600k_reds4_FromScratch_LLWN",models[2]),
    # (r"D:\pretrained\basicvsr_plusplus_c64n7_8x1_300k_vimeo90k_bi_20210305-4ef437e2.pth","vimeoPretrained",models[2]),
    # (r"D:\pretrained\basicvsr_plusplus_c64n7_8x1_600k_reds4_20210217-db622b2f.pth","REDSPretrained",models[2]),
    (r"D:\mmediting\work_dirs\basicvsr_plusplus_c64n7_8x1_600k_reds4_NLWN\iter_20000.pth","basicvsr_plusplus_c64n7_8x1_600k_reds4_REDSResume_NLNN_20000iter",models[2]),
    # (r"D:\pretrained\basicvsr_plusplus_c128n25_ntire_decompress_track1_20210223-7b2eba02.pth","REDSPretrained_c128n5",models[1])



    
]



def calculate(pth_file,lq_source,model,num_img):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    sd = torch.load(pth_file)["state_dict"]

    #Load Pretrained model
    model.load_state_dict(sd)

    #read images from folder
    img_batch = []
    for root, dirs, files in os.walk(lq_source):
        for f in files:
            if f.endswith(".png"):
                img =mmcv.imread(str.join("\\",[root,f]),channel_order="rgb")
                #RescaleToZeroOne
                img = img.astype(np.float32) / 255.0
                img = torch.from_numpy(img.transpose(2, 0, 1))
                img_batch.append(img)
    
    bat = torch.stack([torch.stack(img_batch)[:num_img]])
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    torch.cuda.empty_cache()
    with torch.no_grad():
        result = model.forward_test(bat.to(device).float())
    out_res = result["output"]
    return out_res


def saveimg(imgs,img_save_dir):
    # save image
    
    i = 0
    os.makedirs(img_save_dir)
    for img in imgs[0]:
       
        img = tensor2img(img)
        mmcv.imwrite(img,img_save_dir+"\\"+str(i)+".png")
        i += 1

def isInbonded(imgs):
    for img in imgs[0]:
        a = img <= 255
        b = img >= 0
        if not a.all() and b.all():
            print(a,b)
            return False
        print("Pass")
        return True



2022-08-04 16:35:41,310 - mmedit - INFO - load checkpoint from http path: https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth
2022-08-04 16:35:41,439 - mmedit - INFO - load checkpoint from http path: https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth
2022-08-04 16:35:41,959 - mmedit - INFO - load checkpoint from http path: https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth


In [4]:


lq_source = r"D:\Dataset\dataset_REDS\Synthetic4x_LLNoNoise\test\000"
num_img = 100
proc_type = "LLNNImages"
# cfg = Config.fromfile(r"D:\mmediting\configs\restorers\basicvsr_plusplus\basicvsr_plusplus_c64n7_8x1_600k_reds4.py")


for pth in PTHLIST:
    path_split = lq_source.split("\\") 
    out_dir_name = pth[1]+"\\"+path_split[-2]+"\\"+path_split[-1]+"_"+proc_type
    out_path = "D:\\mmediting\\Images\\\Demoinput\\"+out_dir_name
    if os.path.exists(out_path):
        continue
    out = calculate(pth[0],lq_source,pth[2],num_img)
    saveimg(out,out_path)
    if not isInbonded(out):
        print("OUTBOUNDED")
    

Pass
