In [1]:
from __future__ import division
import os, time
import numpy as np
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import sys
from optparse import OptionParser
import numpy as np
from torch import optim
from PIL import Image
from torch.autograd import Function, Variable
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset
# import cv2
import pickle
from tqdm import tqdm
import rawpy
%matplotlib inline

from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

In [2]:
if torch.cuda.is_available():
    deviceTag = torch.device('cuda')
else:
    deviceTag = torch.device('cpu')
print(deviceTag)

cuda


In [3]:
def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out

def process_img(input_raw_img, model, ratio):
    ## Process image(s) using the given model
    # input_raw_img: numpy array, dimension: (Batch,Height,Width,Channel)
    # ratio: numpy array, dimension: (Batch,)
    model.eval();
    model.to(deviceTag)
    ratio = ratio.reshape(ratio.shape[0],1,1,1)
    input_raw_img = np.transpose(input_raw_img, [0,3,1,2]).astype('float32')*ratio
    input_tensor = torch.from_numpy(input_raw_img.copy()).float().to(deviceTag)
    with torch.no_grad():
        output_tensor = model(input_tensor)
    output_img = output_tensor.cpu().numpy()
    output_img = np.transpose(output_img, [0,2,3,1])
    
    return output_img

def ssim_numpy(img1, img2):
    img1, img2 = torch.tensor(np.transpose(img1,(2,0,1))).unsqueeze(0), torch.tensor(np.transpose(img2,(2,0,1))).unsqueeze(0)
    ssim_calc = ssim(img1,img2, data_range=1, size_average=True)
    msssim_calc = ms_ssim(img1,img2, data_range=1, size_average=True)
    return ssim_calc.numpy(), msssim_calc.numpy()
    
def validate(model, input_list, gt_list, block_size = None, batch_size = 8, save_img_dir = None):
    assert len(input_list) == len(gt_list)
    
    model.eval();
    PSNR_list = []
    SSIM_list = []
    MSSSIM_list = []
    
    for i in range(len(input_list)//batch_size):
        if i%10 == 0:
            print(i)
        input_raw_img_batch = []
        gt_img_batch = []
        ratio_batch = []
        for b in range(batch_size):
            if i*batch_size+b < len(input_list):
                in_path = input_list[i*batch_size+b]
                gt_path = gt_list[i*batch_size+b]
            else:
                break
            in_fn = os.path.basename(in_path)
            gt_fn = os.path.basename(gt_path)
            in_exposure = float(in_fn[9:-5])
            gt_exposure = float(gt_fn[9:-5])
            ratio = min(gt_exposure / in_exposure, 300)
        
            raw = rawpy.imread(in_path)
            input_raw_img = pack_raw(raw)
            
            gt_raw = rawpy.imread(gt_path)
            gt_img = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt_img = np.float32(gt_img / 65535.0)
            
            if block_size is not None:
                i_cut, j_cut = np.random.randint(0,input_raw_img.shape[0]-block_size), np.random.randint(0,input_raw_img.shape[1]-block_size)
                gt_img = gt_img[i_cut*2:i_cut*2+block_size*2, j_cut*2:j_cut*2+block_size*2, :]
                input_raw_img = input_raw_img[i_cut:i_cut+block_size, j_cut:j_cut+block_size, :]
            
            ratio_batch.append(ratio)
            input_raw_img_batch.append(input_raw_img)
            gt_img_batch.append(gt_img)
        
        input_raw_img_batch = np.array(input_raw_img_batch)
        ratio_batch = np.array(ratio_batch)
        gt_img_batch = np.array(gt_img_batch)
        
        output_img_batch = process_img(input_raw_img_batch, model, ratio_batch)
        if save_img_dir is not None:
            plt.imsave(save_img_dir+'{}_gt.png'.format(i),gt_img_batch[0,:,:,:])
            plt.imsave(save_img_dir+'{}_out.png'.format(i),output_img_batch[0,:,:,:])
        MSE = np.mean((output_img_batch.reshape(output_img_batch.shape[0],-1) - gt_img_batch.reshape(gt_img_batch.shape[0],-1))**2, axis = 1)
        PSNR_batch = 10*np.log10(1/MSE)
        PSNR_list.append(list(PSNR_batch))
        ssim_calc, msssim_calc = ssim_numpy(gt_img_batch[0,:,:,:], output_img_batch[0,:,:,:])
        SSIM_list.append(ssim_calc)
        MSSSIM_list.append(msssim_calc)
    
    Val_PSNR = np.mean(PSNR_list)
    Val_SSIM = np.mean(SSIM_list)
    Val_MSSSIM = np.mean(MSSSIM_list)
    return Val_PSNR, Val_SSIM, Val_MSSSIM

## PRIDNet

In [4]:
class conv_lrelu(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(conv_lrelu, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_ch,out_ch,3, padding = 1),nn.LeakyReLU())

    def forward(self, x):
        x = self.conv(x)
        return x
    
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.conv1 = conv_lrelu(in_ch,out_ch)
        self.conv2 = conv_lrelu(out_ch,out_ch)
        self.down =  nn.MaxPool2d((2,2))
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.down(x)
        return x
    

class up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up, self).__init__()       
        self.up =  nn.UpsamplingBilinear2d(scale_factor = 2)
        self.conv1 = conv_lrelu(in_ch,out_ch) 
        self.conv2 = conv_lrelu(out_ch,out_ch) 

    def forward(self, x1, x2):
        x1 = self.up(x1)
        if x1.shape != x2.shape:
            x1 = transforms.functional.resize(x1, x2.shape[2:])
        x = torch.cat([x2, x1], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class UNet(nn.Module):
    def __init__(self, in_ch = 4, CH_PER_SCALE = [32,64,128,256,512], out_ch = 12):
        super(UNet, self).__init__()
        self.inc = conv_lrelu(in_ch, CH_PER_SCALE[0])
        self.inc2 = conv_lrelu(CH_PER_SCALE[0], CH_PER_SCALE[0])
        self.down1 = down(CH_PER_SCALE[0], CH_PER_SCALE[1]) 
        self.down2 = down(CH_PER_SCALE[1],CH_PER_SCALE[2])
        self.down3 = down(CH_PER_SCALE[2],CH_PER_SCALE[3])                
        self.down4 = down(CH_PER_SCALE[3],CH_PER_SCALE[4])                
        self.up1 = up(CH_PER_SCALE[4]+CH_PER_SCALE[3],CH_PER_SCALE[3])
        self.up2 = up(CH_PER_SCALE[3]+CH_PER_SCALE[2],CH_PER_SCALE[2])
        self.up3 = up(CH_PER_SCALE[2]+CH_PER_SCALE[1],CH_PER_SCALE[1])
        self.up4 = up(CH_PER_SCALE[1]+CH_PER_SCALE[0],CH_PER_SCALE[0])
        self.outc = nn.Conv2d(CH_PER_SCALE[0], out_ch, 1, padding = 0)

    def forward(self, x):
        x0 = self.inc(x)
        x0 = self.inc2(x0)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x3_up = self.up1(x4,x3)
        x2_up = self.up2(x3_up,x2)
        x1_up = self.up3(x2_up,x1)
        out = self.up4(x1_up,x0)
        out = self.outc(out)
#         out = F.pixel_shuffle(out,2) ## Paper final step rearranges 12 channes to 3 RGB channels
#         out = F.hardtanh(out, min_val=0, max_val=1) #Clamp the top and bottom to 0,1 since pixels can only be in this value
        return out
    
class PRIDNet(nn.Module):
    def __init__(self, in_ch = 4, out_ch = 12):
        super(PRIDNet, self).__init__()
        self.feature_extraction = nn.Sequential(conv_lrelu(in_ch, 32), *[conv_lrelu(32, 32) for i in range(3)])
        self.unet0 = UNet(in_ch = 32, out_ch = 12)
        self.unet1 = UNet(in_ch = 32, out_ch = 12)
        self.unet2 = UNet(in_ch = 32, out_ch = 12)
        self.unet3 = UNet(in_ch = 32, out_ch = 12)
        self.unet4 = UNet(in_ch = 32, out_ch = 12)
        self.avgpool1 = nn.AvgPool2d((2,2))
        self.avgpool2 = nn.AvgPool2d((4,4))
        self.avgpool3 = nn.AvgPool2d((8,8))
        self.avgpool4 = nn.AvgPool2d((16,16))
        self.up4 =  nn.UpsamplingBilinear2d(scale_factor = 16)
        self.up3 =  nn.UpsamplingBilinear2d(scale_factor = 8)
        self.up2 =  nn.UpsamplingBilinear2d(scale_factor = 4)
        self.up1 =  nn.UpsamplingBilinear2d(scale_factor = 2)
        self.out =  nn.Conv2d(32+12*5, out_ch, 1, padding = 0)

    def forward(self, x):
        x_feat = self.feature_extraction(x)
        x0 = self.unet0(x_feat)
        x1 = self.up1(self.unet1(self.avgpool1(x_feat)))
        x2 = self.up2(self.unet2(self.avgpool2(x_feat)))
        x3 = self.up3(self.unet3(self.avgpool3(x_feat)))
        x4 = self.up4(self.unet4(self.avgpool4(x_feat)))
        if x1.shape != x0.shape:
            x1 = transforms.functional.resize(x1, x0.shape[2:])
        if x2.shape != x0.shape:
            x2 = transforms.functional.resize(x2, x0.shape[2:])
        if x3.shape != x0.shape:
            x3 = transforms.functional.resize(x3, x0.shape[2:])
        if x4.shape != x0.shape:
            x4 = transforms.functional.resize(x4, x0.shape[2:])
        x_unet_all = torch.cat([x_feat,x0,x1,x2,x3,x4], axis = 1)
        out = self.out(x_unet_all)
        
        out = F.pixel_shuffle(out,2) ## Paper final step rearranges 12 channes to 3 RGB channels
        out = F.hardtanh(out, min_val=0, max_val=1) #Clamp the top and bottom to 0,1 since pixels can only be in this value
        return out
    
    def load_my_state_dict(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            #if isinstance(param, self.Parameter):
            else:
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)

In [5]:
# model = UNet(in_ch = 4, CH_PER_SCALE = [32,64,128,256,512], out_ch = 12)
# model = model.cuda()
# # model.load_state_dict(torch.load('./results_Sony/sony3830.pth'))
# model.load_state_dict(torch.load('./results_UNet/net_weights/sony3996.pth'))

model = PRIDNet()
model = model.cuda()
model.load_state_dict(torch.load('./results_Sony/sony3830.pth'))
# model.load_state_dict(torch.load('./results_SSIM/net_weights/sony2601.pth'))

<All keys matched successfully>

In [6]:
with open('./Dataset/Sony_val_raw_list.pickle','rb') as f:
    val_raw_list = pickle.load(f)
val_raw_list = ['./Dataset/'+path for path in val_raw_list]
with open('./Dataset/Sony_val_gt_list.pickle','rb') as f:
    val_gt_list = pickle.load(f)
val_gt_list = ['./Dataset/'+path for path in val_gt_list]

Val_PSNR, Val_SSIM, Val_MSSSIM = validate(model, val_raw_list, val_gt_list,block_size = None, batch_size = 1, 
                                          save_img_dir = './Output_Images_PRIDNet_MSE/')

0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230


In [7]:
print(Val_PSNR)
print(Val_SSIM)
print(Val_MSSSIM)

28.202051
0.7219911
0.8361999
