In [None]:
import numpy as np
try:
    import rawpy
except ModuleNotFoundError:
    !pip3 install rawpy
    import rawpy

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

In [None]:
input_dir = '/content/drive/My Drive/ImageDataset/Sony/Sony/short/'
gt_dir = '/content/drive/My Drive/ImageDataset/Sony/Sony/long/'
checkpoint_dir = '/content/drive/My Drive/CheckPoint3/'
result_dir = '/content/drive/My Drive/Results3/'

In [None]:
import glob
import os

In [None]:
if not os.path.isdir(checkpoint_dir):
    os.makedirs(checkpoint_dir)
if not os.path.isdir(result_dir):
    os.makedirs(result_dir)

In [None]:
gt_dir_2 = '/content/drive/My Drive/ImageDataset/Sony/Sony_gt_16bitPNG'

In [None]:
train_fns = glob.glob(gt_dir + "0*.ARW")
train_ids = [int(os.path.basename(train_fn)[0:5]) for train_fn in train_fns]

In [None]:
ps = 512
save_freq = 10

In [None]:
d =10
a=1
if d:
    a=2
a

2

In [None]:
DEBUG = 0
if DEBUG:
    save_freq = 2
    train_ids = np.random.choice(train_ids, 161)
    # train_ids = train_ids[0:100]

In [None]:
train_ids[80:88]

[110, 112, 113, 114, 117, 118, 119, 121]

In [None]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.LeakyReLU(inplace=True),
        nn.BatchNorm2d(out_channels),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.LeakyReLU(inplace=True),
        nn.BatchNorm2d(out_channels)
    )

In [None]:
def expansive_block(in_channels, mid_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=mid_channels, padding=1),
        nn.LeakyReLU(inplace=True),
        nn.BatchNorm2d(mid_channels),
        nn.Conv2d(kernel_size=3, in_channels=mid_channels, out_channels=mid_channels, padding=1),
        nn.LeakyReLU(inplace=True),
        nn.BatchNorm2d(mid_channels),
        nn.ConvTranspose2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
    )

In [None]:
def final_block(in_channels, mid_channels, out_channels, kernel_size=3):
    return nn.Sequential(
            nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channels, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.BatchNorm2d(mid_channels),
            nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channels, out_channels=mid_channels, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.BatchNorm2d(mid_channels),
            nn.Conv2d(kernel_size=1, in_channels=mid_channels, out_channels=out_channels, padding=0),
            )

In [None]:
class DepthToSpace(nn.Module):

    def __init__(self, block_size):
        super().__init__()
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W)  # (N, bs, bs, C//bs^2, H, W)
        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # (N, C//bs^2, H, bs, W, bs)
        x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs)  # (N, C//bs^2, H * bs, W * bs)
        # print("In DS shape "+str(x.shape))
        return x

In [None]:
def pack_raw(raw):
    #packing bayer sensor image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0)/(16383 - 512) #subtracting the black level
    # print(im.shape)
    im = np.expand_dims(im, axis=0)
    img_shape = im.shape
    # print(img_shape)
    H = img_shape[1]
    W = img_shape[2]
    # The "channels" channel might be needed to be moved ahead(in that case axis = 0)
    out = np.concatenate((im[:, 0:H:2, 0:W:2],
                          im[:, 0:H:2, 1:W:2],
                          im[:,1:H:2, 1:W:2],
                          im[:, 1:H:2, 0:W:2],
                          ), axis=0)
    return out

In [None]:
# gt_images = [None] * 6000
# input_images = {}
# input_images['300'] = [None] * len(train_ids)
# input_images['250'] = [None] * len(train_ids)
# input_images['100'] = [None] * len(train_ids)

In [None]:
g_loss = np.zeros((5000, 1))

In [None]:
allfolders = glob.glob(result_dir + '*0')

In [None]:
lr = 1e-6

In [None]:
import time

In [None]:
class UNet(nn.Module):
    def __init__(self, n_out_channels):
        super(UNet, self).__init__()
        self.dconv_down1 = double_conv(4, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)
        self.dconv_down4 = double_conv(128, 256)
        self.dconv_down5 = double_conv(256, 512)
        self.maxpool = nn.MaxPool2d(2)
        self.conv2d_t = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dconv_up1 = expansive_block(in_channels = 512, mid_channels = 256, out_channels = 128)
        self.dconv_up2 = expansive_block(in_channels = 256, mid_channels = 128, out_channels = 64)
        self.dconv_up3 = expansive_block(in_channels = 128, mid_channels = 64, out_channels = 32)
        self.final_layer = final_block(in_channels = 64, mid_channels = 32, out_channels = n_out_channels)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        # print("Con1= "+str(conv1.shape))
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        # print("Con2= "+str(conv2.shape))
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        # print("Con3 = "+ str(conv3.shape))
        conv4 = self.dconv_down4(x)
        x = self.maxpool(conv4)
        # print("Con4= "+str(conv4.shape))
        conv5 = self.dconv_down5(x)
        # print("Con5= "+str(conv5.shape))
        # First upsampling
        x = self.conv2d_t(conv5)
        # print("First upcon+conv4 = "+str(x.shape))
        #concatenation and up_conv(Includes 2 conv+up_conv)
       
        x = torch.cat([x, conv4], dim = 1)
        #256+256 channels as ip
        x = self.dconv_up1(x)
        # print("Second upcon+conv3 = "+str(x.shape))
        #has 128 channels op
        x = torch.cat([x, conv3], dim = 1)
        #128+128 channels as ip
        x = self.dconv_up2(x)
        # has 64 channels op
        # print("third upcon+conv2 = "+str(x.shape))
        x = torch.cat([x, conv2], dim = 1)
        #64+64 channels as ip
        x = self.dconv_up3(x)
        # has 32 channels op
        # print("fourth upcon+conv1 = "+str(x.shape))
        x = torch.cat([x, conv1], dim = 1)
        # has 32+32 channels
        x = self.final_layer(x)
        # print("pre final shape = ", str(x.shape))
        x = DepthToSpace(2)(x)
        # print("X shape = "+str(x.shape))
        return x

Load the latest saved model for further processing.

In [None]:
dirs = [checkpoint_dir+str(d) for d in os.listdir(checkpoint_dir)]
dirs  = sorted(dirs, reverse=True)
# dirs = sorted(dirs, key=lambda x: os.path.getctime(x), reverse=True)[:1]

In [None]:
dirs

['/content/drive/My Drive/CheckPoint3/0010-model.ckpt']

In [None]:
if dirs == []:
    train_epoch = None
else:
    train_epoch = int(dirs[0][-15:-11])

In [None]:
train_epoch

10

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
net = UNet(12)
net = net.to(device)
optimizer = optim.Adam(net.parameters(), lr = 1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
if train_epoch != None:
    checkpoint = torch.load(checkpoint_dir+'/%04d-model.ckpt'%train_epoch)
    net.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    train_epoch = checkpoint['epoch']
    loss = checkpoint['loss']

In [None]:
loss

tensor(0.1735, device='cuda:0', requires_grad=True)

In [None]:
for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)

In [None]:
if train_epoch == None:
    train_epoch = -1

In [None]:
def abs_L1_loss(output, target):
    G_loss = torch.mean(torch.abs(output - target))
    return G_loss

In [None]:
len(train_ids)

161

In [None]:
import cv2
from PIL import Image

In [None]:
for epoch in range(train_epoch+1, 50+train_epoch+1):
    print(epoch)
    ind_list = []
    # print(epoch)
    if(os.path.isfile(checkpoint_dir+"%04d"%epoch)):
        lo
        continue
    cnt = 0
    for ind in np.random.permutation(len(train_ids)):
        # print("INd " + str(ind))
        train_id = train_ids[ind]
        in_files = glob.glob(input_dir + "%05d_00*.ARW" %train_id)
        # print("len "+str(len(in_files)))
        if(len(in_files) == 1):
            idx = 0
        else:
            idx = np.random.randint(0, len(in_files)-1)
        in_path = in_files[idx]
        in_fn = os.path.basename(in_path)

        gt_files = glob.glob(gt_dir + "%05d_00*.ARW" %train_id)
        gt_path = gt_files[0]
        gt_fn = os.path.basename(gt_path)

        in_exposure = float(in_fn[9:-5])
        gt_exposure = float(gt_fn[9:-5])
        ratio = min(gt_exposure / in_exposure, 300) #300 can be varied

        st = time.time()
        cnt += 1

        # if gt_images[ind] is None:
        if ind not in ind_list:
            ind_list.append(ind)
            raw = rawpy.imread(in_path)
            ip_raw = np.expand_dims(pack_raw(raw), axis = 0) * ratio
            ## memory consumption is extreme
            # input_images[str(ratio)[0:3]][ind] = np.expand_dims(pack_raw(raw), axis = 0) * ratio
            ## FOr full size
            
            gt_raw = rawpy.imread(gt_path)
            im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt_image = np.expand_dims(np.float32(im / 65535.0), axis=0)
            
            ## For preprocessed images
            
            # im = cv2.imread(gt_path)
            # gt_image = np.expand_dims(np.float32(im / 65535.0), axis=0)
            
            ## print("Shape "+ str(gt_images[ind].shape))
        
        # H = input_images[str(ratio)[0:3]][ind].shape[2]
        # W = input_images[str(ratio)[0:3]][ind].shape[3]
        H = ip_raw.shape[2]
        W = ip_raw.shape[3]
        ## print("H "+str(H))
        ## print("W "+str(W))
        ## TO extract random patch of the image
        xx = np.random.randint(0, W - ps)
        yy = np.random.randint(0, H - ps)
        xx_ = xx*2
        yy_ = yy*2
        ps_ = ps*2
        # print("xx_ "+str(xx_))
        # print("yy_ "+str(yy_))
        # print("ps_ "+str(ps_))
        # input_patch = input_images[str(ratio)[0:3]][ind][:, :,yy:yy+ps, xx:xx+ps]
        input_patch = ip_raw[:, :,yy:yy+ps, xx:xx+ps]
        gt_patch = gt_image[:, yy_:yy_ + ps_, xx_:xx_ + ps_, :]  #not sure
        # gt_patch = gt_images[ind][:, yy:yy+ps, xx:xx+ps, :]
        # print("gt image shape "+str(gt_images[ind].shape))
        # print("gt_patch drawn np="+str(gt_patch.shape))
        gt_patch = np.transpose(gt_patch, (0, 3, 1, 2))
        # print("gt_patch np="+str(gt_patch.shape))
        #manual data augmentation for the patch
        if np.random.randint(2, size=1)[0] == 1: #random flip
            input_patch = np.flip(input_patch, axis = 2)
            gt_patch = np.flip(gt_patch, axis = 2)
        if np.random.randint(2, size=1)[0] == 1: #vertical flip
            input_patch = np.flip(input_patch, axis = 3)
            gt_patch = np.flip(gt_patch, axis = 3)
        if np.random.randint(2, size=1)[0] == 1: #random transpse
            input_patch = np.transpose(input_patch, (0, 1, 3, 2))
            gt_patch = np.transpose(gt_patch, (0, 1, 3, 2))
        
        input_patch = np.minimum(input_patch, 1.0)
        # if torch.cuda.
        input_patch =torch.tensor(input_patch, device=device).float()
        gt_patch = np.ascontiguousarray(gt_patch)
        gt_patch = torch.from_numpy(gt_patch).float().to(device)
        # print("gtpatch =" +str(gt_patch.shape))
        # print("Ip patch shape = "+str(input_patch.shape))
        optimizer.zero_grad()
        out = net(input_patch)
        scheduler.step()
        # print(type(out))
        # print("Op shape = " + str(out.shape))
        loss = abs_L1_loss(out, gt_patch)
        loss.backward()
        optimizer.step()
        out = out.detach().cpu().numpy()
        out = np.minimum(np.maximum(out, 0), 1)
        # print("Out shape ="+str(out.shape))
        g_loss[ind] = loss.detach().cpu()

        print("%d %d Loss=%.5f Time=%.3f" % (epoch, cnt, np.mean(g_loss[np.where(g_loss)]), time.time() - st))
        if epoch % save_freq == 0:
            if not os.path.isdir(result_dir + '/%04d' % epoch):
                os.makedirs(result_dir + '/%04d' % epoch)
            gt_patch = gt_patch.cpu().numpy()
            # out = out.detach().cpu().numpy()
            # print("GTS = "+str(gt_patch.shape))
            # print("OPS = "+str(out.shape))
            temp = np.concatenate((gt_patch[0, :, :, :], out[0, :, :, :]), axis=2)
            # print("TP = "+str(temp.shape))
            temp = np.transpose(temp, (1, 2, 0))
            temp = np.reshape(temp, (temp.shape[0], temp.shape[1], 3))
            Image.fromarray((temp * 255).astype(np.uint8)).save(
                result_dir + '/%04d/%05d_00_train_%d.jpg' % (epoch, train_id, ratio))
        
    torch.save({
        'epoch': epoch,
        'model_state_dict': net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
        }, checkpoint_dir+'/%04d-model.ckpt' %epoch)
    # if epoch % 10 == 0:
    #     train_ids = np.random.choice(train_ids, 80)   
        # if cnt % 50 == 0:
        #     torch.cuda.clear_cache()
        # running_loss += loss.item()
    

11




11 1 Loss=0.27304 Time=1.787
11 2 Loss=0.16755 Time=2.789
11 3 Loss=0.14418 Time=1.681
11 4 Loss=0.13681 Time=3.016
11 5 Loss=0.12381 Time=3.122
11 6 Loss=0.13835 Time=1.639
11 7 Loss=0.14443 Time=3.064
11 8 Loss=0.13840 Time=2.914
11 9 Loss=0.14754 Time=2.725
11 10 Loss=0.14364 Time=3.047
11 11 Loss=0.13723 Time=1.621
11 12 Loss=0.12994 Time=1.677
11 13 Loss=0.12302 Time=2.822
11 14 Loss=0.12197 Time=2.814
11 15 Loss=0.11913 Time=3.196
11 16 Loss=0.11483 Time=2.673
11 17 Loss=0.11277 Time=2.729
11 18 Loss=0.11324 Time=1.774
11 19 Loss=0.11192 Time=2.734
11 20 Loss=0.12160 Time=1.701
11 21 Loss=0.11906 Time=2.268
11 22 Loss=0.11639 Time=3.134
11 23 Loss=0.11571 Time=2.231
11 24 Loss=0.11895 Time=2.796
11 25 Loss=0.12114 Time=1.681
11 26 Loss=0.12068 Time=2.845
11 27 Loss=0.12002 Time=1.713
11 28 Loss=0.11906 Time=2.868
11 29 Loss=0.11930 Time=3.065
11 30 Loss=0.12100 Time=2.915
11 31 Loss=0.12460 Time=2.088
11 32 Loss=0.12305 Time=2.241
11 33 Loss=0.12319 Time=2.906
11 34 Loss=0.12222 