In [48]:
!git clone https://github.com/milistu/VLM-from-Scratch.git

fatal: destination path 'VLM-from-Scratch' already exists and is not an empty directory.


In [49]:
!pip install --upgrade torch torchvision torchaudio



In [50]:
!pip install datasets
!pip install pandas
!pip install transformers
!pip install sentencepiece
!pip install tqdm
from tqdm import tqdm



In [51]:
import psutil
import os

def print_memory_usage():
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss / (1024 * 1024)  # in MB
    print(f"Memory usage: {mem:.2f} MB")

print_memory_usage()
import sentencepiece as spm
print_memory_usage()
tokenizer = spm.SentencePieceProcessor(model_file='./spm.model')
print_memory_usage()
ki = tokenizer.encode("Hello world asdasd dasd ")
print_memory_usage()
print(tokenizer.decode(ki))


Memory usage: 4253.23 MB
Memory usage: 4253.23 MB
Memory usage: 4253.23 MB
Memory usage: 4253.23 MB
Hello world asdasd dasd


### Dataset_reward

In [52]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds2 = load_dataset("MMInstruction/VL-RewardBench")

In [53]:
ds2

DatasetDict({
    test: Dataset({
        features: ['id', 'query', 'response', 'image', 'human_ranking', 'models', 'judge', 'rationale', 'query_source', 'ground_truth'],
        num_rows: 1250
    })
})

In [54]:
ds2 = ds2.filter(lambda x: x['human_ranking'] == [0,1])

In [55]:
len(ds2['test'])

1244

In [56]:
ds2 = ds2['test'].map(lambda x: {'image': x['image'], 'prompt': x['query'], 'chosen': x['response'][0], 'rejected': x['response'][1]}
                      , remove_columns=['human_ranking', 'response', 'id','models','judge','rationale', 'ground_truth','query_source','query'])

In [57]:
ds2

Dataset({
    features: ['image', 'prompt', 'chosen', 'rejected'],
    num_rows: 1244
})

In [58]:
# Image preprocessing
import transformers
import torchvision.transforms as transforms
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [59]:
# Dạng dữ liệu mới: [{'image': PIL.Image, 'prompt': str, 'caption': str}, ...]
from PIL import Image
import numpy as np
import torch

def preprocess2(example):
    img_data = example["image"]

    # Đảm bảo ảnh là PIL.Image
    if isinstance(img_data, Image.Image):
        img = img_data.convert("RGB")
    elif isinstance(img_data, np.ndarray):
        img = Image.fromarray(img_data).convert("RGB")
    else:
        raise ValueError(f"Unsupported image format: {type(img_data)}")

    # Transform ảnh sang Tensor (3, 224, 224)
    image = image_transform(img)
    chosen = example["prompt"] + " " + example["chosen"]
    rejected = example["prompt"] + " " + example["rejected"]

    pad_id = tokenizer.pad_id() if tokenizer.pad_id() >= 0 else 0
    chosen_input_ids = tokenizer.encode(chosen)
    chosen_tokens = chosen_input_ids[:256]
    chosen_tokens += [pad_id] * (256 - len(chosen_tokens))
    chosen_input_ids = torch.tensor(chosen_tokens, dtype=torch.long)

    rejected_input_ids = tokenizer.encode(rejected)
    rejected_tokens = rejected_input_ids[:256]
    rejected_tokens += [pad_id] * (256 - len(rejected_tokens))
    rejected_input_ids = torch.tensor(rejected_tokens, dtype=torch.long)
    return {
        "image": image,
        "chosen_input_ids": chosen_input_ids,
        "reject_input_ids": rejected_input_ids
    }

# # Apply preprocessing
# dataset2 = []
# for i in range(len(ds2)):
#     dataset2.append(preprocess2(ds2[i]))
dataset2 = list(map(preprocess2, ds2))



In [60]:
def collate_fn2(batch):
    imgs = torch.stack([torch.tensor(item['image']) if not isinstance(item['image'], torch.Tensor) else item['image'] for item in batch])
    input_ids = torch.stack([item["chosen_input_ids"] for item in batch])
    target_ids = torch.stack([item["reject_input_ids"] for item in batch])
    return imgs, input_ids, target_ids


#1 Base model

In [61]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import Tuple
# ---
class PatchEmbeddings(nn.Module):
    def __init__(
        self, img_size: int = 96, patch_size: int = 16, hidden_dim: int = 512
    ) -> None:
        super().__init__()
        # Store the input image size, the patch size and hidden dimension
        self.img_size = img_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim

        self.num_patches = (self.img_size // self.patch_size) ** 2
        self.conv = nn.Conv2d(
            in_channels=3,
            out_channels=self.hidden_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        X = self.conv(X)
        X = X.flatten(2)
        X = X.transpose(1, 2)

        return X
    
B, C, H, W = 1, 3, 96, 96  # Batch size, Channels, Height, Width
X = torch.randn(B, C, H, W)

patch_size = 16
hidden_dim = 512

patch_embeddings = PatchEmbeddings(
    img_size=H, patch_size=patch_size, hidden_dim=hidden_dim
)

patches = patch_embeddings(X)
num_patches = (H // patch_size) ** 2
assert patches.shape == (B, num_patches, hidden_dim), "Output shape is incorrect"
print("Test passed!")

Test passed!


## Attention Mechanism

In [62]:
class Head(nn.Module):
    def __init__(
        self,
        n_embed: int,
        head_size: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Linear layer for Key projection
        self.key = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Linear layer for Query projection
        self.query = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Linear layer for Value projection
        self.value = nn.Linear(in_features=n_embed, out_features=head_size, bias=False)

        # Dropout layer for regularization to prevent overfitting
        self.dropout = nn.Dropout(p=dropout)

        # Flag indicating wheter the head is used as a decoder
        self.is_decoder = is_decoder

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Get batch size (B), sequence length (T), and embedding dimension (C) from the input tensor
        B, T, C = x.shape

        # Compute Key, Query, and Value projections
        k = self.key(x)  # Shape: (B, T, head_size)
        q = self.query(x)  # Shape: (B, T, head_size)
        v = self.value(x)  # SHape: (B, T, head_size)

        # Compute attention scores by taking the dot product of Query and Key
        # and scaling by the square root of the embedding dimension
        wei = q @ k.transpose(-2, -1) * (C**-0.5)  # Shape: (B, T, T)

        if self.is_decoder:
            # If this head is used in the decoder, apply causal mask to the attention scores
            # to prevent attention to future positions
            tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
            wei = wei.masked_fill(mask=tril == 0, value=float("-inf"))

        # Apply softmax to the attention scores to obtain attention probabilities
        # Sum of probabilities for each row will be 1
        wei = F.softmax(input=wei, dim=-1)  # Shape: (B, T, T)

        # Apply Dropout to the attention probabilities for regularization
        wei = self.dropout(wei)

        # Perform a weighted aggregation of values using the attention probabilities
        out = wei @ v  # Shape: (B, T, head_size)

        return out
B, T, C = patches.shape  # Batch size, Sequence length, Embedding dimension
head_size = 16  # Size of the attention head

head = Head(n_embed=C, head_size=head_size)
output = head(patches)
print(f"Shape of output tensor: {output.shape}")
assert output.shape == (B, T, head_size), "Output shape is incorrect"
print("Test passed!")

Shape of output tensor: torch.Size([1, 36, 16])
Test passed!


In [63]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        n_embed: int,
        num_heads: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Ensure that the embedding dimension is divisible by the number of heads
        assert n_embed % num_heads == 0, "n_embed must be divisible by num_heads!"

        # Create a ModuleList of attention heads
        self.heads = nn.ModuleList(
            modules=[
                Head(
                    n_embed=n_embed,
                    head_size=n_embed // num_heads,
                    dropout=dropout,
                    is_decoder=is_decoder,
                )
                for _ in range(num_heads)
            ]
        )

        # Linear layer for projecting the concatenated head outputs
        self.proj = nn.Linear(in_features=n_embed, out_features=n_embed)

        # Dropout layer for regularization
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply each attention head to the input tensor
        head_outputs = [
            h(x) for h in self.heads
        ]  # Shape: num_heads * (B, T, head_size)

        # Concatenate the outputs from all heads along the last dimension
        out = torch.cat(tensors=head_outputs, dim=-1)  # Shape: (B, T, m_embed)

        # Apply the projection layer to the concatenated outputs
        out = self.proj(out)  # Shape: (B, T, m_embed)

        # Apply Dropout to the projected outputs for regularization
        out = self.dropout(out)

        return out
num_heads = 2
dropout = 0.1
mha = MultiHeadAttention(n_embed=C, num_heads=num_heads, dropout=dropout)
output = mha(patches)
print(f"Shape of output tensor: {output.shape}")
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Shape of output tensor: torch.Size([1, 36, 512])
Test passed!


In [64]:
class MLP(nn.Module):
    def __init__(
        self, n_embed: int, dropout: float = 0.1, is_decoder: bool = False
    ) -> None:
        super().__init__()

        # Define the layers of the MLP
        layers = [
            # First linear layer that expands the input dimension from n_embed to 4 * n_embed
            nn.Linear(in_features=n_embed, out_features=4 * n_embed),
            # Activation function: ReLU if is_decoder is True, else GELU
            nn.ReLU() if is_decoder else nn.GELU(),
            # Second linear layer that projects the intermediate dimension back to n_embed
            nn.Linear(in_features=4 * n_embed, out_features=n_embed),
            # Dropout layer for regularization
            nn.Dropout(p=dropout),
        ]

        # Create the MLP as a sequence of layers
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pass the input through the MLP layers
        return self.net(x)
dropout = 0.1
mlp = MLP(n_embed=C, dropout=dropout)
output = mlp(output)  # Previous output of the Multihead Attention
print(f"Shape of output tensor: {output.shape}")
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Shape of output tensor: torch.Size([1, 36, 512])
Test passed!


In [65]:
class Block(nn.Module):
    def __init__(
        self,
        n_embed: int,
        num_heads: int,
        dropout: float = 0.1,
        is_decoder: bool = False,
    ) -> None:
        super().__init__()

        # Layer normalization for the input to the attention layer
        self.ln1 = nn.LayerNorm(normalized_shape=n_embed)

        # Multi-head attention module
        self.mhattn = MultiHeadAttention(
            n_embed=n_embed, num_heads=num_heads, dropout=dropout, is_decoder=is_decoder
        )

        # Layer normalization for the input to the FFN
        self.ln2 = nn.LayerNorm(normalized_shape=n_embed)

        # Feed-forward neural network (FFN)
        self.ffn = nn.Sequential(
            nn.Linear(in_features=n_embed, out_features=4 * n_embed),
            nn.GELU(),  # Activation function
            nn.Linear(
                in_features=4 * n_embed, out_features=n_embed
            ),  # Projection back to the original dimension
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Saving the input for residual connection
        original_x = x

        # Apply layer normalization to the input
        x = self.ln1(x)

        # Apply multi-head attention
        mhattn_output = self.mhattn(x)

        # Add the residual connection (original input) to the attention output
        x = original_x + mhattn_output

        # Apply later normalization to the input to the FFN
        x = self.ln2(x)

        # Apply the FFN
        ffn_output = self.ffn(x)

        # Apply the residual connection (input to the FFN) to the FFN output
        x = x + ffn_output

        return x
num_heads = 2
dropout = 0.1
block = Block(n_embed=C, num_heads=num_heads, dropout=dropout)
output = block(patches)
print(f"Shape of output tensor: {output.shape}")
assert output.shape == (B, T, C), "Output shape is incorrect"
print("Test passed!")

Shape of output tensor: torch.Size([1, 36, 512])
Test passed!


## Vit

In [66]:
class ViT(nn.Module):
    def __init__(
        self,
        img_size: int,
        patch_size: int,
        num_hiddens: int,
        num_heads: int,
        num_blocks: int,
        emb_dropout: float,
        block_dropout: float,
    ) -> None:
        super().__init__()

        # Patch embedding layer to convert the input image into patches
        self.patch_embedding = PatchEmbeddings(
            img_size=img_size, patch_size=patch_size, hidden_dim=num_hiddens
        )

        # Learnable classification token
        self.cls_token = nn.Parameter(data=torch.zeros(size=(1, 1, num_hiddens)))

        # Calculate the number of patches
        num_patches = (img_size // patch_size) ** 2

        # Learnable position embedding
        self.pos_embedding = nn.Parameter(
            data=torch.randn(size=(1, num_patches + 1, num_hiddens))
        )

        # Dropout layer for the embeddings
        self.dropout = nn.Dropout(p=emb_dropout)

        # Stack of transformer blocks
        self.blocks = nn.ModuleList(
            [
                Block(
                    n_embed=num_hiddens,
                    num_heads=num_heads,
                    dropout=block_dropout,
                    is_decoder=False,
                )
                for _ in range(num_blocks)
            ]
        )

        # Layer normalization for the final representation
        self.layer_norm = nn.LayerNorm(normalized_shape=num_hiddens)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # Convert the input image into patch embeddings
        x = self.patch_embedding(X)  # Shape: (B, num_patches, num_hiddens)

        # Expand the classification token to match the batch size
        cls_tokens = self.cls_token.expand(
            x.shape[0], -1, -1
        )  # Shape: (B, 1, num_hiddens)

        # Concatenate the classification token with the patch embeddings
        x = torch.cat(
            tensors=(cls_tokens, x), dim=1
        )  # Shape: (B, num_patches + 1, num_hiddens)

        # Add the position embedding to the patch embeddings
        x += self.pos_embedding  # Shape: (B, num_patches + 1, num_hiddens)

        # Apply dropout to the embeddings
        x = self.dropout(x)  # Shape: (B, num_patches + 1, num_hiddens)

        # Pass the embeddings through the transformer blocks
        for block in self.blocks:
            x = block(x)  # Shape: (B, num_patches + 1, num_hiddens)

        # Apply layer normalization to the `[CLS]` token's final representation
        x = self.layer_norm(x[:, 0])  # Shape: (B, num_hiddens)

        return x
B, C, H, W = 2, 3, 96, 96  # Batch size, Channels, Height, Width
X = torch.randn(B, C, H, W)
vit = ViT(
    img_size=H,
    patch_size=16,
    num_hiddens=64,
    num_heads=2,
    num_blocks=2,
    emb_dropout=0.1,
    block_dropout=0.1,
)
output = vit(X)
print(f"Output shape: {output.shape}")
assert output.shape == (B, 64), "Output shape is incorrect"
print("Test passed!")

Output shape: torch.Size([2, 64])
Test passed!


In [67]:
class MultiModalProjector(nn.Module):
    def __init__(
        self,
        n_embed: int,
        img_embed_dim: int,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()

        # Define the projection network
        self.net = nn.Sequential(
            # Linear layer to expand the image embedding dimension
            nn.Linear(in_features=img_embed_dim, out_features=4 * img_embed_dim),
            # GELU activation function
            nn.GELU(),
            # Linear layer to project the expanded image embeddings to the text embedding dimension
            nn.Linear(in_features=4 * img_embed_dim, out_features=n_embed),
            # Dropout layer for regularization
            nn.Dropout(p=dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pass the input through the projection network
        x = self.net(x)  # Shape: (B, img_embed_dim) --> (B, n_embed)
        return x
B, n_embed, img_embed_dim = 2, 64, 128
X = torch.randn(size=(B, img_embed_dim))

projector = MultiModalProjector(
    n_embed=n_embed, img_embed_dim=img_embed_dim, dropout=0.1
)
output = projector(X)
print(f"Output shape: {output.shape}")
assert output.shape == (B, n_embed), "Output shape is incorrect"
print("Test passed!")

Output shape: torch.Size([2, 64])
Test passed!


## Decode

In [68]:
class DecoderLanguageModel(nn.Module):
    def __init__(
        self,
        n_embed: int,
        img_embed_dim: int,
        vocab_size: int,
        num_heads: int,
        n_layer: int,
        use_images: bool = False,
    ) -> None:
        super().__init__()

        self.use_images = use_images

        # Token embedding table
        self.token_embedding_table = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=n_embed
        )

        # Position embedding table
        self.position_embedding_table = nn.Embedding(
            num_embeddings=1000, embedding_dim=n_embed
        )

        if use_images:
            # Image projection layer to align image embeddings with text embeddings
            self.image_projection = MultiModalProjector(
                n_embed=n_embed, img_embed_dim=img_embed_dim
            )

        # Stack of transformer decoder blocks
        self.blocks = nn.Sequential(
            *[
                Block(n_embed=n_embed, num_heads=num_heads, is_decoder=True)
                for _ in range(n_layer)
            ]
        )

        # Final layer normalization
        self.ln_f = nn.LayerNorm(normalized_shape=n_embed)

        # Language modeling head
        self.lm_head = nn.Linear(in_features=n_embed, out_features=vocab_size)

    def forward(
        self,
        idx: torch.Tensor,
        img_embeds: torch.Tensor = None,
        targets: torch.Tensor = None,
    ) -> torch.Tensor:
        # Get token embeddings from the input indices
        tok_emb = self.token_embedding_table(idx)

        if self.use_images:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(img_embeds).unsqueeze(1)
            tok_emb = torch.cat([img_emb, tok_emb], dim=1)

        # Get position embeddings
        pos_emb = self.position_embedding_table(
            torch.arange(tok_emb.size(1), device=idx.device)
        )

        # Add position embeddings to token embeddings
        x = tok_emb + pos_emb

        # Pass through the transformer decoder blocks
        x = self.blocks(x)

        # Apply final layer normalization
        x = self.ln_f(x)

        # Get the logits from the language modeling head
        logits = self.lm_head(x)

        if targets is not None:
            if self.use_images and img_embeds is not None:
                # Prepare targets by concatenating a dummy target for the image embedding
                batch_size = idx.size(0)
                targets = torch.cat(
                    [
                        torch.full(
                            (batch_size, 1), -100, dtype=torch.long, device=idx.device
                        ),
                        targets,
                    ],
                    dim=1,
                )

            # Compute the cross-entropy loss
            loss = F.cross_entropy(
                input=logits.view(-1, logits.size(-1)),
                target=targets.view(-1),
                ignore_index=-100,
            )
            return logits, loss

        return logits

    def generate(
        self, idx: torch.Tensor, img_embeds: torch.Tensor, max_new_tokens: int
    ) -> torch.Tensor:
        # Get the batch size and sequence length
        B, T = idx.shape

        # Initialize the generated sequence with the input indices
        generated = idx
        tok_emb = self.token_embedding_table(idx)

        if self.use_images and img_embeds is not None:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(img_embeds).unsqueeze(1)
            current_output = torch.cat([img_emb, tok_emb], dim=1)
        else:
            current_output = tok_emb

        # Generate new tokens iteratevely
        for i in range(max_new_tokens):
            # Get the current sequence length
            T_current = current_output.shape[1]

            # Get position embeddings for the current sequence length
            current_pos_emb = self.position_embedding_table(
                torch.arange(T_current, device=idx.device)
            ).unsqueeze(0)

            # Add position embeddings to the current output
            current_output += current_pos_emb

            # Pass through the transformer decoder blocks
            for block in self.blocks:
                current_output = block(current_output)

            # Get the logits for the last token
            logits = self.lm_head(current_output[:, -1, :])

            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample the next token based on the probability
            idx_next = torch.multinomial(input=probs, num_samples=1)

            # Concatenate the generated token to the generated sequence
            generated = torch.cat([generated, idx_next], dim=1)

            # Get the embeddings for the generated token
            idx_next_emb = self.token_embedding_table(idx_next)

            # Concatenate the generated token embeddings to the current output
            current_output = torch.cat([current_output, idx_next_emb], dim=1)

        return generated


In [69]:
import torch
import torch.nn.functional as F

# Simulated inputs
batch_size = 2
seq_len = 4
vocab_size = 10

# Fake target labels (batch_size x seq_len)
targets = torch.tensor([
    [1, 2, 3, 4],
    [2, 3, 4, 5]
], dtype=torch.long)  # shape: [2, 4]

# Dummy index tensor for alignment (same batch_size)
idx = torch.arange(batch_size)

# Simulated logits output from a model (batch_size x seq_len+1 x vocab_size)
# We assume image embeddings are used, so seq_len + 1
logits = torch.randn(batch_size, seq_len + 1, vocab_size)  # shape: [2, 5, 10]
print(logits)

# Add dummy target token for image embedding (-100 will be ignored in loss)
targets = torch.cat(
    [
        torch.full(
            (batch_size, 1), -100, dtype=torch.long, device=targets.device
        ),
        targets
    ],
    dim=1
)  # shape: [2, 5]

print(logits.view(-1, vocab_size).shape)
print(targets.view(-1).shape)

# Confirm shapes match
assert logits.shape[:2] == targets.shape, f"Shape mismatch: {logits.shape[:2]} vs {targets.shape}"

# Compute cross-entropy loss
loss = F.cross_entropy(
    input=logits.view(-1, vocab_size),  # shape: [2*5, 10]
    target=targets.view(-1),            # shape: [2*5]
    ignore_index=-100
)

print("Loss:", loss.item())


tensor([[[-1.9717,  0.2538,  1.5162,  0.3013,  1.2697, -0.1159, -0.1931,
           0.7020, -0.0179,  2.0949],
         [-1.0747, -0.8158,  0.9019, -0.9119, -1.1437, -1.2513,  2.2801,
           0.3489, -0.1461,  0.8202],
         [ 0.2504, -0.5564, -1.3196, -1.0037, -0.9294,  0.6707,  1.2467,
           2.3581,  0.2327,  0.4532],
         [-1.5703,  0.7633,  0.2696,  1.0340,  0.6696,  0.3486,  0.0156,
           0.0284, -0.5442,  2.0323],
         [ 0.3278,  0.3139, -1.7571,  2.7767,  1.0911, -0.8164, -0.2428,
           0.5178, -1.1838, -1.3295]],

        [[-0.6479, -0.0668,  0.3694, -0.6553,  1.0205, -1.1558, -0.5911,
          -0.1538, -0.9652,  0.0478],
         [ 0.2931,  1.3820,  1.0517,  0.2992, -1.0632, -2.0147,  0.8976,
          -0.9075,  0.1654, -1.2044],
         [ 0.5730,  0.0580, -0.7704, -0.1815,  0.5685,  0.7436, -0.5249,
          -2.2182,  0.1672,  0.0692],
         [-0.8057, -0.1025,  0.7853,  1.0227,  1.3004,  1.6454, -0.6180,
          -0.5352,  0.2914,  2.0800],

In [70]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='cpu')

In [71]:
n_embed, img_embed_dim, vocab_size, num_heads, n_layer = 128, 256, 1000, 8, 6
# `n_layer` is used to represent number of decoder transformer blocks and num_blocks for the vision encoder to avoid confusion
model = DecoderLanguageModel(
    n_embed=n_embed,
    img_embed_dim=img_embed_dim,
    vocab_size=vocab_size,
    num_heads=num_heads,
    n_layer=n_layer,
    use_images=True,
)
model.to(device)


# Dummy input
B, T = 10, 50
idx = torch.randint(low=0, high=vocab_size, size=(B, T)).to(device)
image_embeds = torch.randn(B, 256).to(device)  # Assume img_embed_dim is 256

targets = torch.randint(0, vocab_size, (B, T)).to(
    device
)  # Only if you want to compute loss

# Test forward pass
# Check if you need to calculate loss by providing targets
if targets is not None:
    logits, loss = model(idx, image_embeds, targets)
    print(f"Logits shape: {logits.shape}, Loss: {loss}")
else:
    logits = model(idx, image_embeds)  # Call without targets
    print(f"Logits shape: {logits.shape}")

# Test generation
generated = model.generate(idx, image_embeds, max_new_tokens=20)
print(f"Generated sequence shape: {generated.shape}")

Logits shape: torch.Size([10, 51, 1000]), Loss: 7.060498237609863


Generated sequence shape: torch.Size([10, 70])


## Put everything

In [72]:
class VisionLanguageModel(nn.Module):
    def __init__(
        self,
        n_embed: int,
        img_embed_dim: int,
        vocab_size: int,
        n_layer: int,
        img_size: int,
        patch_size: int,
        num_heads: int,
        num_blocks: int,
        emb_dropout: float,
        block_dropout: float,
    ) -> None:
        super().__init__()

        # Set num_hiddens equal to img_embed_dim
        num_hiddens = img_embed_dim

        # Assert that num_hiddens is divisible by num_heads
        assert num_hiddens % num_heads == 0, ValueError(
            "num_hiddens must be divisible by num_heads!"
        )

        # Initialize the Vision Transformer (ViT) encoder
        self.vision_encoder = ViT(
            img_size=img_size,
            patch_size=patch_size,
            num_hiddens=num_hiddens,
            num_heads=num_heads,
            num_blocks=num_blocks,
            emb_dropout=emb_dropout,
            block_dropout=block_dropout,
        )

        # Initialize the Language Model Decoder (DecoderLanguageModel)
        self.decoder = DecoderLanguageModel(
            n_embed=n_embed,
            img_embed_dim=img_embed_dim,
            vocab_size=vocab_size,
            num_heads=num_heads,
            n_layer=n_layer,
            use_images=True,
        )

    def _check_image_embeddings(self, image_embeds: torch.Tensor) -> None:
        """Chek if image embeddings are valid."""
        if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
            raise ValueError(
                "Something is wrong with the ViT model. It's returning an empty tensor or the embedding dimension is empty."
            )

    def forward(
        self, img_array: torch.Tensor, idx: torch.Tensor, targets: torch.Tensor = None
    ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
        # Get the image embeddings from the Vision Encoder
        image_embeds = self.vision_encoder(img_array)

        # Check if image embeddings are valid
        self._check_image_embeddings(image_embeds)

        if targets is not None:
            # If targets are provided, compute the logits and loss
            logits, loss = self.decoder(idx, image_embeds, targets)
            return logits, loss
        else:
            # If targets are not provided, compute only the logits
            logits = self.decoder(idx, image_embeds)
            return logits

    def generate(
        self, img_array: torch.Tensor, idx: torch.Tensor, max_new_tokens: int
    ) -> torch.Tensor:
        # Get the image embeddings from the Vision Encoder
        image_embeds = self.vision_encoder(img_array)

        # Check if image embeddings are valid
        self._check_image_embeddings(image_embeds)

        # Generate new tokens using the Language Model Decoder
        generated_tokens = self.decoder.generate(
            idx=idx, img_embeds=image_embeds, max_new_tokens=max_new_tokens
        )
        return generated_tokens
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cpu


In [73]:
n_embed, num_hiddens, vocab_size, num_heads, n_layer = 128, 512, 1000, 8, 8
image_embed_dim = num_hiddens
img_size = 96
patch_size = 16
num_blocks = 2

n_layer, block_size, num_hiddens = 8, 32, 512

# Initialize the model
model = VisionLanguageModel(
    n_embed=n_embed,
    img_embed_dim=image_embed_dim,
    vocab_size=vocab_size,
    n_layer=n_layer,
    img_size=img_size,
    patch_size=patch_size,
    num_heads=num_heads,
    num_blocks=num_blocks,
    emb_dropout=0.1,
    block_dropout=0.1,
)
model.to(device)

# Create dummy data with correct dimensions
dummy_img = torch.randn(1, 3, img_size, img_size).to(
    device
)  # Correct shape for image input
dummy_idx = torch.randint(0, vocab_size, (1, block_size)).to(
    device
)  # Correct shape for text input

# Forward pass to initialize all parameters
try:
    output = model(dummy_img, dummy_idx)  # Output for debugging
    print("Output from initialization forward pass:", output)
except RuntimeError as e:
    print(f"Runtime Error during forward pass: {str(e)}")
    print("Check layer configurations and input shapes.")
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Output from initialization forward pass: tensor([[[-0.6109, -0.6073,  0.3757,  ...,  0.7373, -0.3653,  0.2470],
         [-1.0122,  0.9071,  1.0353,  ...,  0.7759, -0.5381, -0.6131],
         [-0.1407, -0.8277,  0.5460,  ...,  1.3337, -0.1829,  0.0894],
         ...,
         [-0.6025,  1.1103,  1.0675,  ..., -0.9632, -0.1273, -0.4883],
         [-0.0513, -1.3979, -0.6534,  ...,  0.4909, -0.4671, -0.0014],
         [ 0.3264,  0.5641, -0.2365,  ...,  0.5534, -1.0300,  0.7593]]],
       grad_fn=<ViewBackward0>)


In [74]:
# TRAIN
num_epochs = 1
img_arrays = torch.randn(100, 3, img_size, img_size).to(
    device
)

idxs = torch.randint(0, vocab_size, (100, block_size)).to(
    device
)

targetss = torch.randint(0, vocab_size, (100, block_size)).to(
    device
)

batch_size = 1
num_batches = len(img_arrays) // batch_size

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for i in range(num_batches):
        start = i * batch_size
        end = start + batch_size

        img_array = img_arrays[start:end]
        idx = idxs[start:end]
        targets = targetss[start:end]

        optimizer.zero_grad()
        logits, loss = model(img_array, idx, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / num_batches
    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}")

torch.save(model.state_dict(), "my_model.pth")
print("Model saved to my_model.pth")


Epoch 1 - Loss: 7.0728
Model saved to my_model.pth


#1.2 base model vs ref_model load

In [75]:
n_embed, num_hiddens, num_heads, n_layer = 128, 512, 8, 8
image_embed_dim = num_hiddens
img_size = 96
patch_size = 16
num_blocks = 2

n_layer, block_size, num_hiddens = 8, 32, 512

# Initialize the model
vlm = VisionLanguageModel(
    n_embed=n_embed,
    img_embed_dim=image_embed_dim,
    vocab_size=tokenizer.vocab_size(),
    n_layer=n_layer,
    img_size=img_size,
    patch_size=patch_size,
    num_heads=num_heads,
    num_blocks=num_blocks,
    emb_dropout=0.1,
    block_dropout=0.1,
)
device = torch.device('cpu')
vlm.to(device)

# Optimizer, chọn bộ phù hợp, chưa thử nhiều nên không bt bộ nào tốt
# optimizer = torch.optim.AdamW(vlm.parameters(), lr=1e-4)
optimizer = torch.optim.SGD(vlm.parameters(), lr=0.001, momentum=0.9)


In [76]:
# Create DataLoader
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset2, batch_size=8, shuffle=True, collate_fn=collate_fn2)

In [77]:
# ... (previous code) ...
import torch.optim as optim
from tqdm import tqdm
# Use AdamW optimizer
optimizer = torch.optim.AdamW(vlm.parameters(), lr=1e-4)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)  # Reduce LR by a factor of 0.1 every 3 epochs

# Gradient accumulation
accumulation_steps = 2

vlm.train()
for epoch in range(10):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")
    total_loss = 0
    for i, (imgs, input_ids, target_ids) in enumerate(pbar):
        input_ids = input_ids.to(device)
        imgs = imgs.to(device)
        target_ids = target_ids.to(device)

        _, loss = vlm(imgs, input_ids, targets=target_ids)
        loss = loss / accumulation_steps  # Normalize the loss
        loss.backward()

        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item()
        pbar.set_postfix({"loss": loss.item()})

    scheduler.step()  # Update learning rate

    print(f"Epoch {epoch+1} - Avg Loss: {total_loss / len(dataloader):.4f}")

Epoch 1:   0%|                                                                                                                               | 0/156 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (197) must match the size of tensor b (37) at non-singleton dimension 1

In [None]:
torch.save(vlm.state_dict(), "model_10.pth")
print("Model saved to model.pth")

Model saved to model.pth


# 1.1 SFT base model

In [81]:
model = VisionLanguageModel(
    n_embed=n_embed,
    img_embed_dim=image_embed_dim,
    vocab_size=tokenizer.vocab_size(),
    n_layer=n_layer,
    img_size=img_size,
    patch_size=patch_size,
    num_heads=num_heads,
    num_blocks=num_blocks,
    emb_dropout=0.1,
    block_dropout=0.1,
)
model.to('cpu')
model.load_state_dict(torch.load("my_model.pth", map_location=torch.device('cpu')))

RuntimeError: Error(s) in loading state_dict for VisionLanguageModel:
	size mismatch for decoder.token_embedding_table.weight: copying a param with shape torch.Size([1000, 128]) from checkpoint, the shape in current model is torch.Size([128000, 128]).
	size mismatch for decoder.lm_head.weight: copying a param with shape torch.Size([1000, 128]) from checkpoint, the shape in current model is torch.Size([128000, 128]).
	size mismatch for decoder.lm_head.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([128000]).

In [None]:
model.train()
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
for epoch in range(num_epochs):
    for batch in dataloader:
        inputs = {k: v.squeeze(1).to(device) for k, v in batch.items()}
        outputs = model(**inputs)
        loss = loss_fn(outputs.logits.view(-1, vocab_size), inputs["labels"].view(-1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


#2 Reward model

## Reward

In [82]:
# Create DataLoader
from torch.utils.data import DataLoader
dataloader2 = DataLoader(dataset2, batch_size=4, shuffle=True, collate_fn=collate_fn2)
dataset2[4]

{'image': tensor([[[0.2196, 0.2353, 0.3569,  ..., 0.8588, 0.8549, 0.8471],
          [0.2510, 0.3451, 0.4078,  ..., 0.8667, 0.8588, 0.8510],
          [0.3725, 0.4353, 0.4353,  ..., 0.8706, 0.8627, 0.8549],
          ...,
          [0.5294, 0.5490, 0.5765,  ..., 0.6314, 0.6667, 0.6667],
          [0.5490, 0.5765, 0.5882,  ..., 0.6196, 0.6353, 0.6392],
          [0.5569, 0.5922, 0.5922,  ..., 0.6392, 0.6353, 0.6275]],
 
         [[0.2353, 0.2353, 0.3490,  ..., 0.9922, 0.9922, 0.9882],
          [0.2745, 0.3569, 0.4157,  ..., 0.9922, 0.9922, 0.9882],
          [0.4118, 0.4706, 0.4745,  ..., 0.9961, 0.9922, 0.9922],
          ...,
          [0.1608, 0.1765, 0.2000,  ..., 0.5922, 0.6275, 0.6275],
          [0.1725, 0.2000, 0.2039,  ..., 0.5804, 0.5961, 0.6000],
          [0.1804, 0.2118, 0.2039,  ..., 0.6000, 0.5961, 0.5882]],
 
         [[0.2510, 0.2627, 0.3882,  ..., 0.9961, 0.9922, 0.9882],
          [0.3020, 0.3922, 0.4667,  ..., 0.9961, 0.9961, 0.9922],
          [0.4510, 0.5137, 0.53

In [83]:
class DecoderLanguageModel(nn.Module):
    def __init__(
        self,
        n_embed: int,
        img_embed_dim: int,
        vocab_size: int,
        num_heads: int,
        n_layer: int,
        use_images: bool = False,
    ) -> None:
        super().__init__()

        self.use_images = use_images

        # Token embedding table
        self.token_embedding_table = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=n_embed
        )

        # Position embedding table
        self.position_embedding_table = nn.Embedding(
            num_embeddings=1000, embedding_dim=n_embed
        )

        if use_images:
            # Image projection layer to align image embeddings with text embeddings
            self.image_projection = MultiModalProjector(
                n_embed=n_embed, img_embed_dim=img_embed_dim
            )

        # Stack of transformer decoder blocks
        self.blocks = nn.Sequential(
            *[
                Block(n_embed=n_embed, num_heads=num_heads, is_decoder=True)
                for _ in range(n_layer)
            ]
        )

        # Final layer normalization
        self.ln_f = nn.LayerNorm(normalized_shape=n_embed)

        # Language modeling head
        self.lm_head = nn.Linear(in_features=n_embed, out_features=vocab_size)

    def forward(
        self,
        idx: torch.Tensor,
        img_embeds: torch.Tensor = None,
        targets: torch.Tensor = None,
        return_hidden:bool = True,
    ) -> torch.Tensor:
        # Get token embeddings from the input indices
        tok_emb = self.token_embedding_table(idx)

        if self.use_images:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(img_embeds).unsqueeze(1)
            tok_emb = torch.cat([img_emb, tok_emb], dim=1)

        # Get position embeddings
        pos_emb = self.position_embedding_table(
            torch.arange(tok_emb.size(1), device=idx.device)
        )

        # Add position embeddings to token embeddings
        x = tok_emb + pos_emb

        # Pass through the transformer decoder blocks
        x = self.blocks(x)

        # Apply final layer normalization
        x = self.ln_f(x)
        if return_hidden:
            return x  # Return hidden state (B, T, D)
        # Get the logits from the language modeling head
        logits = self.lm_head(x)

        if targets is not None:
            if self.use_images and img_embeds is not None:
                # Prepare targets by concatenating a dummy target for the image embedding
                batch_size = idx.size(0)
                targets = torch.cat(
                    [
                        torch.full(
                            (batch_size, 1), -100, dtype=torch.long, device=idx.device
                        ),
                        targets,
                    ],
                    dim=1,
                )

            # Compute the cross-entropy loss
            loss = F.cross_entropy(
                input=logits.view(-1, logits.size(-1)),
                target=targets.view(-1),
                ignore_index=-100,
            )
            return logits, loss

        return logits

    def generate(
        self, idx: torch.Tensor, img_embeds: torch.Tensor, max_new_tokens: int
    ) -> torch.Tensor:
        # Get the batch size and sequence length
        B, T = idx.shape

        # Initialize the generated sequence with the input indices
        generated = idx
        tok_emb = self.token_embedding_table(idx)

        if self.use_images and img_embeds is not None:
            # Project and concatenate image embeddings with token embeddings
            img_emb = self.image_projection(img_embeds).unsqueeze(1)
            current_output = torch.cat([img_emb, tok_emb], dim=1)
        else:
            current_output = tok_emb

        # Generate new tokens iteratevely
        for i in range(max_new_tokens):
            # Get the current sequence length
            T_current = current_output.shape[1]

            # Get position embeddings for the current sequence length
            current_pos_emb = self.position_embedding_table(
                torch.arange(T_current, device=idx.device)
            ).unsqueeze(0)

            # Add position embeddings to the current output
            current_output += current_pos_emb

            # Pass through the transformer decoder blocks
            for block in self.blocks:
                current_output = block(current_output)

            # Get the logits for the last token
            logits = self.lm_head(current_output[:, -1, :])

            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample the next token based on the probability
            idx_next = torch.multinomial(input=probs, num_samples=1)

            # Concatenate the generated token to the generated sequence
            generated = torch.cat([generated, idx_next], dim=1)

            # Get the embeddings for the generated token
            idx_next_emb = self.token_embedding_table(idx_next)

            # Concatenate the generated token embeddings to the current output
            current_output = torch.cat([current_output, idx_next_emb], dim=1)

        return generated


In [84]:
import torch.nn as nn

class RewardModel(nn.Module):
    def __init__(self, vlm: VisionLanguageModel, hidden_dim: int = 768):
        super().__init__()
        self.vlm = vlm

        # You can modify this head (e.g., deeper MLP) if needed
        self.reward_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    def forward(self, img_array: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
        # Encode image using ViT
        image_embeds = self.vlm.vision_encoder(img_array)

        # Get hidden states from decoder
        hidden_states = self.vlm.decoder(
            idx=input_ids,
            img_embeds=image_embeds,
            return_hidden=True  # This is the key addition
        )  # (B, T+1, D) — if image embed is prepended

        # Option 1: Use image token (first token after img_embeds)
        pooled = hidden_states[:, 0, :]  # (B, D)

        # Option 2: Use last token — hidden_states[:, -1, :]
        # Option 3: Mean pooling — hidden_states.mean(dim=1)

        reward = self.reward_head(pooled).squeeze(-1)  # (B,)
        return reward


In [86]:
# Instantiate base VLM
vlm = VisionLanguageModel(
    n_embed=768,
    img_embed_dim=768,
    vocab_size=30522,
    n_layer=6,
    img_size=224,
    patch_size=16,
    num_heads=12,
    num_blocks=6,
    emb_dropout=0.1,
    block_dropout=0.1
).to(device)

# Wrap in RewardModel
reward_model = RewardModel(vlm=vlm, hidden_dim=768).to(device)

# Example input
img_tensor = torch.randn(4, 3, 224, 224).to(device)      # batch of 2 images
text_input = torch.randint(0, 30522, (4, 64)).to(device)  # batch of 2 text sequences

# Forward pass to get reward values
rewards = reward_model(img_tensor, text_input)
print(img_tensor.shape)
print(text_input.shape)
print("Rewards:", rewards)


torch.Size([4, 3, 224, 224])
torch.Size([4, 64])
Rewards: tensor([0.1923, 0.0498, 0.1760, 0.1713], grad_fn=<SqueezeBackward1>)


In [87]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

t = torch.tensor([1,2], device=device)

In [93]:
# Instantiate base VLM
vlm = VisionLanguageModel(
    n_embed=768,
    img_embed_dim=768,
    vocab_size=tokenizer.vocab_size(),
    n_layer=6,
    img_size=224,
    patch_size=16,
    num_heads=12,
    num_blocks=6,
    emb_dropout=0.1,
    block_dropout=0.1
).to(device)

# Wrap in RewardModel
reward_model = RewardModel(vlm=vlm, hidden_dim=768).to(device)

for epoch in range(1):
        pbar = tqdm(dataloader2, desc=f"Epoch {epoch + 1}")
        total_loss = 0
        for i, (imgs, chosen_ids, rejected_ids) in enumerate(pbar):
            imgs = imgs.to(device)
            chosen_ids = chosen_ids.to(device)
            rejected_ids = rejected_ids.to(device)
            print(imgs.shape)
            print(chosen_ids.shape)
            print(rejected_ids.shape)
            r_chosen = reward_model(imgs, chosen_ids)  # (B,) [B] [B,C, H,W]
            r_rejected = reward_model(imgs, rejected_ids)  # (B,)

Epoch 1:   0%|                                                                              | 0/311 [00:00<?, ?it/s]

torch.Size([4, 3, 224, 224])
torch.Size([4, 256])
torch.Size([4, 256])


Epoch 1:   0%|▏                                                                     | 1/311 [00:02<13:55,  2.69s/it]

torch.Size([4, 3, 224, 224])
torch.Size([4, 256])
torch.Size([4, 256])


Epoch 1:   1%|▍                                                                     | 2/311 [00:05<13:10,  2.56s/it]

torch.Size([4, 3, 224, 224])
torch.Size([4, 256])
torch.Size([4, 256])


Epoch 1:   1%|▍                                                                     | 2/311 [00:07<18:19,  3.56s/it]


KeyboardInterrupt: 

In [None]:
from torch.nn.functional import sigmoid
from torch.optim import AdamW
import torch.nn as nn
import os
from tqdm import tqdm
def train_reward_model(reward_model, df, epochs=10, batch_size=4, lr=1e-5, tokenizer=None):
    optimizer = AdamW(reward_model.parameters(), lr=lr)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    reward_model.to(device)
    reward_model.train()
    for epoch in range(epochs):
        pbar = tqdm(df, desc=f"Epoch {epoch + 1}")
        total_loss = 0
        for i, (imgs, chosen_ids, rejected_ids) in enumerate(pbar):
            imgs = imgs.to(device)
            chosen_ids = chosen_ids.to(device)
            rejected_ids = rejected_ids.to(device)
            # print(imgs.shape)
            # print(chosen_ids.shape)
            # print(rejected_ids.shape)
            r_chosen = reward_model(imgs, chosen_ids)  # (B,) [B] [B,C, H,W]
            r_rejected = reward_model(imgs, rejected_ids)  # (B,)

            # Pairwise ranking loss
            loss = -torch.log(sigmoid(r_chosen - r_rejected)).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss / len(dataloader):.4f}")
vlm = VisionLanguageModel(
    n_embed=768,
    img_embed_dim=768,
    vocab_size=tokenizer.vocab_size(),
    n_layer=6,
    img_size=224,
    patch_size=16,
    num_heads=12,
    num_blocks=6,
    emb_dropout=0.1,
    block_dropout=0.1
).to(device)
reward_model = RewardModel(vlm=vlm, hidden_dim=768).to(device)
train_reward_model(reward_model, dataloader2 , epochs=10, batch_size=4, lr=1e-5, tokenizer = tokenizer)

Epoch 1: 100%|████████████████████████████████████████████████████████████████████| 311/311 [23:47<00:00,  4.59s/it]


Epoch 1/10 | Loss: 1.3983


Epoch 2: 100%|████████████████████████████████████████████████████████████████████| 311/311 [23:09<00:00,  4.47s/it]


Epoch 2/10 | Loss: 1.3886


Epoch 3: 100%|████████████████████████████████████████████████████████████████████| 311/311 [23:11<00:00,  4.47s/it]


Epoch 3/10 | Loss: 1.3842


Epoch 4: 100%|████████████████████████████████████████████████████████████████████| 311/311 [23:15<00:00,  4.49s/it]


Epoch 4/10 | Loss: 1.3869


Epoch 5: 100%|██████████████████████████████████████████████████████████████████| 311/311 [1:00:25<00:00, 11.66s/it]


Epoch 5/10 | Loss: 1.3867


Epoch 6:  40%|██████████████████████████▎                                       | 124/311 [47:57<1:13:16, 23.51s/it]

# 3 PPO model

## Gen cap with log prob

In [None]:
import torch
import torch.nn.functional as F

def generate_caption_with_log_probs(vlm_model, tokenizer, image_tensor, prompt_ids, max_length=20):
    """
    Sinh caption từ VLM + tính log-prob từng token (dùng trong PPO)

    Args:
        vlm_model: Vision-Language model
        tokenizer: Tokenizer dùng để encode/decode
        image_tensor: Tensor ảnh, shape (B, C, H, W)
        prompt_ids: Tensor (B, T), token ids của prompt
        max_length: Độ dài caption tối đa

    Returns:
        output_ids: (B, max_len), ids caption được sinh
        log_probs: (B, max_len), log-prob từng token
    """
    device = image_tensor.device
    B = image_tensor.size(0)
    output_ids = []
    log_probs = []

    input_ids = prompt_ids.clone()
    # ....... fix lai
    for _ in range(max_length):
        logits = vlm_model(image_tensor, input_ids)  # (B, T, V)
        next_token_logits = logits[:, -1, :]  # (B, V)

        probs = F.softmax(next_token_logits, dim=-1)
        log_prob = F.log_softmax(next_token_logits, dim=-1)

        # Sampling (để PPO có đa dạng response)
        next_tokens = torch.multinomial(probs, num_samples=1)  # (B, 1)

        # Lưu log-probs theo token đã chọn
        selected_log_probs = log_prob.gather(1, next_tokens)  # (B, 1)

        # Append kết quả
        output_ids.append(next_tokens)
        log_probs.append(selected_log_probs)

        # Update input
        input_ids = torch.cat([input_ids, next_tokens], dim=1)

    # Kết quả dạng (B, max_len)
    output_ids = torch.cat(output_ids, dim=1)
    log_probs = torch.cat(log_probs, dim=1)

    return output_ids, log_probs
# Example:
# # Dữ liệu input
# image_tensor = encode_image(image).unsqueeze(0).to(device)
# prompt_ids = tokenizer("A photo of", return_tensors="pt").input_ids.to(device)

# # Sinh caption + log_probs
# caption_ids, log_probs = generate_caption_with_log_probs(vlm_model, tokenizer, image_tensor, prompt_ids)

# # Decode caption
# caption_text = tokenizer.decode(caption_ids[0], skip_special_tokens=True)
# print("Caption:", caption_text)


## Compute loss ppo

In [None]:
import torch
import torch.nn.functional as F

def compute_ppo_loss(old_log_probs, new_log_probs, rewards, clip_epsilon=0.2):
    """
    Tính PPO loss cho policy model

    Args:
        old_log_probs: (B, T) log-probs từ policy cũ
        new_log_probs: (B, T) log-probs từ policy mới
        rewards: (B,) reward scalar cho từng sample
        clip_epsilon: hệ số clip (PPO)

    Returns:
        loss: scalar PPO loss
    """
    # Chuyển rewards từ (B,) thành (B, T) để match với log_probs
    rewards = rewards.unsqueeze(1).expand_as(old_log_probs)  # (B, T)

    # Tính ratio
    ratio = torch.exp(new_log_probs - old_log_probs)  # (B, T)

    # Tính advantage = reward (không có value baseline)
    advantage = rewards

    # PPO loss (clipped)
    unclipped = ratio * advantage
    clipped = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
    loss = -torch.min(unclipped, clipped).mean()

    return loss
#Example:
# loss = compute_ppo_loss(old_log_probs, new_log_probs, rewards)
# loss.backward()
# optimizer.step()


## ppo update

In [None]:
def ppo_update(policy_model, tokenizer, optimizer, reward_model, images, prompts, clip_epsilon=0.2, max_length=20):
    """
    Args:
        policy_model: mô hình VLM
        tokenizer: tokenizer của model
        optimizer: optimizer cho policy
        reward_model: mô hình reward đã huấn luyện
        images: (B, C, H, W)
        prompts: list[str]
        clip_epsilon: hệ số PPO clip
        max_length: độ dài caption tối đa

    Returns:
        loss: giá trị loss PPO sau 1 bước update
    """
    policy_model.train()

    # Encode prompt
    prompt_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(images.device)

    # === STEP 1: Generate captions with old policy ===
    with torch.no_grad():
        old_output_ids, old_log_probs = generate_caption_with_log_probs(
            policy_model, tokenizer, images, prompt_ids, max_length=max_length
        )

    # === STEP 2: Compute reward for each (image, prompt, caption) ===
    rewards = []
    for i in range(len(images)):
        prompt = prompts[i]
        caption = tokenizer.decode(old_output_ids[i], skip_special_tokens=True)
        reward = reward_model(images[i].unsqueeze(0), prompt, caption)  # scalar
        rewards.append(reward)
    rewards = torch.tensor(rewards).to(images.device)
# ????? chỗ này chưa update model nên chưa có old-new !
    # === STEP 3: Re-run policy to get new_log_probs ===
    new_output_ids, new_log_probs = generate_caption_with_log_probs(
        policy_model, tokenizer, images, prompt_ids, max_length=max_length
    )

    # === STEP 4: Compute PPO loss ===
    loss = compute_ppo_loss(
        old_log_probs=old_log_probs,
        new_log_probs=new_log_probs,
        rewards=rewards,
        clip_epsilon=clip_epsilon
    )

    # === STEP 5: Update ===
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()
# Example
# loss = ppo_update(
#     policy_model=vlm_model,
#     tokenizer=vlm_tokenizer,
#     optimizer=vlm_optimizer,
#     reward_model=reward_model_fn,
#     images=image_batch,
#     prompts=["A photo of", "A scene of"],
#     clip_epsilon=0.2
# )
# print("PPO loss:", loss)