# Data Loader

## Includes

In [None]:
# mass includes
import os
import rawpy as rp
import numpy as np
import torch as t
from torch.utils import data

## User defined dataset

In [None]:
class sonySet(data.Dataset):
    def __init__(self, data_root, bk_level, img_size, mode='train'):
        self.short_root = os.path.join(data_root, 'Sony', 'short')
        self.long_root = os.path.join(data_root, 'Sony', 'long')
        self.bk_level = bk_level
        self.img_size = img_size
        self.mode = mode
        with open(os.path.join(data_root, 'Sony_%s_list.txt' % mode),
                  'r') as txt:
            self.sample_list = txt.read().splitlines()

    def __getitem__(self, index):
        # load a sample
        raw, gt, _, _ = self.sample_list[index].split()
        raw = raw.split('/')[-1]
        gt = gt.split('/')[-1]
        ratio = min(float(gt[9:-5]) / float(raw[9:-5]), 300.0)

        # read from files
        raw_data = rp.imread(os.path.join(self.short_root, raw))
        gt_data = rp.imread(os.path.join(self.long_root, gt))
        gt_img = gt_data.postprocess(
            use_camera_wb=True,
            half_size=False,
            no_auto_bright=True,
            output_bps=16)

        # convert flat image to RGBG tensor
        raw_flat = raw_data.raw_image_visible.astype(np.float32)
        hei, wid = raw_flat.shape
        raw_4d = np.stack(
            (raw_flat[0:hei:2, 0:wid:2], raw_flat[0:hei:2, 1:wid:2],
             raw_flat[1:hei:2, 1:wid:2], raw_flat[1:hei:2, 0:wid:2]),
            axis=2)

        if self.mode == 'train':
            # random crop
            crop_h = np.random.randint(0, hei / 2 - self.img_size)
            crop_w = np.random.randint(0, wid / 2 - self.img_size)
            raw_patch = raw_4d[crop_h:crop_h + self.img_size, crop_w:crop_w +
                               self.img_size, :]
            gt_patch = gt_img[crop_h * 2:(crop_h + self.img_size) * 2, crop_w *
                              2:(crop_w + self.img_size) * 2, :]

            # random flip
            op = np.random.randint(0, 3)
            if op == 0:
                # vertical flip
                raw_patch = np.flip(raw_patch, axis=0).copy()
                gt_patch = np.flip(gt_patch, axis=0).copy()
            elif op == 1:
                # horizontal flip
                raw_patch = np.flip(raw_patch, axis=1).copy()
                gt_patch = np.flip(gt_patch, axis=1).copy()

            # random transpose
            op = np.random.randint(0, 2)
            if op == 0:
                raw_patch = np.transpose(raw_patch, (1, 0, 2))
                gt_patch = np.transpose(gt_patch, (1, 0, 2))
        else:
            mid_h = int(hei / 4)
            mid_w = int(wid / 4)
            raw_patch = raw_4d[mid_h:mid_h + self.img_size, mid_w:mid_w +
                               self.img_size, :]
            gt_patch = gt_img[mid_h * 2:(mid_h + self.img_size) * 2, mid_w *
                              2:(mid_w + self.img_size) * 2, :]

        # normalize
        raw_patch = np.maximum(raw_patch - self.bk_level,
                               0) / (16383 - self.bk_level)
        gt_patch = np.float32(gt_patch / 65535.0)

        # to pyTorch tensor
        raw_patch = t.from_numpy(raw_patch)
        gt_patch = t.from_numpy(gt_patch)

        # amplify raw data and clamp
        raw_patch = t.clamp(raw_patch * ratio, 0.0, 1.0)

        return raw_patch.permute(2, 0, 1), gt_patch.permute(2, 0, 1)

    def __len__(self):

        return len(self.sample_list)