#run if using Colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd '/content/drive/MyDrive/fun deep learning/image-enhancement'

/content/drive/MyDrive/fun deep learning/image-enhancement


In [None]:
! pip install -r requirements.txt

Collecting opencv_python==4.5.5.64
  Downloading opencv_python-4.5.5.64-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (60.5 MB)
[K     |████████████████████████████████| 60.5 MB 1.4 MB/s 
[?25hCollecting pillow_heif
  Downloading pillow_heif-0.1.11-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.6 MB)
[K     |████████████████████████████████| 10.6 MB 34.8 MB/s 
Collecting sk_video
  Downloading sk_video-1.1.10-py2.py3-none-any.whl (2.3 MB)
[K     |████████████████████████████████| 2.3 MB 26.2 MB/s 
Installing collected packages: sk-video, pillow-heif, opencv-python
  Attempting uninstall: opencv-python
    Found existing installation: opencv-python 4.1.2.30
    Uninstalling opencv-python-4.1.2.30:
      Successfully uninstalled opencv-python-4.1.2.30
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albumentations 0.1.12 requires

In [None]:
cd '/content/drive/MyDrive/fun deep learning/image-enhancement/low-light/RetinexDIP'

/content/drive/MyDrive/fun deep learning/image-enhancement/low-light/RetinexDIP


#Main code

In [None]:
import sys
sys.path.append('../../')
import general_utils as gu
import img_utils as iu

import importlib
importlib.reload(iu)

<module 'img_utils' from '../../img_utils.py'>

In [None]:
from collections import namedtuple
from net import *
from net.downsampler import *
from net.losses import StdLoss, GradientLoss, ExtendedL1Loss, GrayLoss
from net.losses import ExclusionLoss, TVLoss
from net.noise import get_noise
from PIL import Image, ImageFont, ImageDraw
from pillow_heif import register_heif_opener
register_heif_opener()

import numpy as np
import math
import torch
import cv2
from torchvision import transforms
import time
from pathlib import Path
from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(0)

def scale_fill(im, target_width, target_height, return_offset=False):
    '''
    Resize PIL image keeping ratio and adding zero-valued background.
    '''
    target_ratio = target_height / target_width
    im_ratio = im.height / im.width
    if target_ratio > im_ratio:
        # It must be fixed by width
        resize_width = target_width
        resize_height = round(resize_width * im_ratio)
    else:
        # Fixed by height
        resize_height = target_height
        resize_width = round(resize_height / im_ratio)

    image_resize = im.resize((resize_width, resize_height), Image.ANTIALIAS)
    background = Image.new('RGBA', (target_width, target_height), 0)
    offset = (round((target_width - resize_width) / 2), round((target_height - resize_height) / 2))
    background.paste(image_resize, offset)

    if return_offset:
        return background.convert('RGB'), offset, image_resize.size 

    return background.convert('RGB')
def append_img(im1, im2, new_size, labels=('im1, im2')):
    _, _, new_size = scale_fill(im1, new_size, new_size, True)
    im1 = im1.resize(new_size)
    im2 = im2.resize(new_size)
    w, h = im1.size
    background = Image.new('RGBA', (w*2, h), 0)
    background.paste(im1, (0,0))
    background.paste(im2, (w, 0))
    im = background.convert('RGB')
    ImageDraw.Draw(im).text((0,0), labels[0], (100,100,100))
    ImageDraw.Draw(im).text((w,0), labels[1], (100,100,100))
    return im

def reverse_scale_fill(im, original_size, offset, image_resize_size):
    '''
    original_size: (W, H)
    image_resize_size: (W, H) : size of the original image in the scalefill image 
    offset: (W_dim, H_dim)
    Example: scalefill from 720, 960 to 960, 960 will have offset=120,0 and image_resize_size=720,960
    '''
    box = offset[0], offset[1], offset[0]+image_resize_size[0], offset[1]+image_resize_size[1]
    return im.crop(box).resize(original_size)

EnhancementResult = namedtuple("EnhancementResult", ['reflection', 'illumination'])

class Enhancement(object):
    def __init__(self, image_name, image, target_size=960, plot_during_training=False, show_every=10, num_iter=300):
        self.original_image = image
        self.size = image.size # (height, width)
        self.target_size = target_size
        self.image_np = None
        self.images_torch = None
        self.plot_during_training = plot_during_training
        if plot_during_training: 
            self.output_path = gu.folder('output/')
        # self.ratio = ratio
        self.psnrs = []
        self.show_every = show_every
        self.image_name = image_name
        self.num_iter = num_iter
        self.loss_function = None
        # self.ratio_net = None
        self.parameters = None
        self.learning_rate = 0.01
        self.input_depth = 3  # This value could affect the performance. 3 is ok for natural image, if your
                            #images are extremely dark, you may consider 8 for the value.
        self.data_type = torch.FloatTensor
        self.reflection_net_inputs = None
        self.illumination_net_inputs = None
        self.original_illumination = None
        self.original_reflection = None
        self.reflection_net = None
        self.illumination_net = None
        self.total_loss = None
        self.reflection_out = None
        self.illumination_out = None
        self.current_result = None
        self.best_result = None
        self._init_all()

        print('Original image size:', self.size)

    def _init_all(self):
        self._init_images()
        self._init_decomposition()
        self._init_nets()
        self._init_inputs()
        self._init_parameters()
        self._init_losses()


    def _maxRGB(self):
        '''
        self.image: pil image, input low-light image
        :return: np, initial illumnation
        '''
        (R, G, B) = self.image.split()
        I_0 = np.array(np.maximum(np.maximum(R, G), B))
        return I_0

    def _init_decomposition(self):
        temp = self._maxRGB() # numpy
        # get initial illumination map
        #IMPORTANT: min clip value is a hyper param
        self.original_illumination = np.clip(np.asarray([temp for _ in range(3)]), 2, 255) #range [0,255]
        # get initial reflection
        self.original_reflection = self.image_np / self.original_illumination #range [0, 255] / [0,255] = [0, 1]

        self.original_illumination = np_to_torch(self.original_illumination).type(self.data_type).to(device)
        self.original_reflection = np_to_torch(np.asarray(self.original_reflection)).type(self.data_type).to(device)

    def _init_images(self):

        self.original_image_torch = transforms.PILToTensor()(self.original_image)
        self.image, self.scalefill_offset, self.image_resize_size = scale_fill(self.original_image, self.target_size, self.target_size, return_offset=True)
        # self.image = transforms.Resize((512, 512))(self.original_image)
        self.image_np = pil_to_np(self.image)  # pil image to numpy
        self.image_torch = np_to_torch(self.image_np).type(self.data_type).to(device)

    def _init_inputs(self):
        if self.image_torch is not None:
            size = (self.image_torch.shape[2], self.image_torch.shape[3])
            # print(size)
        input_type = 'noise'
        # input_type = 'meshgrid'
        self.reflection_net_input = get_noise(self.input_depth,
                                                  input_type, size).type(self.data_type).detach().to(device)
        self.illumination_net_input = get_noise(self.input_depth,
                                             input_type, size).type(self.data_type).detach().to(device)


    def _init_parameters(self):
        self.parameters = [p for p in self.reflection_net.parameters()] + \
                          [p for p in self.illumination_net.parameters()]
        self.optimizer = torch.optim.Adam(self.parameters, lr=self.learning_rate)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                              milestones=[300, 600])


    def weight_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0.0, 0.5 * math.sqrt(2. / n))
            if m.bias is not None:
                m.bias.data.zero_()
        elif classname.find('BatchNorm') != -1:
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif classname.find('Linear') != -1:
            n = m.weight.size(1)
            m.weight.data.normal_(0, 0.01)
            m.bias.data = torch.ones(m.bias.data.size())

    def _init_nets(self):
        pad = 'zero'
        self.reflection_net = skip(self.input_depth, 3,
               num_channels_down = [8, 16, 32, 64, 128, 256, 512, 1024],
               num_channels_up   = [8, 16, 32, 64, 128, 256, 512, 1024],
               num_channels_skip = [0, 0, 0, 0, 0, 0, 0, 0],
               filter_size_down = 3, filter_size_up = 3, filter_skip_size=1,
               upsample_mode='bilinear',
               downsample_mode='avg',
               need_sigmoid=True, need_bias=True, pad=pad)
        self.reflection_net.apply(self.weight_init).type(self.data_type).to(device)


        self.illumination_net = skip(self.input_depth, 3,
               num_channels_down = [8, 16, 32, 64, 128, 256, 512, 1024],
               num_channels_up   = [8, 16, 32, 64, 128, 256, 512, 1024],
               num_channels_skip = [0, 0, 0, 0, 0, 0, 0, 0],
               filter_size_down = 3, filter_size_up = 3, filter_skip_size=1,
               upsample_mode='bilinear',
               downsample_mode='avg',
               need_sigmoid=True, need_bias=True, pad=pad)
        self.illumination_net.apply(self.weight_init).type(self.data_type).to(device)



    def _init_losses(self):
        self.l1_loss = nn.SmoothL1Loss().type(self.data_type) # for illumination
        self.mse_loss = nn.MSELoss().type(self.data_type)     # for reflection and reconstruction
        self.exclusion_loss =  ExclusionLoss().type(self.data_type)
        self.tv_loss = TVLoss().type(self.data_type)
        self.gradient_loss = GradientLoss().type(self.data_type)



    def optimize(self):
        # torch.backends.cudnn.enabled = True
        # torch.backends.cudnn.benchmark = True
        # optimizer = SGLD(self.parameters, lr=self.learning_rate)
        
        start = time.time()
        for j in tqdm(range(self.num_iter)):
            self.optimizer.zero_grad()
            self.calculate_loss()
            self._save_result(j)
            self.optimizer.step()
            self.scheduler.step()

        end = time.time()
        print("time:%.4f"%(end-start))
        self.best_result.save(str(self.image_name))
        append_img(self.original_image, self.best_result, 1024, ('original', 'enhanced')).save(self.image_name.parent/('compared_'+self.image_name.name))

    def calculate_loss(self):

        reg_noise_std = 1e-4

        illumination_net_input = self.illumination_net_input #+ self.illumination_net_input.clone().normal_(mean=0, std=reg_noise_std).to(device)
        reflection_net_input = self.reflection_net_input #+ self.reflection_net_input.clone().normal_(mean=0, std=reg_noise_std).to(device)


        self.illumination_out = self.illumination_net(illumination_net_input)
        self.reflection_out = self.reflection_net(reflection_net_input)

        # weighted with the gradient of latent reflectance
        self.total_loss = 0.5*self.tv_loss(self.illumination_out, self.reflection_out)
        self.total_loss += 0.0001*self.tv_loss(self.reflection_out)
        # self.total_loss += 0.0001*self.tv_loss(self.illumination_out)
        self.total_loss += self.l1_loss(self.illumination_out, self.original_illumination/255)
        self.total_loss += self.mse_loss(
            self.illumination_out*self.reflection_out, 
            self.image_torch/255
        )
        self.total_loss.backward()


    def _obtain_current_result(self, step):
        """
        puts in self.current result the current result.
        also updates the best result
        :return:
        """
        if step == self.num_iter - 1 or step % 8 == 0:
            reflection_out_np = np.clip(torch_to_np(self.reflection_out),0,1)
            illumination_out_np = np.clip(torch_to_np(self.illumination_out),0,1)
            # psnr = compare_psnr(np.clip(self.image_np,0,1),  reflection_out_np * illumination_out_np)
            # self.psnrs.append(psnr)

            self.current_result = EnhancementResult(reflection=reflection_out_np, illumination=illumination_out_np)
            # if self.best_result is None or self.best_result.psnr < self.current_result.psnr:
            #     self.best_result = self.current_result

    def _save_result(self, step):
        if (step % self.show_every == self.show_every - 1) or (step==0) or (step==self.num_iter-1):
            print('Iteration {:5d}    Loss {:5f}'.format(step,self.total_loss.item()))

            self.get_enhanced(step)

    def gamma_trans(self, img, gamma):
        gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]
        gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)
        return cv2.LUT(img, gamma_table)

    def adjust_gammma(self,img):
        image_gamma_correct = self.gamma_trans(img, 0.5)
        return image_gamma_correct

    def get_enhanced(self, step, flag=False):
        (R, G, B) = self.image.split()
        ini_illumination = self.illumination_out * 255.0
        ini_illumination = torch_to_np(ini_illumination).transpose(1, 2, 0) # H x W x C
        # If the input image is extremely dark, setting the flag as True can produce promising result.
        if flag==True:
            ini_illumination = np.clip(np.max(ini_illumination, axis=2, keepdims=True), 0.0000002, 255)
        else:
            # initial_shape = ini_illumination.shape
            # ini_illumination = np.max(ini_illumination, axis=2, keepdims=True)
            # ini_illumination = np.broadcast_to(ini_illumination, shape=initial_shape)
            ini_illumination = np.clip(self.adjust_gammma(ini_illumination.astype(np.uint8)), 0.0000002, 255)
        
        R = R / ini_illumination[:,:,0] # R range [0,255], ini_illu supposedly range [0, 255]
        G = G / ini_illumination[:,:,1]
        B = B / ini_illumination[:,:,2]
        self.best_result = Image.fromarray(np.clip(cv2.merge([R, G, B])*255, 0.000002, 255).astype(np.uint8), mode='RGB')#.resize(self.size)    
        self.best_result = reverse_scale_fill(self.best_result, self.size, self.scalefill_offset, self.image_resize_size)
        self.best_result.save('output/'+self.image_name.stem+('_{}.jpeg'.format(step)))




In [None]:
img_folder = Path('/content/drive/MyDrive/fun deep learning/image-enhancement/original_images/5-4-2022/')
if not img_folder.is_file():
    result_path = gu.folder(Path('./result/'+img_folder.parts[-1]))
else:
    result_path = gu.folder(Path('./result/'+img_folder.parts[-2]))
print('Output path = ', result_path)
img_path_list = iu.get_image_paths(img_folder)
print(img_path_list)

for img_path in img_path_list:
    print('-------------Processing '+img_path.name)
    image = Image.open(img_path)
    s = Enhancement(result_path/(img_path.stem + '.jpeg'), 
                    image, 
                    target_size=1024,
                    plot_during_training=True, show_every=10, num_iter=1000)
    s.optimize()
    break
    

    


Output path =  result/5-4-2022
[PosixPath('/content/drive/MyDrive/fun deep learning/image-enhancement/original_images/5-4-2022/IMG_3985.HEIC'), PosixPath('/content/drive/MyDrive/fun deep learning/image-enhancement/original_images/5-4-2022/IMG_3979.HEIC'), PosixPath('/content/drive/MyDrive/fun deep learning/image-enhancement/original_images/5-4-2022/IMG_3983.HEIC'), PosixPath('/content/drive/MyDrive/fun deep learning/image-enhancement/original_images/5-4-2022/IMG_3974.HEIC'), PosixPath('/content/drive/MyDrive/fun deep learning/image-enhancement/original_images/5-4-2022/IMG_3973.HEIC'), PosixPath('/content/drive/MyDrive/fun deep learning/image-enhancement/original_images/5-4-2022/IMG_3975.HEIC'), PosixPath('/content/drive/MyDrive/fun deep learning/image-enhancement/original_images/5-4-2022/IMG_3976.HEIC')]
-------------Processing IMG_3985.HEIC
Original image size: (3024, 4032)


  0%|          | 0/1000 [00:00<?, ?it/s]

Iteration     0    Loss 0.113814
Iteration     9    Loss 0.052970
Iteration    19    Loss 0.024481
Iteration    29    Loss 0.015282
Iteration    39    Loss 0.010796
Iteration    49    Loss 0.008719
Iteration    59    Loss 0.008117
Iteration    69    Loss 0.007588
Iteration    79    Loss 0.007067
Iteration    89    Loss 0.006503
Iteration    99    Loss 0.007227
Iteration   109    Loss 0.006461
Iteration   119    Loss 0.007147
Iteration   129    Loss 0.006233
Iteration   139    Loss 0.007018
Iteration   149    Loss 0.005928
Iteration   159    Loss 0.005844
Iteration   169    Loss 0.005847
Iteration   179    Loss 0.005739
Iteration   189    Loss 0.005773
Iteration   199    Loss 0.005557
Iteration   209    Loss 0.005537
Iteration   219    Loss 0.005754
Iteration   229    Loss 0.005367
Iteration   239    Loss 0.005367
Iteration   249    Loss 0.005156
Iteration   259    Loss 0.005220
Iteration   269    Loss 0.005287
Iteration   279    Loss 0.005314
Iteration   289    Loss 0.005277
Iteration 