In [1]:
import torch
from torch import nn

from torchvision.transforms import Resize, Compose, ToTensor
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import random


In [2]:
# Define the data directory
data_dir = "./data/Dataset"


# Batch size
BATCH_SIZE = 32

In [3]:
# Define the train_transform using Compose
train_transform = Compose([Resize((224, 224)), ToTensor()])

# Define the test_transform using Compose
test_transform = Compose([Resize((224, 224)), ToTensor()])


In [4]:
# Create the training set
training_dataset = ImageFolder(root=data_dir + "/Train", transform=train_transform)

# Create the testing set
testing_dataset = ImageFolder(root=data_dir + "/Test", transform=test_transform)

# Create the validation set
validation_dataset = ImageFolder(
    root=data_dir + "/Validation", transform=test_transform
)


In [5]:
# Create the training dataloader using DataLoader
training_dataloader = DataLoader(
    dataset=training_dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=12
)

# Create the testing dataloader using DataLoader
testing_dataloader = DataLoader(
    dataset=testing_dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=12
)

# Create the validation dataloader using DataLoader
validation_dataloader = DataLoader(
    dataset=validation_dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=12
)

In [6]:
def showRandomDataFromTraining(num_rows: int = 5):
    num_cols = num_rows
    # Create a figure with subplots
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10))

    # Iterate over the subplots and display random images from the training dataset
    for i in range(num_rows):
        for j in range(num_cols):
            # Choose a random index from the training dataset
            image_index = random.randrange(len(training_dataset))

            # Display the image in the subplot
            axs[i, j].imshow(training_dataset[image_index][0].permute((1, 2, 0)))

            # Set the title of the subplot as the corresponding class name
            axs[i, j].set_title(
                training_dataset.classes[training_dataset[image_index][1]],
                color="white",
            )

            # Disable the axis for better visualization
            axs[i, j].axis(False)

    # Set the super title of the figure
    fig.suptitle(
        f"Random {num_rows * num_cols} images from the training dataset",
        fontsize=16,
        color="white",
    )

    # Set the background color of the figure as black
    fig.set_facecolor(color="black")

    # Display the plot
    plt.show()

In [7]:
PATCH_SIZE = 16
IMAGE_WIDTH = 224
IMAGE_HEIGHT = IMAGE_WIDTH
IMAGE_CHANNELS = 3
EMBEDDING_DIMS = IMAGE_CHANNELS * PATCH_SIZE**2
NUM_OF_PATCHES = int((IMAGE_WIDTH * IMAGE_HEIGHT) / PATCH_SIZE**2)


## Vit Submodules

In [8]:
class PatchEmbeddingLayer(nn.Module):
    def __init__(
        self,
        in_channels,
        patch_size,
        embedding_dim,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.embedding_dim = embedding_dim
        self.in_channels = in_channels
        self.conv_layer = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embedding_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )
        self.flatten_layer = nn.Flatten(start_dim=1, end_dim=2)
        self.class_token_embeddings = nn.Parameter(
            torch.rand((BATCH_SIZE, 1, EMBEDDING_DIMS), requires_grad=True)
        )
        self.position_embeddings = nn.Parameter(
            torch.rand((1, NUM_OF_PATCHES + 1, EMBEDDING_DIMS), requires_grad=True)
        )

    def forward(self, x):
        output = (
            torch.cat(
                (
                    self.class_token_embeddings,
                    self.flatten_layer(self.conv_layer(x).permute((0, 2, 3, 1))),
                ),
                dim=1,
            )
            + self.position_embeddings
        )
        return output



In [9]:
class MultiHeadSelfAttentionBlock(nn.Module):
    def __init__(self, embedding_dims=768, num_heads=12, attn_dropout=0.0) -> None:
        super().__init__()
        self.embedding_dims = embedding_dims
        self.num_head = num_heads
        self.attn_dropout = attn_dropout

        self.layernorm = nn.LayerNorm(normalized_shape=embedding_dims)
        self.multiheadattention = nn.MultiheadAttention(
            num_heads=num_heads,
            embed_dim=embedding_dims,
            dropout=attn_dropout,
            batch_first=True,
        )

    def forward(self, x):
        x = self.layernorm(x)
        output, _ = self.multiheadattention(query=x, key=x, value=x, need_weights=False)
        return output


In [10]:
class MultiLayerPerceptronBlock(nn.Module):
    def __init__(self, embedding_dims, mlp_size, mlp_dropout) -> None:
        super().__init__()
        self.embedding_dims = embedding_dims
        self.mlp_size = mlp_size
        self.dropout = mlp_dropout

        self.layernorm = nn.LayerNorm(normalized_shape=embedding_dims)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dims, out_features=mlp_size),
            nn.GELU(),
            nn.Dropout(p=mlp_dropout),
            nn.Linear(in_features=mlp_size, out_features=embedding_dims),
            nn.Dropout(p=mlp_dropout),
        )

    def forward(self, x):
        return self.mlp(self.layernorm(x))

In [11]:
class TransformerBlock(nn.Module):
    def __init__(
        self,
        embedding_dims=768,
        mlp_dropout=0.1,
        attn_dropout=0.0,
        mlp_size=3072,
        num_heads=12,
    ) -> None:
        super().__init__()
        self.msa_block = MultiHeadSelfAttentionBlock(
            embedding_dims=embedding_dims,
            num_heads=num_heads,
            attn_dropout=attn_dropout,
        )
        self.mlp_block = MultiLayerPerceptronBlock(
            embedding_dims=embedding_dims, mlp_size=mlp_size, mlp_dropout=mlp_dropout
        )

    def forward(self, x):
        x = self.msa_block(x) + x
        x = self.mlp_block(x) + x
        return x

## ViT 

In [12]:
class ViT(nn.Module):
    def __init__(
        self,
        img_size=224,
        in_channels=3,
        patch_size=16,
        embedding_dims=768,
        num_transformer_layers=6,
        mlp_dropout=0.1,
        attn_dropout=0.0,
        mlp_size=3072,
        num_heads=12,
        num_classes=2,
    ) -> None:
        super().__init__()

        self.patch_embedding_layer = PatchEmbeddingLayer(
            in_channels=in_channels, patch_size=patch_size, embedding_dim=embedding_dims
        )
        self.transformer_encoder = nn.Sequential(
            *[
                TransformerBlock(
                    embedding_dims=embedding_dims,
                    mlp_dropout=mlp_dropout,
                    attn_dropout=attn_dropout,
                    mlp_size=mlp_size,
                    num_heads=num_heads,
                )
                for _ in range(num_transformer_layers)
            ]
        )
        self.classifer = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dims),
            nn.Linear(in_features=embedding_dims, out_features=num_classes),
        )

    def forward(self, x):
        return self.classifer(
            self.transformer_encoder(self.patch_embedding_layer(x))[:, 0]
        )