In [2]:
import numpy as np
import cv2

In [3]:
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import cv2
from pathlib import Path
import os
import numpy as np

In [4]:
%pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning
  Downloading pytorch_lightning-1.8.0.post1-py3-none-any.whl (796 kB)
[K     |████████████████████████████████| 796 kB 31.1 MB/s 
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.10.2-py3-none-any.whl (529 kB)
[K     |████████████████████████████████| 529 kB 66.5 MB/s 
Collecting lightning-utilities==0.3.*
  Downloading lightning_utilities-0.3.0-py3-none-any.whl (15 kB)
Collecting lightning-lite==1.8.0.post1
  Downloading lightning_lite-1.8.0.post1-py3-none-any.whl (136 kB)
[K     |████████████████████████████████| 136 kB 72.2 MB/s 
[?25hCollecting fire
  Downloading fire-0.4.0.tar.gz (87 kB)
[K     |████████████████████████████████| 87 kB 7.3 MB/s 
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115942 sha256=eeaa3b36f

In [5]:
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import LightningDataModule
from pathlib import Path

In [6]:
def _get_ab_hist(img: "np.ndarray", num_bin: int) -> "np.ndarray":
    """Get ab-space histogram of an image
    Parameters
    ----------
    img : np.ndarray
        Image numpy array
    num_bin : int
        Number of bins
    Returns
    -------
    np.ndarray
        Ab-space histogram
    """

    H = cv2.calcHist(
        [img.astype(np.float32)],
        channels=[1, 2],
        mask=None,
        histSize=[num_bin, num_bin],
        ranges=[0, 256, 0, 256],
    )
    H = H[None, ...]
    H = H / np.sum(H, axis=None)

    # arr = img.astype(float)

    # # Exclude Zeros and Make value 0 ~ 1
    # arr1 = (arr[1].ravel()[np.flatnonzero(arr[1])] + 1) / 2
    # arr2 = (arr[2].ravel()[np.flatnonzero(arr[2])] + 1) / 2

    # if arr1.shape[0] != arr2.shape[0]:
    #     if arr2.shape[0] < arr1.shape[0]:
    #         arr2 = np.concatenate([arr2, np.array([0])])
    #     else:
    #         arr1 = np.concatenate([arr1, np.array([0])])

    # # AB space
    # arr_new = [arr1, arr2]
    # H, edges = np.histogramdd(arr_new, bins=[num_bin, num_bin], range=((0, 1), (0, 1)))

    # H = np.rot90(H)
    # H = np.flip(H, 0)

    # H = H[None, ...].astype(float)
    # H = H / np.sum(H, axis=None)

    return H

In [7]:
def _get_l_hist(img: "np.ndarray", num_bin: int) -> "np.ndarray":
    """Get luminance histogram of an image
    Parameters
    ----------
    img : np.ndarray
        Image numpy array
    num_bin : int
        Number of bins
    Returns
    -------
    np.ndarray
        Luminance histogram
    """
    H = cv2.calcHist(
        [img.astype(np.float32)],
        channels=[0],
        mask=None,
        histSize=[num_bin],
        ranges=[0, 256],
    )
    H = H[..., None]
    H = H / np.sum(H, axis=None)

    return H

    # # Preprocess
    # arr = img.astype(float)
    # arr0 = (arr[0].ravel()[np.flatnonzero(arr[0])] + 1) / 2
    # arr1 = np.zeros(arr0.size)

    # arr_new = [arr0, arr1]
    # H, edges = np.histogramdd(arr_new, bins=[num_bin, 1], range=((0, 1), (-1, 2)))
    # H = np.transpose(H[None, ...], (1, 0, 2)).astype(float)

    # H = H / np.sum(H, axis=None)

    # return H


In [8]:
def get_histogram(img: "np.ndarray", l_bin: int, ab_bin: int) -> "np.ndarray":
    """_summary_
    Parameters
    ----------
    img : np.ndarray
        Image numpy array
    l_bin : int
        Size of luminance bin
    ab_bin : int
        Size of ab bin
    Returns
    -------
    np.ndarray
        Histogram
    """
    l_hist = _get_l_hist(img, l_bin)
    ab_hist = _get_ab_hist(img, ab_bin)

    l_hist = np.tile(l_hist, (1, ab_bin, ab_bin))

    hist = np.concatenate([ab_hist, l_hist], axis=0)

    return hist

In [9]:
def get_segwise_hist(
    img: "np.ndarray", l_bin: int, ab_bin: int, seg: "np.ndarray", num_classses: int
) -> "np.ndarray":
    """Get segmentation-wise histogram of an image
    Parameters
    ----------
    img : np.ndarray
        Image numpy array
    l_bin : int
        Size of luminance bin
    ab_bin : int
        Size of ab bin
    seg : np.ndarray
        Segementation map
    num_classses : int
        Number of segmentation labels
    Returns
    -------
    np.ndarray
        Histogram
    """
    l = []
    for i in range(num_classses):
        mask_img = img * (seg == i)
        mask_hist = get_histogram(mask_img, l_bin, ab_bin)
        l.append(mask_hist[None, :])

    return np.concatenate(l, axis=0)


In [10]:
def one_hot(seg: "np.ndarray[int]", num_classes: int) -> "np.ndarray[int]":
    """One-hot encode segmentation map
    Parameters
    ----------
    seg : np.ndarray[int]
        Segmentation map
    num_classes : int
        Number of segmentation labels
    Returns
    -------
    np.ndarray[int]
        One-hot encoded segmentation map with shape of (num_classes, w, h)
    """
    w, h = seg.shape
    res = np.tile(seg[None, ...], (num_classes, 1, 1))
    mask = np.ones((num_classes, w, h)) * np.arange(num_classes)[..., None, None]
    return (res == mask).astype(int)

In [11]:
def gen_common_seg_map(
    input_seg: "np.ndarray[int]", ref_seg: "np.ndarray[int]", num_classes: int
) -> "np.ndarray[int]":
    """_summary_
    Parameters
    ----------
    input_seg : np.ndarray[int]
        Segmentation label of input image.
    ref_seg : np.ndarray[int]
        Segmentation label of reference image.
    num_classes : int
        Number of segmentation labels.
    Returns
    -------
    np.ndarray[int]
        One-hot encoded input img seg map, only preserve common seg labels
    """
    in_uni = np.unique(input_seg)
    ref_uni = np.unique(ref_seg)
    common = np.intersect1d(in_uni, ref_uni)  # * common segmentation labels

    input_oh = one_hot(input_seg, num_classes)  # (num_labels, w1, h1)
    input_oh[~np.isin(np.arange(num_classes), common), :, :] = 0

    return input_oh

In [12]:
class Adobe5kDataset(Dataset):
    def __init__(self, data_dir, l_bin, ab_bin, num_classes):
        super(Dataset, self).__init__()

        self.data_dir = Path(data_dir)
        self.l_bin = l_bin
        self.ab_bin = ab_bin
        self.num_classes = num_classes

        self.in_img_paths = list((self.data_dir / "input" / "imgs").glob("**/*.png"))
        self.in_img_segs = list((self.data_dir / "input" / "segs").glob("**/*.npy"))
        self.ref_img_paths = list(
            (self.data_dir / "reference" / "imgs").glob("**/*.png")
        )
        self.ref_img_segs = list(
            (self.data_dir / "reference" / "segs").glob("**/*.npy")
        )

        target_size = (512, 512)
        self.img_transform = T.Compose(
            [
                T.ToTensor(),
                T.Resize(target_size),  # ! TBD
                # T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        self.seg_transform = T.Compose(
            [
                T.ToTensor(),
                T.Resize(target_size),  # ! TBD
            ]
        )

    def __len__(self):
        return len(self.in_img_paths)

    def __getitem__(self, index):
        Indexx = str(self.in_img_paths[index])
        indexx = Indexx[49:len(Indexx)-4]
        in_img = cv2.imread(str(self.in_img_paths[index]))
        in_img = cv2.cvtColor(in_img, cv2.COLOR_RGB2LAB)
        in_img = self.img_transform(in_img).numpy()
        in_seg = np.load(self.in_img_segs[index])[0]
        in_seg = self.seg_transform(in_seg[..., None]).numpy()[0]

        ref_img = cv2.imread(str(self.ref_img_paths[index]))
        ref_img = cv2.cvtColor(ref_img, cv2.COLOR_RGB2LAB)
        ref_img = self.img_transform(ref_img).numpy()
        ref_seg = np.load(self.ref_img_segs[index])[0]
        ref_seg = self.seg_transform(ref_seg[..., None]).numpy()[0]

        # in_img = self._rescale_img(in_img)
        # ref_img = self._rescale_img(ref_img)

        # * In case of mis-alignment

        # ! tmp fix
        # in_hist = np.random.rand(self.l_bin + 1, self.ab_bin, self.ab_bin)
        # ref_hist = np.random.rand(self.l_bin + 1, self.ab_bin, self.ab_bin)
        # ref_seg_hist = np.random.rand(
        #     self.num_classes, self.l_bin + 1, self.ab_bin, self.ab_bin
        # )
        in_hist = get_histogram(in_img, self.l_bin, self.ab_bin)
        ref_hist = get_histogram(ref_img, self.l_bin, self.ab_bin)
        ref_seg_hist = get_segwise_hist(
            ref_img, self.l_bin, self.ab_bin, ref_seg, self.num_classes
        )

        in_common_seg = gen_common_seg_map(in_seg, ref_seg, self.num_classes)

        in_img = torch.from_numpy(in_img).float()
        in_hist = torch.from_numpy(in_hist).float()
        in_common_seg = torch.from_numpy(in_common_seg).long()
        ref_img = torch.from_numpy(ref_img).float()
        ref_hist = torch.from_numpy(ref_hist).float()
        ref_seg_hist = torch.from_numpy(ref_seg_hist).float()

        # print(
        #     in_img.size(),
        #     in_hist.size(),
        #     in_common_seg.size(),
        #     ref_img.size(),
        #     ref_hist.size(),
        #     ref_seg_hist.size(),
        # )
        return in_img, indexx, in_hist, in_common_seg, ref_img, ref_hist, ref_seg_hist

    def _rescale_img(self, img, max_length=700):
        if (img.shape[1] > max_length) or (img.shape[2] > max_length):
            aspect_ratio = img.shape[1] / img.shape[2]
            if img.shape[1] > img.shape[2]:
                img = (
                    F.upsample(
                        torch.Tensor(img).unsqueeze(0),
                        size=(max_length, int(max_length / aspect_ratio)),
                        mode="bilinear",
                    )
                    .cpu()
                    .numpy()
                    .astype(int)[0]
                )
            else:
                img = (
                    F.upsample(
                        torch.Tensor(img).unsqueeze(0),
                        size=(int(max_length * aspect_ratio), max_length),
                        mode="bilinear",
                    )
                    .cpu()
                    .numpy()
                    .astype(int)[0]
                )
        return img

In [13]:
in_colab = False
try:
    from google.colab import drive
    drive.mount('/content/drive')
    in_colab = True
except:
    pass

Mounted at /content/drive


In [14]:
%cd /content/drive/My Drive/DCT

/content/drive/My Drive/DCT


In [15]:
data_dir = '/content/drive/MyDrive/DCT/data/train'

In [16]:
data = Adobe5kDataset(data_dir, 64, 64, 50)

In [17]:
in_histt = {}
in_common_segg = {}
ref_histt = {}
ref_seg_histt = {}
for i in range(50):
    in_img, indexx, in_hist, in_common_seg, ref_img, ref_hist, ref_seg_hist = data.__getitem__(i)
    torchvision.utils.save_image(in_img, '/content/drive/MyDrive/DCT/Histogram/in_img/'+ indexx +'.jpg')
    torchvision.utils.save_image(ref_img, '/content/drive/MyDrive/DCT/Histogram/ref_img/'+ indexx +'.jpg')
    in_histt.update({indexx : in_hist.clone()})
    in_common_segg.update({ indexx : in_common_seg.clone()})
    ref_histt.update({ indexx : ref_hist.clone()})
    ref_seg_histt.update({ indexx : ref_seg_hist.clone()})

In [18]:
torch.save(in_histt,'/content/drive/MyDrive/DCT/Histogram/in_hist.pt')
torch.save(in_common_segg,'/content/drive/MyDrive/DCT/Histogram/in_common_seg.pt')
torch.save(ref_histt,'/content/drive/MyDrive/DCT/Histogram/ref_hist.pt')
torch.save(ref_seg_histt,'/content/drive/MyDrive/DCT/Histogram/ref_seg_hist.pt')

In [19]:
input_histogram = torch.load('/content/drive/MyDrive/DCT/Histogram/in_hist.pt')

In [20]:
input_histogram

{'a3994-DSC_0033': tensor([[[1., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 