# Multimodal Grounding with MLLM (LLaVA + PEFT)

This notebook implements a Multimodal Large Language Model (MLLM) aiming for visual accuracy and linguistic fluency, as outlined in the project plan.

## Key Components:
1.  **Architecture**: LLaVA-style (Frozen Vision Encoder + Trainable Projection + Frozen LLM).
2.  **Optimization**: PEFT (LoRA) for efficient fine-tuning.
3.  **Data**: Flickr8k dataset.
4.  **Grounding**: Analysis of Cross-Attention/Self-Attention maps.

In [None]:
# Install necessary libraries if not present
!pip install -q torch transformers peft accelerate bitsandbytes pillow matplotlib pandas pycocoevalcap

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    CLIPVisionModel, 
    CLIPImageProcessor,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Data Preparation (Flickr8k)
We load the Flickr8k dataset, formatting it into a visual instruction tuning format.

In [None]:
# Configuration (Paths from original notebook)
IMAGES_DIR = '/kaggle/input/flickr8k/Images/'
CAPTIONS_FILE = '/kaggle/input/flickr8k/captions.txt'

# Model Checkpoints (using accessible open-source models)
VISION_MODEL_CKPT = "openai/clip-vit-large-patch14-336"
LLM_CKPT = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Lightweight for demonstration, replaceable with Llama-2-7b etc.

# Hyperparameters
MAX_LENGTH = 128
BATCH_SIZE = 4
EPOCHS = 1

In [None]:
class Flickr8kDataset(Dataset):
    def __init__(self, captions_file, images_dir, tokenizer, image_processor, max_length=128):
        self.images_dir = images_dir
        self.df = pd.read_csv(captions_file)
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        
        # Visual Instruction Template
        self.prompt_template = "User: <image>\nDescribe this image.\nAssistant: "

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_file = row['image']
        caption = row['caption']
        
        # Load and process image
        image_path = os.path.join(self.images_dir, image_file)
        image = Image.open(image_path).convert("RGB")
        pixel_values = self.image_processor(image, return_tensors="pt").pixel_values.squeeze(0)
        
        # Prepare Text
        full_text = self.prompt_template + str(caption)
        
        # Tokenize
        tokenized = self.tokenizer(
            full_text,
            return_tensors="pt",
            max_length=self.max_length,
            padding="max_length",
            truncation=True
        )
        
        input_ids = tokenized.input_ids.squeeze(0)
        attention_mask = tokenized.attention_mask.squeeze(0)
        
        # Create labels
        labels = input_ids.clone()
        
        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

## 2. Architecture Setup (LLaVA-like)

- **Vision Encoder**: CLIP ViT-L/14 (Frozen)
- **Projection Layer**: Linear layer mapping Visual Dim -> LLM Dim (Trainable)
- **LLM**: TinyLlama (Frozen, except LoRA adapters)

In [None]:
class SimpleMLLM(nn.Module):
    def __init__(self, vision_ckpt, llm_ckpt):
        super().__init__()
        
        # 1. Vision Encoder
        self.vision_encoder = CLIPVisionModel.from_pretrained(vision_ckpt)
        for param in self.vision_encoder.parameters():
            param.requires_grad = False # Freeze Vision Encoder
            
        # 2. LLM
        # Using 4-bit quantization for optimization if bitsandbytes is available and CUDA is present
        if torch.cuda.is_available():
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16
            )
            self.llm = AutoModelForCausalLM.from_pretrained(
                llm_ckpt, 
                quantization_config=bnb_config, 
                device_map={"": 0} 
            )
        else:
            print("WARNING: CUDA not available. Loading model on CPU without quantization.")
            self.llm = AutoModelForCausalLM.from_pretrained(
                llm_ckpt
            )
            
        # Disable standard gradient calculation for LLM (we will use LoRA)
        for param in self.llm.parameters():
            param.requires_grad = False
            
        # 3. Projection Layer (The "Bridge")
        vision_dim = self.vision_encoder.config.hidden_size
        llm_dim = self.llm.config.hidden_size
        
        self.projection = nn.Linear(vision_dim, llm_dim)
        # Projection is trainable

    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        # 1. Extract Visual Features
        with torch.no_grad():
            vision_outputs = self.vision_encoder(pixel_values)
            image_features = vision_outputs.last_hidden_state # [Batch, Patches, Dim]
        
        # 2. Project Visual Features to LLM Space
        image_embeddings = self.projection(image_features) # [Batch, Patches, LLM_Dim]
        
        # 3. Embed Text
        inputs_embeds = self.llm.get_input_embeddings()(input_ids)
        
        # 4. Concatenate: [Image_Embeds, Text_Embeds]
        combined_embeds = torch.cat([image_embeddings, inputs_embeds], dim=1)
        
        # Adjust Attention Mask (1 for images)
        batch_size = pixel_values.shape[0]
        image_mask = torch.ones((batch_size, image_embeddings.shape[1]), device=device)
        combined_mask = torch.cat([image_mask, attention_mask], dim=1)
        
        # Adjust Labels (ignore images for loss calculation -> -100)
        if labels is not None:
            image_labels = torch.ones((batch_size, image_embeddings.shape[1]), device=device, dtype=torch.long) * -100
            combined_labels = torch.cat([image_labels, labels], dim=1)
        else:
            combined_labels = None
            
        # 5. Pass through LLM
        outputs = self.llm(
            inputs_embeds=combined_embeds,
            attention_mask=combined_mask,
            labels=combined_labels,
            output_attentions=True 
        )
        
        return outputs

## 3. Optimization and Training: PEFT (LoRA)

We apply LoRA to the LLM decoder to enable parameter-efficient fine-tuning.

In [None]:
# Initialize Components
tokenizer = AutoTokenizer.from_pretrained(LLM_CKPT)
tokenizer.pad_token = tokenizer.eos_token
image_processor = CLIPImageProcessor.from_pretrained(VISION_MODEL_CKPT)

model = SimpleMLLM(VISION_MODEL_CKPT, LLM_CKPT)
model.to(device)

# Apply LoRA
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"]
)

model.llm = get_peft_model(model.llm, peft_config)
model.llm.print_trainable_parameters()

In [None]:
# Training Setup
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
dataset = Flickr8kDataset(CAPTIONS_FILE, IMAGES_DIR, tokenizer, image_processor)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Mock Training Loop
model.train()
for epoch in range(EPOCHS):
    print(f"Starting Epoch {epoch+1}")
    # Training loop would go here
    print("Epoch completed (Mock)")

## 4. Enforcing Grounding & Evaluation

We use the **Attention Maps** from the model to visualize which parts of the image the model focuses on when generating specific words.

In [None]:
def analyze_attention(model, image_path, prompt, tokenizer, image_processor):
    model.eval()
    
    # Prepare inputs
    image = Image.open(image_path).convert("RGB")
    pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
    
    text_input = "User: <image>\n" + prompt + "\nAssistant: "
    inputs = tokenizer(text_input, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(pixel_values, inputs.input_ids, inputs.attention_mask)
    
    # Extract LAST layer attention
    last_layer_attentions = outputs.attentions[-1] # [Batch, Heads, Seq, Seq]
    avg_attention = last_layer_attentions.mean(dim=1).squeeze(0) # [Seq, Seq]
    
    print("Attention analysis complete.")


## 5. Evaluation Metrics (CIDEr, SPICE, BLEU)

We use `pycocoevalcap` to compute standard captioning metrics. This requires a dictionary of reference captions and a dictionary of generated hypotheses.

In [None]:
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice

def score_captions(ref, hypo):
    """
    ref: dictionary of reference captions (id -> list of strings)
    hypo: dictionary of hypothesis captions (id -> list of strings)
    """
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        (Meteor(), "METEOR"),
        (Rouge(), "ROUGE_L"),
        (Cider(), "CIDEr"),
        (Spice(), "SPICE")
    ]
    
    final_scores = {}
    for scorer, method in scorers:
        # compute_score returns (score, scores)
        score, scores = scorer.compute_score(ref, hypo)
        if isinstance(method, list):
            for m, s in zip(method, score):
                final_scores[m] = s
        else:
            final_scores[method] = score
            
    return final_scores

def evaluate_metrics_demo():
    print("Running Evaluation Metrics on Mock Data...")
    
    # Mock Data (Format required by pycocoevalcap)
    # Keys must be integers (image_ids)
    references = {
        0: ["a photo of a dog running", "the dog runs fast", "a puppy running"],
        1: ["a cat sitting on a sofa", "a kitten resting", "cat on couch"]
    }
    
    hypotheses = {
        0: ["a dog runs"],
        1: ["a cat plays"]
    }
    
    try:
        scores = score_captions(references, hypotheses)
        for metric, score in scores.items():
            print(f"{metric}: {score:.4f}")
    except Exception as e:
        print(f"Evaluation failed (likely missing dependencies or java): {e}")
        print("Please ensure pycocoevalcap and Java are installed.")

evaluate_metrics_demo()