In [11]:
import matplotlib.pyplot as plt
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch
import torch.nn as nn
import numpy as np
from datasets import load_dataset
import PIL
from torchvision import transforms
from PIL import Image
from transformers import AutoProcessor, CLIPVisionModel, CLIPTextModel, AutoTokenizer
print(PIL.__version__)


11.2.1


In [3]:
torch.manual_seed(42)

transform = transforms.Compose([
    transforms.ToTensor(),
])

batch_size = 4

ds = load_dataset("nlphuji/flickr30k")
test_dataset = ds['test']

def apply_transform(example):
    example['image'] = transform(example['image'])
    return example


# splitting the dataset as it only has one split
split_dataset = test_dataset.train_test_split(test_size=0.2, seed=42)

train_dataset = split_dataset['train']
test_dataset = split_dataset['test']

print(f"Train size: {len(train_dataset)}, Test size: {len(test_dataset)}")


Train size: 24811, Test size: 6203


In [7]:
first_4_images = train_dataset[:4]

print(first_4_images)

images = [transform(image) for image in first_4_images['image']]
captions = first_4_images['caption']


print(images[2].shape)
print(captions[0])
print(type(captions[0]))

print(images[0])


{'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=417x500 at 0x287197200>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=333x500 at 0x287197950>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x323 at 0x287195DF0>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=476x500 at 0x2871958E0>], 'caption': [['Three people are standing on the dock of a boat looking at something off-camera.', 'Three men stand atop a boat from Seattle Washington.', 'Three people stand on a boat that is docked.', 'A blue and white sailboat sits by a dock.', 'A few people park their boat at a dock.'], ['A man with a gray beard and a little boy are sitting on the floor looking over some papers in a room with a bunk bed.', 'A man in a blue shirt playing with a young boy in a red shirt on the ground in a bedroom.', "A dad and his son are playing with some Legos in the child's bedroom.", 'A man and bot playing on the floor with bunk beds in the background.', "Father and son ar

In [31]:
model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

image = images[0]

inputs = processor(images=image, return_tensors="pt")

outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
print(last_hidden_state.shape)


# making images the same embedding dim as caption
class Projection(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)
    
projection = Projection(768, 512)
image_embedding = projection(last_hidden_state)
print(image_embedding.shape)


torch.Size([1, 50, 768])
torch.Size([1, 50, 512])


In [51]:
# take only the embedding layer not the hidden
caption = captions[0][0]
print(caption)

model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
print(model)
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

caption_inputs = tokenizer(caption, return_tensors="pt", padding=True)

input_ids = caption_inputs['input_ids']
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long).unsqueeze(0)

token_embeddings = model.text_model.embeddings.token_embedding(input_ids)
position_embeddings = model.text_model.embeddings.position_embedding(position_ids)

input_embeddings = token_embeddings + position_embeddings

print(input_embeddings.shape)


Three people are standing on the dock of a boat looking at something off-camera.
CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_f

In [61]:
# Decoder only
# Masked multi-head attention
import math

class MaskedAttention(nn.Module):
    def __init__(self, embedding_dim, head_size, max_seq_len, num_heads=1, bias=False, dropout=0.2):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.head_size = head_size
        self.bias = bias
        self.dropout = dropout

        assert embedding_dim % num_heads == 0, "embedding_dim must be divisible by num_heads"

        """arguments: 
        embedding_dim = size of embedding dimension
        num_heads = number of attention heads
        max_seq_len = maximum sequence length
        bias = whether to use bias in the linear layer
        dropout = probability of dropout
        """

        self.c_attn = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)

        self.output_projection = nn.Linear(embedding_dim, embedding_dim, bias=bias)

        self.attention_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

        self.register_buffer("mask", torch.tril(torch.ones(max_seq_len, max_seq_len)).bool().unsqueeze(0).unsqueeze(0))

    def forward(self, x): 
        batch_size, max_seq_len, _ = x.size() 

        # compute query, key and value vectors for all heads in a batch
        # split the embedding dimension into query, key and value
        Q, K, V = self.c_attn(x).split(self.embedding_dim, dim=2) # [batch_size, max_seq_len, embedding_dim]
        
        # reshape the query, key and value vectors to have a separate head for each token
        Q = Q.view(batch_size, max_seq_len, self.num_heads, self.head_size).transpose(1, 2) # [batch_size, max_seq_len, num_heads, head_size]
        K = K.view(batch_size, max_seq_len, self.num_heads, self.head_size).transpose(1, 2)
        V = V.view(batch_size, max_seq_len, self.num_heads, self.head_size).transpose(1, 2)

        attention = (Q @ K.transpose(-2, -1)) * (1.0/math.sqrt(K.size(-1))) # transpose swaps the last two dimensions of K = (1,5,24) @ (1,24,5) = (1,5,5)
        mask = torch.tril(torch.ones(max_seq_len, max_seq_len)).bool().unsqueeze(0).unsqueeze(0)
        attention = attention.masked_fill(~mask[:, :, :max_seq_len, :max_seq_len], float("-inf"))  
        attention = torch.softmax(attention, dim=-1)
        attention = self.attention_dropout(attention)

        hidden_state = attention @ V # [batch_size, num_heads, max_seq_len, head_size]

        hidden_state = hidden_state.transpose(1, 2).contiguous().view(batch_size, max_seq_len, self.embedding_dim)
        hidden_state = self.resid_dropout(hidden_state)

        return hidden_state


In [53]:
# testing attention mask
x = input_embeddings

masked_attention = MaskedAttention(embedding_dim=512, head_size=64, max_seq_len=19, num_heads=8)

output = masked_attention(x)
print(output.shape)



Q.shape before reshape: torch.Size([1, 19, 512])
torch.Size([1, 19, 512])


In [48]:
class FNN(nn.Module):

    def __init__(self, embedding_dim, bias=False, dropout=0.2):
        super().__init__()

        self.linear1 = nn.Linear(embedding_dim, 4 * embedding_dim, bias=bias)
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(4 * embedding_dim, embedding_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        x = self.dropout(x)

        return x

In [49]:
# putting it all together
class DecoderBlock(nn.Module):

    def __init__(self, embedding_dim, head_size, max_seq_len, num_heads=1, bias=False, dropout=0.2):
        super().__init__()

        self.masked_attention = MaskedAttention(embedding_dim, head_size, max_seq_len, num_heads, bias, dropout)
        self.fnn = FNN(embedding_dim, bias, dropout)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        x = x + self.masked_attention(self.norm1(x))
        x = x + self.fnn(self.norm2(x))

        return x
    

In [63]:
# Decoder class 
import torch.nn.functional as F

class TransformerDecoder(nn.Module):
    def __init__(self, model_name="openai/clip-vit-base-patch32", embedding_dim=512, num_heads=8, max_seq_len=19, size_of_vocab=20, num_layers=6, bias=False, dropout=0.2, head_size=64):
        super().__init__()

        self.clip_model = CLIPTextModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.embedding_dim = self.clip_model.config.hidden_size

        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.max_seq_len = max_seq_len
        self.size_of_vocab = size_of_vocab
        self.bias = bias
        self.dropout = dropout
        self.head_size = head_size



        self.transformer = nn.ModuleDict(dict(
            dropout = nn.Dropout(dropout),
            blocks = nn.ModuleList([DecoderBlock(embedding_dim, head_size, max_seq_len, num_heads, bias, dropout) for _ in range(num_layers)]),
            layer_norm = nn.LayerNorm(embedding_dim),
            head = nn.Linear(embedding_dim, size_of_vocab, bias=bias)
        ))

    def forward(self, captions, targets=None):
        caption_inputs = self.tokenizer(captions, return_tensors="pt", padding=True)

        input_ids = caption_inputs['input_ids']
        position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long).unsqueeze(0)

        token_embeddings = self.clip_model.text_model.embeddings.token_embedding(input_ids)
        position_embeddings = self.clip_model.text_model.embeddings.position_embedding(position_ids)

        x = token_embeddings + position_embeddings
        x = self.transformer['dropout'](x)

        for block in self.transformer.blocks:
            x = block(x)
        x = self.transformer.layer_norm(x)

        if targets is not None:
            # compute the loss if we are given targets
            logits = self.transformer['head'](x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=-1,
            )

        else:
            # only look at last token if performing inference
            logits = self.transformer.head(x[:, [-1], :])
            loss = None

        return logits, loss
    
decoder = TransformerDecoder()

caption = captions[0][0]
logits, loss = decoder(caption)

print("Logits shape:", logits.shape)
print("Loss:", loss)


Logits shape: torch.Size([1, 1, 20])
Loss: None
