<a href="https://colab.research.google.com/github/aguzel/computational_imaging/blob/main/image_color_space.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from PIL import Image
import numpy as np_cpu
import matplotlib.pyplot as plt
import matplotlib.colors

import torch
import torch.optim as optim

from torchvision import transforms


%matplotlib inline

In [84]:
class image_color_space():
    """
    A class for manipulating color space of an image.
    """

    def __init__(self,image_location):
      self.image_location = image_location
      
    

    def load_image(self):
      """
      Definition to load an image from a given location as a Numpy array.
      Parameters
      ----------
      fn           : str
                     Filename.
      Returns
      ----------
      image        :  NCHW
                      Image loaded as a Numpy array.
      """
      image = Image.open(self.image_location)
      image = np_cpu.array(image)
      image = torch.from_numpy(image)
      image = torch.swapaxes(image,0,2)
      image = image.unsqueeze(0)
      return image
    
    def rgb_2_ycrcb(self,image):
        """
        Converts an image from RGB colourspace to YCrCb colourspace.
        Parameters
        ----------
        image   : torch.tensor
                    Input image. Should be an RGB floating-point image with values in the range [0, 1]
                    Should be in NCHW format.
        Returns
        -------
        ycrcb   : torch.tensor
                    Image converted to YCrCb colourspace.
        """
        ycrcb = torch.zeros(image.size()).to(image.device)
        ycrcb[:, 0, :, :] = 0.299 * image[:, 0, :, :] + 0.587 * \
            image[:, 1, :, :] + 0.114 * image[:, 2, :, :]
        ycrcb[:, 1, :, :] = 0.5 + 0.713 * (image[:, 0, :, :] - ycrcb[:, 0, :, :])
        ycrcb[:, 2, :, :] = 0.5 + 0.564 * (image[:, 2, :, :] - ycrcb[:, 0, :, :])
        return ycrcb

    def ycrcb_2_rgb(self,image):
        """
        Converts an image from YCrCb colourspace to RGB colourspace.
        Parameters
        ----------
        image   : torch.tensor
                    Input image. Should be a YCrCb floating-point image with values in the range [0, 1]
                    Should be in NCHW format.
        Returns
        -------
        rgb     : torch.tensor
                    Image converted to RGB colourspace.
        """
        rgb = torch.zeros(image.size(), device=image.device)
        rgb[:, 0, :, :] = image[:, 0, :, :] + 1.403 * (image[:, 1, :, :] - 0.5)
        rgb[:, 1, :, :] = image[:, 0, :, :] - 0.714 * \
            (image[:, 1, :, :] - 0.5) - 0.344 * (image[:, 2, :, :] - 0.5)
        rgb[:, 2, :, :] = image[:, 0, :, :] + 1.773 * (image[:, 2, :, :] - 0.5)
        return rgb

    def rgb_2_hsv(self,image,eps):
        image = image / 255.0
        hue = torch.Tensor(image.shape[0], image.shape[2], image.shape[3]).to(image.device)

        hue[ image[:,2]==image.max(1)[0] ] = 4.0 + ( (image[:,0]-image[:,1]) / ( image.max(1)[0] - image.min(1)[0] + eps) ) [ image[:,2]==image.max(1)[0] ]
        hue[ image[:,1]==image.max(1)[0] ] = 2.0 + ( (image[:,2]-image[:,0]) / ( image.max(1)[0] - image.min(1)[0] + eps) ) [ image[:,1]==image.max(1)[0] ]
        hue[ image[:,0]==image.max(1)[0] ] = (0.0 + ( (image[:,1]-image[:,2]) / ( image.max(1)[0] - image.min(1)[0] + eps) ) [ image[:,0]==image.max(1)[0] ]) % 6

        hue[image.min(1)[0]==image.max(1)[0]] = 0.0
        hue = hue/6

        saturation = ( image.max(1)[0] - image.min(1)[0] ) / ( image.max(1)[0] + eps )
        saturation[ image.max(1)[0]==0 ] = 0

        value = image.max(1)[0]
        hue = hue.unsqueeze(1)
        saturation = saturation.unsqueeze(1)
        value = value.unsqueeze(1)
        hsv = torch.cat([hue,saturation,value], dim =1)
        return hsv  


    def hsv_2_rgb(self,hsv_image):

        h = hsv_image[:,0,:,:]
        s = hsv_image[:,1,:,:]
        v = hsv_image[:,2,:,:]
        
        h = h%1
        s = torch.clamp(s,0,1)
        v = torch.clamp(v,0,1)

        r = torch.zeros_like(h)
        g = torch.zeros_like(h)
        b = torch.zeros_like(h)

        hi = torch.floor(h * 6)
        f = h * 6 - hi
        p = v * (1 - s)
        q = v * (1 - (f * s))
        t = v * (1 - ((1 - f) * s))

        hi0 = hi == 0
        hi1 = hi == 1
        hi2 = hi == 2
        hi3 = hi == 3
        hi4 = hi == 4 
        hi5 = hi == 5

        r[hi0] = v[hi0]
        g[hi0] = t[hi0]
        b[hi0] = p[hi0]

        r[hi1] = q[hi1]
        g[hi1] = v[hi1]
        b[hi1] = p[hi1]
        
        r[hi2] = p[hi2]
        g[hi2] = v[hi2]
        b[hi2] = t[hi2]

        r[hi3] = p[hi3]
        g[hi3] = q[hi3]
        b[hi3] = v[hi3]

        r[hi4] = t[hi4]
        g[hi4] = p[hi4]
        b[hi4] = v[hi4]

        r[hi5] = v[hi5]
        g[hi5] = p[hi5]
        b[hi5] = q[hi5]

        r = r.unsqueeze(1)
        g = g.unsqueeze(1)
        b = b.unsqueeze(1)

        rgb_image = torch.cat([r,g,b], dim = 1)
        return rgb_image




In [85]:
color_space = image_color_space('/content/parrot.png')

In [86]:
image = color_space.load_image()

In [None]:
image

In [88]:
hsv_tensor_image = color_space.rgb_2_hsv(image,1e-8)

In [90]:
rgb_image = color_space.hsv_2_rgb(hsv_tensor_image)

In [None]:
rgb_image

In [50]:
rgb_image.size()

torch.Size([1, 3, 1001, 1001])

In [None]:
numpy_image = rgb_image.cpu().detach().numpy()
numpy_image = numpy_image.squeeze(0)
numpy_image = numpy_image.swapaxes(0,2)
plt.imshow(numpy_image)

In [59]:
import matplotlib.colors as mcolors