In [2]:
!pip install rawpy

Collecting rawpy
  Downloading rawpy-0.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.2 kB)
Downloading rawpy-0.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hInstalling collected packages: rawpy
Successfully installed rawpy-0.24.0


In [5]:
import os
import time
import scipy.io
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers, initializers
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
import glob
import rawpy
from PIL import Image

In [6]:
gpus = tf.config.list_physical_devices('GPU')

if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f'TensorFlow is using GPU: {gpus}')
    except RuntimeError as e:
        print(e)
else:
    print('No GPU detected! Running on CPU')

TensorFlow is using GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [17]:
input_dir = '/kaggle/input/aml-sld/Sony/Sony/short/'

# processed images (RAW format to PNG format) are already there

gt_dir = '/kaggle/input/aml-sld/Sony_gt_16bitPNG/gt/'

#

result_dir = '/kaggle/working/results/'

saving_dir = '/kaggle/working/saved_model/'

In [13]:
# getting image IDs of the training images (.png format)

train_fns = glob.glob(gt_dir + '0*.png')

train_ids = []

for train_fn in train_fns:
    full_id =  os.path.basename(train_fn)
    id_number = int(full_id[0:5])
    train_ids.append(id_number)

# os.path.basename(file_dir): strips off the directory part from file_dir (which 
# is the complete file path) and returns just the filename.file_extension

In [None]:
patch_size = 512  # image patch size used for training

# save_freq = 500   # model saving frequency

# function for decoder steps in U-Net model

# effectively implementing a decoder 'step + skip' connection
# x1: input from previous decoder layer (low-resolution, more channels)
# x2: skip-connection from encoder (higher-resolution, fewer channels)

def upsample_and_concatenate(x1, x2, output_channels):
    upsample = layers.Conv2DTranspose(
        output_channels,
        kernel_size = (2, 2),
        strides = (2, 2),
        padding = 'same',
        kernel_initializer = 'he_normal'
        )(x1)
    
    concat = layers.Concatenate(axis = -1)([upsample, x2])
    
    return concat

In [None]:
# building the U-Net

def U_net(input_tensor):
    
    # Encoder layers
    
    conv1 = layers.Conv2D(32, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(input_tensor)   
    conv1 = layers.PReLU(shared_axes = [1, 2])(conv1)
    conv1 = layers.Conv2D(32, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(conv1)   
    conv1 = layers.PReLU(shared_axes = [1, 2])(conv1)
    pool1 = layers.MaxPooling2D((2, 2), padding = 'same')(conv1)
    
    conv2 = layers.Conv2D(64, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(pool1)   
    conv2 = layers.PReLU(shared_axes = [1, 2])(conv2)
    conv2 = layers.Conv2D(64, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(conv2)   
    conv2 = layers.PReLU(shared_axes = [1, 2])(conv2)
    pool2 = layers.MaxPooling2D((2, 2), padding = 'same')(conv2)
    
    conv3 = layers.Conv2D(128, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(pool2)   
    conv3 = layers.PReLU(shared_axes = [1, 2])(conv3)
    conv3 = layers.Conv2D(128, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(conv3)   
    conv3 = layers.PReLU(shared_axes = [1, 2])(conv3)
    pool3 = layers.MaxPooling2D((2, 2), padding = 'same')(conv3)
    
    conv4 = layers.Conv2D(256, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(pool3)   
    conv4 = layers.PReLU(shared_axes = [1, 2])(conv4)
    conv4 = layers.Conv2D(256, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(conv4)   
    conv4 = layers.PReLU(shared_axes = [1, 2])(conv4)
    pool4 = layers.MaxPooling2D((2, 2), padding = 'same')(conv4)
    
    conv5 = layers.Conv2D(512, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(pool4)   
    conv5 = layers.PReLU(shared_axes = [1, 2])(conv5)
    conv5 = layers.Conv2D(512, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(conv5)   
    conv5 = layers.PReLU(shared_axes = [1, 2])(conv5)
    
    # Decoder layers
    
    up6 = upsample_and_concatenate(conv5, conv4, 256)
    conv6 = layers.Conv2D(256, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(up6)
    conv6 = layers.PReLU(shared_axes = [1, 2])(conv6)
    conv6 = layers.Conv2D(256, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(conv6)
    conv6 = layers.PReLU(shared_axes = [1, 2])(conv6)
    
    up7 = upsample_and_concatenate(conv6, conv3, 128)
    conv7 = layers.Conv2D(128, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(up7)
    conv7 = layers.PReLU(shared_axes = [1, 2])(conv7)
    conv7 = layers.Conv2D(128, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(conv7)
    conv7 = layers.PReLU(shared_axes = [1, 2])(conv7)
    
    up8 = upsample_and_concatenate(conv7, conv2, 64)
    conv8 = layers.Conv2D(64, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(up8)
    conv8 = layers.PReLU(shared_axes = [1, 2])(conv8)
    conv8 = layers.Conv2D(64, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(conv8)
    conv8 = layers.PReLU(shared_axes = [1, 2])(conv8)
    
    up9 = upsample_and_concatenate(conv8, conv1, 32)
    conv9 = layers.Conv2D(32, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(up9)
    conv9 = layers.PReLU(shared_axes = [1, 2])(conv9)
    conv9 = layers.Conv2D(32, (3, 3), padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv9 = layers.PReLU(shared_axes = [1, 2])(conv9)
    
    conv10 = layers.Conv2D(12, (1, 1), padding = 'same')(conv9)
    final_output = tf.nn.depth_to_space(conv10, block_size = 2)
    
    model = tf.keras.Model(inputs = input_tensor, outputs = final_output)
    
    return model

In [None]:
# image preprocessing

def pack_raw(raw):
    # packing Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum((im - 512), 0) / (16383 - 512)  # subtracting the black level

    im = np.expand_dims(im, axis = 2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis = 2)
    
    return out

In [None]:

# defining input tensor explicitly

input_tensor = tf.keras.Input(shape = (None, None, 4))  # dynamic height, width, 4 channels

# building the model

model = U_net(input_tensor)  

learning_rate = 1e-4

optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)

loss_fn = tf.keras.losses.MeanSquaredError()

In [None]:
# function for a single epoch!

def train_step(input_tensor_batch, gt_tensor_batch):
    with tf.GradientTape() as tape:
        output = model(input_tensor_batch, training = True)
        loss = loss_fn(gt_tensor_batch, output)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    return loss, output

num_epochs = 50

train_dataset = []

# sample training loop

for epoch in range(num_epochs):
    for input_batch, gt_batch in train_dataset:  # train_dataset yields batches of (input, gt)
        loss = train_step(input_batch, gt_batch)
        print(f'Epoch {epoch}, Loss: {loss.numpy():.4f}')
        
#

gt_images = []

for i in range(6000):
    gt_images.append(None)

#

train_images = {}

_300_list = []
_250_list = []
_100_list = []

for i in range(len(train_ids)):
    _300_list.append(None)
    _250_list.append(None)
    _100_list.append(None)
    
train_images['300'] = _300_list
train_images['250'] = _250_list
train_images['100'] = _100_list
    
g_loss = np.zeros((5000, 1))


# training loop!


for epoch in range(1, 51):
    
    iteration_count = 0
    
    if epoch > 200:    # learning rate scheduling, (here) lowering the learning rate after a certain number of epochs
        learning_rate = 1e-5
        
    # shuffling for randomization
        
    for index in np.random.permutation(len(train_ids)):
        # the path of a specific training image
        
        train_id = train_ids[index]
        
        # pool of 1! randomly selects and returns the path of a single image
        # among all the images of the same scene!!
        
        in_files = glob.glob(input_dir + '%05d_00*.ARW' % train_id)
        in_path = in_files[np.random.rand(0, len(in_files) - 1)]
        in_fn = os.path.basename(in_path)
        
        # does the same for ground truth image!
        
        gt_files = glob.glob(gt_dir + '%05d_00*.png' % train_id)
        gt_path = gt_files[0]
        gt_fn = os.path.basename(gt_path)
        
        in_exposure = float(in_fn[9:-5])    # exposure of input image
        gt_exposure = float(gt_fn[9:-5])    # exposure of ground truth image
        
        ratio = min(gt_exposure / in_exposure, 300)
        
        st = time.time()
        
        iteration_count = iteration_count + 1
        
        if train_images[str(ratio)[0:3]][index] is None:
            raw = rawpy.imread(in_path)
            train_images[str(ratio)[0:3]][index] = np.expand_dims(pack_raw(raw), axis = 0) * ratio
        
        # cropping the training image

        H = train_images[str(ratio)[0:3]][index].shape[1]   # image height
        W = train_images[str(ratio)[0:3]][index].shape[2]   # image width  

        xx = np.random.randint(0, W - patch_size)
        yy = np.random.randint(0, H - patch_size)
        
        # train image patch
        
        train_image_patch = train_images[str(ratio)[0:3]][index][:, yy:yy + patch_size, xx:xx + patch_size, :]
        
        # ground truth image patch
        
        tensor_image = load_img(gt_path)
        gt_image_array = img_to_array(tensor_image)
        
        gt_patch = gt_image_array[:, yy * 2:yy * 2 + patch_size * 2, xx * 2:xx * 2 + patch_size * 2, :]
        
        if np.random.rand() >= 0.5:
            # random flipping along 1st axis
            
            train_image_patch = np.flip(train_image_patch, axis = 1)
            gt_patch = np.flip(gt_patch, axis = 1)
            
        if np.random.rand() >= 0.5:
            # random flipping along 2nd axis
            
            train_image_patch = np.flip(train_image_patch, axis = 2)
            gt_patch = np.flip(gt_patch, axis = 2)
            
        if np.random.rand() >= 0.5:  
            # random matrix transpose operation
            
            train_image_patch = np.transpose(train_image_patch, (0, 2, 1, 3))
            gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))
            
        train_image_patch = np.minimum(train_image_patch, 1.0)
        
        losses = []
        
        start_time = time.time()
        
        loss_value, output = train_step(train_image_patch, gt_patch)
        
        output = tf.clip_by_value(output, 0.0, 1.0)  # equivalent to np.minimum/maximum

        losses.append(loss_value.numpy())
        
        print(f'-> Epoch {epoch} iteration {iteration_count} Loss = {np.mean(losses):.3f} Time taken = {time.time() - start_time:.3f}')
        print('')

        # saving image every 'save_freq' epochs
        
        save_freq = 3
        
        if epoch % save_freq == 0:
            epoch_dir = os.path.join(result_dir, f"{epoch:04d}")
            os.makedirs(epoch_dir, exist_ok = True)
            
            temp = np.concatenate((gt_patch[0].numpy(), output[0].numpy()), axis = 1)  # concatenate along width
            temp = np.clip(temp * 255, 0, 255).astype(np.uint8)
            
            img = Image.fromarray(temp)
            img.save(os.path.join(epoch_dir, f"{train_id:05d}_00_train_{ratio}.jpg"))

        # saving model every 'save_model_freq' epochs
        
        save_model_freq = 25
        
        if epoch % save_model_freq == 0:
            model.save(os.path.join(saving_dir, 'model.keras.illumination'))
            
        
                        