In [74]:
import io
import math
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as opt
import matplotlib.pyplot as plt
from datasets import load_dataset
from PIL import Image as PILImage
import torchvision.transforms as T
from datasets import Dataset as DatasetData
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset, TensorDataset

In [75]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [76]:
textTransform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0.5, 0.5)
])

In [77]:
class CustomPokemonDataset(Dataset):
    def __init__(self, split='train', transform=None):
        self.dataset = load_dataset("lambdalabs/pokemon-blip-captions", split=split)
        self.transform = transform
        self.label_encoder = LabelEncoder()
        unique_captions = set(item["text"] for item in self.dataset)
        self.label_encoder.fit(list(unique_captions))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]
        caption = item["text"]
        print(caption, len(item["text"]))

        if self.transform:
            image = self.transform(image)
        # caption = textTransform(caption)
        label = self.label_encoder.transform([caption])[0]

        return image, label

In [78]:
image_size = (256, 256)
batch_size = 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

In [79]:
PATCH_SIZE = 128
IMAGE_WIDTH = 256
IMAGE_HEIGHT = IMAGE_WIDTH
IMAGE_CHANNELS = 3
EMBEDDING_DIMS = IMAGE_CHANNELS * PATCH_SIZE ** 2
NUM_OF_PATCHES = int((IMAGE_WIDTH * IMAGE_HEIGHT) / PATCH_SIZE ** 2)

assert IMAGE_WIDTH % PATCH_SIZE == 0 and IMAGE_HEIGHT % PATCH_SIZE == 0, print(
    "Image Width is not divisible by patch size")


In [80]:
transform = T.Compose([
    T.Resize(image_size),
    T.CenterCrop(image_size),
    T.ToTensor(),
    T.Normalize(0.5, 0.5)
])

In [81]:
train_dataset = CustomPokemonDataset(split='train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [82]:
examples = enumerate(train_loader)
batch_idx, data = next(examples)
print(len(data), data[1])  # 2

a cartoon picture of a small furry animal 41
a cartoon character is holding a large object 45
a drawing of an animal with wings and a tail 44
a picture of a cartoon character in a costume 45
an image of a cartoon character flying through the air 54
a cartoon of a deer with a tree on its back 43
a drawing of a cartoon character holding a flame 48
a drawing of a spider with big eyes 35
2 tensor([242, 135, 541, 590, 728, 214, 319, 503], dtype=torch.int32)


In [83]:
class MultiHeadAttention(torch.nn.Module):
    r"""Multi-headed Attention for input Query, Key, Value

    Multi-headed Attention is a module for attention mechanisms which runs through attention in several times in
    parallel, then the multiple outputs are concatenated and linearly transformed

    Args:
        embed_size  (int): Max embedding size
        num_heads   (int): Number of heads in multi-headed attention; Number of splits in the embedding size
        dropout     (float, optional): Percentage of Dropout to be applied in range 0 <= dropout <=1
        batch_dim   (int, optional): The dimension in which batch dimensions is

    """

    def __init__(
            self, embed_size: int, num_heads: int, dropout: float = 0.2, batch_dim: int = 0
    ):
        super(MultiHeadAttention, self).__init__()

        self.embed_size = embed_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_dim = batch_dim

        self.dropout_layer = torch.nn.Dropout(dropout)

        self.head_size = self.embed_size // self.num_heads

        assert (
                self.head_size * self.num_heads == self.embed_size
        ), "Heads cannot split Embedding size equally"

        self.Q = torch.nn.Linear(self.embed_size, self.embed_size)
        self.K = torch.nn.Linear(self.embed_size, self.embed_size)
        self.V = torch.nn.Linear(self.embed_size, self.embed_size)

        self.linear = torch.nn.Linear(self.embed_size, self.embed_size)

    def forward(self, q, k, v, mask=None):
        if self.batch_dim == 0:
            out = self.batch_0(q, k, v, mask)
        elif self.batch_dim == 1:
            out = self.batch_1(q, k, v, mask)

        return out

    def batch_0(self, q, k, v, mask=None):
        q_batch_size, q_seq_len, q_embed_size = q.size()
        k_batch_size, k_seq_len, k_embed_size = k.size()
        v_batch_size, v_seq_len, v_embed_size = v.size()

        q = self.Q(q).reshape(q_batch_size, q_seq_len, self.num_heads, self.head_size)
        k = self.K(k).reshape(k_batch_size, k_seq_len, self.num_heads, self.head_size)
        v = self.V(v).reshape(v_batch_size, v_seq_len, self.num_heads, self.head_size)

        attention = self.attention(q, k, v, mask=mask)
        concatenated = attention.reshape(v_batch_size, -1, self.embed_size)
        out = self.linear(concatenated)

        return out

    def batch_1(self, q, k, v, mask=None):
        q_seq_len, q_batch_size, q_embed_size = q.size()
        k_seq_len, k_batch_size, k_embed_size = k.size()
        v_seq_len, v_batch_size, v_embed_size = v.size()

        q = (
            self.Q(q)
            .reshape(q_seq_len, q_batch_size, self.num_heads, self.head_size)
            .transpose(0, 1)
        )
        k = (
            self.K(k)
            .reshape(k_seq_len, k_batch_size, self.num_heads, self.head_size)
            .transpose(0, 1)
        )
        v = (
            self.V(v)
            .reshape(v_seq_len, v_batch_size, self.num_heads, self.head_size)
            .transpose(0, 1)
        )

        attention = self.attention(q, k, v, mask=mask)
        concatenated = attention.reshape(-1, v_batch_size, self.embed_size)

        out = self.linear(concatenated)

        return out

    def attention(self, q, k, v, mask=None):
        scores = torch.einsum("bqhe,bkhe->bhqk", [q, k])

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        scores /= math.sqrt(self.embed_size)
        scores = torch.nn.functional.softmax(scores, dim=-1)
        scores = self.dropout_layer(scores)
        attention = torch.einsum("bhql,blhd->bqhd", [scores, v])
        return attention


In [84]:
class VisionEncoder(nn.Module):
    def __init__(self, embed_size, nb_heads, hidden_size, dropout):
        super(VisionEncoder, self).__init__()

        self.attention = torch.nn.MultiheadAttention(
            embed_size, nb_heads, dropout=dropout
        )
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(embed_size, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_size, embed_size),
        )

        self.norm1 = torch.nn.LayerNorm(embed_size)
        self.norm2 = torch.nn.LayerNorm(embed_size)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, x):
        # Attention
        attn_output, _ = self.attention(x, x, x)
        x = x + attn_output
        x = self.norm1(x)
        # MLP
        mlp_output = self.mlp(x)
        x = x + mlp_output
        x = self.norm2(x)
        x = self.dropout(x)

        return x

In [85]:
class ViT(torch.nn.Module):
    def __init__(
            self,
            image_size,
            channel_size,
            patch_size,
            embed_size,
            nb_heads,
            classes,
            nb_layers,
            hidden_size,
            dropout,
    ):
        super(ViT, self).__init__()

        self.patch_size = patch_size
        self.embed_size = embed_size
        self.nb_patches = int((image_size * image_size) / patch_size ** 2)
        self.pixels_per_patch = channel_size * (patch_size ** 2)
        self.nb_heads = nb_heads
        self.classes = classes
        self.nb_layers = nb_layers
        self.hidden_size = hidden_size
        self.dropout = dropout

        self.patch_embedings = nn.Linear(self.pixels_per_patch, embed_size)
        self.class_token = nn.Parameter(torch.randn(1, 1, embed_size))
        self.positional_encoding = nn.Parameter(
            torch.randn(1, self.nb_patches, embed_size)
        )

        self.encoders = torch.nn.ModuleList(
            [
                VisionEncoder(embed_size, nb_heads, hidden_size, dropout)
                for _ in range(nb_layers)
            ]
        )

        self.classifier = torch.nn.Linear(embed_size, classes)

    def forward(self, img_torch):
        b, c, h, w = img_torch.size()
        img_torch_reshape = img_torch.reshape(
            b,
            int((h / self.patch_size) * (w / self.patch_size)),
            c * self.patch_size * self.patch_size,
        ).float()

        patch_embedings = self.patch_embedings(img_torch_reshape)
        patch_embedings = patch_embedings + self.positional_encoding
        img_with_class_token = torch.cat(
            (self.class_token.expand(b, -1, -1), patch_embedings), dim=1
        )

        for encoder in self.encoders:
            img_with_positional_encoding = encoder(img_with_class_token)

        fwd_cls = img_with_positional_encoding[:, -1, :]

        fwd_norm = torch.nn.LayerNorm(self.embed_size)(fwd_cls)

        fwd_classifier = self.classifier(fwd_norm)

        fwd_softmax = torch.nn.functional.log_softmax(fwd_classifier, dim=-1)

        return fwd_softmax


model = ViT(
    image_size=256,
    channel_size=3,
    patch_size=128,
    embed_size=256,
    nb_heads=8,
    classes=812,
    nb_layers=3,
    hidden_size=256,
    dropout=0.2,
).to(device)

In [86]:
loss_function = nn.NLLLoss()
optimizer = opt.Adam(model.parameters(), 0.00005)

In [None]:
losses = []
accuracies = []
nb_epochs = 10
for epoch in range(nb_epochs):
    model.train()

    epoch_loss = 0

    y_pred = []
    y_true = []
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device).long()  # Convert labels to Long
    
        print("Shape before linear layer:", imgs.shape)
    
        predictions = model(imgs)
    
        loss = loss_function(predictions, labels)
    
        optimizer.zero_grad()
    
        loss.backward()
    
        optimizer.step()
    
        y_pred.extend(predictions.detach().argmax(dim=1).tolist())
        y_true.extend(labels.detach().tolist())
    
        epoch_loss += loss.item()

    losses.append(epoch_loss)

    nb_imgs = len(y_pred)
    total_correct = 0
    for i in range(nb_imgs):
        if y_pred[i] == y_true[i]:
            total_correct += 1
    accuracy = total_correct * 100 / nb_imgs

    accuracies.append(accuracy)

    print("----------")
    print("Epoch:", epoch)
    print("Loss:", epoch_loss)
    print(f"Accuracy: {accuracy} % ({total_correct} / {nb_imgs})")

a green cartoon character with a big smile 42
a cartoon character with a hat on his head 42
a picture of a piece of clothing with a scarf on top of it 58
a cartoon of a white and gray animal 36
a cartoon bird flying through the air 37
a cartoon car with a yellow top on it 37
a drawing of a cartoon character with big eyes 46
a very cute looking cartoon character with big eyes 51
Shape before linear layer: torch.Size([8, 3, 256, 256])
a drawing of an animal's head with spikes on it 47
a stylized image of a robot with a helmet on 44
a drawing of a yellow and black cat 35
a drawing of a key with a face on it 36
a drawing of a plant with leaves on it 38
a cartoon slotty sleeping on its back 37
a cartoon bird with a pink hat on its head 42
a bird with a long beak flying through the air 46
Shape before linear layer: torch.Size([8, 3, 256, 256])
a cartoon green lizard with a pink belly 40
a green and yellow cartoon character holding a flower 53
a cartoon deer with a yellow bow on its head 44
a