# Other Useful Functions

## Includes

In [None]:
# mass includes
import numpy as np
import math as m
import torch as t
import torchvision as tv
from numpy.random import randint, uniform
from scipy import signal
from colour_demosaicing import demosaicing_CFA_Bayer_DDFAPD as DDFAPD
from torch.nn.functional import grid_sample, mse_loss, pad, conv2d, interpolate, avg_pool2d

## Bayer pattern unification

In [None]:
def unifyBayerPtn(in_cfa, cfa_type):
    # check bayer type
    color = {0: 'R', 1: 'G', 2: 'B', 3: 'G'}
    cfa_type = cfa_type.reshape(-1)
    cfa_type = color[cfa_type[0]] + color[cfa_type[1]] + color[
        cfa_type[2]] + color[cfa_type[3]]

    # GRBG to RGGB
    if cfa_type == 'GRBG':
        out_cfa = in_cfa[:, 1:-1].copy()

    # GBRG to RGGB
    elif cfa_type == 'GBRG':
        out_cfa = in_cfa[1:-1, :].copy()

    # BGGR to RGGB
    elif cfa_type == 'BGGR':
        out_cfa = in_cfa[1:-1, 1:-1].copy()

    # native RGGB
    else:
        out_cfa = in_cfa.copy()

    return out_cfa

## Random operations

In [None]:
# random cropping
def batchCrop(in_imgs, crop_size, centred=False):
    # generate random coordinates if necessary
    if isinstance(in_imgs, list):
        hei, wid, _ = in_imgs[0].shape
    else:
        hei, wid, _ = in_imgs.shape
    min_dim = min(crop_size, hei - 1, wid - 1)
    if centred == False:
        crop_y = randint(hei - min_dim)
        crop_x = randint(wid - min_dim)
    else:
        crop_y = round((hei - crop_size) / 2)
        crop_x = round((wid - crop_size) / 2)

    # crop all images
    if isinstance(in_imgs, list):
        out_imgs = []
        for img in in_imgs:
            out_imgs.append(img[crop_y:crop_y + crop_size,
                                crop_x:crop_x + crop_size, :].copy())
    else:
        out_imgs = in_imgs[crop_y:crop_y + crop_size,
                           crop_x:crop_x + crop_size, :].copy()

    return out_imgs


# random horizontal flipping
def randHorFlip(in_imgs):
    out_imgs = []
    if uniform() > 0.5:
        if isinstance(in_imgs, list):
            for img in in_imgs:
                out_imgs.append(np.fliplr(img).copy())
        else:
            out_imgs = np.fliplr(in_imgs).copy()
    else:
        out_imgs = in_imgs

    return out_imgs


# random vertical flipping
def randVerFlip(in_imgs):
    out_imgs = []
    if uniform() > 0.5:
        if isinstance(in_imgs, list):
            for img in in_imgs:
                out_imgs.append(np.flipud(img).copy())
        else:
            out_imgs = np.flipud(in_imgs).copy()

    else:
        out_imgs = in_imgs

    return out_imgs

## Image downsizing

In [None]:
def downsize(in_img, out_size=256):
    # compute sigma
    _, chn, hei, wid = in_img.size()
    ratio = hei / out_size
    sigma = 2 * ratio / 6

    # construct Gaussian kernel
    padding = round(3 * sigma)
    gauss_knl = t.arange(-padding, padding + 1, device=in_img.device)
    gauss_knl = t.exp(-0.5 * gauss_knl**2 / sigma**2)
    gauss_knl /= t.sum(gauss_knl[:])
    gauss_knl = gauss_knl.repeat(chn, 1)

    # add padding
    out_img = pad(in_img, (padding, padding, padding, padding), mode='reflect')

    # x direction 1D convolution
    conv_kernel = gauss_knl.view(chn, 1, 1, -1)
    out_img = conv2d(out_img, conv_kernel, groups=chn)

    # y direction 1D convolution
    conv_kernel = gauss_knl.view(chn, 1, -1, 1)
    out_img = conv2d(out_img, conv_kernel, groups=chn)

    # downsample
    out_img = interpolate(out_img, size=(out_size, out_size), mode='nearest')

    return out_img

## Color space transform

In [None]:
def cam2sRGB(in_img, cam2xyz):
    xyz2srgb = cam2xyz.new_tensor([[3.1339, -1.6169, -0.4907],
                                   [-0.9784, 1.9158, 0.0334],
                                   [0.0720, -0.2290, 1.4057]])

    bch, ch, hei, wid = in_img.size()
    if ch == 4:
        out_img = in_img.new_zeros(
            (in_img.size(0), 3, in_img.size(2), in_img.size(3)))
        out_img[:, 0, :, :] = in_img[:, 0, :, :].clone()
        out_img[:, 1, :, :] = (in_img[:, 1, :, :] + in_img[:, 2, :, :]) / 2
        out_img[:, 2, :, :] = in_img[:, 3, :, :].clone()
    else:
        out_img = in_img.clone()

    out_img = out_img.view((bch, 3, -1))
    out_img = t.matmul(cam2xyz, out_img)
    out_img = t.clamp(out_img, 0.0, 1.0)
    out_img = t.matmul(xyz2srgb, out_img)
    out_img = t.clamp(out_img, 0.0, 1.0)
    out_img = out_img.view((bch, 3, hei, wid))

    return out_img


def rgb2lumin(in_img, cam2xyz):
    bch, ch, hei, wid = in_img.size()
    if ch == 4:
        lmn_coes = cam2xyz[:, 1].unsqueeze(1)
        out_img = in_img.new_zeros(
            (in_img.size(0), 3, in_img.size(2), in_img.size(3)))
        out_img[:, 0, :, :] = in_img[:, 0, :, :].clone()
        out_img[:, 1, :, :] = (in_img[:, 1, :, :] + in_img[:, 2, :, :]) / 2
        out_img[:, 2, :, :] = in_img[:, 3, :, :].clone()
    else:
        lmn_coes = in_img.new_tensor([[0.2224, 0.7169, 0.0606]])
        out_img = in_img.clone()

    out_img = out_img.view((bch, 3, -1))
    out_img = t.matmul(lmn_coes, out_img)
    out_img = out_img.view((bch, 1, hei, wid))
    out_img = t.clamp(out_img, 0.0, 1.0)

    return out_img

## Bayer CFA Demosaic

In [None]:
def demosaic(in_raw, ptn='RGGB'):
    img_list = []
    for index in range(in_raw.size(0)):
        packed_raw = in_raw[index, :, :, :]

        # unfold to flat bayer
        flat_raw = np.zeros((int(in_raw.size(2) * 2), int(in_raw.size(3) * 2)))
        flat_raw[0::2, 0::2] = packed_raw[0, :, :].clone().cpu().numpy()
        flat_raw[0::2, 1::2] = packed_raw[1, :, :].clone().cpu().numpy()
        flat_raw[1::2, 0::2] = packed_raw[2, :, :].clone().cpu().numpy()
        flat_raw[1::2, 1::2] = packed_raw[3, :, :].clone().cpu().numpy()

        # demosaic
        lin_img = DDFAPD(flat_raw, pattern=ptn)
        lin_img = in_raw.new_tensor(lin_img).permute(2, 0, 1)

        # add to results
        img_list.append(lin_img)

    # stack as one tensor
    out_img = t.stack(img_list, dim=0)

    return out_img

## Illumination adjustment

In [None]:
def applyIlmCoes(in_img, coes, cam2xyz=None):
    out_img = in_img.clone()

    # interative enhancement
    for index in range(coes.size(1)):
        coe_slice = coes[:, index, :, :].unsqueeze(1)
        lmn_img = rgb2lumin(out_img, cam2xyz)
        res = coe_slice * (1.0 - lmn_img) * out_img
        out_img = out_img + res
        out_img = t.clamp(out_img, 0.0, 1.0)

    return out_img


def applyIlmDen(in_img, noise_map, denoiser, coes, cam2xyz):
    out_img = in_img.clone()

    # interative enhancement with denoising
    for index in range(coes.size(1)):
        coe_slice = coes[:, index, :, :].unsqueeze(1)
        lmn_img = rgb2lumin(out_img, cam2xyz)
        res = coe_slice * (1.0 - lmn_img) * out_img
        out_img = denoiser(in_img, noise_map, res)

    return out_img

## Color adjustment

In [None]:
def applyClrCoes(in_img, coes, cam2xyz=None):
    pcc = [
        '1',\
        'r', 'g', 'b',\
        'rr', 'gg', 'bb', 'rg', 'gb', 'rb',\
        'rrr', 'ggg', 'bbb', 'rgg', 'gbb', 'rbb', 'grr', 'bgg', 'brr', 'rgb'
    ]

    # perform demosaicing if needed
    if in_img.size(1) == 4:
        srgb_img = demosaic(in_img)
    else:
        srgb_img = in_img.clone()

    # convert to sRGB color space if needed
    if cam2xyz is not None:
        srgb_img = cam2sRGB(srgb_img, cam2xyz)

    # polynominal transforms
    pcc_len = len(pcc)
    out_img = t.zeros_like(srgb_img)
    coes = coes.view(coes.size(0), -1, 1, 1)
    for (index, poly) in enumerate(pcc):
        poly_img = t.ones_like(out_img[:, 0, :, :])
        for char in poly:
            if char == '1':
                pass
            elif char == 'r':
                poly_img = poly_img * srgb_img[:, 0, :, :]
            elif char == 'g':
                poly_img = poly_img * srgb_img[:, 1, :, :]
            elif char == 'b':
                poly_img = poly_img * srgb_img[:, 2, :, :]
            else:
                sys.exit('Unrecognized polynominal term: %s' % char)

        # R
        cur_chn = out_img[:, 0, :, :]
        cur_coes = coes[:, index, :, :]
        out_img[:, 0, :, :] = cur_chn + cur_coes * poly_img

        # G
        cur_chn = out_img[:, 1, :, :]
        cur_coes = coes[:, index + pcc_len, :, :]
        out_img[:, 1, :, :] = cur_chn + cur_coes * poly_img

        # B
        cur_chn = out_img[:, 2, :, :]
        cur_coes = coes[:, index + 2 * pcc_len, :, :]
        out_img[:, 2, :, :] = cur_chn + cur_coes * poly_img

    out_img = t.clamp(out_img, 0.0, 1.0)

    return out_img

## Loss functions

In [None]:
class vgg16Loss(t.nn.Module):
    def __init__(self, device):
        super(vgg16Loss, self).__init__()

        # VGG input normalization
        self.mean = t.tensor([0.485, 0.456, 0.406], device=device)
        self.std = t.tensor([0.229, 0.224, 0.225], device=device)
        self.mean = self.mean.view(1, -1, 1, 1)
        self.std = self.std.view(1, -1, 1, 1)

        # pretrained weights
        features = list(tv.models.vgg16(pretrained=True).features)[:30]
        self.features = t.nn.ModuleList(features).to(device).eval()
        for param in self.parameters():
            param.requires_grad = False

        # instance normalization
        self.insNorm = {
            3: t.nn.InstanceNorm2d(64, affine=False),
            8: t.nn.InstanceNorm2d(128, affine=False),
            15: t.nn.InstanceNorm2d(256, affine=False),
            22: t.nn.InstanceNorm2d(512, affine=False),
            29: t.nn.InstanceNorm2d(512, affine=False)
        }

        # indices of layers to be extracted
        self.layer_list = [3, 8, 15, 22, 29]

    def forward(self, img1, img2):
        x = (img1 - self.mean.expand_as(img1)) / self.std.expand_as(img1)
        y = (img2 - self.mean.expand_as(img2)) / self.std.expand_as(img2)
        vgg_loss = 0.0

        # compute VGG perceptual loss
        for index, layer in enumerate(self.features):
            x = layer(x)
            y = layer(y)
            if index in self.layer_list:
                vgg_loss += mse_loss(self.insNorm[index](x),
                                     self.insNorm[index](y))
        vgg_loss = vgg_loss / len(self.layer_list)

        return vgg_loss


class expLoss(t.nn.Module):
    def __init__(self, shadow=0.1, sigma=0.2):
        super(expLoss, self).__init__()
        self.shadow = shadow
        self.sigma = sigma

    def forward(self, in_img, out_img):
        out_avg = t.mean(out_img, dim=1)
        exp_loss = 1.0 - t.exp(-0.5 * (out_avg - 0.5)**2 / self.sigma**2)
        if in_img is not None:
            in_avg = t.mean(in_img, dim=1)
            blk_loss = t.nn.functional.l1_loss(out_img, in_img)
            out_loss = t.mean(t.where(in_avg < self.shadow, blk_loss,
                                      exp_loss))
        else:
            out_loss = t.mean(exp_loss)

        return out_loss


class tvLoss(t.nn.Module):
    def __init__(self):
        super(tvLoss, self).__init__()

    def forward(self, in_feats):
        # compute gradients
        pixel_dif1 = in_feats[:, :, 1:, :-1] - in_feats[:, :, :-1, :-1]
        pixel_dif2 = in_feats[:, :, :-1, 1:] - in_feats[:, :, :-1, :-1]

        # apply weighting
        tv_loss = t.mean(t.abs(pixel_dif1) + t.abs(pixel_dif2))

        return tv_loss

In [None]:
def applyRGBCoes(in_img, coes, cam2xyz=None):
    out_img = in_img.clone()

    # interative enhancement
    for index in range(coes.size(1)):
        coe_slice = coes[:, index, :, :].unsqueeze(1)
        res = coe_slice * (1.0 - out_img) * out_img
        out_img = out_img + res
        out_img = t.clamp(out_img, 0.0, 1.0)

    return out_img