![BridgingAI Logo](../bridgingai_logo.png)

# Deep Learning - Exercise 6.2: Efficient Fine-Tuning 

---
1. [Finetuning pretrained models for Image Captioning](#implementation)

2. [Experiments](#experiments)
   
3. [References](#references)

---

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoModel

from config import ExperimentConfig
from trainer import Trainer
from utils import compute_bleu

# silence warnings and avoid deadlocks due to HF tokenizer
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In this assignment, your will train an image captioning model by integrating a vision backbone with a language model and fine-tuning it using LoRA. You will use the [Flickr30k](https://huggingface.co/datasets/nlphuji/flickr30k) dataset for this task. Running the experiments in this notebook requires around 6GB of disk space for the dataset and the model checkpoints, and a GPU to achieve reasonable training times.

The vision backbone is a ViT that transforms images into a sequence of embeddings, then the language model will process both the image embeddings and the text embeddings to generate captions. This is similar to translation tasks, where the source text is replaced by the image embeddings.

We choose [AIMv2](https://huggingface.co/apple/aimv2-large-patch14-224) as the vision backbone, and [SmolLM2](https://huggingface.co/HuggingFaceTB/SmolLM2-135M) as the language model. AIMv2 has 300M parameters and SmolLM2 has 135M parameters (same as the smallest GPT-2 model). The overall architecture closely resembles Figure 1 in the [PaliGemma paper](https://arxiv.org/abs/2407.07726), where the output of the vision backbone is projected to the same dimension as the language model via a linear layer. The combined embeddings are then processed by the language model to produce captions.

This assignment demonstrates that:
1. With transformers, it is trivial to combine different modalities (images and text).
2. LoRA enables us to fine-tune large models even with limited resources.

# 1. Implementation Tasks <a id="implementation"></a>

A critical aspect of combining the vision backbone with the language model is defining an appropriate attention mask between the image embeddings and the text embeddings. This attention mask is illustrated in Figure 3 of [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100). In this setup, image embeddings attend only to other image embeddings and not to the text, while text embeddings attend to all image embeddings as well as the preceding text embeddings.

This masking strategy is widely used in tasks such as image captioning and visual question answering. Many well-known models, including PaliGemma, adopt this approach to effectively handle the interplay between vision and language representations.

---

**TODO**: 
- Complete the `image_text_attention_mask` method in `MaskMixin` class to implement the attention mask between the image embeddings and the text embeddings.
- Note that you should use the constants `ATTEND` and `IGNORE` to fill the attention mask.


In [None]:
class MaskMixin:
    """Handles generation of attention masks for huggingface (especially LLaMA) models that accept image and text embeddings.

    Some notes:
    - Attention mask produced by HF tokenizer: torch.int64, {0, 1}
    - Attention mask accepted by PyTorch sdpa attention: float or bool
    - But attention mask used by HF LLaMA models: float, {-inf, 0} since it's additive
    """

    ATTEND = 0
    IGNORE = float("-inf")

    @staticmethod
    def check_mask(mask):
        assert (
            mask.dtype == torch.float32
        ), f"Attention mask must be float32, got {mask.dtype}"
        assert torch.all(
            (mask == MaskMixin.ATTEND) | (mask == MaskMixin.IGNORE)
        ), f"Attention mask must be {MaskMixin.ATTEND} or {MaskMixin.IGNORE}, got {mask.unique()}"

    @staticmethod
    def image_text_attention_mask(batch_size, img_len, text_len, device):
        """Create an attention mask for [image, text] -> [image, text] attention.

        Args:
            batch_size: int, the number of samples in the batch
            img_len: int, the length of the image token sequence
            text_len: int, the maximum length of the text token sequence

        Returns:
            attn_mask: float tensor of shape (batch_size, 1, img_len+text_len, img_len+text_len)
                It looks like this (for img_len=2 and text_len=3):
                    [ 0,  0, -inf, -inf, -inf]
                    [ 0,  0, -inf, -inf, -inf]
                    [ 0,  0,    0, -inf, -inf]
                    [ 0,  0,    0,    0, -inf]
                    [ 0,  0,    0,    0,    0]
                where image can see all image but cannot see text, and text can see image,
                and current text can see all previous text.
        """
        # YOUR CODE HERE
        raise NotImplementedError()
        attn_mask = attn_mask.to(device)
        return attn_mask

    @staticmethod
    def padding_mask(batch_size, img_len, text_padding_mask, device):
        """Extend the text padding mask to include the image tokens (which are always attended to).

        Args:
            batch_size: int, the number of samples in the batch
            img_len: int, the length of the image token sequence
            text_padding_mask: float tensor of shape (batch_size, text_seq_len)

        Returns:
            padding_mask: tensor of shape (batch_size, 1, 1, img_len+text_len)
        """
        assert text_padding_mask.device == device, "Device mismatch"
        MaskMixin.check_mask(text_padding_mask)

        image_padding_mask = torch.full((batch_size, 1, 1, img_len), MaskMixin.ATTEND)
        image_padding_mask = image_padding_mask.float().to(device)

        text_padding_mask = text_padding_mask[:, None, None, :]
        padding_mask = torch.cat([image_padding_mask, text_padding_mask], dim=-1)
        padding_mask = padding_mask
        return padding_mask

    @staticmethod
    def convert_to_float(mask):
        """Converts the attention mask to float and expected values for LLaMA models."""
        if mask.dtype == torch.int64:
            return_mask = mask.clone().float()
            return_mask[mask == 1] = MaskMixin.ATTEND
            return_mask[mask == 0] = MaskMixin.IGNORE
            MaskMixin.check_mask(return_mask)
            return return_mask
        elif mask.dtype == torch.bool:
            return_mask = mask.clone().float()
            return_mask[mask == 1] = MaskMixin.ATTEND
            return_mask[mask == 0] = MaskMixin.IGNORE
            MaskMixin.check_mask(return_mask)
            return return_mask
        else:
            raise ValueError(f"Unknown mask dtype: {mask.dtype}")

    @staticmethod
    def combine_masks(mask1, mask2):
        """Encode the logical AND operation between two masks."""
        MaskMixin.check_mask(mask1)
        MaskMixin.check_mask(mask2)
        return mask1 + mask2

In [None]:
class ImageCaptioningModel(nn.Module, MaskMixin):
    """Image captioning model that use a language model to process both image embeddings and text tokens."""

    def __init__(self, config):
        super().__init__()
        self.train_encoder = config.train_encoder
        self.processor = config.processor

        self.image_encoder = AutoModel.from_pretrained(
            config.image_encoder_checkpoint,
            trust_remote_code=True,
        )

        self.decoder = AutoModelForCausalLM.from_pretrained(
            config.decoder_checkpoint,
            trust_remote_code=True,
        )

        # Extract the text embedding layer from the language model
        self.text_embedding = self.decoder.get_input_embeddings()

        # Project the image embeddings to the same size as the text embeddings
        img_hidden_size = config.encoder_hidden_size
        text_hidden_size = config.decoder_hidden_size
        self.image_out_proj = nn.Linear(img_hidden_size, text_hidden_size)
        print("Image encoder hidden size:", img_hidden_size)
        print("Decoder hidden size:", text_hidden_size)

    @property
    def device(self):
        return next(self.parameters()).device

    def encode_images(self, images=None, image_features=None):
        """Encode images or process pre-computed features.

        Args:
            images: Optional float tensor of shape (batch_size, 3, height, width)
            image_features: Optional float tensor of shape (batch_size, seq_len, hidden_size). Must be the output from the same image encoder architecture

        Returns:
            image_embeds: float tensor of shape (batch_size, seq_len, hidden_size)
        """
        if images is not None and image_features is not None:
            raise ValueError("Only one of images or image_features should be provided")

        if images is None and image_features is None:
            raise ValueError("Either images or image_features must be provided")

        if image_features is not None:
            return image_features

        if images is not None:
            with torch.set_grad_enabled(self.train_encoder):
                image_features = self.image_encoder(
                    pixel_values=images
                ).last_hidden_state

            if image_features.dim() == 4:
                # this means (B, C, H, W). We need to flatten it to (B, H*W, C)
                B, C, H, W = image_features.shape
                image_features = image_features.permute(0, 2, 3, 1)
                image_features = image_features.reshape(B, H * W, C)
                image_features = image_features.contiguous()

            img_embeds = self.image_out_proj(image_features)

            return img_embeds

    def forward(self, texts, text_padding_mask, images=None, image_features=None):
        """
        Args:
            texts: int tensor of shape (batch_size, seq_len)
            text_padding_mask: tensor of shape (batch_size, seq_len) generted by the tokenizer
            images: flaot tensor of shape (batch_size, 3, 224, 224)
            image_features: float tensor of shape (batch_size, seq_len, hidden_size)

        Returns:
            logits: float tensor of shape (batch_size, seq_len, vocab_size)
        """
        # encode images
        image_embeds = self.encode_images(images, image_features)

        # encode texts
        text_embeds = self.text_embedding(texts)

        # concatenate image and text embeddings
        inputs_embeds = torch.cat([image_embeds, text_embeds], dim=1)

        # create attention mask
        B = image_embeds.shape[0]
        img_len = image_embeds.shape[1]
        text_len = text_embeds.shape[1]
        attn_mask = self.image_text_attention_mask(B, img_len, text_len, self.device)
        # logical and
        text_padding_mask = self.convert_to_float(text_padding_mask)
        padding_mask = self.padding_mask(B, img_len, text_padding_mask, self.device)
        attn_mask = self.combine_masks(attn_mask, padding_mask)

        # forward pass through the decoder
        outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=attn_mask)
        return outputs.logits[:, img_len:, :].contiguous()

    @torch.no_grad()
    def generate(self, images, max_length=50):
        """Given a batch of images, generate captions using greedy decoding.

        Args:
            images: list of PIL images
            max_length: int, the maximum length of the generated sequence

        Returns:
            captions: list of predicted captions strings
            full_text: list of predicted captions strings with special tokens
        """
        B = len(images)
        # Convert images and start tokens to tensors
        start_token = [self.processor.tokenizer.bos_token for _ in range(B)]
        inputs = self.processor(images=images, text=start_token, return_tensors="pt")

        pixel_values = inputs["pixel_values"].to(self.device)
        generated_ids = inputs["input_ids"].to(self.device)
        text_padding_mask = inputs["attention_mask"].to(self.device)

        # encode images
        image_embeds = self.encode_images(images=pixel_values)

        # maintain a list of finished sequences
        is_finished = torch.zeros((B, 1), dtype=torch.bool, device=self.device)

        # generate captions
        for i in range(max_length):
            logits = self.forward(
                texts=generated_ids,
                text_padding_mask=text_padding_mask,
                image_features=image_embeds,
            )
            next_token_logits = logits[:, -1, :]
            next_token_id = next_token_logits.argmax(dim=-1, keepdim=True)

            # update finished sequences
            is_finished |= next_token_id == self.processor.tokenizer.eos_token_id
            next_token_id[is_finished] = self.processor.tokenizer.pad_token_id

            generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
            _next_mask = (~is_finished).to(torch.int64)
            text_padding_mask = torch.cat([text_padding_mask, _next_mask], dim=1)

            if is_finished.all():
                break

        # cut off the start token
        generated_ids = generated_ids[:, 1:]

        # decode the generated token ids
        captions = self.processor.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True
        )
        return captions

In [None]:
from peft import LoraConfig, get_peft_model


def build_model(config):
    # configures the model for finetuning with LoRA
    model = ImageCaptioningModel(config)
    if config.train_decoder is False:
        # Freeze the decoder
        model.decoder.requires_grad_(False)
    elif config.decoder_lora_modules is not None:
        lora_config = LoraConfig(
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            target_modules=config.decoder_lora_modules,
            lora_dropout=0.1,
            bias="none",
        )
        model.decoder = get_peft_model(model.decoder, lora_config)
    else:
        print("No LoRA modules for decoder provided")

    if config.train_encoder is False:
        # Freeze the image encoder
        model.image_encoder.requires_grad_(False)
    elif config.image_encoder_lora_modules is not None:
        # Train the image encoder modules with LoRA
        lora_config = LoraConfig(
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            target_modules=config.image_encoder_lora_modules,
            lora_dropout=0.1,
            bias="none",
        )
        model.image_encoder = get_peft_model(model.image_encoder, lora_config)
    else:
        print("No LoRA modules for image encoder provided")

    return model

# 2. Experiments <a id="experiments"></a>

In this section, you will train the image captioning model using the Flickr30k dataset. Run the following cells to fine-tune the model using LoRA. If your implementation is correct, you should achieve a validation loss of around 2.5 after 10k steps, and 2.1 after 40k steps.

Note that this experiement takes a long time to run. Training the model for 10k steps takes about 2 hours on a GPU. 

In [None]:
config = ExperimentConfig(
    "aimv2_smollm2_lora",
    encoder_hidden_size=1024,
    decoder_hidden_size=576,
    train_encoder=True,
    image_encoder_lora_modules=[
        "qkv",
        "proj",
    ],
    train_decoder=True,
    decoder_lora_modules=[
        "q_proj",
        "k_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    compute_bleu=False,
    max_steps=40000,
    save_model=True,
)
# if you have a powerful GPU, you can also try to finetune the full encoder
model = build_model(config)
trainer = Trainer(model, config)
trainer.run_experiment()

We can compute the BLEU score of a trained model. If trained for 40k steps, the BLEU score should be above 30. 

In [None]:
# If you want to load the saved model
# trainer.model.load_state_dict(torch.load(f"model_base.pth", weights_only=True))
trainer.model.eval()
bleu_score = compute_bleu(trainer.model, trainer.val_loader)
print(f"BLEU score: {bleu_score:.2f}")

# 3. References <a id="references"></a>

- [PaliGemma: Vision-and-Language Pretraining for Image Captioning](https://arxiv.org/abs/2407.07726)
- [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100)
- [AIMv2](https://huggingface.co/apple/aimv2-large-patch14-224)
- [SmolLM2](https://huggingface.co/HuggingFaceTB/SmolLM2-135M)