In [1]:
%pylab inline
import sys

sys.path.insert(0, '..')

import torch
from torch.autograd import Variable
from torch import FloatTensor
from torch.utils.data import DataLoader
import torch.utils.data
from skimage import data
import matplotlib.pyplot as plt
import skimage
import numpy as np
import os
from tqdm import tqdm
import glob

from utils.training_util import calculate_psnr, calculate_ssim, prep_for_vis
from utils.visualize_util import transpose_table, write_table_of_images

from utils.image_utils import center_crop_tensor
from utils import HTML
from models.model_utils import get_model
from data_generation.pipeline import ImageDegradationPipeline
from data_generation.constants import XYZ2sRGB, ProPhotoRGB2XYZ
from data_generation.data_utils import random_crop


torch.no_grad()


def numpy2tensor(arr):
    if len(arr.shape) < 3:
        arr = np.expand_dims(arr, -1)
    return FloatTensor(arr).permute(2, 0, 1).unsqueeze(0).float() / 255.0

def tensor2numpy(t, idx=None):
    t = torch.clamp(t, 0, 1)
    if idx is None:
        t = t[0, ...]
    else:
        t = t[idx, ...]
    return (t.detach().permute(1, 2, 0).cpu().squeeze().numpy() * 255.0).astype('uint8')


class RealCellphoneDSLRDataset(torch.utils.data.Dataset):

    def __init__(self,
                 clean_img_root,
                 noisy_img_root,
                 clean_img_ext='jpg',
                 noisy_img_ext='npy',
                 im_size=None,
                 crop_type='center',
                 clean_exp_adj=0.0,
                 noisy_exp_adj=0.0,
                 file_list = None,
                 n_patch_per_img=1,
                ):
        """
            Turn on with_context by default, because extra stuff in the output dicts
            doesn't matter, and we mostly do center crop anyway.
        """
        super().__init__()

        self._DEGRADED_EXT = noisy_img_ext
        self._TARGET_EXT = clean_img_ext
        self._DEGRADED_DIR = noisy_img_root
        self._TARGET_DIR = clean_img_root
        
        self.undo_srgb = ImageDegradationPipeline([
            ('UndosRGBGamma', {}),
        ])
        
        if file_list is None:
            file_list = glob.glob(os.path.join(self._DEGRADED_DIR,
                                               '*.' + self._DEGRADED_EXT))
            file_list = [os.path.basename(f) for f in file_list]
            file_list = [os.path.splitext(f)[0] for f in file_list]
        self.file_list = sorted(file_list)
        self.count = len(self.file_list)
        assert self.count > 0
        self.im_size = im_size
        self.crop_type = crop_type
        self.clean_exp_adj = ImageDegradationPipeline(
                [
                    ('ExposureAdjustment', {'nstops': clean_exp_adj}),
                    ('PixelClip', {}),
                ]
        )
        self.noisy_exp_adj = ImageDegradationPipeline(
                [
                    ('ExposureAdjustment', {'nstops': noisy_exp_adj}),
                    ('PixelClip', {}),
                ]
        )
        self.n_patch_per_img = n_patch_per_img
        self.crop_windows = [None] * self.count
        
    def _npy_loader(self, path):
        img = np.load(path).astype('float32')
        img = FloatTensor(img).permute(2, 0, 1)
        return img
    
    def _t7_loader(self, path):
        img = torch.load(path)
        return img
    
    def _jpg_loader(self, path):
        img = skimage.io.imread(path).astype('float32') / 255.0
        img = FloatTensor(img).permute(2, 0, 1).unsqueeze(0)
        return self.undo_srgb(img).squeeze(0)

    def _load_img(self, path):
        _, ext = os.path.splitext(path)
        if ext == ".jpg" or ext == ".png":
            img = self._jpg_loader(path)
        elif ext == ".npy":
            img = self._npy_loader(path)
        elif ext == ".t7":
            img = self._t7_loader(path)
        else:
            raise ValueError("Unrecognized extension received: {}".format(ext))
        return img
    
    def __getitem__(self, index):
        img_idx = index % self.count
        crop_idx = index // self.count
        # Load Degraded Image
        degraded_path = os.path.join(self._DEGRADED_DIR,
                                     self.file_list[img_idx] + \
                                             '.' + \
                                             self._DEGRADED_EXT)
        target_path = os.path.join(self._TARGET_DIR,
                                   self.file_list[img_idx] + \
                                           '.' + \
                                           self._TARGET_EXT)

        degraded = self._load_img(degraded_path)
#         degraded += torch.randn_like(degraded) * 0.1
        target = self._load_img(target_path)
#         degraded = target + torch.randn_like(target) * 0.1
        if self.im_size is not None:
            im = torch.cat([degraded, target], 0)

            crop_sz = self.im_size
                
            if self.crop_type == 'center':
                im = center_crop_tensor(im.unsqueeze(0), crop_sz[0], crop_sz[1])[0]
            elif self.crop_type == 'random':
                if self.crop_windows[img_idx] is None:
                    # randomize and memorize it.
                    w, h = self.im_size
                    tw = im.size(-1)
                    th = im.size(-2)
                    wn = []
                    for c in range(self.n_patch_per_img):
                        h0 = np.random.randint(th - h)
                        w0 = np.random.randint(tw - w)
                        h1 = h0 + h
                        w1 = w0 + w
                        wn.append((h0, w0, h1, w1))
                    self.crop_windows[img_idx] = wn
                h0, w0, h1, w1 = self.crop_windows[img_idx][crop_idx]
                im = im[..., h0:h1, w0:w1]
            else:
                raise ValueError("Invalid crop type received: {}".format(self.crop_type))
            im = im.squeeze()
            degraded, target = torch.split(im, int(im.size(0) / 2), dim=0)

        degraded = self.noisy_exp_adj(degraded.unsqueeze(0)).squeeze(0)
        target = self.clean_exp_adj(target.unsqueeze(0)).squeeze(0)
        
        data = {'degraded_img': degraded,
                'original_img': target}
        
 
        return data
                            
    def __len__(self):
        return self.count * self.n_patch_per_img

Populating the interactive namespace from numpy and matplotlib


# Set These Values

In [6]:
# base_html = PATH_TO_OUTPUT_HTML
base_html = '/afs/csail.mit.edu/u/t/tiam/public_html/phone2dslr/sandbox'

# Note that this clean images are averaged RAW and they are not properly tonemapped.

clean_img_root = "../samples/sample_test/clean"
noisy_img_root = "../samples/sample_test/noisy"
file_list = None # a list of string or None if to detect automatically
clean_ext = 't7'
noisy_ext = 't7'

In [7]:
from configobj import ConfigObj
from validate import Validator

from utils.training_util import load_statedict_runtime


def run_test(expname,
             max_test_batch,
             max_visual_batch,
             test_data,
             manual_seed=3,
             iteration='best',
             max_model_patch_size=None,
             model_patch_buffer=None,
            ):
    # if get_visual is True, focus on just getting the images
    # Load pytorch model
    # Get Configs
    config_file = '../denoiser_specs/{}.conf'.format(expname)
    config_spec = '../denoiser_specs/configspec.conf'

    configspec = ConfigObj(config_spec, raise_errors=True)
    config = ConfigObj(config_file, configspec=configspec, raise_errors=True, file_error=True)
    config.validate(Validator())
    
    # Build the model
    use_gpu = torch.cuda.is_available()
    model = get_model(config["architecture"], max_model_patch_size, model_patch_buffer)
    if use_gpu:
        model = model.cuda()

    # Load Checkpoint
    checkpoint_dir = config["training"]["checkpoint_dir"]
    print(expname)
    if iteration is not None:
        state_dict, global_iter = load_statedict_runtime(checkpoint_dir, iteration)
        model.load_state_dict(state_dict)
        print('Iteration = ', global_iter)
    else:
        global_iter = 0
        print('Not loading pretrained network')

    # Prepare to save output
    rows = []

    # Fivek Dataset
    # max_test_batch = float('inf')
    n_vis = 4
    batch_size = 1

    torch.manual_seed(manual_seed)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=1)
    
    configs_postprocess = [
        ('PixelClip', {}),
        ('sRGBGamma', {}),
    ]

    pipeline_postprocess = ImageDegradationPipeline(configs_postprocess)
    
    n = 0.0
    psnr = 0.0
    input_psnr = 0.0
    ssim = 0.0
    input_ssim = 0.0
    model.eval()
    n_test_batch = len(test_loader)
    n_test_batch = min(n_test_batch, max_test_batch)
    pbar = tqdm(total=n_test_batch, desc="Evaluating batches")
    input_imglist = []
    gt_imglist = []
    output_imglist = []
#     model = model.train()
    model = model.eval()
    for iter, batch in enumerate(test_loader):
        if iter > n_test_batch:
            break
        if use_gpu:
            degraded_img = batch['degraded_img'].cuda()
            target_img = batch['original_img'].cuda()
            if 'context_img' in batch and use_context:
                context_img = batch['context_img'].cuda()
                extra_args = {'context_img': context_img}
            else:
                extra_args = {}
        else:
            degraded_img = batch['degraded_img']
            target_img = batch['original_img']
            if 'context_img' in batch and use_context:
                context_img = batch['context_img']
                extra_args = {'context_img': context_img}
            else:
                extra_args = {}

        output_img = model(degraded_img, extra_args=extra_args)

#         min_v = torch.min(output_img)
#         max_v = torch.max(output_img)
    #     output_img = (output_img - min_v) / (max_v - min_v)
        exp = batch['vis_exposure'] if 'vis_exposure' in batch else None
        psnr += calculate_psnr(output_img, target_img)
        input_psnr += calculate_psnr(degraded_img, target_img)
        ssim += calculate_ssim(output_img, target_img)
        input_ssim += calculate_ssim(degraded_img, target_img)
        n += 1.0
        if max_visual_batch > iter:
            d = pipeline_postprocess(degraded_img) * 255.0
            t = pipeline_postprocess(target_img) * 255.0
            o = pipeline_postprocess(output_img) * 255.0
            d = torch.clamp(d, 0, 255)
            t = torch.clamp(t, 0, 255)
            o = torch.clamp(o, 0, 255)
            d = d.cpu().permute(0, 2, 3, 1).data.numpy().astype('uint8')
            t = t.cpu().permute(0, 2, 3, 1).data.numpy().astype('uint8')
            o = o.cpu().permute(0, 2, 3, 1).data.numpy().astype('uint8')
            # d, t, o = prep_for_vis(degraded_img, target_img, output_img, exp)
            d = np.squeeze(d)
            t = np.squeeze(t)
            o = np.squeeze(o)
            input_imglist.append(d)
            gt_imglist.append(t)
            output_imglist.append(o)
        pbar.update(1)
    pbar.close()
    psnr /= n
    input_psnr /= n
    ssim /= n
    input_ssim /= n
    return input_imglist, \
           gt_imglist, \
           output_imglist, \
           psnr, input_psnr, \
           ssim, input_ssim, \
           global_iter



## Get Visual Comparison

In [8]:
# Experiments to compare

expnames = [
                ('full_dataset_n3net', 'latest', 'N3Net'),
           ]


# Dataset

In [9]:
# S7_ISP

url_name = "n3net"
DATASET_NAME = "tiam_iphone8"

crop_size = (200, 200)
n_patch_per_img = 3

test_data = RealCellphoneDSLRDataset(clean_img_root=clean_img_root,
                                     noisy_img_root=noisy_img_root,
                                     clean_img_ext=clean_ext,
                                     noisy_img_ext=noisy_ext,
                                     file_list=file_list,
                                     im_size=crop_size,
                                     n_patch_per_img=n_patch_per_img)

In [10]:
output_html = '{}/index.html'.format(DATASET_NAME + '_' + url_name)
output_html = os.path.join(base_html, output_html)

manual_seed = 5
n_images = 25
n_psnr_image = 500
# n_psnr_image = 5
table = []
for i in range(len(expnames)):
    expname = expnames[i][0]
    iter_to_load = expnames[i][1]
    disp_name = expnames[i][2]
    input_img, gt_img, output_img, psnr, input_psnr, ssim, input_ssim, global_iter = run_test(expname,
                                                                n_psnr_image,
                                                                n_images,
                                                                manual_seed=manual_seed,
                                                                iteration=iter_to_load,
                                                                test_data=test_data
                                                                )
    output_img.insert(0, '{} <br> PSNR = {:0.4g} dB <br> SSIM = {:0.4g} <br> ({} iter = {})'.format(disp_name,
                                                                                psnr,
                                                                                ssim,
                                                                                str(iter_to_load),
                                                                                global_iter))
    table.append(output_img)
input_img.insert(0, 'Noisy Images <br> PSNR = {:0.4g} dB <br> SSIM = {:0.4g}'.format(input_psnr, input_ssim))
gt_img.insert(0, 'Clean Images')
table.insert(0, input_img)
table.insert(0, gt_img)

table = transpose_table(table)
write_table_of_images(output_html, table)
print(output_html)
print('Done')

Using No Normalization


Evaluating batches:   0%|          | 0/30 [00:00<?, ?it/s]

full_dataset_n3net
Iteration =  62000


Evaluating batches: 100%|██████████| 30/30 [00:21<00:00,  1.40it/s]
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)


/afs/csail.mit.edu/u/t/tiam/public_html/phone2dslr/sandbox/tiam_iphone8_n3net/index.html
Done


In [None]:
print(output_html)