In [1]:
! free -m

              total        used        free      shared  buff/cache   available
Mem:         128830       10871       98477        2665       19480      112514
Swap:         97162        2477       94685


In [2]:
from __future__ import division
import os, time
import numpy as np
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import sys
import os
from optparse import OptionParser
import numpy as np
from torch import optim
from PIL import Image
from torch.autograd import Function, Variable
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset
import cv2
import glob
import pickle
from tqdm import tqdm
import rawpy
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
%matplotlib inline

input_dir = './dataset/Sony/short/'
gt_dir = './dataset/Sony/long/'
checkpoint_dir = './result_Sony/'
result_dir = './result_Sony/'

In [3]:
if torch.cuda.is_available():
    deviceTag = torch.device('cuda')
else:
    deviceTag = torch.device('cpu')
print(deviceTag)

cuda


# LOAD DATASET

In [4]:
# get train IDs
train_fns = glob.glob(gt_dir + '0*.ARW')
train_ids = [int(os.path.basename(train_fn)[0:5]) for train_fn in train_fns]

ps = 512  # patch size for training
save_freq = 500

DEBUG = 0
if DEBUG == 1:
    save_freq = 2
    train_ids = train_ids[0:5]


In [5]:
print(len(train_fns))

161


# UNET MODULES

In [6]:
import torch.nn.functional as F

class conv_lrelu(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(conv_lrelu, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_ch,out_ch,3, padding = 1),nn.LeakyReLU())

    def forward(self, x):
        x = self.conv(x)
        return x
    
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.conv1 = conv_lrelu(in_ch,out_ch)
        self.conv2 = conv_lrelu(out_ch,out_ch)
        self.down =  nn.MaxPool2d((2,2))
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.down(x)
        return x
    

class up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up, self).__init__()       
        self.up =  nn.UpsamplingBilinear2d(scale_factor = 2)
        self.conv1 = conv_lrelu(in_ch,out_ch) 
        self.conv2 = conv_lrelu(out_ch,out_ch) 

    def forward(self, x1, x2):
        x1 = self.up(x1)
        if x1.shape != x2.shape:
            x1 = transforms.functional.resize(x1, x2.shape[2:])
        x = torch.cat([x2, x1], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class luminanceGain(nn.Module):
    def __init__(self, in_ch = 4, CH_PER_SCALE = [32,64,64,64,64]):
        super(luminanceGain, self).__init__()
        self.inc = conv_lrelu(in_ch, CH_PER_SCALE[0])
        self.inc2 = conv_lrelu(CH_PER_SCALE[0], CH_PER_SCALE[0])
        self.down1 = down(CH_PER_SCALE[0], CH_PER_SCALE[1]) 
        self.down2 = down(CH_PER_SCALE[1],CH_PER_SCALE[2])
        self.down3 = down(CH_PER_SCALE[2],CH_PER_SCALE[3])                
        self.down4 = down(CH_PER_SCALE[3],CH_PER_SCALE[4]) 
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(64,1)
    def forward(self, x):
        x = self.inc(x)
        x = self.inc2(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        x = self.pool(x) # This adaptive pool all the 4x4x512 channels to 1x1x512.
        #THis greatly speeds up the training speed of our network
        x = x.view(-1, 64)
        x = F.relu(self.fc1(x))
        out = F.sigmoid(x)*500+30
        return out

class UNet(nn.Module):
    def __init__(self, in_ch = 4, CH_PER_SCALE = [32,64,128,256,512], out_ch = 12):
        super(UNet, self).__init__()
        self.inc = conv_lrelu(in_ch, CH_PER_SCALE[0])
        self.inc2 = conv_lrelu(CH_PER_SCALE[0], CH_PER_SCALE[0])
        self.down1 = down(CH_PER_SCALE[0], CH_PER_SCALE[1]) 
        self.down2 = down(CH_PER_SCALE[1],CH_PER_SCALE[2])
        self.down3 = down(CH_PER_SCALE[2],CH_PER_SCALE[3])                
        self.down4 = down(CH_PER_SCALE[3],CH_PER_SCALE[4])                
        self.up1 = up(CH_PER_SCALE[4]+CH_PER_SCALE[3],CH_PER_SCALE[3])
        self.up2 = up(CH_PER_SCALE[3]+CH_PER_SCALE[2],CH_PER_SCALE[2])
        self.up3 = up(CH_PER_SCALE[2]+CH_PER_SCALE[1],CH_PER_SCALE[1])
        self.up4 = up(CH_PER_SCALE[1]+CH_PER_SCALE[0],CH_PER_SCALE[0])
        self.outc = nn.Conv2d(CH_PER_SCALE[0], out_ch, 1, padding = 0)

    def forward(self, x):
        x0 = self.inc(x)
        x0 = self.inc2(x0)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x3_up = self.up1(x4,x3)
        x2_up = self.up2(x3_up,x2)
        x1_up = self.up3(x2_up,x1)
        out = self.up4(x1_up,x0)
        out = self.outc(out)
#         out = F.pixel_shuffle(out,2) ## Paper final step rearranges 12 channes to 3 RGB channels
#         out = F.hardtanh(out, min_val=0, max_val=1) #Clamp the top and bottom to 0,1 since pixels can only be in this value
        return out
    
class PRIDNet(nn.Module):
    def __init__(self, in_ch = 4, out_ch = 12):
        super(PRIDNet, self).__init__()
        self.gain = luminanceGain()
        self.feature_extraction = nn.Sequential(conv_lrelu(in_ch, 32), *[conv_lrelu(32, 32) for i in range(3)])
        self.unet0 = UNet(in_ch = 32, out_ch = 12)
        self.unet1 = UNet(in_ch = 32, out_ch = 12)
        self.unet2 = UNet(in_ch = 32, out_ch = 12)
        self.unet3 = UNet(in_ch = 32, out_ch = 12)
        self.unet4 = UNet(in_ch = 32, out_ch = 12)
        self.avgpool1 = nn.AvgPool2d((2,2))
        self.avgpool2 = nn.AvgPool2d((4,4))
        self.avgpool3 = nn.AvgPool2d((8,8))
        self.avgpool4 = nn.AvgPool2d((16,16))
        self.up4 =  nn.UpsamplingBilinear2d(scale_factor = 16)
        self.up3 =  nn.UpsamplingBilinear2d(scale_factor = 8)
        self.up2 =  nn.UpsamplingBilinear2d(scale_factor = 4)
        self.up1 =  nn.UpsamplingBilinear2d(scale_factor = 2)
        self.out =  nn.Conv2d(32+12*5, out_ch, 1, padding = 0)

    def forward(self, x):
        amp_factor = self.gain(x)
       # print(x.shape)
        #print(amp_factor.shape)
        x = amp_factor[0,0]*x
        #x[1] = amp_factor[1,0]*x[1];
        #x[2] = amp_factor[2,0]*x[2];
        #x[3] = amp_factor[3,0]*x[3];
        #x[4] = amp_factor[4,0]*x[4];
        #x[5] = amp_factor[5,0]*x[5];
        #x[6] = amp_factor[6,0]*x[6];
        #x[7] = amp_factor[7,0]*x[7];

        x_feat = self.feature_extraction(x)
        x0 = self.unet0(x_feat)
        x1 = self.up1(self.unet1(self.avgpool1(x_feat)))
        x2 = self.up2(self.unet2(self.avgpool2(x_feat)))
        x3 = self.up3(self.unet3(self.avgpool3(x_feat)))
        x4 = self.up4(self.unet4(self.avgpool4(x_feat)))
        x_unet_all = torch.cat([x_feat,x0,x1,x2,x3,x4], axis = 1)
        out = self.out(x_unet_all)
        
        out = F.pixel_shuffle(out,2) ## Paper final step rearranges 12 channes to 3 RGB channels
        out = F.hardtanh(out, min_val=0, max_val=1) #Clamp the top and bottom to 0,1 since pixels can only be in this value
        return out
    
    def load_my_state_dict(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            #if isinstance(param, self.Parameter):
            else:
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)

# Helper Functions for packing raw and saving images

In [7]:
def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out

In [8]:
# Raw data takes long time to load. Keep them in memory after loaded.
gt_images = [None] * 6000
input_images = {}
input_images['300'] = [None] * len(train_ids)
input_images['250'] = [None] * len(train_ids)
input_images['100'] = [None] * len(train_ids)

g_loss = np.zeros((5000, 1))

allfolders = glob.glob(result_dir + '*0')

for folder in allfolders:
    lastepoch = np.maximum(epochs, int(folder[-4:]))

# Training 

In [9]:
def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out

def process_img(input_raw_img, model):
    ## Process image(s) using the given model
    # input_raw_img: numpy array, dimension: (Batch,Height,Width,Channel)
    # ratio: numpy array, dimension: (Batch,)
    model.eval();
    model.to(deviceTag)
    #ratio = ratio.reshape(ratio.shape[0],1,1,1)
    input_raw_img = np.transpose(input_raw_img, [0,3,1,2]).astype('float32')#*ratio
    input_tensor = torch.from_numpy(input_raw_img.copy()).float().to(deviceTag)
    with torch.no_grad():
        output_tensor = model(input_tensor)
    output_img = output_tensor.cpu().numpy()
    output_img = np.transpose(output_img, [0,2,3,1])
    
    return output_img

def ssim_numpy(img1, img2):
    img1, img2 = torch.tensor(np.transpose(img1,(2,0,1))).unsqueeze(0), torch.tensor(np.transpose(img2,(2,0,1))).unsqueeze(0)
    ssim_calc = ssim(img1,img2, data_range=1, size_average=True)
    msssim_calc = ms_ssim(img1,img2, data_range=1, size_average=True)
    return ssim_calc.numpy(), msssim_calc.numpy()
    
def validate(model, input_list, gt_list, block_size = None, batch_size = 8, save_img_dir = None):
    assert len(input_list) == len(gt_list)
    
    model.eval();
    PSNR_list = []
    SSIM_list = []
    MSSSIM_list = []
    
    for i in range(len(input_list)//batch_size):
        if i%10 == 0:
            print(i)
        input_raw_img_batch = []
        gt_img_batch = []
        ratio_batch = []
        for b in range(batch_size):
            if i*batch_size+b < len(input_list):
                in_path = input_list[i*batch_size+b]
                gt_path = gt_list[i*batch_size+b]
            else:
                break
            in_fn = os.path.basename(in_path)
            gt_fn = os.path.basename(gt_path)
            in_exposure = float(in_fn[9:-5])
            gt_exposure = float(gt_fn[9:-5])
            #ratio = min(gt_exposure / in_exposure, 300)
        
            raw = rawpy.imread(in_path)
            input_raw_img = pack_raw(raw)
            
            gt_raw = rawpy.imread(gt_path)
            gt_img = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt_img = np.float32(gt_img / 65535.0)
            
            if block_size is not None:
                i_cut, j_cut = np.random.randint(0,input_raw_img.shape[0]-block_size), np.random.randint(0,input_raw_img.shape[1]-block_size)
                gt_img = gt_img[i_cut*2:i_cut*2+block_size*2, j_cut*2:j_cut*2+block_size*2, :]
                input_raw_img = input_raw_img[i_cut:i_cut+block_size, j_cut:j_cut+block_size, :]
            
            #ratio_batch.append(ratio)
            input_raw_img_batch.append(input_raw_img)
            gt_img_batch.append(gt_img)
        
        input_raw_img_batch = np.array(input_raw_img_batch)
        #ratio_batch = np.array(ratio_batch)
        gt_img_batch = np.array(gt_img_batch)
        
        output_img_batch = process_img(input_raw_img_batch, model)
        if save_img_dir is not None and i < 20:
            plt.imsave(save_img_dir+'{}_gt.png'.format(i),gt_img_batch[0,:,:,:])
            plt.imsave(save_img_dir+'{}_out.png'.format(i),output_img_batch[0,:,:,:])
        MSE = np.mean((output_img_batch.reshape(output_img_batch.shape[0],-1) - gt_img_batch.reshape(gt_img_batch.shape[0],-1))**2, axis = 1)
        PSNR_batch = 10*np.log10(1/MSE)
        PSNR_list.append(list(PSNR_batch))
        ssim_calc, msssim_calc = ssim_numpy(gt_img_batch[0,:,:,:], output_img_batch[0,:,:,:])
        SSIM_list.append(ssim_calc)
        MSSSIM_list.append(msssim_calc)
    
    Val_PSNR = np.mean(PSNR_list)
    Val_SSIM = np.mean(SSIM_list)
    Val_MSSSIM = np.mean(MSSSIM_list)
    return Val_PSNR, Val_SSIM, Val_MSSSIM

In [10]:
#PATH_A = './results_Sony/Imp_Models/Pyramid_network.pth'
PATH_B = './results_Sony/sony{}.pth'.format(999)
model_Avg = torch.load(PATH_B,deviceTag)
#n_models = 10
#
#for i in range(n_models-1):
#    PATH_B = './results_Sony/sony{}.pth'.format(3829-i)
#    model_load = torch.load(PATH_B,deviceTag)
#    for key in model_Avg:
#        model_Avg[key] = model_Avg[key] + model_load[key]
#for key in model_Avg:
#    model_Avg[key] = model_Avg[key]/n_models
# model.train();

In [11]:
model = PRIDNet().to(deviceTag)
model.load_my_state_dict(model_Avg) ## Model Amplification 
#Print the number of parameters good for characterizing the size
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params);

39594393


In [15]:
Start_epoch = 0
epochs = 1
learning_rate = 1e-5
model_save_path = './results_Sony/'  # directory to same the model after each epoch. 
batch_num = 1;
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
# note that although we want to use DICE for evaluation, we use BCELoss for training in this example
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1) #Step Scheduler. 
TrainingLossData = np.zeros(epochs)
# The loss function we use is binary cross entropy: nn.BCELoss()
#criterion = nn.L1Loss()
criterion = nn.MSELoss()
# note that although we want to use DICE for evaluation, we use BCELoss for training in this example
trainF= open("./epoch_loss/TrainingEpoch.txt","w+")



for epoch in range(Start_epoch, Start_epoch+epochs):
    print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
    epoch_loss = 0 ## Set Epoch Loss
    count = 0;
    batches_processed = 0
    ##This version has a batch size of 1. In the future conside increasing batchsize
    for ind in np.random.permutation(len(train_ids)):    
        # get the path from image id
        train_id = train_ids[ind]
        in_files = glob.glob(input_dir + '%05d_00*.ARW' % train_id)
        in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
        in_fn = os.path.basename(in_path)

        gt_files = glob.glob(gt_dir + '%05d_00*.ARW' % train_id)
        gt_path = gt_files[0]
        gt_fn = os.path.basename(gt_path)
        in_exposure = float(in_fn[9:-5])
        gt_exposure = float(gt_fn[9:-5])
        ratio = min(gt_exposure / in_exposure, 300)

        st = time.time()

        if input_images[str(ratio)[0:3]][ind] is None:
            raw = rawpy.imread(in_path)
            input_images[str(ratio)[0:3]][ind] = np.expand_dims(pack_raw(raw), axis=0)# * ratio
#
            gt_raw = rawpy.imread(gt_path)
            im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt_images[ind] = np.expand_dims(np.float32(im / 65535.0), axis=0)

        # crop
        H = input_images[str(ratio)[0:3]][ind].shape[1]
        W = input_images[str(ratio)[0:3]][ind].shape[2]

        xx = np.random.randint(0, W - ps)
        yy = np.random.randint(0, H - ps)
        input_patch = input_images[str(ratio)[0:3]][ind][:, yy:yy + ps, xx:xx + ps, :]
        gt_patch = gt_images[ind][:, yy * 2:yy * 2 + ps * 2, xx * 2:xx * 2 + ps * 2, :]

        if np.random.randint(2, size=1)[0] == 1:  # random flip
            input_patch = np.flip(input_patch, axis=1)
            gt_patch = np.flip(gt_patch, axis=1)
        if np.random.randint(2, size=1)[0] == 1:
            input_patch = np.flip(input_patch, axis=2)
            gt_patch = np.flip(gt_patch, axis=2)
        if np.random.randint(2, size=1)[0] == 1:  # random transpose
            input_patch = np.transpose(input_patch, (0, 2, 1, 3))
            gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))
        #(1, 512, 512, 4)
        #(1, 1024, 1024, 3)
        input_patch = np.transpose(input_patch, (0,3,1,2))
        input_patch = torch.from_numpy(input_patch.copy()).cuda()
        gt_patch = np.transpose(gt_patch, (0,3,1,2))
        gt_patch = torch.from_numpy(gt_patch.copy()).cuda()
        ##Batch concatenation
        if count%(batch_num)==0:
            input_patch_all = input_patch
            gt_patch_all = gt_patch
        else:
            input_patch_all = torch.cat([input_patch_all, input_patch], dim=0)
            gt_patch_all = torch.cat([gt_patch_all, gt_patch], dim=0)
        ##Every N batches we ship it back 
        if count%(batch_num)==batch_num-1:
            #print(input_patch_all.shape)
            img_pred = model.forward(input_patch_all)
            loss = criterion(img_pred, gt_patch_all)
            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batches_processed += 1
            ##print(epoch_loss/count);
        count = count +1
    scheduler.step()
    print('Epoch finished ! Loss: {}'.format(epoch_loss / batches_processed))
    trainF.write('Epoch finished ! Loss: {}\r\n'.format(epoch_loss / batches_processed))
    TrainingLossData[epoch] = epoch_loss / batches_processed ## Save for plotting
    ################################################ [TODO] ###################################################
    # Perform validation with eval_net() on the validation data
    # Save the model after every 10 epochs. This save our Memory on HPC.
    ##Save Top results after 95%
    if epoch > epochs*0.95:
        if os.path.isdir(model_save_path):
            torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))
        else:
            os.makedirs(model_save_path, exist_ok=True)
            torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))
        print('Checkpoint {} saved !'.format(epoch + 1))
    if epoch%99 == 0:
        if os.path.isdir(model_save_path):
            torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))
        else:
            os.makedirs(model_save_path, exist_ok=True)
            torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))
        print('Checkpoint {} saved !'.format(epoch + 1))

trainF.close() ## Close your files to write to them

Starting epoch 1/1.


  in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
  in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
  in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]


Epoch finished ! Loss: 0.0023224143762761865
Checkpoint 1 saved !


In [None]:
if os.path.isdir(model_save_path):
            torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))
else:
    os.makedirs(model_save_path, exist_ok=True)
    torch.save(model.state_dict(),model_save_path + 'sony{}.pth'.format(epoch + 1))

# Code to display 10 images.

In [12]:
with open('./dataset/Sony_val_raw_list.pickle','rb') as f:
    val_raw_list = pickle.load(f)
val_raw_list = ['./dataset/'+path for path in val_raw_list]
with open('./dataset/Sony_val_gt_list.pickle','rb') as f:
    val_gt_list = pickle.load(f)
val_gt_list = ['./dataset/'+path for path in val_gt_list]

## Leave save_img_dir as None, if you don't want to save the result images
Val_PSNR, Val_SSIM, Val_MSSSIM = validate(model, val_raw_list, val_gt_list,block_size = None, batch_size = 1, 
                                          save_img_dir = './images_AutoLG/')

0




10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230


In [13]:
print(Val_PSNR)
print(Val_SSIM)
print(Val_MSSSIM)

23.030909
0.6923857
0.8146631
