# Other Useful Functions

## Includes

In [None]:
# mass includes
import torch as t
import torchvision as tv

## Manual post-process of raw

In [None]:
# convert RAW to sRGB image
def raw2Img(raw_data, wb, cam_matrix):
    raw_data = applyWB(raw_data, wb)
    img = demosaic(raw_data)
    img = cam2sRGB(img, cam_matrix)
    img = applyGamma(img)

    return t.clamp(img, 0.0, 1.0)


# apply white balancing
def applyWB(raw_data, wb):
    raw_out = raw_data.clone()
    raw_out[0, :, :] *= wb[0]
    raw_out[3, :, :] *= wb[2]

    return raw_out


# demosaicing
def demosaic(raw_data):
    _, hei, wid = raw_data.size()
    img = raw_data.new_empty([3, hei, wid])
    img[0, :, :] = raw_data[0, :, :]  # R
    img[1, :, :] = (raw_data[1, :, :] + raw_data[2, :, :]) / 2  # G1+G2
    img[2, :, :] = raw_data[3, :, :]  # B

    return img


# color space conversion
def cam2sRGB(img, cam_matrix):
    cam_matrix = img.new_tensor(cam_matrix)
    xyz_matrix = img.new_tensor([[0.4124564, 0.3575761, 0.1804375],
                                 [0.2126729, 0.7151522, 0.0721750],
                                 [0.0193339, 0.1191920, 0.9503041]])
    trans_matrix = t.matmul(cam_matrix, xyz_matrix)
    trans_matrix /= t.sum(trans_matrix, 1, keepdim=True).repeat(1, 3)
    trans_matrix = t.inverse(trans_matrix)
    new_img = t.empty_like(img)
    new_img[0, :, :] = img[0, :, :] * trans_matrix[0, 0] + img[
        1, :, :] * trans_matrix[0, 1] + img[2, :, :] * trans_matrix[0, 2]
    new_img[1, :, :] = img[0, :, :] * trans_matrix[1, 0] + img[
        1, :, :] * trans_matrix[1, 1] + img[2, :, :] * trans_matrix[1, 2]
    new_img[2, :, :] = img[0, :, :] * trans_matrix[2, 0] + img[
        1, :, :] * trans_matrix[2, 1] + img[2, :, :] * trans_matrix[2, 2]

    return new_img


# gamma correction
def applyGamma(img):
    new_img = t.pow(img, 1 / 2.2)

    return new_img

## Raw data manipulation

In [None]:
# normalization
def normalize(raw_data, bk_level, sat_level):
    normal_raw = t.empty_like(raw_data)
    for index in range(raw_data.size(0)):
        for channel in range(raw_data.size(1)):
            normal_raw[index, channel, :, :] = (
                raw_data[index, channel, :, :] -
                bk_level[channel]) / (sat_level - bk_level[channel])

    return normal_raw


# resize Bayer pattern
def downSample(raw_data, struct_img_size):
    # convert Bayer pattern to down-sized sRGB image
    batch, _, hei, wid = raw_data.size()
    raw_img = raw_data.new_empty((batch, 3, hei, wid))
    raw_img[:, 0, :, :] = raw_data[:, 0, :, :]  # R
    raw_img[:,
            1, :, :] = (raw_data[:, 1, :, :] + raw_data[:, 2, :, :]) / 2.0  # G
    raw_img[:, 2, :, :] = raw_data[:, 3, :, :]  # B

    # down-sample to small size
    if hei != struct_img_size[1] and wid != struct_img_size[0]:
        raw_img = t.nn.functional.interpolate(raw_img,
                                              size=(struct_img_size[1],
                                                    struct_img_size[0]),
                                              mode='bicubic')
    raw_img = t.clamp(raw_img, 0.0, 1.0)

    return raw_img


# image standardization (mean 0, std 1)
def standardize(srgb_img):
    struct_img = t.empty_like(srgb_img)
    adj_std = 1.0 / t.sqrt(srgb_img.new_tensor(srgb_img[0, :, :, :].numel()))
    for index in range(srgb_img.size(0)):
        mean = t.mean(srgb_img[index, :, :, :])
        std = t.std(srgb_img[index, :, :, :])
        adj_std = t.max(std, adj_std)
        struct_img[index, :, :, :] = (srgb_img[index, :, :, :] -
                                      mean) / adj_std

    return struct_img

## Training sample sythesis

In [None]:
# convert sRGB image to RGGB pattern
def toRGGB(srgb_img):
    rggb_img = t.stack(
        (srgb_img[:, 0, 0::2, 0::2], srgb_img[:, 1, 0::2, 1::2],
         srgb_img[:, 1, 1::2, 0::2], srgb_img[:, 2, 1::2, 1::2]),
        dim=1)

    return rggb_img


# add noise to Bayer pattern
def addPGNoise(raw_data, noise_stat):
    # add noise to each sample
    noisy_raw = t.empty_like(raw_data)
    for index in range(raw_data.size(0)):
        log_shot = raw_data.new_empty(1).uniform_(noise_stat['min'],
                                                  noise_stat['max'])
        log_read = raw_data.new_empty(1).normal_(
            mean=noise_stat['slope'] * log_shot.item() + noise_stat['const'],
            std=noise_stat['std'])
        delta_final = t.sqrt(
            t.exp(log_shot) * raw_data[index, :, :, :] + t.exp(log_read))
        pg_noise = delta_final * t.randn_like(raw_data[index, :, :, :])
        noisy_raw[index, :, :, :] = raw_data[index, :, :, :] + pg_noise
    noisy_raw = t.clamp(noisy_raw, 0.0, 1.0)

    return noisy_raw


# blend weighted fg & bg and convert to Bayer pattern
def toRaw(r2rNet, syth_img, syth_mask, opt):
    # convert sRGB image to half size RGBG pattern
    rggb_raw = toRGGB(syth_img)

    # extract saturation mask
    sat_mask = rggb_raw.new_tensor(t.mean(rggb_raw, 1, keepdim=True) > 0.95)

    #random white balance
    batch, _, hei, wid = rggb_raw.size()
    wb = rggb_raw.new_empty((batch, 3, hei, wid))
    for index in range(0, batch):
        wb_r = rggb_raw.new_empty(1).uniform_(opt.wb_stat['min'],
                                              opt.wb_stat['max'])
        wb_b = rggb_raw.new_empty(1).normal_(
            mean=opt.wb_stat['slope'] * wb_r.item() + opt.wb_stat['const'],
            std=opt.wb_stat['std'])
        wb[index, 0, :, :] = t.exp(wb_r)
        wb[index, 1, :, :] = 1.0
        wb[index, 2, :, :] = t.exp(wb_b)

    # convert to Bayer pattern
    with t.no_grad():
        org_raw = r2rNet(rggb_raw, wb)
        org_raw = t.clamp(org_raw, 0.0, 1.0)


# random amplification ratio
    sorted_mask = syth_mask.clone()
    half_mask = t.nn.functional.interpolate(syth_mask, scale_factor=0.5)
    half_mask = t.clamp(half_mask, 0.0, 1.0)
    amp = org_raw.new_empty((batch, 2))
    clean_raw = t.empty_like(org_raw)
    for index in range(0, batch):
        amp[index, :] = t.clamp(
            syth_img.new_empty((2, )).uniform_(0.0, opt.amp_range[1]), 1.0,
            opt.amp_range[1])
        clean_raw[index, :, :, :] = half_mask[index, 0, :, :].unsqueeze(
            0) * org_raw[index, :, :, :] / amp[index, 0] + half_mask[
                index,
                1, :, :].unsqueeze(0) * org_raw[index, :, :, :] / amp[index, 1]
        if amp[index, 0] < amp[index, 1]:
            sorted_mask[index, :, :, :] = t.flip(sorted_mask[index, :, :, :],
                                                 [0])

    # preserve saturation
    clean_raw = t.max(clean_raw, sat_mask)

    # add noise
    noisy_raw = addPGNoise(clean_raw, opt.noise_stat)

    # down-sample to fixed size
    thumb_img = downSample(clean_raw, opt.att_size)
    struct_img = standardize(thumb_img)
    seg_mask = t.nn.functional.interpolate(syth_mask,
                                           size=(opt.att_size[1],
                                                 opt.att_size[0]))
    seg_mask = t.clamp(seg_mask, 0.0, 1.0)

    return thumb_img, struct_img, seg_mask, amp, noisy_raw, sorted_mask, wb

## Loss function

In [None]:
class vgg16Loss(t.nn.Module):
    def __init__(self, device):
        super(vgg16Loss, self).__init__()
        features = list(tv.models.vgg16(pretrained=True).features)[:23]
        self.features = t.nn.ModuleList(features).to(device).eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, pred_img, gt_img):
        x = pred_img
        y = gt_img
        vgg_loss = 0.0

        # use outputs of relu1_2, relu2_2, relu3_3, relu4_3 as loss
        for index, layer in enumerate(self.features):
            x = layer(x)
            y = layer(y)
            if index in {3, 8, 15, 22}:
                vgg_loss += t.nn.functional.mse_loss(x, y)

        return vgg_loss / 4.0


class imgLoss(t.nn.Module):
    def __init__(self, device):
        super(imgLoss, self).__init__()
        self.l2_loss = t.nn.MSELoss()
        self.vgg_loss = vgg16Loss(device)

    def forward(self, masked_img, fused_img, gt_img):
        l2_loss = (self.l2_loss(masked_img, gt_img) +
                   self.l2_loss(fused_img, gt_img)) / 2.0
        vgg_loss = (self.vgg_loss(masked_img, gt_img) +
                    self.vgg_loss(fused_img, gt_img)) / 2.0

        return l2_loss + vgg_loss