# VLM Notebook 1: Using Pre-trained Vision-Language Models

## 1. Environment Setup

First, let's install the required libraries and verify GPU availability.

In [None]:
# Install required packages
!pip install -q transformers accelerate pillow requests torch torchvision
!pip install -q bitsandbytes  # For efficient loading

print("✓ All packages installed successfully!")

In [None]:
# Import libraries
import torch
import requests
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
from transformers import (
    AutoProcessor,
    AutoModelForCausalLM,
    Blip2Processor,
    Blip2ForConditionalGeneration,
    CLIPProcessor,
    CLIPModel,
)

# Check GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ Warning: Running on CPU. Inference will be slow.")

### Helper Functions

Let's create utility functions for loading and displaying images.

In [None]:
def load_image(image_path_or_url):
    """
    Load an image from a local path or URL.

    Args:
        image_path_or_url: Local file path or HTTP(S) URL

    Returns:
        PIL Image object
    """
    if image_path_or_url.startswith(('http://', 'https://')):
        response = requests.get(image_path_or_url)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_path_or_url).convert('RGB')
    return image

def display_image(image, title=None, figsize=(8, 6)):
    """
    Display an image with optional title.

    Args:
        image: PIL Image object
        title: Optional title string
        figsize: Figure size tuple (width, height)
    """
    plt.figure(figsize=figsize)
    plt.imshow(image)
    plt.axis('off')
    if title:
        plt.title(title, fontsize=14, fontweight='bold')
    plt.show()

print("✓ Helper functions defined")

### Load Sample Images

We'll use some example images from the web for demonstration.

In [None]:
from datasets import load_dataset

print("Loading images from microsoft/cats_vs_dogs dataset...")

# Load dataset in streaming mode to avoid downloading the entire dataset
dataset = load_dataset("microsoft/cats_vs_dogs", split="train", streaming=True)

images = {}

# Iterate through the dataset to find one cat and one dog
# Labels: 0 = Cat, 1 = Dog
for example in dataset:
    label = example['labels']
    if 'cat' not in images and label == 0:
        images['cat'] = example['image'].convert("RGB")
        print("✓ Loaded cat image from dataset")
    elif 'dog' not in images and label == 1:
        images['dog'] = example['image'].convert("RGB")
        print("✓ Loaded dog image from dataset")

    if 'cat' in images and 'dog' in images:
        break

# Display sample images
if len(images) > 0:
    fig, axes = plt.subplots(1, len(images), figsize=(15, 5))
    if len(images) == 1:
        axes = [axes]
    for ax, (name, img) in zip(axes, images.items()):
        ax.imshow(img)
        ax.set_title(name.capitalize(), fontsize=12, fontweight='bold')
        ax.axis('off')
    plt.tight_layout()
    plt.show()
else:
    print("No images loaded to display.")

## 2. Image Captioning with BLIP-2

BLIP-2 is a model for image captioning that uses a lightweight Q-Former to bridge vision and language.

**Key Concepts (from Module 2):**
- Q-Former extracts visual features using learnable queries
- Queries act as information bottleneck
- Connected to frozen LLM for generation

In [None]:
# Load BLIP-2 model
print("Loading BLIP-2 model... (this may take a few minutes)")

blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
blip_model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b",
    torch_dtype=torch.float16,  # Use FP16 for efficiency
    device_map="auto"  # Automatically map to available devices
)

print("✓ BLIP-2 model loaded successfully")
print(f"Model size: {sum(p.numel() for p in blip_model.parameters()) / 1e9:.2f}B parameters")

### Generate Captions with Different Decoding Strategies

In [None]:
def generate_caption(image, decoding_strategy="greedy", max_length=50):
    """
    Generate caption for an image using BLIP-2.

    Args:
        image: PIL Image
        decoding_strategy: One of ["greedy", "beam_search", "nucleus"]
        max_length: Maximum caption length

    Returns:
        Generated caption string
    """
    # Process image
    inputs = blip_processor(images=image, return_tensors="pt").to(device, torch.float16)

    # Generate based on strategy
    if decoding_strategy == "greedy":
        # Greedy: Select most probable token at each step
        generated_ids = blip_model.generate(
            **inputs,
            max_length=max_length,
            do_sample=False  # Deterministic
        )
    elif decoding_strategy == "beam_search":
        # Beam search: Keep top-k candidates
        generated_ids = blip_model.generate(
            **inputs,
            max_length=max_length,
            num_beams=5,  # Number of beams
            do_sample=False
        )
    elif decoding_strategy == "nucleus":
        # Nucleus (top-p) sampling: Sample from top cumulative probability
        generated_ids = blip_model.generate(
            **inputs,
            max_length=max_length,
            do_sample=True,
            top_p=0.9,  # Nucleus probability
            temperature=1.0  # Sampling temperature
        )
    else:
        raise ValueError(f"Unknown strategy: {decoding_strategy}")

    # Decode to text
    caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    return caption

# Test different decoding strategies
test_image = images["dog"]

print("\n" + "="*60)
print("CAPTION GENERATION COMPARISON")
print("="*60 + "\n")

for strategy in ["greedy", "beam_search", "nucleus"]:
    caption = generate_caption(test_image, decoding_strategy=strategy)
    print(f"{strategy.upper():15s}: {caption}")

display_image(test_image, "Test Image")

### Caption All Sample Images

In [None]:
# Generate captions for all images
captions = {}

for name, image in images.items():
    caption = generate_caption(image, decoding_strategy="beam_search")
    captions[name] = caption
    print(f"\n{name.upper()}:")
    print(f"  Caption: {caption}")

# Visualize images with captions
fig, axes = plt.subplots(1, len(images), figsize=(18, 6))
for ax, (name, img) in zip(axes, images.items()):
    ax.imshow(img)
    ax.set_title(f"{name.capitalize()}\n{captions[name]}",
                 fontsize=10, wrap=True)
    ax.axis('off')
plt.tight_layout()
plt.show()

### Single-Turn VQA

In [None]:
def answer_question(image, question, max_new_tokens=20):
    """
    Answer a question about an image using BLIP-2.

    Args:
        image: PIL Image
        question: Question string
        max_new_tokens: Maximum tokens in answer

    Returns:
        Answer string
    """
    # Prepare prompt for VQA
    prompt = f"Question: {question} Answer:"

    # Process inputs
    inputs = blip_processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)

    # Generate answer
    generated_ids = blip_model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False
    )

    # Decode output
    answer = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()

    return answer

# Test VQA on sample images
vqa_examples = [
    ("dog", "What breed is this dog?"),
    ("dog", "Is the dog sitting or standing?"),
    ("cat", "What color is the cat?"),
]

print("\n" + "="*60)
print("VISUAL QUESTION ANSWERING EXAMPLES (BLIP-2)")
print("="*60 + "\n")

for image_name, question in vqa_examples:
    if image_name in images:
        image = images[image_name]
        answer = answer_question(image, question)
        print(f"Image: {image_name.upper()}")
        print(f"Q: {question}")
        print(f"A: {answer}")
        print()
    else:
        print(f"Skipping {image_name}: Image not loaded.")