## Test The Model
This Program is to find out PSNR on test set and provide predicted output

In [1]:
import os
import numpy as np
import time
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
import visdom
import rawpy
import glob
from PIL import Image
import matplotlib.pyplot as plt

### Model

In [2]:
class LeakyReLU(nn.Module):

    def __init__(self):
        super(LeakyReLU, self).__init__()

    def forward(self, x):
        return torch.max(x * 0.2, x)

class UNetConvBlock(nn.Module):

    def __init__(self, in_channel, out_channel):
        super(UNetConvBlock, self).__init__()
        self.UNetConvBlock = torch.nn.Sequential(
            nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1),
            LeakyReLU(),
            nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1),
            LeakyReLU()
        )

    def forward(self, x):
        return self.UNetConvBlock(x)

class UNet(nn.Module):

    def __init__(self):
        super(UNet, self).__init__()
        self.conv1 = UNetConvBlock(4, 32)   #We have 4 Channel (R, G, B G)- Bayer Pattern Input
        self.conv2 = UNetConvBlock(32, 64)
        self.conv3 = UNetConvBlock(64, 128)
        self.conv4 = UNetConvBlock(128, 256)
        self.conv5 = UNetConvBlock(256, 512)
        self.up6 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv6 = UNetConvBlock(512, 256)
        self.up7 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv7 = UNetConvBlock(256, 128)
        self.up8 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv8 = UNetConvBlock(128, 64)
        self.up9 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.conv9 = UNetConvBlock(64, 32)
        self.conv10 = nn.Conv2d(in_channels=32, out_channels=12, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = F.max_pool2d(conv1, kernel_size=2)

        conv2 = self.conv2(pool1)
        pool2 = F.max_pool2d(conv2, kernel_size=2)

        conv3 = self.conv3(pool2)
        pool3 = F.max_pool2d(conv3, kernel_size=2)

        conv4 = self.conv4(pool3)
        pool4 = F.max_pool2d(conv4, kernel_size=2)

        conv5 = self.conv5(pool4)

        up6 = self.up6(conv5)
        up6 = torch.cat([up6, conv4], 1)
        conv6 = self.conv6(up6)

        up7 = self.up7(conv6)
        up7 = torch.cat([up7, conv3], 1)
        conv7 = self.conv7(up7)
        
        up8 = self.up8(conv7)
        up8 = torch.cat([up8, conv2], 1)
        conv8 = self.conv8(up8)

        up9 = self.up9(conv8)
        up9 = torch.cat([up9, conv1], 1)
        conv9 = self.conv9(up9)

        conv10 = self.conv10(conv9)
        out = F.pixel_shuffle(conv10, 2)

        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.02)
                if m.bias is not None:
                    m.bias.data.normal_(0.0, 0.02)
            if isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0.0, 0.02)

Dataset & Path

In [3]:
ShortExposure = './Sony/short/'
LongExposure = './Sony/long/'
ResultFolder = './Results/'

listImage = glob.glob(LongExposure + '/1*.ARW')
imageList = [int(os.path.basename(singleImage)[0:5]) for singleImage in listImage]

### Detarming Black Lebel

In [4]:
imgBlack = rawpy.imread('./Sony/short/00001_00_0.04s.ARW')
BlackCh = imgBlack.black_level_per_channel[0]
BlackMax = np.max(imgBlack.raw_image)
print(BlackCh, BlackMax)

512 16383


### Convert Bayer pattern 4 channels R,G,B,G before passing to U-Net

In [5]:
def rgbg(imgRaw):
    img = imgRaw.raw_image_visible.astype(np.float32)
    img = np.maximum(img - BlackCh, 0) / (BlackMax - BlackCh)
    img = np.expand_dims(img, axis=2)
    S0, S1 = img.shape[0], img.shape[1]

    grbgCh = np.concatenate((img[0:S0:2, 0:S1:2, :], img[0:S0:2, 1:S1:2, :], img[1:S0:2, 1:S1:2, :], img[1:S0:2, 0:S1:2, :]), axis=2)
    return grbgCh

## Calculate PSNR

In [6]:
def testPsnr(A, B):
    Ch, Hight, Width = A.shape
    #sum_psnr = 0 
    output = np.clip(B, 0.0, 1.0)
    mse = np.sum((A - B)**2)/(Ch*Hight*Width)
    psnr =  -10*np.log10(mse)
    return psnr
    #print(psnr)

## Save Image
I took this saving blocks of code from stackoverflow - It can handle with different no of Chanel and cliping functionality. So, I feel comfortable using this 

In [7]:
import numpy as np
from PIL import Image


_errstr = "Mode is unknown or incompatible with input array shape."


def bytescale(data, cmin=None, cmax=None, high=255, low=0):
    """
    Byte scales an array (image).
    Byte scaling means converting the input image to uint8 dtype and scaling
    the range to ``(low, high)`` (default 0-255).
    If the input image already has dtype uint8, no scaling is done.
    This function is only available if Python Imaging Library (PIL) is installed.
    Parameters
    ----------
    data : ndarray
        PIL image data array.
    cmin : scalar, optional
        Bias scaling of small values. Default is ``data.min()``.
    cmax : scalar, optional
        Bias scaling of large values. Default is ``data.max()``.
    high : scalar, optional
        Scale max value to `high`.  Default is 255.
    low : scalar, optional
        Scale min value to `low`.  Default is 0.
    Returns
    -------
    img_array : uint8 ndarray
        The byte-scaled array.
    Examples
    --------
    >>> from scipy.misc import bytescale
    >>> img = np.array([[ 91.06794177,   3.39058326,  84.4221549 ],
    ...                 [ 73.88003259,  80.91433048,   4.88878881],
    ...                 [ 51.53875334,  34.45808177,  27.5873488 ]])
    >>> bytescale(img)
    array([[255,   0, 236],
           [205, 225,   4],
           [140,  90,  70]], dtype=uint8)
    >>> bytescale(img, high=200, low=100)
    array([[200, 100, 192],
           [180, 188, 102],
           [155, 135, 128]], dtype=uint8)
    >>> bytescale(img, cmin=0, cmax=255)
    array([[91,  3, 84],
           [74, 81,  5],
           [52, 34, 28]], dtype=uint8)
    """
    if data.dtype == np.uint8:
        return data

    if high > 255:
        raise ValueError("`high` should be less than or equal to 255.")
    if low < 0:
        raise ValueError("`low` should be greater than or equal to 0.")
    if high < low:
        raise ValueError("`high` should be greater than or equal to `low`.")

    if cmin is None:
        cmin = data.min()
    if cmax is None:
        cmax = data.max()

    cscale = cmax - cmin
    if cscale < 0:
        raise ValueError("`cmax` should be larger than `cmin`.")
    elif cscale == 0:
        cscale = 1

    scale = float(high - low) / cscale
    bytedata = (data - cmin) * scale + low
    return (bytedata.clip(low, high) + 0.5).astype(np.uint8)


def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None, mode=None, channel_axis=None):
    """Takes a numpy array and returns a PIL image.
    This function is only available if Python Imaging Library (PIL) is installed.
    The mode of the PIL image depends on the array shape and the `pal` and
    `mode` keywords.
    For 2-D arrays, if `pal` is a valid (N,3) byte-array giving the RGB values
    (from 0 to 255) then ``mode='P'``, otherwise ``mode='L'``, unless mode
    is given as 'F' or 'I' in which case a float and/or integer array is made.
    .. warning::
        This function uses `bytescale` under the hood to rescale images to use
        the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``.
        It will also cast data for 2-D images to ``uint32`` for ``mode=None``
        (which is the default).
    Notes
    -----
    For 3-D arrays, the `channel_axis` argument tells which dimension of the
    array holds the channel data.
    For 3-D arrays if one of the dimensions is 3, the mode is 'RGB'
    by default or 'YCbCr' if selected.
    The numpy array must be either 2 dimensional or 3 dimensional.
    """
    data = np.asarray(arr)
    if np.iscomplexobj(data):
        raise ValueError("Cannot convert a complex-valued array.")
    shape = list(data.shape)
    valid = len(shape) == 2 or ((len(shape) == 3) and
                                ((3 in shape) or (4 in shape)))
    if not valid:
        raise ValueError("'arr' does not have a suitable array shape for "
                         "any mode.")
    if len(shape) == 2:
        shape = (shape[1], shape[0])  # columns show up first
        if mode == 'F':
            data32 = data.astype(np.float32)
            image = Image.frombytes(mode, shape, data32.tostring())
            return image
        if mode in [None, 'L', 'P']:
            bytedata = bytescale(data, high=high, low=low,
                                 cmin=cmin, cmax=cmax)
            image = Image.frombytes('L', shape, bytedata.tostring())
            if pal is not None:
                image.putpalette(np.asarray(pal, dtype=np.uint8).tostring())
                # Becomes a mode='P' automagically.
            elif mode == 'P':  # default gray-scale
                pal = (np.arange(0, 256, 1, dtype=np.uint8)[:, np.newaxis] *
                       np.ones((3,), dtype=np.uint8)[np.newaxis, :])
                image.putpalette(np.asarray(pal, dtype=np.uint8).tostring())
            return image
        if mode == '1':  # high input gives threshold for 1
            bytedata = (data > high)
            image = Image.frombytes('1', shape, bytedata.tostring())
            return image
        if cmin is None:
            cmin = np.amin(np.ravel(data))
        if cmax is None:
            cmax = np.amax(np.ravel(data))
        data = (data*1.0 - cmin)*(high - low)/(cmax - cmin) + low
        if mode == 'I':
            data32 = data.astype(np.uint32)
            image = Image.frombytes(mode, shape, data32.tostring())
        else:
            raise ValueError(_errstr)
        return image

    # if here then 3-d array with a 3 or a 4 in the shape length.
    # Check for 3 in datacube shape --- 'RGB' or 'YCbCr'
    if channel_axis is None:
        if (3 in shape):
            ca = np.flatnonzero(np.asarray(shape) == 3)[0]
        else:
            ca = np.flatnonzero(np.asarray(shape) == 4)
            if len(ca):
                ca = ca[0]
            else:
                raise ValueError("Could not find channel dimension.")
    else:
        ca = channel_axis

    numch = shape[ca]
    if numch not in [3, 4]:
        raise ValueError("Channel axis dimension is not valid.")

    bytedata = bytescale(data, high=high, low=low, cmin=cmin, cmax=cmax)
    if ca == 2:
        strdata = bytedata.tostring()
        shape = (shape[1], shape[0])
    elif ca == 1:
        strdata = np.transpose(bytedata, (0, 2, 1)).tostring()
        shape = (shape[2], shape[0])
    elif ca == 0:
        strdata = np.transpose(bytedata, (1, 2, 0)).tostring()
        shape = (shape[2], shape[1])
    if mode is None:
        if numch == 3:
            mode = 'RGB'
        else:
            mode = 'RGBA'

    if mode not in ['RGB', 'RGBA', 'YCbCr', 'CMYK']:
        raise ValueError(_errstr)

    if mode in ['RGB', 'YCbCr']:
        if numch != 3:
            raise ValueError("Invalid array shape for mode.")
    if mode in ['RGBA', 'CMYK']:
        if numch != 4:
            raise ValueError("Invalid array shape for mode.")

    # Here we know data and mode is correct
    image = Image.frombytes(mode, shape, strdata)
    return image

## Testing Part
### Select any model for testing. I've saved model for each 100 epoch

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
U_Net = UNet()
U_Net.load_state_dict(torch.load(ResultFolder + 'ModelSnapshot.pth'))
U_Net.to(device)   #Loading to CUda

UNet(
  (conv1): UNetConvBlock(
    (UNetConvBlock): Sequential(
      (0): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU()
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): LeakyReLU()
    )
  )
  (conv2): UNetConvBlock(
    (UNetConvBlock): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU()
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): LeakyReLU()
    )
  )
  (conv3): UNetConvBlock(
    (UNetConvBlock): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU()
      (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): LeakyReLU()
    )
  )
  (conv4): UNetConvBlock(
    (UNetConvBlock): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU()
      (2): Conv2d(256, 2

### Now Test The code

In [None]:
with torch.no_grad():
    U_Net.eval()
    for aImage in imageList:
        totalPSNR = 0
        PSNRCNT = 0
        Names = ""
        
        SEimages = glob.glob(ShortExposure + '%05d_00*.ARW' % aImage)
        for k in range(len(SEimages)):
            SEpath = SEimages[k]
            SEname = os.path.basename(SEpath)
            Names = Names+SEname+ "\t"
            #print(SEname)
            print("#", end="")
            LEimages = glob.glob(LongExposure + '%05d_00*.ARW' % aImage)
            LEpath = LEimages[0]
            LEname = os.path.basename(LEpath)
            SEexposure = float(SEname[9:-5])
            LEexposure = float(LEname[9:-5])
            Exposure = min(LEexposure / SEexposure, 300)

            imgRaw = rawpy.imread(SEpath)

            ProcessedIm = imgRaw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)    #Process Image with rawpy package
            NormalizedIm = np.expand_dims(np.float32(ProcessedIm / 65535.0), axis=0)   #Deviding by 16 bit max no value

            LERaw = rawpy.imread(LEpath)
            ProcessedIm = LERaw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            LENormalizedIm = np.expand_dims(np.float32(ProcessedIm / 65535.0), axis=0)  #Deviding by 16 bit max no value

            
            ExpImage = np.expand_dims(rgbg(imgRaw), axis=0) * Exposure
            ExpImage = np.minimum(ExpImage, 1.0)
            ImageIn = torch.from_numpy(ExpImage).permute(0,3,1,2).to(device)
            ImageOut = U_Net(ImageIn)
            final = ImageOut.permute(0, 2, 3, 1).cpu().data.numpy()
            final = np.minimum(np.maximum(final, 0), 1)

            final = final[0, :, :, :]
            LENormalizedIm = LENormalizedIm[0, :, :, :]
            
            PSNR = testPsnr(final, LENormalizedIm)
            totalPSNR = totalPSNR+PSNR
            PSNRCNT = PSNRCNT+1
            
            NormalizedIm = NormalizedIm[0, :, :, :]
            NormalizedIm = NormalizedIm * np.mean(LENormalizedIm) / np.mean(NormalizedIm)  # scale the low-light image to the same mean of the groundtruth

            toimage(final).save(ResultFolder + 'Predicted_Output/%5d_00_%d_out.png' % (aImage, Exposure))
            toimage(NormalizedIm).save(ResultFolder + 'Predicted_Output/%5d_00_%d_scale.png' % (aImage, Exposure))
            toimage(LENormalizedIm).save(ResultFolder + 'Predicted_Output/%5d_00_%d_gt.png' % (aImage, Exposure))
        print(f"\tImage Names:- {Names}\nPSNR of these Predicted Images = {totalPSNR/PSNRCNT}")

##	Image Names:- 10003_00_0.04s.ARW	10003_00_0.1s.ARW	
PSNR of these Predicted Images = 25.092809362742344
##	Image Names:- 10006_00_0.04s.ARW	10006_00_0.1s.ARW	
PSNR of these Predicted Images = 25.661068776958608
##	Image Names:- 10011_00_0.04s.ARW	10011_00_0.1s.ARW	
PSNR of these Predicted Images = 26.718432771155452
#