# Main File
## This is the Main file to Train the network

In [1]:
import os
import numpy as np
import time
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
import visdom
import rawpy
import glob
from PIL import Image
import matplotlib.pyplot as plt

## Model - Based on U-Net Architecture

In [2]:
class LeakyReLU(nn.Module):

    def __init__(self):
        super(LeakyReLU, self).__init__()

    def forward(self, x):
        return torch.max(x * 0.2, x)

class UNetConvBlock(nn.Module):

    def __init__(self, in_channel, out_channel):
        super(UNetConvBlock, self).__init__()
        self.UNetConvBlock = torch.nn.Sequential(
            nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1),
            LeakyReLU(),
            nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1),
            LeakyReLU()
        )

    def forward(self, x):
        return self.UNetConvBlock(x)

class UNet(nn.Module):

    def __init__(self):
        super(UNet, self).__init__()
        self.conv1 = UNetConvBlock(4, 32)   #We have 4 Channel (R, G, B G)- Bayer Pattern Input
        self.conv2 = UNetConvBlock(32, 64)
        self.conv3 = UNetConvBlock(64, 128)
        self.conv4 = UNetConvBlock(128, 256)
        self.conv5 = UNetConvBlock(256, 512)
        self.up6 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv6 = UNetConvBlock(512, 256)
        self.up7 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv7 = UNetConvBlock(256, 128)
        self.up8 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv8 = UNetConvBlock(128, 64)
        self.up9 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.conv9 = UNetConvBlock(64, 32)
        self.conv10 = nn.Conv2d(in_channels=32, out_channels=12, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = F.max_pool2d(conv1, kernel_size=2)

        conv2 = self.conv2(pool1)
        pool2 = F.max_pool2d(conv2, kernel_size=2)

        conv3 = self.conv3(pool2)
        pool3 = F.max_pool2d(conv3, kernel_size=2)

        conv4 = self.conv4(pool3)
        pool4 = F.max_pool2d(conv4, kernel_size=2)

        conv5 = self.conv5(pool4)

        up6 = self.up6(conv5)
        up6 = torch.cat([up6, conv4], 1)
        conv6 = self.conv6(up6)

        up7 = self.up7(conv6)
        up7 = torch.cat([up7, conv3], 1)
        conv7 = self.conv7(up7)
        
        up8 = self.up8(conv7)
        up8 = torch.cat([up8, conv2], 1)
        conv8 = self.conv8(up8)

        up9 = self.up9(conv8)
        up9 = torch.cat([up9, conv1], 1)
        conv9 = self.conv9(up9)

        conv10 = self.conv10(conv9)
        out = F.pixel_shuffle(conv10, 2)

        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.02)
                if m.bias is not None:
                    m.bias.data.normal_(0.0, 0.02)
            if isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0.0, 0.02)


### Location of Dataset

In [3]:
ShortExposure = './Sony/short/'	#Training Data
LongExposure = './Sony/long/'	#Referrance Data
ResultFolder = './Results/'			#Save result and model
listImage = glob.glob(LongExposure + '0*.ARW')
imageList = [int(os.path.basename(singleImage)[0:5]) for singleImage in listImage]

PatchSize = 512  						# 512X512 size is considered for training

### Detarming Black Lebel

In [4]:
imgBlack = rawpy.imread('./Sony/short/00001_00_0.04s.ARW')
BlackCh = imgBlack.black_level_per_channel[0]
BlackMax = np.max(imgBlack.raw_image)
print(BlackCh, BlackMax)

512 16383


### Convert Bayer pattern 4 channels R,G,B,G before passing to U-Net

In [5]:
def rgbg(imgRaw):
    img = imgRaw.raw_image_visible.astype(np.float32)
    img = np.maximum(img - BlackCh, 0) / (BlackMax - BlackCh)
    img = np.expand_dims(img, axis=2)
    S0, S1 = img.shape[0], img.shape[1]

    grbgCh = np.concatenate((img[0:S0:2, 0:S1:2, :], img[0:S0:2, 1:S1:2, :], img[1:S0:2, 1:S1:2, :], img[1:S0:2, 0:S1:2, :]), axis=2)
    return grbgCh

In [6]:
def randomTrue():
    isTrue = np.random.randint(2, size=1)[0] == 1
    return isTrue

### Determine PSNR value from two Images

In [7]:
def psnrValue(inp, avgOut):

    totalPsnr = 0 
    Tcnt, Ch, Hig, Wid = inp.shape

    for i in range(Tcnt):
        avgOut[i] = torch.clamp(avgOut[i], min=0.0, max=1.0)
        mse = torch.sum((inp[i] - avgOut[i])**2)/(Ch*Hig*Wid)
        psnr =  -10*torch.log10(mse)
        totalPsnr += psnr

    AvgPsnr = totalPsnr/Tcnt
    return AvgPsnr

### L1 Loss function. 
Can try L2, but L2 is little bit slower and doesn't provide much improvement here.


In [8]:
def CalcLoss(Im1, Im2):
    lossval = torch.mean(torch.abs(Im1 - Im2))
    return lossval

### Allocating Spaces in Memory

In [9]:
LongExp = [None] * len(imageList)    # Allocating spaces for Long Explsure Images. For Training Sony Dataset, it's 161.
ShortExp = {}
ShortExp['300'] = [None] * len(imageList)
ShortExp['250'] = [None] * len(imageList)
ShortExp['100'] = [None] * len(imageList)

GradientLoss = np.zeros((len(imageList), 1))

allfolders = glob.glob(ResultFolder + '*0')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
U_Net = UNet()
U_Net.to(device)
U_Net.train()

l_rate = 1e-4
GradientOutput = optim.Adam(U_Net.parameters(), lr=l_rate)


## Training The Model
Training Learning to See in dark Require 64 GB Ram. But, This program is optimized to run in 32 GB. Pleae don't run anything else. Even in my New pc it's comsuming 31GB Ram

In [14]:
Epoch_Cnt = 2

In [15]:
import warnings
warnings.filterwarnings('ignore')

for epoch in range(Epoch_Cnt):
    
    #Calculating total loss
    etime = time.time()
    eloss = 0
    epsnr = 0
	
    for i in np.random.permutation(len(imageList)):
        # get the path from image id
        ImageId = imageList[i]
        SEimages = glob.glob(ShortExposure + '%05d_00*.ARW' % ImageId)
        SEpath = SEimages[np.random.random_integers(0, len(SEimages) - 1)]
        SEname = os.path.basename(SEpath)

        LEimages = glob.glob(LongExposure + '%05d_00*.ARW' % ImageId)
        LEpath = LEimages[0]
        LEname = os.path.basename(LEpath)
        SEexposure = float(SEname[9:-5])
        LEexposure = float(LEname[9:-5])
        inratio = LEexposure / SEexposure
        Exposure = min(inratio, 300)

        if ShortExp[str(Exposure)[0:3]][i] is None:
            imgRaw = rawpy.imread(SEpath)
            ShortExp[str(Exposure)[0:3]][i] = np.expand_dims(rgbg(imgRaw), axis=0) * Exposure

            LERaw = rawpy.imread(LEpath)
            im = LERaw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            LongExp[i] = np.expand_dims(np.float32(im / 65535.0), axis=0)


        Dim1, Dim2 = ShortExp[str(Exposure)[0:3]][i].shape[1], ShortExp[str(Exposure)[0:3]][i].shape[2]
        Ax1, Ax2 = np.random.randint(0, Dim2 - PatchSize), np.random.randint(0, Dim1 - PatchSize)
        SEpatch = ShortExp[str(Exposure)[0:3]][i][:, Ax2:Ax2 + PatchSize, Ax1:Ax1 + PatchSize, :]
        LEpatch = LongExp[i][:, Ax2 * 2:Ax2 * 2 + PatchSize * 2, Ax1 * 2:Ax1 * 2 + PatchSize * 2, :]
        
        if randomTrue():  # random flip
            SEpatch = np.flip(SEpatch, axis=1)
            LEpatch = np.flip(LEpatch, axis=1)
        if randomTrue():
            SEpatch = np.flip(SEpatch, axis=2)
            LEpatch = np.flip(LEpatch, axis=2)
        if randomTrue():  # random transpose
            SEpatch = np.transpose(SEpatch, (0, 2, 1, 3))
            LEpatch = np.transpose(LEpatch, (0, 2, 1, 3))

        SEpatch, LEpatch = np.minimum(SEpatch, 1.0), np.maximum(LEpatch, 0.0)
        ImageIn = torch.from_numpy(SEpatch).permute(0,3,1,2).to(device)
        LEimageOut = torch.from_numpy(LEpatch).permute(0,3,1,2).to(device)

        GradientOutput.zero_grad()
        ImageOut = U_Net(ImageIn)

        final = ImageOut.permute(0, 2, 3, 1).cpu().data.numpy()
        final = np.minimum(np.maximum(final,0),1)

        loss = CalcLoss(ImageOut, LEimageOut)
        eloss = eloss+loss #Total Loss
        PSNR = psnrValue(ImageOut, LEimageOut)
        epsnr = epsnr+PSNR #Total psnr
        
        loss.backward()
        GradientOutput.step()
        GradientLoss[i] = loss.item()

        print("#", end="")
    
    
    # Saving Snapshot of Model with different name for each 100 epoch
    if np.mod(epoch, 100):
        ModelName = ResultFolder + "ModelSnapshot_"+str(epoch)+"_epoch.pth"
        torch.save(U_Net.state_dict(), ModelName)
        
    # Saving Snapshot of Model with different name for each 100 epoch
    if np.mod(epoch, 5):
        torch.save(U_Net.state_dict(), ResultFolder + 'ModelSnapshot.pth')
    
    #Calculate Average Loss & PSNR
    esize = len(imageList)
    aloss = eloss/esize
    apsnr = epsnr/esize
    
    print(f"\nEpoch = {epoch}. \tLoss = {aloss}, \tPSNR = {apsnr}, \tTime = {time.time() - etime}")

#################################################################################################################################################################
Epoch = 0. 	Loss = 0.08368094265460968, 	PSNR = 20.18464469909668, 	Time = 59.651766300201416
#################################################################################################################################################################
Epoch = 1. 	Loss = 0.06980279833078384, 	PSNR = 21.591625213623047, 	Time = 39.538325548172
