# Projet : Learning to See in the Dark
Paul-Vinh LÊ, Sébastien Morel

# Préparation

## Récupération des différents code 

### Clone notre répertoire de code : 

In [1]:
try:
    !rm -fr pytorch-Learning-to-See-in-the-Dark
    !git clone --quiet https://github.com/Paul-Vinh/pytorch-Learning-to-See-in-the-Dark.git 

except ImportError:
    pass

### Clone le répertoire lié à l'article HDR+

In [2]:
try:
    !rm -fr hdr_plus_pytorch
    !git clone --quiet https://github.com/martin-marek/hdr-plus-pytorch.git hdr_plus_pytorch

except ImportError:
    pass

### Connecte notre Drive

 Cela est utile pour sauvegarder les différents résultats de nos experiences, sauvegarder les checkpoints de nos entrainements, les courbes de loss etc... On peut aussi stocker les données. 

In [None]:
# Connect colab with Drive
persistent_storage = 'trainings/'
try:
    # Load the Drive helper and mount
    from google.colab import drive
    import os

    # This will prompt for authorization.
    drive.mount( 'Drive')
    persistent_storage = 'Drive/My Drive/MVA/Imagerie_numérique'
    os.makedirs(persistent_storage, exist_ok=True)
except:
    pass

## Installe les packages nécessaires au code 

In [None]:
!pip install rawpy
!pip install bm3d
!pip install multiprocess
!pip install pytorch_ssim
!pip install gdown
!pip install pytorch_msssim
!pip install torchgeometry

## Télécharge le dataset

In [None]:
# Get the zip file
!wget https://storage.googleapis.com/isl-datasets/SID/Sony.zip

In [6]:
# unzip it !
!mkdir dataset/
!unzip -q Sony.zip -d dataset/

In [7]:
# Remove the zip file
!rm Sony.zip

# Entrainement 

Pour entrainer le modèle, il faut définir : 

*   Le dossier où l'on stocke les courbes d'entrainements et les checkpoints. On peut aussi mettre un modèle pour lequel on souhaite continuer plus longtemps l'entrainement.
*  La loss que l'on souhaite utiliser ("ssim", "L1" ou "L2").
*  Le nombre d'epochs que va prendre l'entrainement. 
*  La fréquence de sauvegarde des checkpoints save_freq ainsi que la fréquence de validation val_freq qui permet de construire la courbe d'entrainement si plot_loss est True.

In [None]:
!python pytorch-Learning-to-See-in-the-Dark/train_Sony.py --models_dir "/content/Drive/MyDrive/MVA/Image_bis/" --loss "ssim" --save_freq 10 --val_freq 1  --plot_loss True --epoch 5

Traceback (most recent call last):
  File "pytorch-Learning-to-See-in-the-Dark/train_Sony.py", line 10, in <module>
    import torch
  File "/usr/local/lib/python3.7/dist-packages/torch/__init__.py", line 197, in <module>
    from torch._C import *  # noqa: F403
RuntimeError: KeyboardInterrupt: 


# Test du modèle

## Test sur le test Set de Sony
Pour tester la qualité d'un modèle, on peut directement appliquer le script test_Sony qui permet de le tester sur l'ensemble des images du test Set. 

In [None]:
!python pytorch-Learning-to-See-in-the-Dark/test_Sony.py --dataset_dir "/content/" --model_dir "/content/pytorch-Learning-to-See-in-the-Dark/"

## Application du model sur un fichier particulier

On peut également appliquer notre modèle sur un fichier en particulier. Pour cela, on utilse l'argument *file* qui donne le chemin d'accès vers le fichier particulier. 
<br>
Cela permet alors de générer rapidement une output pour une image voulue. 
<br> On peut également tester la capacité de généralisation de notre modèle pour des images non présentes dans le data set de Sony. Comme on ne dispose pas de ground truth sur ce genre d'image, on définir un argument *generalization* pour pouvoir éviter les étapes de calcul de métriques, impossible sans groundtruth. 

In [None]:
!python pytorch-Learning-to-See-in-the-Dark/test_Sony.py --file "/content/dataset/Sony/short/00001_00_0.1s.ARW" --generalization False --dataset_dir "/content/" --model_dir "/content/pytorch-Learning-to-See-in-the-Dark/"

# État de l'Art

Pour pouvoir comparer la méthode Learning to See in the Dark aux méthodes plus classiques que sont BM3D et HDR+. 
<br> On va donc générer les résultats de ces méthodes sur le test set, les sauvergarder si nécessaire et calculer leur écart à la ground truth. 

## Import des différentes librairies 

In [8]:
import torch
import torchvision
import numpy as np
import pandas as pd
from skimage import exposure
from skimage.metrics import structural_similarity as ssim
from hdr_plus_pytorch import align
import rawpy
import imageio
from glob import glob
import matplotlib.pyplot as plt
import zipfile
from PIL import Image

import bm3d
import rawpy
import cv2
import multiprocessing as mp



## Some usefull functions

In [9]:
def compute_ssim(gt_im, pred_im):
    """ Compute SSIM between ground truth image & predicted image.
    """
    return(ssim(gt_im, pred_im,
                  data_range=pred_im.max() - pred_im.min(), multichannel=True))

In [10]:
def load_raw_images(image_paths):
    """loads bayer pixels from raw images"""
    images = []
    for path in image_paths:
        with rawpy.imread(path) as raw:
            image = raw.raw_image.copy().astype(np.float32)
            images.append(image)

    # store the pixels in a tensor with an added "channel" dimension
    images_s = np.stack(images)
    images_s = torch.from_numpy(images_s)[:, None, :, :]

    print(f'burst of shape {list(images_s.shape)} loaded')
    return images_s


In [11]:
def get_grountruth(image_index, gt_dir):
  """
  Get the ground truth for the image_index 
  """
  gt_files = glob(gt_dir + '%s_00*.ARW'%image_index)
  gt_path = gt_files[0]
  _, gt_fn = os.path.split(gt_path)
  in_exposure =  0.1
  gt_exposure =  float(gt_fn[9:-5])
  ratio = min(gt_exposure/in_exposure,300)

  gt_raw = rawpy.imread(gt_path)
  sat = gt_raw.white_level
  wb = gt_raw.camera_white_level_per_channel
  bl = gt_raw.black_level_per_channel
  im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
  gt_image = np.expand_dims(np.float32(im/65535.0),axis = 0)
  
  gt_full = gt_image[0,:,:,:]
  return (gt_full*255).astype('uint8'), wb, sat, bl


In [12]:
def get_rgb_values(image_path, bayer_array=None, **kwargs):
    """using a raw file [and modified bayer pixels], get rgb pixels"""
    # open the raw image
    with rawpy.imread(image_path) as raw:
        # overwrite the original bayer array
        if bayer_array is not None:
            raw.raw_image[:] = bayer_array
        # get postprocessed rgb pixels
        rgb = raw.postprocess(**kwargs)
    return rgb

In [13]:
def extract_index(in_files):
  """
  Get the indexs of the images in the in_files
  Return the list of the indexs sorted, whithout dupplicate
  """
  indexs = [i.split('/')[-1] for i in in_files]
  indexs = [i.split('_', 1)[0] for i in indexs]
  indexs = list(dict.fromkeys(indexs))
  indexs.sort()
  return indexs

In [38]:
def compare_methods(indexs, input_dir, gt_dir, device, use_BM3D = False, use_HDR = False, scores = None, result_dir = None):
  
  N = len(indexs)
  for i in range(N):
    print('Iteration %d sur %d iterations for image %s'%(i, N, indexs[i]))

    ## Get file 
    # Stack images
    in_file = glob(input_dir+indexs[i]+'_*_0.1s.ARW')
    images = load_raw_images(in_file)
    # One image for BM3D
    if use_BM3D:
      bm_raw = rawpy.imread(in_file[0])

    # Get Ground Truth and some parameters
    ref_rgb, wb, sat, bl = get_grountruth(indexs[i], gt_dir)
    
    ## Compute scale of gt
    ref_rgb = ref_rgb.astype('float32')
    means_gt = np.mean(ref_rgb, axis = (1,0))

    if use_HDR:
      ## HDR +
      print("Begin HDR method")
      # Alignement
      merged_image = align.align_and_merge(images, device=device)
      # Convert raw images to rgb images
      merged_rgb = get_rgb_values(in_file[0], merged_image[0],user_wb = wb, user_sat = sat, user_black= bl[0], half_size=False, no_auto_bright=True, output_bps=16)
      merged_rgb = merged_rgb.astype('float32')
      # Scaling on Gt
      #means_rgb = np.mean(merged_rgb, axis = (1,0))
      #means =  means_gt/means_rgb
      #merged_rgb *= means[None, None, :]
      #merged_rgb = exposure.match_histograms(merged_rgb, ref_rgb, multichannel=True)

    if use_BM3D:
      ## BM3D
      print("Begin BM3D method")
      # Convert raw images to rgb images
      bm3d_rgb = bm_raw.postprocess(user_wb = wb, user_sat = sat, user_black= bl[0], half_size=False, no_auto_bright=True, output_bps=16)
      bm3d_rgb = bm3d_rgb.astype('float32')
      # Scaling on Gt
      means_bm3d = np.mean(bm3d_rgb, axis = (1,0))
      bm3d_rgb = (means_gt/means_bm3d) * bm3d_rgb
      # Denoising
      denoised_rgb = bm3d.bm3d(bm3d_rgb, sigma_psd=0.2, stage_arg=bm3d.BM3DStages.ALL_STAGES)

    if scores:
      scores['id'].append(indexs[i])
      print("Compute SSIM")
      ## Compute SSIM comparaison
      if use_BM3D:
        score_bm3d = compute_ssim(ref_rgb.astype('uint8'), denoised_rgb.astype('uint8'))
        print("Methode BM3D image {} SSIM = {}".format(indexs[i], score_bm3d))
        scores['BM3D SSIM'].append(score_bm3d)
      if use_HDR:
        score_hdr = compute_ssim(ref_rgb.astype('uint8'), merged_rgb.astype('uint8'))
        print("Methode HDR image {} SSIM = {}".format(indexs[i], score_hdr))
        scores['HDR+ SSIM'].append(score_hdr)
    
    if result_dir:
      print("Saving file ")
      ## Save Images
      if use_HDR:
        Image.fromarray(merged_rgb.astype('uint8')).save(result_dir + indexs[i] + '_hdr.png')     
      if use_BM3D:
        Image.fromarray(denoised_rgb.astype('uint8')).save(result_dir + indexs[i] + '_BM3DD.png')                           


In [None]:
help(bm3d.bm3d)

## Get the score of each function on test set

In [15]:
device = torch.device('cuda')

In [16]:
input_dir = r"/content/dataset/Sony/short/"
gt_dir = r"/content/dataset/Sony/long/"
in_files = glob(input_dir + '1*_*_0.1s.ARW')

In [17]:
indexs = extract_index(in_files)

In [18]:
indexs = ['10016', '10074', '10106', '10167', '10199', '10187']
indexs = ['10016']

In [19]:
result_dir = "/content/Drive/MyDrive/MVA/Imagerie_numérique/State_of_the_art/"

In [20]:
N = len(indexs)

In [21]:
scores = {'id': [], 'BM3D SSIM' : [], 'HDR+ SSIM': []}

In [None]:
compare_methods(indexs, input_dir, gt_dir, device, use_BM3D = True, use_HDR = True, scores = scores) ### change user_wb = wb wb = camera_white_level_per_channel

In [None]:
scores_df = pd.DataFrame(scores, columns= ['id', 'BM3D SSIM', 'HDR+ SSIM'])
scores_df.set_index(['id'], inplace = True)
scores_df.mean()

In [None]:
scores_df.to_csv(result_dir + 'Scores_BM3D_HDR.csv')