In [1]:
!pip install datasets torchvision sentencepiece

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting sentencepiece
  Downloading sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting filelock (from datasets)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting numpy>=1.17 (from datasets)
  Downloading numpy-2.2.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-19.0.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Downloading pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
Collecting tqdm>=4.66.3 (from datasets)
  Downloading tqdm-4.

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import Tuple

from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from PIL import Image
import sentencepiece as spm
from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## Converting Image to a Sequence of Patches

In [3]:
class PatchEmbeddings(nn.Module):
    def __init__(
        self, img_size: int = 96, patch_size: int = 16, hidden_dim: int = 512
    ) -> None:
        super().__init__()
        # Store the input image size, the patch size and hidden dimension
        self.img_size = img_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim

        # Calculate the total number of patches
        self.num_patches = (self.img_size // self.patch_size) ** 2

        # Create a convolution to extract patch embeddings
        # in_channels=3 asummes a 3-channel image (RGB)
        # outp_channels=hidden_dim sets the number of output channels to match the hidden dimension
        # kernel_size=patch_size and stride=patch_size ensuring each patch is embedded separately
        self.conv = nn.Conv2d(
            in_channels=3,
            out_channels=self.hidden_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # Extract patch embeddings from the input image
        # Output shape: (batch_size, hidden_dim, (self.img_size // self.patch_size), (self.img_size // self.patch_size))
        X = self.conv(X)

        # Flatten the spatial dimensions (height and width) of the patch embeddings
        # This step flattens the patch dimensions to a single dimension
        # Output shape: (batch_size, hidden_dim, self.num_patches)
        X = X.flatten(2)

        # Transpose the dimensions to obtain the shape (batch_size, num_patches, hidden_dim)
        # This step brings the num_patches dimension to the second position
        # Output shape: (batch_size, self.num_patches, hidden_dim)
        X = X.transpose(1, 2)

        return X

In [4]:
B, C, H, W = 1, 3, 96, 96  # Batch size, Channels, Height, Width
X = torch.randn(B, C, H, W)

patch_size = 16
hidden_dim = 512

patch_embeddings = PatchEmbeddings(
    img_size=H, patch_size=patch_size, hidden_dim=hidden_dim
)
patches = patch_embeddings(X)
print(f"Shape of image patches: {patches.shape}")

Shape of image patches: torch.Size([1, 36, 512])


In [5]:
num_patches = (H // patch_size) ** 2
assert patches.shape == (B, num_patches, hidden_dim), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Attention Mechanism
Attention Mechanism across both the vision encoder and language decoder

### The implementation of the Attention Head

In [6]:
class Head(nn.Module):
    def __init__(
        self,
        n_embed: int,
        head_size: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Linear layer for Key projection
        self.key = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Linear layer for Query projection
        self.query = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Linear layer for Value projection
        self.value = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Dropout layer for regularization to prevent overfitting
        self.dropout = nn.Dropout(p=dropout)

        # Flag indicating wheter the head is used as a decoder
        self.is_decoder = is_decoder

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Get batch size (B), sequence length (T), and embedding dimension (C) from the input tensor
        B, T, C = x.shape

        # Compute Key, Query, and Value projections
        k = self.key(x)  # Shape: (B, T, head_size)
        q = self.query(x)  # Shape: (B, T, head_size)
        v = self.value(x)  # SHape: (B, T, head_size)

        # Compute attention scores by taking the dot product of Query and Key
        # and scaling by the square root of the embedding dimension
        wei = q @ k.transpose(-2, -1) * (C**-0.5)  # Shape: (B, T, T)

        if self.is_decoder:
            # If this head is used in the decoder, apply causal mask to the attention scores
            # to prevent attention to future positions
            tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
            wei = wei.masked_fill(mask=tril == 0, value=float("-inf"))

        # Apply softmax to the attention scores to obtain attention probabilities
        # Sum of probabilities for each row will be 1
        wei = F.softmax(input=wei, dim=-1)  # Shape: (B, T, T)

        # Apply Dropout to the attention probabilities for regularization
        wei = self.dropout(wei)

        # Perform a weighted aggregation of values using the attention probabilities
        out = wei @ v  # Shape: (B, T, head_size)

        return out

In [7]:
B, T, C = patches.shape  # Batch size, Sequence length, Embedding dimension
head_size = 16  # Size of the attention head

head = Head(n_embed=C, head_size=head_size)
output = head(patches)
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 16])


In [8]:
assert output.shape == (B, T, head_size), "Output shape is incorrect"
print("Test passed!")

Test passed!


### The implementation of Multihead Attention

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        n_embed: int,
        num_heads: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Ensure that the embedding dimension is divisible by the number of heads
        assert n_embed % num_heads == 0, "n_embed must be divisible by num_heads!"

        # Create a ModuleList of attention heads
        self.heads = nn.ModuleList(
            modules=[
                Head(
                    n_embed=n_embed,
                    head_size=n_embed // num_heads,
                    dropout=dropout,
                    is_decoder=is_decoder,
                )
                for _ in range(num_heads)
            ]
        )

        # Linear layer for projecting the concatenated head outputs
        self.proj = nn.Linear(in_features=n_embed, out_features=n_embed)

        # Dropout layer for regularization
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply each attention head to the input tensor
        head_outputs = [
            h(x) for h in self.heads
        ]  # Shape: num_heads * (B, T, head_size)

        # Concatenate the outputs from all heads along the last dimension
        out = torch.cat(tensors=head_outputs, dim=-1)  # Shape: (B, T, m_embed)

        # Apply the projection layer to the concatenated outputs
        out = self.proj(out)  # Shape: (B, T, m_embed)

        # Apply Dropout to the projected outputs for regularization
        out = self.dropout(out)

        return out

In [10]:
num_heads = 2
dropout = 0.1
mha = MultiHeadAttention(n_embed=C, num_heads=num_heads, dropout=dropout)

In [11]:
output = mha(patches)
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 512])


In [12]:
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Test passed!


### The Multilayer Perceptron

In [13]:
class MLP(nn.Module):
    def __init__(
        self, n_embed: int, dropout: float = 0.1, is_decoder: bool = False
    ) -> None:
        super().__init__()

        # Define the layers of the MLP
        layers = [
            # First linear layer that expands the input dimension from n_embed to 4 * n_embed
            nn.Linear(in_features=n_embed, out_features=4 * n_embed),
            # Activation function: ReLU if is_decoder is True, else GELU
            nn.ReLU() if is_decoder else nn.GELU(),
            # Second linear layer that projects the intermediate dimension back to n_embed
            nn.Linear(in_features=4 * n_embed, out_features=n_embed),
            # Dropout layer for regularization
            nn.Dropout(p=dropout),
        ]

        # Create the MLP as a sequence of layers
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pass the input through the MLP layers
        return self.net(x)

In [14]:
dropout = 0.1
mlp = MLP(n_embed=C, dropout=dropout)

In [15]:
output = mlp(output)  # Previous output of the Multihead Attention
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 512])


In [16]:
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Test passed!


### Transformer Blocks

In [17]:
class Block(nn.Module):
    def __init__(
        self,
        n_embed: int,
        num_heads: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Layer normalization for the input to the attention layer
        self.ln1 = nn.LayerNorm(normalized_shape=n_embed)

        # Multi-head attention module
        self.mhattn = MultiHeadAttention(
            n_embed=n_embed, num_heads=num_heads, dropout=dropout, is_decoder=is_decoder
        )

        # Layer normalization for the input to the FFN
        self.ln2 = nn.LayerNorm(normalized_shape=n_embed)

        # Feed-forward neural network (FFN)
        self.ffn = nn.Sequential(
            nn.Linear(in_features=n_embed, out_features=4 * n_embed),
            nn.GELU(),  # Activation function
            nn.Linear(
                in_features=4 * n_embed, out_features=n_embed
            ),  # Projection back to the original dimension
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Saving the input for residual connection
        original_x = x

        # Apply layer normalization to the input
        x = self.ln1(x)

        # Apply multi-head attention
        mhattn_output = self.mhattn(x)

        # Add the residual connection (original input) to the attention output
        x = original_x + mhattn_output

        # Apply later normalization to the input to the FFN
        x = self.ln2(x)

        # Apply the FFN
        ffn_output = self.ffn(x)

        # Apply the residual connection (input to the FFN) to the FFN output
        x = x + ffn_output

        return x

In [18]:
num_heads = 2
dropout = 0.1
block = Block(n_embed=C, num_heads=num_heads, dropout=dropout)

In [19]:
output = block(patches)
print(f"Shape of output tensor: {output.shape}")

Shape of output tensor: torch.Size([1, 36, 512])


In [20]:
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Vision Encoder - Vision Transformer (ViT)

Combining patchification logic and attention block in to ViT

In [21]:
class ViT(nn.Module):
    def __init__(
        self,
        img_size: int,
        patch_size: int,
        num_hiddens: int,
        num_heads: int,
        num_blocks: int,
        emb_dropout: float,
        block_dropout: float,
    ) -> None:
        super().__init__()

        # Patch embedding layer to convert the input image into patches
        self.patch_embedding = PatchEmbeddings(
            img_size=img_size, patch_size=patch_size, hidden_dim=num_hiddens
        )

        # Learnable classification token
        self.cls_token = nn.Parameter(data=torch.zeros(size=(1, 1, num_hiddens)))

        # Calculate the number of patches
        num_patches = (img_size // patch_size) ** 2

        # Learnable position embedding
        self.pos_embedding = nn.Parameter(
            data=torch.randn(size=(1, num_patches + 1, num_hiddens))
        )

        # Dropout layer for the embeddings
        self.dropout = nn.Dropout(p=emb_dropout)

        # Stack of transformer blocks
        self.blocks = nn.ModuleList(
            [
                Block(
                    n_embed=num_hiddens,
                    num_heads=num_heads,
                    dropout=block_dropout,
                    is_decoder=False,
                )
                for _ in range(num_blocks)
            ]
        )

        # Layer normalization for the final representation
        self.layer_norm = nn.LayerNorm(normalized_shape=num_hiddens)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # Convert the input image into patch embeddings
        x = self.patch_embedding(X)  # Shape: (B, num_patches, num_hiddens)

        # Expand the classification token to match the batch size
        cls_tokens = self.cls_token.expand(
            x.shape[0], -1, -1
        )  # Shape: (B, 1, num_hiddens)

        # Concatenate the classification token with the patch embeddings
        x = torch.cat(
            tensors=(cls_tokens, x), dim=1
        )  # Shape: (B, num_patches + 1, num_hiddens)

        # Add the position embedding to the patch embeddings
        x += self.pos_embedding  # Shape: (B, num_patches + 1, num_hiddens)

        # Apply dropout to the embeddings
        x = self.dropout(x)  # Shape: (B, num_patches + 1, num_hiddens)

        # Pass the embeddings through the transformer blocks
        for block in self.blocks:
            x = block(x)  # Shape: (B, num_patches + 1, num_hiddens)

        # Apply layer normalization to the `[CLS]` token's final representation
        x = self.layer_norm(x[:, 0])  # Shape: (B, num_hiddens)

        return x

In [22]:
B, C, H, W = 2, 3, 96, 96  # Batch size, Channels, Height, Width
X = torch.randn(B, C, H, W)
vit = ViT(
    img_size=H,
    patch_size=16,
    num_hiddens=64,
    num_heads=2,
    num_blocks=2,
    emb_dropout=0.1,
    block_dropout=0.1,
)

In [23]:
output = vit(X)
print(f"Output shape: {output.shape}")

Output shape: torch.Size([2, 64])


In [24]:
assert output.shape == (B, 64), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Vision-Language Projection Module

Unfortunatelly, we can't directly concatenate ViT output to the text embeddings. <br>
We need to project this from dimensionality of image embeddings from the vision transformer to the dimensionality of text embeddings.

Why MLP for this part? If you want to train VLM with low resources you can do so by keeping both the pretrained vision encoder and language decoder frozen during the VLM training. Therefore, allocating more parameters to the connection module could enhance the overall VLM's ability to generalize and help in the downstream instruction-tuning process.

In [25]:
class MultiModalProjector(nn.Module):
    def __init__(
        self,
        n_embed: int,
        img_embed_dim: int,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()

        # Define the projection network
        self.net = nn.Sequential(
            # Linear layer to expand the image embedding dimension
            nn.Linear(in_features=img_embed_dim, out_features=4 * img_embed_dim),
            # GELU activation function
            nn.GELU(),
            # Linear layer to project the expanded image embeddings to the text embedding dimension
            nn.Linear(in_features=4 * img_embed_dim, out_features=n_embed),
            # Dropout layer for regularization
            nn.Dropout(p=dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pass the input through the projection network
        x = self.net(x)  # Shape: (B, img_embed_dim) --> (B, n_embed)
        return x

In [26]:
B, n_embed, img_embed_dim = 2, 64, 128
X = torch.randn(size=(B, img_embed_dim))

projector = MultiModalProjector(
    n_embed=n_embed, img_embed_dim=img_embed_dim, dropout=0.1
)

In [27]:
output = projector(X)
print(f"Output shape: {output.shape}")

Output shape: torch.Size([2, 64])


In [28]:
assert output.shape == (B, n_embed), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Building the Decoder Language Model

Only thing that deviates from origianl implementation is that here projection module is integrated into decoder model class. <br>
In contrary, when using pretrained models with HuggingFace (or any other library), you can directly feed embeddings as input to the model.

In [29]:
class DecoderLanguageModel(nn.Module):
    def __init__(
        self,
        n_embed: int,
        img_embed_dim: int,
        vocab_size: int,
        num_heads: int,
        n_layer: int,
        use_images: bool = False,
    ) -> None:
        super().__init__()

        self.use_images = use_images

        # Token embedding table
        self.token_embedding_table = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=n_embed
        )

        # Position embedding table
        self.position_embedding_table = nn.Embedding(
            num_embeddings=1000, embedding_dim=n_embed
        )

        if use_images:
            # Image projection layer to align image embeddings with text embeddings
            self.image_projection = MultiModalProjector(
                n_embed=n_embed, img_embed_dim=img_embed_dim
            )

        # Stack of transformer decoder blocks
        self.blocks = nn.Sequential(
            *[
                Block(n_embed=n_embed, num_heads=num_heads, is_decoder=True)
                for _ in range(n_layer)
            ]
        )

        # Final layer normalization
        self.ln_f = nn.LayerNorm(normalized_shape=n_embed)

        # Language modeling head
        self.lm_head = nn.Linear(in_features=n_embed, out_features=vocab_size)

    def forward(
        self,
        idx: torch.Tensor,
        img_embeds: torch.Tensor = None,
        targets: torch.Tensor = None,
    ) -> torch.Tensor:
        # Get token embeddings from the input indices
        tok_emb = self.token_embedding_table(idx)

        if self.use_images:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(img_embeds).unsqueeze(1)
            tok_emb = torch.cat([img_emb, tok_emb], dim=1)

        # Get position embeddings
        pos_emb = self.position_embedding_table(
            torch.arange(tok_emb.size(1), device=idx.device)
        )

        # Add position embeddings to token embeddings
        x = tok_emb + pos_emb

        # Pass through the transformer decoder blocks
        x = self.blocks(x)

        # Apply final layer normalization
        x = self.ln_f(x)

        # Get the logits from the language modeling head
        logits = self.lm_head(x)

        if targets is not None:
            if self.use_images and img_embeds is not None:
                # Prepare targets by concatenating a dummy target for the image embedding
                batch_size = idx.size(0)
                targets = torch.cat(
                    [
                        torch.full(
                            (batch_size, 1), -100, dtype=torch.long, device=idx.device
                        ),
                        targets,
                    ],
                    dim=1,
                )

            # Compute the cross-entropy loss
            loss = F.cross_entropy(
                input=logits.view(-1, logits.size(-1)),
                target=targets.view(-1),
                ignore_index=-100,
            )
            return logits, loss

        return logits

    def generate(
        self, idx: torch.Tensor, img_embeds: torch.Tensor, max_new_tokens: int
    ) -> torch.Tensor:
        # Get the batch size and sequence length
        B, T = idx.shape

        # Initialize the generated sequence with the input indices
        generated = idx
        tok_emb = self.token_embedding_table(idx)

        if self.use_images and img_embeds is not None:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(img_embeds).unsqueeze(1)
            current_output = torch.cat([img_emb, tok_emb], dim=1)
        else:
            current_output = tok_emb

        # Generate new tokens iteratevely
        for i in range(max_new_tokens):
            # Get the current sequence length
            T_current = current_output.shape[1]

            # Get position embeddings for the current sequence length
            current_pos_emb = self.position_embedding_table(
                torch.arange(T_current, device=idx.device)
            ).unsqueeze(0)

            # Add position embeddings to the current output
            current_output += current_pos_emb

            # Pass through the transformer decoder blocks
            for block in self.blocks:
                current_output = block(current_output)

            # Get the logits for the last token
            logits = self.lm_head(current_output[:, -1, :])

            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample the next token based on the probability
            idx_next = torch.multinomial(input=probs, num_samples=1)

            # Concatenate the generated token to the generated sequence
            generated = torch.cat([generated, idx_next], dim=1)

            # Get the embeddings for the generated token
            idx_next_emb = self.token_embedding_table(idx_next)

            # Concatenate the generated token embeddings to the current output
            current_output = torch.cat([current_output, idx_next_emb], dim=1)

        return generated

In [30]:
import torch
import torch.nn.functional as F

# Simulated inputs
batch_size = 2
seq_len = 4
vocab_size = 10

# Fake target labels (batch_size x seq_len)
targets = torch.tensor([
    [1, 2, 3, 4],
    [2, 3, 4, 5]
], dtype=torch.long)  # shape: [2, 4]

# Dummy index tensor for alignment (same batch_size)
idx = torch.arange(batch_size)

# Simulated logits output from a model (batch_size x seq_len+1 x vocab_size)
# We assume image embeddings are used, so seq_len + 1
logits = torch.randn(batch_size, seq_len + 1, vocab_size)  # shape: [2, 5, 10]
print(logits)

# Add dummy target token for image embedding (-100 will be ignored in loss)
targets = torch.cat(
    [
        torch.full(
            (batch_size, 1), -100, dtype=torch.long, device=targets.device
        ),
        targets
    ],
    dim=1
)  # shape: [2, 5]

print(logits.view(-1, vocab_size).shape)
print(targets.view(-1).shape)

# Confirm shapes match
assert logits.shape[:2] == targets.shape, f"Shape mismatch: {logits.shape[:2]} vs {targets.shape}"

# Compute cross-entropy loss
loss = F.cross_entropy(
    input=logits.view(-1, vocab_size),  # shape: [2*5, 10]
    target=targets.view(-1),            # shape: [2*5]
    ignore_index=-100
)

print("Loss:", loss.item())


tensor([[[ 0.4043, -0.2360,  0.7757, -0.6890,  0.5713,  0.7772,  0.7611,
           0.7728, -1.7342,  2.3012],
         [-1.2186, -0.5990,  0.6929,  0.2358, -1.5885, -0.0234,  0.4622,
           0.2673,  0.6569,  0.1141],
         [-1.4403, -0.2727,  1.9735, -0.0568,  0.6660, -0.4947, -0.5999,
           0.4363, -0.3913, -0.5853],
         [-1.3703,  0.4677,  0.9049, -0.5474, -0.6246, -1.3839,  1.9559,
           0.5793,  1.0454, -0.0923],
         [-0.4327, -1.2240,  0.0360,  1.7603,  0.5940, -0.4259, -0.0973,
           0.0160, -0.9349, -1.7356]],

        [[-0.9396,  1.5042,  1.2620,  0.5159, -0.4417,  1.2174, -0.5435,
          -1.5183, -0.3332, -0.9561],
         [ 1.3663, -1.3814, -0.7335,  2.0053,  0.2647, -1.0977, -0.1515,
           1.3244, -0.1602,  1.2682],
         [-0.2830,  0.2802,  1.0564,  0.5530, -0.2721, -0.9782, -0.2651,
           0.8758, -0.8674, -1.5603],
         [ 1.6397,  1.5049, -2.1373,  0.3688, -1.2621,  0.1268, -1.0480,
           0.9555,  0.1263, -0.3048],

In [31]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
device

device(type='cuda')

Testing

In [32]:
n_embed, img_embed_dim, vocab_size, num_heads, n_layer = 128, 256, 1000, 8, 6
# `n_layer` is used to represent number of decoder transformer blocks and num_blocks for the vision encoder to avoid confusion
model = DecoderLanguageModel(
    n_embed=n_embed,
    img_embed_dim=img_embed_dim,
    vocab_size=vocab_size,
    num_heads=num_heads,
    n_layer=n_layer,
    use_images=True,
)
model.to(device)


# Dummy input
B, T = 10, 50
idx = torch.randint(low=0, high=vocab_size, size=(B, T)).to(device)
image_embeds = torch.randn(B, 256).to(device)  # Assume img_embed_dim is 256

targets = torch.randint(0, vocab_size, (B, T)).to(
    device
)  # Only if you want to compute loss

# Test forward pass
# Check if you need to calculate loss by providing targets
if targets is not None:
    logits, loss = model(idx, image_embeds, targets)
    print(f"Logits shape: {logits.shape}, Loss: {loss}")
else:
    logits = model(idx, image_embeds)  # Call without targets
    print(f"Logits shape: {logits.shape}")

# Test generation
generated = model.generate(idx, image_embeds, max_new_tokens=20)
print(f"Generated sequence shape: {generated.shape}")

Logits shape: torch.Size([10, 51, 1000]), Loss: 7.057732582092285
Generated sequence shape: torch.Size([10, 70])


## Putting everything together: Simple Vision Language Model

In [33]:
class VisionLanguageModel(nn.Module):
    def __init__(
        self,
        n_embed: int,
        img_embed_dim: int,
        vocab_size: int,
        n_layer: int,
        img_size: int,
        patch_size: int,
        num_heads: int,
        num_blocks: int,
        emb_dropout: float,
        block_dropout: float,
    ) -> None:
        super().__init__()

        # Set num_hiddens equal to img_embed_dim
        num_hiddens = img_embed_dim

        # Assert that num_hiddens is divisible by num_heads
        assert num_hiddens % num_heads == 0, ValueError(
            "num_hiddens must be divisible by num_heads!"
        )

        # Initialize the Vision Transformer (ViT) encoder
        self.vision_encoder = ViT(
            img_size=img_size,
            patch_size=patch_size,
            num_hiddens=num_hiddens,
            num_heads=num_heads,
            num_blocks=num_blocks,
            emb_dropout=emb_dropout,
            block_dropout=block_dropout,
        )

        # Initialize the Language Model Decoder (DecoderLanguageModel)
        self.decoder = DecoderLanguageModel(
            n_embed=n_embed,
            img_embed_dim=img_embed_dim,
            vocab_size=vocab_size,
            num_heads=num_heads,
            n_layer=n_layer,
            use_images=True,
        )

    def _check_image_embeddings(self, image_embeds: torch.Tensor) -> None:
        """Chek if image embeddings are valid."""
        if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
            raise ValueError(
                "Something is wrong with the ViT model. It's returning an empty tensor or the embedding dimension is empty."
            )

    def forward(
        self, img_array: torch.Tensor, idx: torch.Tensor, targets: torch.Tensor = None
    ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
        # Get the image embeddings from the Vision Encoder
        image_embeds = self.vision_encoder(img_array)

        # Check if image embeddings are valid
        self._check_image_embeddings(image_embeds)

        if targets is not None:
            # If targets are provided, compute the logits and loss
            logits, loss = self.decoder(idx, image_embeds, targets)
            return logits, loss
        else:
            # If targets are not provided, compute only the logits
            logits = self.decoder(idx, image_embeds)
            return logits

    def generate(
        self, img_array: torch.Tensor, idx: torch.Tensor, max_new_tokens: int
    ) -> torch.Tensor:
        # Get the image embeddings from the Vision Encoder
        image_embeds = self.vision_encoder(img_array)

        # Check if image embeddings are valid
        self._check_image_embeddings(image_embeds)

        # Generate new tokens using the Language Model Decoder
        generated_tokens = self.decoder.generate(
            idx=idx, img_embeds=image_embeds, max_new_tokens=max_new_tokens
        )
        return generated_tokens

Testing

In [34]:
n_embed, num_hiddens, vocab_size, num_heads, n_layer = 128, 512, 1000, 8, 8
image_embed_dim = num_hiddens
img_size = 96
patch_size = 16
num_blocks = 2

n_layer, block_size, num_hiddens = 8, 32, 512

# Initialize the model
model = VisionLanguageModel(
    n_embed=n_embed,
    img_embed_dim=image_embed_dim,
    vocab_size=vocab_size,
    n_layer=n_layer,
    img_size=img_size,
    patch_size=patch_size,
    num_heads=num_heads,
    num_blocks=num_blocks,
    emb_dropout=0.1,
    block_dropout=0.1,
)
model.to(device)

# Create dummy data with correct dimensions
dummy_img = torch.randn(1, 3, img_size, img_size).to(
    device
)  # Correct shape for image input
dummy_idx = torch.randint(0, vocab_size, (1, block_size)).to(
    device
)  # Correct shape for text input

# Forward pass to initialize all parameters
try:
    output = model(dummy_img, dummy_idx)  # Output for debugging
    print("Output from initialization forward pass:", output)
except RuntimeError as e:
    print(f"Runtime Error during forward pass: {str(e)}")
    print("Check layer configurations and input shapes.")

Output from initialization forward pass: tensor([[[ 0.1523,  0.2643, -0.7298,  ...,  0.9584,  0.7303, -0.4125],
         [ 0.5711,  0.1386,  0.6068,  ...,  0.2104, -0.2524, -0.1402],
         [ 0.2705, -0.0751,  0.0375,  ..., -0.5754,  0.3984, -0.7458],
         ...,
         [-0.5594,  0.6825, -1.0122,  ...,  0.5349, -0.2477, -0.0164],
         [-0.0354,  0.6391,  0.2511,  ...,  1.8003, -0.0955, -0.3355],
         [ 0.3329,  0.3013,  0.5336,  ...,  0.2116, -0.6363,  0.1007]]],
       device='cuda:0', grad_fn=<ViewBackward0>)


In [35]:
print(output.shape)

torch.Size([1, 33, 1000])


## Train

In [36]:
# Load tokenizer
tokenizer = spm.SentencePieceProcessor(model_file='spm.model')
max_len = 256  # giới hạn token

In [37]:
# Image preprocessing
image_transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
])

In [39]:
# Load dataset
dataset = load_dataset("HuggingFaceM4/the_cauldron", "ai2d", split="train")

# Lọc những entry có ảnh và text hợp lệ
dataset = dataset.filter(lambda x: x["images"] and x["texts"] and "user" in x["texts"][0])

# # Lấy 100 sample đầu tiên, lấy full dataset thì bỏ qua dòng này
# dataset = dataset.select(range(100))

Filter:   0%|          | 0/2434 [00:00<?, ? examples/s]

In [None]:
# Preprocessing function
def preprocess(example):
    img_data = example["images"][0]

    # Đảm bảo ảnh là PIL.Image
    if isinstance(img_data, Image.Image):
        img = img_data.convert("RGB")
    elif isinstance(img_data, np.ndarray):
        img = Image.fromarray(img_data).convert("RGB")
    else:
        raise ValueError(f"Unsupported image format: {type(img_data)}")

    # Transform để ra Tensor
    image = image_transform(img)  # Tensor (3, 224, 224)

    # Tokenize prompt và target
    prompt = example["texts"][0]["user"]
    target = example["texts"][0].get("assistant", "")

    full_input = prompt + "\n" + target if target else prompt

    pad_id = tokenizer.pad_id() if tokenizer.pad_id() >= 0 else 0
    tokens = tokenizer.encode(full_input)
    tokens = tokens[:max_len]
    tokens += [pad_id] * (max_len - len(tokens))
    input_ids = torch.tensor(tokens, dtype=torch.long)

    # Tokenize target
    if target:
        target_tokens = tokenizer.encode(target)
        target_tokens = target_tokens[:max_len]
        target_tokens += [pad_id] * (max_len - len(target_tokens))
        target_ids = torch.tensor(target_tokens, dtype=torch.long)
    else:
        target_ids = torch.full_like(input_ids, fill_value=pad_id)

    return {
        "image": image,
        "input_ids": input_ids,
        "target_ids": target_ids
    }

# Apply preprocessing
dataset = dataset.map(preprocess)

Filter:   0%|          | 0/2434 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [None]:
# Bắt buộc để giữ tensor thay vì list!
dataset.set_format(type="torch")

In [None]:
# Kiểm tra xem data có đúng tensor ko
example = dataset[0]
print(type(example['image']))          # <class 'torch.Tensor'>
print(example['image'].shape)          # torch.Size([3, 96, 96])
print(example['input_ids'].shape)      # torch.Size([256])
print(example['target_ids'].shape)

<class 'torch.Tensor'>
torch.Size([3, 96, 96])
torch.Size([256])
torch.Size([256])


In [None]:
def collate_fn(batch):
    # Đảm bảo chuyển về tensor đúng shape
    imgs = torch.stack([torch.tensor(item['image']) if not isinstance(item['image'], torch.Tensor) else item['image'] for item in batch])
    input_ids = torch.stack([item['input_ids'] for item in batch])
    target_ids = torch.stack([item['target_ids'] for item in batch])
    return imgs, input_ids, target_ids

In [None]:
n_embed, num_hiddens, num_heads, n_layer = 128, 512, 8, 8
image_embed_dim = num_hiddens
img_size = 96
patch_size = 16
num_blocks = 2

n_layer, block_size, num_hiddens = 8, 32, 512

# Initialize the model
vlm = VisionLanguageModel(
    n_embed=n_embed,
    img_embed_dim=image_embed_dim,
    vocab_size=tokenizer.vocab_size(),
    n_layer=n_layer,
    img_size=img_size,
    patch_size=patch_size,
    num_heads=num_heads,
    num_blocks=num_blocks,
    emb_dropout=0.1,
    block_dropout=0.1,
)
device = torch.device('cuda')
vlm.to(device)

# Optimizer, chọn bộ phù hợp, chưa thử nhiều nên không bt bộ nào tốt
# optimizer = torch.optim.AdamW(vlm.parameters(), lr=1e-4)
optimizer = torch.optim.SGD(vlm.parameters(), lr=0.001, momentum=0.9)


In [None]:
# Create DataLoader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
# Training loop
vlm.train()
for epoch in range(60):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")
    total_loss = 0
    for imgs, input_ids, target_ids in pbar:
        input_ids = input_ids.to(device)
        imgs = imgs.to(device)
        target_ids = target_ids.to(device)

        optimizer.zero_grad()
        _, loss = vlm(imgs, input_ids, targets=target_ids)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix({"loss": loss.item()})

    print(f"Epoch {epoch+1} - Avg Loss: {total_loss / len(dataloader):.4f}")

Epoch 1: 100%|██████████| 13/13 [00:04<00:00,  2.71it/s, loss=4.59]
Epoch 34: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:25<00:00,  1.99s/it, loss=0.108]


Epoch 1 - Avg Loss: 8.9425


Epoch 2: 100%|██████████| 13/13 [00:04<00:00,  3.15it/s, loss=0.192]


Epoch 2 - Avg Loss: 1.1465


Epoch 3: 100%|██████████| 13/13 [00:04<00:00,  3.01it/s, loss=0.22]


Epoch 3 - Avg Loss: 0.2088


Epoch 4: 100%|██████████| 13/13 [00:04<00:00,  2.85it/s, loss=0.227]


Epoch 4 - Avg Loss: 0.2247


Epoch 5: 100%|██████████| 13/13 [00:04<00:00,  3.12it/s, loss=0.224]


Epoch 5 - Avg Loss: 0.2253


Epoch 6: 100%|██████████| 13/13 [00:05<00:00,  2.36it/s, loss=0.22]


Epoch 6 - Avg Loss: 0.2216


Epoch 7: 100%|██████████| 13/13 [00:04<00:00,  3.14it/s, loss=0.214]


Epoch 7 - Avg Loss: 0.2164


Epoch 8: 100%|██████████| 13/13 [00:05<00:00,  2.25it/s, loss=0.206]


Epoch 8 - Avg Loss: 0.2106


Epoch 9: 100%|██████████| 13/13 [00:04<00:00,  2.70it/s, loss=0.2]


Epoch 9 - Avg Loss: 0.2043


Epoch 10: 100%|██████████| 13/13 [00:04<00:00,  3.12it/s, loss=0.192]

Epoch 10 - Avg Loss: 0.1976





Epoch 34 - Avg Loss: 0.1069


Epoch 35: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:25<00:00,  1.93s/it, loss=0.111]


Epoch 35 - Avg Loss: 0.1059


Epoch 36: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:24<00:00,  1.91s/it, loss=0.105]


Epoch 36 - Avg Loss: 0.1042


Epoch 37: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:24<00:00,  1.91s/it, loss=0.104]


Epoch 37 - Avg Loss: 0.1032


Epoch 38: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:24<00:00,  1.90s/it, loss=0.0992]


Epoch 38 - Avg Loss: 0.1018


Epoch 39: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:24<00:00,  1.91s/it, loss=0.101]


Epoch 39 - Avg Loss: 0.1006


Epoch 40:  38%|█████████████████████████████████████████▏                                                                 | 5/13 [00:11<00:18,  2.31s/it, loss=0.103]


KeyboardInterrupt: 

In [None]:
torch.save(vlm.state_dict(), "model_10.pth")
print("Model saved to model.pth")

Model saved to model.pth


## Eval
### define model phải giống với lúc train, img_size=96, nếu đổi img_size phải đổi ở hàm def preprocess(example) và image_transform


In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
device

device(type='cuda')

In [None]:
tokenizer = spm.SentencePieceProcessor(model_file='/content/spm.model')

In [None]:
n_embed, num_hiddens, num_heads, n_layer = 128, 512, 8, 8
image_embed_dim = num_hiddens
img_size = 96
patch_size = 16
num_blocks = 2

n_layer, block_size, num_hiddens = 8, 32, 512

# Initialize the model
model = VisionLanguageModel(
    n_embed=n_embed,
    img_embed_dim=image_embed_dim,
    vocab_size=tokenizer.vocab_size(),
    n_layer=n_layer,
    img_size=img_size,
    patch_size=patch_size,
    num_heads=num_heads,
    num_blocks=num_blocks,
    emb_dropout=0.1,
    block_dropout=0.1,
)
model.to(device)
model.load_state_dict(torch.load("/content/model_10.pth", map_location=torch.device('cuda')))
model.eval()  # set to eval mode if you're going to do inference

# Load image
img_path = '/content/image-1d100e9.jpg'  # 🔁 Replace with your actual image path
image = Image.open(img_path).convert("RGB")

# Preprocessing image
transform = transforms.Compose([
    transforms.Resize((96, 96)),  # make sure this matches ViT input
    transforms.ToTensor(),
])
img_tensor = transform(image).unsqueeze(0)  # shape: [1, 3, 96, 96]

# Move img_tensor to the same device as the model
img_tensor = img_tensor.to(device)

# Prepare prompt and tokenize it
prompt = "Question: What do respiration and combustion give out\nChoices:\nA. Oxygen\nB. Carbon dioxide\nC. Nitrogen\nD. Heat\nAnswer with the letter."
tokens = tokenizer.encode(prompt)

# Convert tokens list to a PyTorch tensor
tokens_tensor = torch.tensor(tokens).unsqueeze(0).to(device)  # Add batch dimension and move to device

# --- 4. Run inference ---
with torch.no_grad():
    output_tokens = model.generate(
        img_array=img_tensor,
        idx=tokens_tensor,  # Pass tensor instead of list
        max_new_tokens=50
    )



In [None]:
# --- 5. Decode and handle special tokens ---
# Convert tensor to list and decode
output_tokens_list = output_tokens[0].cpu().numpy().tolist()

# Remove special tokens manually (if needed)
# For example, let's assume that 0 is the token for padding (common in many models)
# Modify the list to remove any special tokens, if necessary
output_tokens_list = [token for token in output_tokens_list if token != tokenizer.pad_id()]

# Now, decode the remaining tokens
output_text = tokenizer.decode(output_tokens_list)
print("Answer:", output_text)

Answer: Question: What do respiration and combustion give out Choices: A. Oxygen B. Carbon dioxide C. Nitrogen D. Heat Answer with the letter.
