# Dataloader

## Includes

In [None]:
# mass includes
import os, sys
import cv2
import pickle
import rawpy as rp
import numpy as np
import torch as t
from torch.utils import data
from torch.distributions.multivariate_normal import MultivariateNormal

## Dataset for r2rNet

In [None]:
class r2rSet(data.Dataset):
    def __init__(self, opt, mode='train'):
        self.mode = mode
        self.data_root = os.path.join(opt.data_root, self.mode)
        self.r2r_size = opt.r2r_size
        self.file_list = [
            file for file in os.listdir(self.data_root) if '.pkl' in file
        ]

    def __getitem__(self, index):
        # load a new sample
        with open(os.path.join(self.data_root, self.file_list[index]),
                  'rb') as file:
            data_dict = pickle.load(file)

        # read from file
        raw_data = data_dict['raw'].astype(np.float32)
        srgb_data = data_dict['img'].astype(np.float32)
        blk_level = data_dict['blk_level'].astype(np.float32)
        sat_level = data_dict['sat_level'].astype(np.float32)
        cam_wb = data_dict['cam_wb'].astype(np.float32)

        # random transforms
        if self.mode == 'train':
            # random crop
            crop_h = np.random.randint(0, raw_data.shape[1] - self.r2r_size)
            crop_w = np.random.randint(0, raw_data.shape[2] - self.r2r_size)
            raw_patch = raw_data[:, crop_h:crop_h + self.r2r_size,
                                 crop_w:crop_w + self.r2r_size]
            srgb_patch = srgb_data[:, 2 * crop_h:2 * (crop_h + self.r2r_size),
                                   2 * crop_w:2 * (crop_w + self.r2r_size)]
        else:
            raw_patch = raw_data[:, :, :]
            srgb_patch = srgb_data[:, :, :]

        # normalization
        raw_patch = np.clip((raw_patch - np.resize(blk_level, [4, 1, 1])) /
                            (sat_level - np.resize(blk_level, [4, 1, 1])), 0.0,
                            1.0)
        srgb_patch = srgb_patch / 65535.0

        # to pyTorch tensor
        raw_patch = t.from_numpy(raw_patch)
        srgb_patch = t.from_numpy(srgb_patch)
        cam_wb = t.from_numpy(cam_wb).view([3, 1, 1])
        if self.mode == 'train':
            cam_wb = cam_wb.expand([3, self.r2r_size, self.r2r_size])
        else:
            cam_wb = cam_wb.expand([3, raw_patch.size(1), raw_patch.size(2)])

        return raw_patch, srgb_patch, cam_wb

    def __len__(self):

        return len(self.file_list)

## Dataset for fivekNight

In [None]:
# random cropping and flipping
def randTransform(bg_img, fg_img, mask, img_size):
    # cropping
    if bg_img.shape[1] > bg_img.shape[0]:
        crop_h = np.random.randint(img_size[1], bg_img.shape[0])
        crop_w = np.round(crop_h / 0.75)
    else:
        crop_w = np.random.randint(img_size[0], bg_img.shape[1])
        crop_h = np.round(crop_w / 1.33)
    crop_y = int(np.random.randint(0, bg_img.shape[0] - crop_h))
    crop_x = int(np.random.randint(0, bg_img.shape[1] - crop_w))
    crop_h = int(crop_h)
    crop_w = int(crop_w)
    bg_img = bg_img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w, :]

    # flipping
    rand_var = np.random.rand()
    if rand_var < 0.25:
        bg_img = cv2.flip(bg_img, 1)
    elif rand_var < 0.5:
        fg_img = cv2.flip(fg_img, 1)
        mask = cv2.flip(mask, 1)
    elif rand_var < 0.75:
        bg_img = cv2.flip(bg_img, 1)
        fg_img = cv2.flip(fg_img, 1)
        mask = cv2.flip(mask, 1)

    return bg_img, fg_img, mask


# image blending and rescaling
def imgBlend(bg_img, fg_img, mask, img_size):
    # scaling
    scale = np.random.uniform(0.5, 1.0) * np.minimum(
        bg_img.shape[0] / mask.shape[0], bg_img.shape[1] / mask.shape[1])
    offset = np.random.randint(0, bg_img.shape[1] - int(scale * mask.shape[1]))
    mask = cv2.resize(mask,
                      None,
                      fx=scale,
                      fy=scale,
                      interpolation=cv2.INTER_CUBIC)
    fg_img = cv2.resize(fg_img,
                        None,
                        fx=scale,
                        fy=scale,
                        interpolation=cv2.INTER_CUBIC)

    # paste crop image to an empty image
    syth_mask = np.zeros((bg_img.shape[0], bg_img.shape[1]), dtype=np.float32)
    syth_mask[bg_img.shape[0] - mask.shape[0]:bg_img.shape[0],
              offset:offset + mask.shape[1]] = mask
    syth_mask = np.repeat(syth_mask[:, :, np.newaxis], 3, axis=2) / 255.0
    syth_img = np.zeros((bg_img.shape[0], bg_img.shape[1], 3),
                        dtype=np.float32)
    syth_img[bg_img.shape[0] - mask.shape[0]:bg_img.shape[0],
             offset:offset + mask.shape[1], :] = fg_img
    syth_img = (syth_mask * syth_img + (1 - syth_mask) * bg_img) / 65535.0

    # resize to fixed shape
    syth_img = cv2.resize(syth_img, img_size, interpolation=cv2.INTER_CUBIC)
    syth_mask = cv2.resize(syth_mask, img_size, interpolation=cv2.INTER_CUBIC)

    # clip to 0-1
    syth_img = np.clip(syth_img, 0.0, 1.0)
    syth_mask = np.clip(syth_mask, 0.0, 1.0)

    return syth_img, syth_mask


class fivekNight(data.Dataset):
    def __init__(self, opt):
        # get sample list
        self.img_size = opt.isp_size
        self.bg_path = os.path.join(opt.data_root, 'scene')
        self.fg_path = os.path.join(opt.data_root, 'people')
        self.bg_list = [
            file[:-4] for file in os.listdir(os.path.join(self.bg_path, 'raw'))
            if '.png' in file
        ]
        self.fg_list = [
            file[:-4] for file in os.listdir(os.path.join(self.fg_path, 'raw'))
            if '.png' in file
        ]

    def __getitem__(self, index):
        # read images and mask
        bg_index = int(index // len(self.fg_list))
        fg_index = int(index % len(self.fg_list))
        bg_img = cv2.imread(
            os.path.join(self.bg_path, 'raw', self.bg_list[bg_index] + '.png'),
            cv2.IMREAD_UNCHANGED)
        fg_img = cv2.imread(
            os.path.join(self.fg_path, 'raw', self.fg_list[fg_index] + '.png'),
            cv2.IMREAD_UNCHANGED)
        with open(
                os.path.join(self.fg_path, 'mask',
                             self.fg_list[fg_index] + '.pkl'), 'rb') as pkl:
            mask = pickle.load(pkl)

        # BGR to RGB
        bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
        fg_img = cv2.cvtColor(fg_img, cv2.COLOR_BGR2RGB)

        # random transforms
        bg_img, fg_img, mask = randTransform(bg_img, fg_img, mask,
                                             self.img_size)

        # image blending and rescaling
        syth_img, syth_mask = imgBlend(bg_img, fg_img, mask, self.img_size)

        # convert to tensor and normalize
        syth_img = t.tensor(syth_img, dtype=t.float).permute(2, 0, 1)
        syth_mask = t.tensor(np.stack(
            [syth_mask[:, :, 0], 1.0 - syth_mask[:, :, 0]], axis=0),
                             dtype=t.float)

        return syth_img, syth_mask

    def __len__(self):

        return len(self.fg_list) * len(self.bg_list)


class valSet(data.Dataset):
    def __init__(self, opt):
        self.data_root = os.path.join(opt.data_root,
                                      'val%d' % opt.amp_range[1])
        self.file_list = [
            file for file in os.listdir(self.data_root) if '.pkl' in file
        ]

    def __getitem__(self, index):
        # load a new sample
        with open(os.path.join(self.data_root, self.file_list[index]),
                  'rb') as file:
            data_dict = pickle.load(file)

        # read from file
        syth_img = data_dict['syth_img']
        thumb_img = data_dict['thumb_img']
        struct_img = data_dict['struct_img']
        seg_mask = data_dict['seg_mask']
        amp = data_dict['amp']
        noisy_raw = data_dict['noisy_raw']
        sorted_mask = data_dict['sorted_mask']
        wb = data_dict['wb']

        return syth_img, thumb_img, struct_img, seg_mask, amp, noisy_raw, sorted_mask, wb

    def __len__(self):

        return len(self.file_list)