# Dataset Generation

## Includes

In [None]:
# mass includes
import os
import pickle
import pyexiv2 as exiv2
import rawpy as rp
import numpy as np
import torch as t
from rawpy import HighlightMode
from tqdm.notebook import tqdm
from torch.utils import data

## Initialization

In [None]:
# configuration
data_root = '/home/lab/Documents/ssd/DJI'  # dataset path
save_root = '/home/lab/Documents/ssd/r2rSet'  # save path
file_ext = '.DNG'  # extension of raw file
train_num = 710  # num of images for training
patch_size = (400, 300)  # size of each patch

## RAW data manipulation

In [None]:
# get file list
file_list = [file for file in os.listdir(data_root) if file_ext in file]
file_list.sort()

# make new folders
train_path = os.path.join(save_root, 'train')
os.makedirs(train_path)
val_path = os.path.join(save_root, 'val')
os.makedirs(val_path)

for index, file in tqdm(enumerate(file_list),
                        desc='progress',
                        total=len(file_list)):
    # find black, saturation, and whitebalance
    img_md = exiv2.ImageMetadata(os.path.join(data_root, file))
    img_md.read()

    blk_level = img_md['Exif.SubImage1.BlackLevel'].value
    sat_level = img_md['Exif.SubImage1.WhiteLevel'].value
    cam_wb = img_md['Exif.Image.AsShotNeutral'].value

    # convert flat Bayer pattern to 4D tensor (RGGB)
    raw_img = rp.imread(os.path.join(data_root, file))
    flat_bayer = raw_img.raw_image_visible
    raw_data = np.stack((flat_bayer[0::2, 0::2], flat_bayer[0::2, 1::2],
                         flat_bayer[1::2, 0::2], flat_bayer[1::2, 1::2]),
                        axis=2)

    # get ground-truth sRGB image
    gt_img = raw_img.postprocess(use_camera_wb=True,
                                 output_bps=16,
                                 no_auto_bright=True,
                                 adjust_maximum_thr=0.0,
                                 highlight_mode=HighlightMode.Ignore)

    # split to small patches
    part_idx = 0
    raw_hei = gt_img.shape[0] / 2
    raw_wid = gt_img.shape[1] / 2
    for i in range(0, int(raw_hei / patch_size[1])):
        for j in range(0, int(raw_wid / patch_size[0])):
            crop_h = i * patch_size[1]
            crop_w = j * patch_size[0]
            raw_patch = raw_data[crop_h:crop_h + patch_size[1],
                                 crop_w:crop_w + patch_size[0], :]
            gt_patch = gt_img[2 * crop_h:2 * (crop_h + patch_size[1]),
                              2 * crop_w:2 * (crop_w + patch_size[0]), :]

            # save to files
            patch = {}
            patch['blk_level'] = np.array(blk_level, dtype=np.uint16)
            patch['sat_level'] = np.array(sat_level, dtype=np.uint16)
            patch['cam_wb'] = np.array(cam_wb, dtype=np.float32)
            patch['raw'] = np.transpose(raw_patch, (2, 0, 1))
            patch['img'] = np.transpose(gt_patch, (2, 0, 1))
            if index < train_num:
                file_path = os.path.join(
                    train_path, '%s_p%03d.pkl' % (file[:-4], part_idx))
            else:
                file_path = os.path.join(
                    val_path, '%s_p%03d.pkl' % (file[:-4], part_idx))
            with open(file_path, 'wb') as pkl_file:
                pickle.dump(patch, pkl_file)

            # update part index
            part_idx += 1