### Import Lib

In [1]:
import os
import sys
import yaml
import torch
import joblib
import torch.nn as nn

### Utility files

In [7]:
def config_files():
    with open("../config.yml", "r") as config_file:
        return yaml.safe_load(config_file)


def dump_file(value: None, filename: None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)
    else:
        print("Error: 'value' and 'filename' must be provided.".capitalize())


def load_file(filename: None):
    if filename is not None:
        return joblib.load(filename=filename)
    else:
        print("Error: 'filename' must be provided.".capitalize())


def weight_init(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        nn.init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("Conv") != -1:
        nn.init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)


def device_init(device: str = "cuda"):
    if device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    else:
        return torch.device("cpu")

### Patch Embedding

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(
        self,
        channels: int = 3,
        patch_size: int = 16,
        image_size: int = 128,
        dimension: int = None,
    ):
        super(PatchEmbedding, self).__init__()

        self.in_channels = channels
        self.patch_size = patch_size
        self.dimension = dimension
        self.image_size = image_size

        if self.dimension is None:
            self.dimension = (self.patch_size**2) * self.in_channels

        self.number_of_pathches = (self.image_size // self.patch_size) ** 2

        self.kernel_size, self.stride_size = self.patch_size, self.patch_size
        self.padding_size = self.patch_size // self.patch_size

        self.projection = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.dimension,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size,
            bias=False,
        )

        self.positional_embeddings = nn.Parameter(
            torch.randn(self.padding_size, self.number_of_pathches, self.dimension),
            requires_grad=True,
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            x = self.projection(x)
            x = x.view(x.size(0), x.size(1), -1)
            x = x.permute(0, 2, 1)
            x = self.positional_embeddings + x
            return x
        else:
            raise ValueError("Input must be a torch tensor.".capitalize())

    @staticmethod
    def total_params(model):
        if isinstance(model, PatchEmbedding):
            return sum(params.numel() for params in model.parameters())


if __name__ == "__main__":
    image_channels = config_files()["patchEmbeddings"]["channels"]
    image_size = config_files()["patchEmbeddings"]["image_size"]
    patch_size = config_files()["patchEmbeddings"]["patch_size"]
    dimension = config_files()["patchEmbeddings"]["dimension"]

    pathEmbedding = PatchEmbedding(
        channels=image_channels,
        patch_size=patch_size,
        image_size=image_size,
        dimension=dimension,
    )

    image = torch.randn(
        (image_channels // image_channels, image_channels, image_size, image_size)
    )

    assert (pathEmbedding(image).size()) == (
        image_channels // image_channels,
        (image_size // patch_size) ** 2,
        dimension,
    )