> **Note:** This notebook documents the development and experimentation process for the VLM. The final, refactored code for training and inference can be found in the `train.py` and `inference.py` files.

### Step 0: Setup

In [None]:
# Cell 1: Imports and Device Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import timm
import json
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm.notebook import tqdm
import os
import re
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### Step 1: New SPIN Attention and Decoder Architecture

In [None]:
# Cell 2: The Transformer Decoder Architecture with SPIN

# NEW: SPINMultiHeadAttention inspired by the paper
class SPINMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, top_k_heads=4):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.top_k_heads = top_k_heads # Number of image-attentive heads to keep
        
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
    
    def forward(self, Q, K, V, mask=None, is_cross_attention=False):
        B, target_len, _ = Q.size()
        _, source_len, _ = K.size()
        
        Q_proj = self.W_Q(Q).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        K_proj = self.W_K(K).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        V_proj = self.W_V(V).view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q_proj, K_proj.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask, torch.finfo(scores.dtype).min)
        
        attn = F.softmax(scores, dim=-1)
        
        # --- SPIN LOGIC --- #
        # This logic is applied only during cross-attention in the decoder
        if is_cross_attention and self.training is False: # Apply only during inference
            # Assuming the first part of the memory is the image patch tokens
            # This is a simplification; a more robust implementation would track token types.
            num_image_tokens = 197 # For ViT-Tiny (196 patch tokens + 1 CLS token)
            
            # Calculate the average attention paid to image tokens by each head
            # We only care about the attention from the last query token (the one we're generating from)
            image_attention_scores = attn[:, :, -1, :num_image_tokens].mean(dim=-1) # Shape: [Batch, num_heads]
            
            # Find the top-K heads with the highest image attention
            top_k_indices = torch.topk(image_attention_scores, self.top_k_heads, dim=-1).indices
            
            # Create a suppression mask
            suppression_mask = torch.zeros_like(image_attention_scores)
            suppression_mask.scatter_(1, top_k_indices, 1.0)
            suppression_mask = suppression_mask.view(B, self.num_heads, 1, 1)
            
            # Apply the mask to the value projection
            V_proj = V_proj * suppression_mask
        # --- END SPIN LOGIC --- #

        out = torch.matmul(attn, V_proj).transpose(1, 2).contiguous().view(B, -1, self.d_model)
        return self.W_O(out)

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = SPINMultiHeadAttention(d_model, num_heads)
        self.cross_attn = SPINMultiHeadAttention(d_model, num_heads)
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, memory, tgt_mask):
        # Self-attention: is_cross_attention=False (default)
        x = self.norm1(x + self.dropout(self.self_attn(x, x, x, mask=tgt_mask)))
        # Cross-attention: is_cross_attention=True to trigger SPIN
        x = self.norm2(x + self.dropout(self.cross_attn(Q=x, K=memory, V=memory, is_cross_attention=True)))
        x = self.norm3(x + self.dropout(self.ff(x)))
        return x

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_enc = SinusoidalPositionalEncoding(d_model)
        self.layers = nn.ModuleList([TransformerDecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.output_linear = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, memory, tgt_mask):
        tgt_embed = self.pos_enc(self.embedding(tgt))
        for layer in self.layers:
            tgt_embed = layer(tgt_embed, memory, tgt_mask)
        return self.output_linear(tgt_embed)

### Step 2: VLM Architecture and Generation Logic

In [None]:
# Cell 3: VLM Architecture (Revised for VQA)

class VisionEncoder(nn.Module):
    def __init__(self, model_name='vit_tiny_patch16_224', pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        self.embed_dim = self.model.embed_dim

    def forward(self, x):
        return self.model.forward_features(x)

def sample_next_token(logits, temperature=1.0, top_k=50, top_p=0.9, repetition_penalty=1.1, generated_ids=None):
    logits = logits / max(1e-8, temperature)
    if repetition_penalty != 1.0 and generated_ids:
        for token_id in set(generated_ids):
            if token_id < len(logits):
                logits[token_id] = logits[token_id] / repetition_penalty if logits[token_id] > 0 else logits[token_id] * repetition_penalty
    if top_k is not None and top_k > 0:
        topk_vals, topk_idx = torch.topk(logits, min(top_k, logits.size(-1)))
        mask = torch.full_like(logits, float('-inf'))
        mask.scatter_(-1, topk_idx, logits.gather(-1, topk_idx))
        logits = mask
    if top_p is not None and 0.0 < top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        probs = F.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(probs, dim=-1)
        cutoff = cumulative_probs > top_p
        cutoff[..., 0] = False
        sorted_logits[cutoff] = float('-inf')
        logits.scatter_(-1, sorted_indices, sorted_logits)
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1).item()

class VLM(nn.Module):
    def __init__(self, vision_encoder, text_decoder, text_dim):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.text_decoder = text_decoder
        self.vision_projection = nn.Linear(vision_encoder.embed_dim, text_dim)
        self.text_embedding = text_decoder.embedding

    def forward(self, image, question_ids, answer_ids, tgt_mask):
        vision_memory = self.vision_encoder(image)
        vision_memory = self.vision_projection(vision_memory)
        
        question_embed = self.text_embedding(question_ids)
        
        # Concatenate vision and question embeddings to form the memory
        memory = torch.cat([vision_memory, question_embed], dim=1)
        
        return self.text_decoder(answer_ids, memory, tgt_mask)

    def generate(self, image, tokenizer, text_prompt, max_len=50,
                 temperature=1.0, top_k=50, top_p=0.9, repetition_penalty=1.2):
        self.eval()
        end_token_id = tokenizer.vocab['<end>']
        
        # Format the prompt using the VQA template
        full_prompt = f"USER: <image>\n{text_prompt}\nASSISTANT:"
        prompt_ids = tokenizer.tokenize(full_prompt)
        generated_ids = torch.tensor([prompt_ids], device=device, dtype=torch.long)

        with torch.no_grad():
            vision_memory_raw = self.vision_encoder(image.unsqueeze(0))
            image_memory = self.vision_projection(vision_memory_raw)
            
            # Create the memory for generation by combining image and the full prompt
            prompt_embed = self.text_embedding(generated_ids)
            memory = torch.cat([image_memory, prompt_embed], dim=1)
            
            # The decoder starts generating from the full prompt sequence
            current_ids = generated_ids

            for _ in range(max_len):
                sz = current_ids.size(1)
                tgt_mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()
                
                output = self.text_decoder(current_ids, memory, tgt_mask)
                next_logits = output[0, -1, :]
                
                # Pass only the generated part to the repetition penalty logic
                prev_ids_for_penalty = current_ids.squeeze(0).tolist()[len(prompt_ids):]
                
                next_tok = sample_next_token(next_logits, temperature=temperature,
                                             top_k=top_k, top_p=top_p,
                                             repetition_penalty=repetition_penalty,
                                             generated_ids=prev_ids_for_penalty)
                if next_tok == end_token_id:
                    break
                current_ids = torch.cat([current_ids, torch.tensor([[next_tok]], device=device)], dim=1)
        
        # Decode only the generated part, excluding the prompt
        return tokenizer.ids_to_sentence(current_ids.squeeze(0).tolist()[len(prompt_ids):])

### Step 3: Synthetic VQA Dataset and Tokenizer

In [None]:
# Cell 4: NEW VQA-Style Dataset and Tokenizer

class Flickr8kVQADataset(Dataset):
    def __init__(self, image_dir, captions_file, transform, tokenizer, max_len):
        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.data = self._load_data(captions_file)
        self.questions = [
            "Describe the image in detail.",
            "What is happening in this picture?",
            "Provide a detailed description of the image.",
            "What is in the image?",
            "Can you describe this image for me?"
        ]

    def _load_data(self, file_path):
        data = []
        with open(file_path, 'r') as f:
            for line in f.readlines()[1:]:
                parts = line.strip().split(',')
                image_file, caption = parts[0], ','.join(parts[1:])
                data.append({'image': image_file, 'caption': caption.lower()})
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = os.path.join(self.image_dir, item['image'])
        image = self.transform(Image.open(image_path).convert("RGB"))
        
        # --- VQA Formatting ---
        question = random.choice(self.questions)
        answer = item['caption']
        
        prompt = f"USER: <image>\n{question}\nASSISTANT:"
        
        prompt_tokens = self.tokenizer.tokenize(prompt)
        answer_tokens = self.tokenizer.tokenize(answer)
        
        # Decoder input is the answer sequence, starting with <start>
        answer_input_ids = [self.tokenizer.vocab['<start>']] + answer_tokens
        # Target is the answer sequence, ending with <end>
        answer_target_ids = answer_tokens + [self.tokenizer.vocab['<end>']]

        # Pad sequences
        question_ids = (prompt_tokens + [self.tokenizer.vocab['<pad>']] * self.max_len)[:self.max_len]
        answer_input_ids = (answer_input_ids + [self.tokenizer.vocab['<pad>']] * self.max_len)[:self.max_len]
        answer_target_ids = (answer_target_ids + [self.tokenizer.vocab['<pad>']] * self.max_len)[:self.max_len]
        
        return image, torch.tensor(question_ids), torch.tensor(answer_input_ids), torch.tensor(answer_target_ids)

class SimpleTokenizer:
    def __init__(self, vocab_path):
        with open(vocab_path, 'r') as f:
            self.vocab = json.load(f)
        if '<image>' not in self.vocab: self.vocab['<image>'] = len(self.vocab)
        if '<start>' not in self.vocab: self.vocab['<start>'] = len(self.vocab)
        if '<end>' not in self.vocab: self.vocab['<end>'] = len(self.vocab)
        if '<pad>' not in self.vocab: self.vocab['<pad>'] = 0
        self.inv_vocab = {v: k for k, v in self.vocab.items()}

    def tokenize(self, text):
        # Handle special tokens first
        text = text.replace('<image>', ' <image> ')
        tokens = re.findall(r'<image>|\b\w+\b|\S', text.lower())
        return [self.vocab.get(t, self.vocab.get('<unk>', 1)) for t in tokens]

    def ids_to_sentence(self, ids):
        words = []
        for i in ids:
            word = self.inv_vocab.get(i, '<unk>')
            if word in ['<start>', '<pad>', '<image>']: continue
            if word == '<end>': break
            words.append(word)
        return ' '.join(words)

### Step 4: Main Setup and Two-Stage Training

In [None]:
# Cell 5: Main Setup for Two-Stage Training

# --- Hyperparameters ---
D_MODEL = 256
NUM_HEADS = 8
D_FF = 1024
NUM_DECODER_LAYERS = 4
BATCH_SIZE = 16
MAX_LEN = 60 # Increased slightly for prompt template

# --- Two-Stage Training Hyperparameters ---
WARMUP_EPOCHS = 10 
FINETUNE_EPOCHS = 5
WARMUP_LR = 1e-4
FINETUNE_LR = 1e-6

# --- Robust Training Hyperparameters ---
ACCUMULATION_STEPS = 4
PATIENCE = 3

# --- Paths ---
IMAGE_DIR = './flickr8k/Images'
CAPTIONS_FILE = './flickr8k/captions.txt'
VOCAB_PATH = 'finetuned_vocab.json'
BEST_MODEL_PATH = 'vlm_spin_best_model.pth'

# --- Tokenizer, Transforms ---
tokenizer = SimpleTokenizer(VOCAB_PATH)
VOCAB_SIZE = len(tokenizer.vocab)
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- Dataset and DataLoader with VQA-style data ---
full_dataset = Flickr8kVQADataset(IMAGE_DIR, CAPTIONS_FILE, image_transforms, tokenizer, MAX_LEN)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

def collate_fn(batch):
    images, questions, answer_inputs, answer_targets = zip(*batch)
    return torch.stack(images, 0), torch.stack(questions, 0), torch.stack(answer_inputs, 0), torch.stack(answer_targets, 0)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)
print(f"Dataset split: {len(train_dataset)} training samples, {len(val_dataset)} validation samples.")

# --- Model Instantiation ---
vision_encoder = VisionEncoder(model_name='vit_tiny_patch16_224')
text_decoder = TransformerDecoder(VOCAB_SIZE, D_MODEL, NUM_HEADS, D_FF, NUM_DECODER_LAYERS)
vlm = VLM(vision_encoder, text_decoder, text_dim=D_MODEL).to(device)

# --- Loss Function ---
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.vocab['<pad>'], label_smoothing=0.1)

# --- Optimizers and Schedulers ---
warmup_params = list(vlm.text_decoder.parameters()) + list(vlm.vision_projection.parameters())
optimizer_warmup = torch.optim.AdamW(warmup_params, lr=WARMUP_LR)
scheduler_warmup = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_warmup, 'min', patience=1, factor=0.5)

optimizer_finetune = torch.optim.AdamW(vlm.parameters(), lr=FINETUNE_LR)
scheduler_finetune = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_finetune, 'min', patience=1, factor=0.5)

In [None]:
# Cell 6: Two-Stage Training Loop

from torch.cuda.amp import GradScaler, autocast

def create_causal_mask(sz, device):
    return torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()

scaler = GradScaler()
best_val_loss = float('inf')
epochs_no_improve = 0

print("\n--- STARTING STAGE 1: DECODER WARM-UP ---")

for epoch in range(WARMUP_EPOCHS):
    print(f"\n--- Warm-up Epoch {epoch + 1}/{WARMUP_EPOCHS} ---")
    vlm.vision_encoder.eval()
    vlm.text_decoder.train()
    vlm.vision_projection.train()
    for param in vlm.vision_encoder.parameters():
        param.requires_grad = False
    
    total_train_loss = 0
    pbar = tqdm(train_dataloader, desc=f"Training Warm-up {epoch+1}")
    optimizer_warmup.zero_grad()

    for i, (images, q_ids, a_in_ids, a_out_ids) in enumerate(pbar):
        images, q_ids, a_in_ids, a_out_ids = images.to(device), q_ids.to(device), a_in_ids.to(device), a_out_ids.to(device)
        tgt_mask = create_causal_mask(a_in_ids.size(1), device)

        with autocast():
            logits = vlm(images, q_ids, a_in_ids, tgt_mask)
            loss = loss_fn(logits.view(-1, VOCAB_SIZE), a_out_ids.view(-1))
            loss = loss / ACCUMULATION_STEPS
        
        scaler.scale(loss).backward()
        total_train_loss += loss.item() * ACCUMULATION_STEPS

        if (i + 1) % ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer_warmup)
            torch.nn.utils.clip_grad_norm_(warmup_params, max_norm=1.0)
            scaler.step(optimizer_warmup)
            scaler.update()
            optimizer_warmup.zero_grad()
        pbar.set_postfix({"loss": f"{total_train_loss / (i + 1):.4f}"})

    avg_train_loss = total_train_loss / len(train_dataloader)
    print(f"Average Warm-up Training Loss: {avg_train_loss:.4f}")

    vlm.eval()
    total_val_loss = 0
    with torch.no_grad():
        for images, q_ids, a_in_ids, a_out_ids in tqdm(val_dataloader, desc="Validating Warm-up"):
            images, q_ids, a_in_ids, a_out_ids = images.to(device), q_ids.to(device), a_in_ids.to(device), a_out_ids.to(device)
            tgt_mask = create_causal_mask(a_in_ids.size(1), device)
            logits = vlm(images, q_ids, a_in_ids, tgt_mask)
            loss = loss_fn(logits.view(-1, VOCAB_SIZE), a_out_ids.view(-1))
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss / len(val_dataloader)
    print(f"Average Warm-up Validation Loss: {avg_val_loss:.4f}")

    old_lr = optimizer_warmup.param_groups[0]['lr']
    scheduler_warmup.step(avg_val_loss)
    new_lr = optimizer_warmup.param_groups[0]['lr']
    if new_lr < old_lr:
        print(f"Warm-up learning rate reduced to {new_lr}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(vlm.state_dict(), BEST_MODEL_PATH)
        print(f"New best model saved to {BEST_MODEL_PATH}")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
    
    if epochs_no_improve >= PATIENCE:
        print("Early stopping triggered for warm-up.")
        break

print("\n--- LOADING BEST MODEL FOR FINE-TUNING ---")
vlm.load_state_dict(torch.load(BEST_MODEL_PATH))

print("\n--- STARTING STAGE 2: FULL FINE-TUNING ---")
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(FINETUNE_EPOCHS):
    print(f"\n--- Fine-tuning Epoch {epoch + 1}/{FINETUNE_EPOCHS} ---")
    vlm.train()
    for param in vlm.vision_encoder.parameters():
        param.requires_grad = True

    total_train_loss = 0
    pbar = tqdm(train_dataloader, desc=f"Training Fine-tune {epoch+1}")
    optimizer_finetune.zero_grad()

    for i, (images, q_ids, a_in_ids, a_out_ids) in enumerate(pbar):
        images, q_ids, a_in_ids, a_out_ids = images.to(device), q_ids.to(device), a_in_ids.to(device), a_out_ids.to(device)
        tgt_mask = create_causal_mask(a_in_ids.size(1), device)

        with autocast():
            logits = vlm(images, q_ids, a_in_ids, tgt_mask)
            loss = loss_fn(logits.view(-1, VOCAB_SIZE), a_out_ids.view(-1))
            loss = loss / ACCUMULATION_STEPS

        scaler.scale(loss).backward()
        total_train_loss += loss.item() * ACCUMULATION_STEPS

        if (i + 1) % ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer_finetune)
            torch.nn.utils.clip_grad_norm_(vlm.parameters(), max_norm=1.0)
            scaler.step(optimizer_finetune)
            scaler.update()
            optimizer_finetune.zero_grad()
        pbar.set_postfix({"loss": f"{total_train_loss / (i + 1):.4f}"})

    avg_train_loss = total_train_loss / len(train_dataloader)
    print(f"Average Fine-tuning Training Loss: {avg_train_loss:.4f}")

    vlm.eval()
    total_val_loss = 0
    with torch.no_grad():
        for images, q_ids, a_in_ids, a_out_ids in tqdm(val_dataloader, desc="Validating Fine-tune"):
            images, q_ids, a_in_ids, a_out_ids = images.to(device), q_ids.to(device), a_in_ids.to(device), a_out_ids.to(device)
            tgt_mask = create_causal_mask(a_in_ids.size(1), device)
            logits = vlm(images, q_ids, a_in_ids, tgt_mask)
            loss = loss_fn(logits.view(-1, VOCAB_SIZE), a_out_ids.view(-1))
            total_val_loss += loss.item()
            
    avg_val_loss = total_val_loss / len(val_dataloader)
    print(f"Average Fine-tuning Validation Loss: {avg_val_loss:.4f}")

    old_lr = optimizer_finetune.param_groups[0]['lr']
    scheduler_finetune.step(avg_val_loss)
    new_lr = optimizer_finetune.param_groups[0]['lr']
    if new_lr < old_lr:
        print(f"Fine-tuning learning rate reduced to {new_lr}")
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(vlm.state_dict(), BEST_MODEL_PATH)
        print(f"New best model saved to {BEST_MODEL_PATH}")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= PATIENCE:
        print("Early stopping triggered for fine-tuning.")
        break

print("\nFull training finished.")

In [None]:
# Cell 7: Final Corrected Inference Cell

import matplotlib.pyplot as plt
from IPython.display import display
import traceback
import inspect

# Load the best model weights for inference
if os.path.exists(BEST_MODEL_PATH):
    print(f"Loading best model from {BEST_MODEL_PATH} for inference.")
    vlm.load_state_dict(torch.load(BEST_MODEL_PATH))
else:
    print("No trained model found. Using initial model for inference.")

vlm.eval()
sample_image = None
raw_image = None

if len(val_dataset) > 0:
    try:
        random_idx = random.randint(0, len(val_dataset) - 1)
        original_idx = val_dataset.indices[random_idx]
        image_path = os.path.join(IMAGE_DIR, full_dataset.data[original_idx]['image'])
        
        if os.path.exists(image_path):
            raw_image = Image.open(image_path).convert("RGB")
            sample_image = image_transforms(raw_image).to(device)
    except Exception as e:
        print(f"An error occurred while loading the image: {e}")

if raw_image is not None:
    print("--- Displaying Random Image for Inference ---")
    plt.imshow(raw_image)
    plt.axis('off')
    plt.show()
else:
    print("Could not load a sample image to display.")

# --- 1. Image + Text (VQA) ---
print("\n--- Testing Multimodal VQA ---")
if sample_image is not None:
    gen_signature = inspect.signature(vlm.generate)
    available_params = list(gen_signature.parameters.keys())
    
    generation_kwargs = {
        'top_p': 0.9,
        'repetition_penalty': 1.2
    }
    
    # --- THE FIX: Add 'max_len' to the list of possible names ---
    length_param_found = None
    possible_names = ['max_length', 'max_tokens', 'max_new_tokens', 'max_len'] # <-- ADDED 'max_len'
    
    for name in possible_names:
        if name in available_params:
            generation_kwargs[name] = 128
            length_param_found = name
            break
            
    if length_param_found:
        print(f"✅ Found valid length parameter: '{length_param_found}'. Using it for generation.")
    else:
        print(f"⚠️ WARNING: Could not find a known length parameter in {available_params}.")
        
    prompts = ["what is in the image?", "Describe the scene."]
    
    for question in prompts:
        print("-----------------------------------------")
        try:
            vqa_answer = vlm.generate(
                sample_image, 
                tokenizer, 
                text_prompt=question, 
                **generation_kwargs
            )
            
            print(f"Image + Prompt: '{question}'")
            print(f"--> Generated Output: {vqa_answer}")

        except Exception as e:
            print(f"!!! AN ERROR OCCURRED DURING GENERATION FOR PROMPT: '{question}' !!!")
            traceback.print_exc()
else:
    print("Sample image not found. Skipping multimodal test.")

In [None]:
# This code searches the tokenizer's vocabulary for the EOS token ID

try:
    # Check if the vocab attribute exists and is a dictionary
    if hasattr(tokenizer, 'vocab') and isinstance(tokenizer.vocab, dict):
        print("Searching the tokenizer.vocab dictionary...")
        
        found_eos = False
        # Iterate through all token-ID pairs in the vocabulary
        for token_str, token_id in tokenizer.vocab.items():
            # Look for common substrings in the special token's name
            if '<eos>' in token_str or '<|endoftext|>' in token_str or '</s>' in token_str:
                print(f"✅ Found it! The EOS token is '{token_str}' with ID: {token_id}")
                found_eos = True
                break # Stop after finding the first one
                
        if not found_eos:
            print("Could not find a specific EOS token, but here are other special tokens:")
            # As a fallback, print any special token (usually enclosed in <>)
            for token_str, token_id in tokenizer.vocab.items():
                 if token_str.startswith('<') and token_str.endswith('>'):
                        print(f"--> Found special token: '{token_str}' with ID: {token_id}")

    else:
        print("Error: 'tokenizer.vocab' is not a dictionary or does not exist.")

except Exception as e:
    print(f"An error occurred: {e}")