In [None]:
import cv2
import numpy as np
import torch
from skimage import color

# conversion among RGB, LAB, Gray
def channel_convert(in_c, tar_type, img_list):
    if in_c == 3 and tar_type == 'LAB':  # RGB to Lab
        return [color.rgb2lab(img) for img in img_list]
    elif in_c == 3 and tar_type == 'RGB':  # Lab to BGR
        return [color.lab2rgb(img) for img in img_list]
    else:
        return img_list


In [None]:
# read image by skimage or from png
# return: Numpy int8, HWC, RGB, [0, 255]
def read_img(path, size=None):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)

    if img.ndim == 2:
        img = np.expand_dims(img, axis = 2)
    elif img.ndim == 3:
        img = img[:, :, [2, 1, 0]]
        img = img.astype(np.float32)

    # some images have 4 channels
    if img.shape[2] > 3:
        img = img[:, :, :3]
    img = cv2.resize(img, (384, 512), interpolation=cv2.INTER_CUBIC)

    return img



In [None]:
# So, this takes folder name converts them to lab and then return PyTorch tensors containing images stacked on L channel and normal LAB. The tensor is batch, channels, height, width
def read_img_lab_seq(folder_name):
    img_list = [read_img(v) for v in img_lst[:4]]
    img_lab = channel_convert(img_list[0].shape[2], 'lab', img_list)

    # extracting L channel. Selects all elements along the first two dimensions (rows and columns) of the image. :1: Selects elements from the beginning (index 0) up to but not including index 1 along the third dimension (channels).
    img_l = [v[:, :, :1] for v in img_lab]
    # aixs = 0 dimension is often referred to as the batch dimension, as it represents a batch of samples or items. 1st frame, 2nd frame. like that
    imgs_l = np.stack(img_l, axis=0)
    imgs_lab = np.stack(img_lab, axis=0)

    # These lines convert the NumPy arrays imgs_l and imgs_lab into PyTorch tensors. The np.ascontiguousarray function ensures that the data is in a contiguous memory layout, which can improve performance. The np.transpose function rearranges the dimensions of the arrays to (channels, height, width), which is a common format for image data in PyTorch. Finally, the float() method converts the data type to torch.float32.
    imgs_l = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs_l, (0, 3, 1, 2)))).float()
    imgs_lab = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs_lab, (0, 3, 1, 2)))).float()

    return imgs_l, imgs_lab
