# Meme Understanding with CLIP and BLIP-2

This notebook demonstrates a multimodal analysis of internet memes using two powerful vision-language models:

- **CLIP (Contrastive Language-Image Pre-training)**: Identifies the most probable text description (caption) for a given image.
- **BLIP-2 (Bootstrapped Language-Image Pretraining)**: Generates captions and explanations based on image content.

We aim to:
1. Evaluate CLIP's ability to select the correct caption from multiple options.
2. Generate free-form meme captions using BLIP-2.
3. Generate reasoning for humor detection using BLIP-2.

In [None]:
# 1. Imports and Setup
import os
import torch
from PIL import Image
import json
import clip
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import matplotlib.pyplot as plt

# Check device (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

## 2. Load Pretrained Models

We load:
- `CLIP` (ViT-B/32) model and its preprocessing pipeline.
- `BLIP-2` (Flan-T5-XL) with both processor and model for captioning and reasoning.

In [None]:
# Load CLIP
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

# Load BLIP-2
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)

## 3. Load Meme Dataset

Each meme in the dataset contains:
- `image_filename`
- a list of `captions`
- `correct_caption_index`

Images are stored in the local `./memes` folder.

In [None]:
with open("real_meme_dataset_clip_blip.json", "r") as f:
    meme_data = json.load(f)

image_folder = "./memes"
results = []

## 4. Process Each Meme

For each meme:
- Use CLIP to match image with best caption from given options.
- Use BLIP-2 to:
  - Generate a caption without prompt
  - Generate a caption with the prompt: *"Caption this meme"*
  - Explain the humor with: *"Explain why it is funny."*

In [None]:
for meme in meme_data:
    image_path = os.path.join(image_folder, meme["image_filename"])
    if not os.path.exists(image_path):
        print(f"Image not found: {image_path}")
        continue

    raw_image = Image.open(image_path).convert("RGB")
    image_input_clip = clip_preprocess(raw_image).unsqueeze(0).to(device)
    text_inputs = clip.tokenize(meme["captions"]).to(device)

    with torch.no_grad():
        logits_per_image, _ = clip_model(image_input_clip, text_inputs)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0]

    predicted_index = int(probs.argmax())
    correct = (predicted_index == meme["correct_caption_index"])

    blip_inputs = blip_processor(raw_image, return_tensors="pt").to(device)
    with torch.no_grad():
        out = blip_model.generate(**blip_inputs)
    blip_caption = blip_processor.decode(out[0], skip_special_tokens=True)

    prompt = f"Caption this meme"
    blip_inputs_prompt = blip_processor(raw_image, prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        out_prompt = blip_model.generate(**blip_inputs_prompt, max_new_tokens=100)
    blip_caption_with_prompt = blip_processor.tokenizer.decode(out_prompt[0], skip_special_tokens=True)

    reason = f"Caption of the meme is: {blip_caption_with_prompt}. Explain why it is funny."
    blip_inputs_reason = blip_processor(raw_image, reason, return_tensors="pt").to(device)
    with torch.no_grad():
        out_reason = blip_model.generate(**blip_inputs_reason, max_new_tokens=100)
    blip_reason = blip_processor.tokenizer.decode(out_reason[0], skip_special_tokens=True)

    results.append({
        "meme_id": meme["id"],
        "clip_prediction_index": predicted_index,
        "clip_predicted_caption": meme["captions"][predicted_index],
        "clip_correct": correct,
        "clip_probs": [float(p) for p in probs],
        "blip_caption": blip_caption,
        "blip_caption_with_prompt": blip_caption_with_prompt,
        "blip_reason": blip_reason,
        "correct_caption": meme["correct_caption"]
    })

## 5. Accuracy & Results Preview

Check CLIP's top-1 accuracy and preview predictions and generations from both models.

In [None]:
clip_accuracy = sum([r["clip_correct"] for r in results]) / len(results)
print(f"\nCLIP Accuracy: {clip_accuracy:.2%}\n")

for r in results[:5]:
    print(f"Meme ID: {r['meme_id']}")
    print(f"CLIP Prediction: {r['clip_predicted_caption']}")
    print(f"BLIP Caption (Image Only): {r['blip_caption']}")
    print(f"BLIP Caption (With Prompt): {r['blip_caption_with_prompt']}")
    print(f"BLIP Reasoning: {r['blip_reason']}")
    print(f"Correct Caption: {r['correct_caption']}")
    print("--" * 30)

## 6. Save Results

Save the complete result analysis as a JSON file for further analysis or visualization.

In [None]:
with open("meme_analysis_results.json", "w") as out_file:
    json.dump(results, out_file, indent=2)

print("Analysis complete. Results saved to meme_analysis_results.json")