In [None]:
# imports
import torch
import torch.nn as nn
import torch.functional as F
from torchvision import transforms
import math
import cv2
from PIL import Image
import numpy as np

from typing import Optional

## Colour Correction and Normalization

In [None]:
def ColourCorrection(image_path: str = None) -> Optional[np.array]:
    """
    Performs colour correction and normalization on input image.

    args:
    - image_path (str): path to input image

    returns:
    - corrected and normalized image data (as numpy array)
    """
    if not image_path:
        return None
    
    img = cv2.imread(image_path)

    # brightness normalization by gamma correction
    gamma = 1.0     # to be adjusted by passing gamma as parameter to function?
    invGamma = 1.0 / gamma
    table = np.array([((i / 255.0) ** invGamma) * 255 
                      for i in np.arange(0, 256)]).astype("uint8")
    
    img_gamma = cv2.LUT(img, table)

    # image values normalized to between 0 and 1
    img_normalized = cv2.normalize(img_gamma, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    
    return img_normalized

    """
    Gamma correction code referenced from:
    https://pyimagesearch.com/2015/10/05/opencv-gamma-correction/
    """

## Patching Module
Turns palm image into sequence of fixed-size patch embeddings

In [None]:
class Patching(nn.Module):
    """
    Projects sequence of patches of palm / palm regions into low-dimensional space to produce a patch embedding sequence.
    """
    def __init__(self, image_size: int, patch_size: int, in_channels: int = 3, embed_dim: int = 768):
        """
        Initializes an instance of the Patching module.

        args:
        - image_size (int): size of image
        - patch_size (int): size of patches
        - in_channels (int): number of input channels (default: 3 for RGB)
        - embed_dim (int): dimension of output vector embeddings
        """
        super().__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        self.model = nn.Sequential(
            # input image: C x H x W
            nn.Conv2d(in_channels=in_channels, 
                      out_channels=embed_dim, 
                      kernel_size=patch_size, 
                      stride=patch_size),
            # after projection as convolution: embed_dim x num of patches (H) x num of patches (W)
            nn.Flatten(1,2),
            # embed_dim x num of patches in total
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for input image and converts image into patch embeddings.

        args:
        - x (torch.Tensor): input image (C, H, W)

        returns:
        - embedded patches of the image
        """
        x = self.model(x)           # embed_dim x num of patches in total
        x = x.transpose(0,1)        # num of patches in total x embed_dim
        return x

## Positional Embedding Module
Adds positional embedding to sequence of patch embeddings

Positional embedding formula:
$$PE_{(\text{pos}, 2i)} = \sin (\dfrac{\text{pos}}{10000^{2i / d_{\text{model}}}})$$
$$PE_{(\text{pos}, 2i+1)} = \cos (\dfrac{\text{pos}}{10000^{2i / d_{\text{model}}}})$$

In [None]:
class PositionalEmbedding(nn.Module):
    """
    Provides positional embedding for patches to identify location of patch in original image.
    """
    def __init__(self, d_model: int, max_len: int = 2048):
        """
        Initializes an instance of the PositionalEmbedding module.

        args:
        - d_model (int): dimensions of embedding
        - max_len (int): maximum length of sequence
        """
        super.__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        denominator_terms = torch.pow(10000, torch.arange(0, d_model, 2).float() / d_model)

        pe[:, 0::2] = torch.sin(position / denominator_terms)
        pe[:, 1::2] = torch.cos(position / denominator_terms)

        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Adds positional embedding to an input sequence.

        args:
        - x (torch.Tensor): input sequence

        returns:
        - input sequence with positional embedding added
        """
        return x + self.pe[:x.size[1], :]