In [1]:
import torch
from pathlib import Path
import sys
import os
import ipywidgets as widgets
import numpy as np


from basicsr.models import create_model
from basicsr.utils import tensor2img
from basicsr.utils.options import parse
from basicsr.data import create_dataset
import matplotlib.pyplot as plt
from basicsr.metrics import calculate_psnr
from basicsr.train import parse_options

from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
                           make_exp_dirs)
from os import path as osp
import logging
from basicsr.utils.options import dict2str
import matplotlib.image

In [2]:
def load_model(yml_path, model_path):
    # load model
    opt_path = Path(yml_path) # your *training* options file

    opt = parse(opt_path)
    opt['dist'] = False
    opt['path']['pretrain_network_g'] = model_path
    opt['is_train'] = False
    model = create_model(opt)
    
    return model
 

def draw_stuff(model, dataset, i, name):
    fig, axs = plt.subplots(1, 3, figsize=(16, 6))
    for ax in axs.flatten(): ax.axis('off')

    lq = dataset[i]['lq'].unsqueeze(0)
    gt = dataset[i]['gt'].unsqueeze(0)
    y = model.net_g(lq)

    psnr = np.round(calculate_psnr(y, gt, crop_border=0), 2)

    axs[0].imshow(tensor2img(lq, rgb2bgr=False))
    axs[0].set_title('noisy')

    axs[1].imshow(tensor2img(y, rgb2bgr=False))
    axs[1].set_title(f'output, PSNR={psnr}')

    axs[2].imshow(tensor2img(gt, rgb2bgr=False))
    axs[2].set_title('gt')

    plt.suptitle(f'SIDD_val img {i}', fontsize=14)

In [3]:
# load dataset

dataset = create_dataset(dict(
    name='SIDD',
    type='PairedImageDataset',
    dataroot_gt=f'/CascadedGaze/datasets/SIDD/val/gt_crops.lmdb',
    dataroot_lq=f'/CascadedGaze/datasets/SIDD/val/input_crops.lmdb',
    filename_tmpl='{}',
    io_backend=dict(type='lmdb'),
    scale=None,
    phase=None
))

# view dataset
@widgets.interact
def f(i=(0, len(dataset)-1), version=['lq', 'gt']):
    plt.imshow(tensor2img(dataset[i][version], rgb2bgr=False))
    plt.axis('off')

2023-08-18 13:40:54,016 INFO: Dataset PairedImageDataset - SIDD is created.


interactive(children=(IntSlider(value=639, description='i', max=1279), Dropdown(description='version', options…

In [4]:
i = 541

In [4]:
yml_path = "/CascadedGaze/options/test/SIDD/CascadedGaze-SIDD.yml"
model_path = "trained model path"
name = "CascadedGaze"
model = load_model(yml_path, model_path)
draw_stuff(model, dataset, i, name)