# Gemma 2 plus BLIP Meme Rewriter

This notebook walks through the **Gemma 2 + BLIP** pipeline step by step.

Goal:  
Take a meme image and its original text, then:

1. Use **BLIP** to generate a short description of the image.  
2. Combine that description with the original meme text in a prompt.  
3. Call **Gemma 2** (text only model) to rewrite the meme text so it is non offensive but keeps the original meaning as much as possible.  
4. Run this over a CSV file of memes.

Assumptions:

- You have a GPU machine or Colab is okay.  
- You have a Hugging Face account and an access token with permission to use Gemma 2.  
- You have a CSV file `memes.csv` with columns: `id`, `image_path`, `text`.


## 1. Install libraries

Run this cell once to install required packages. You can skip if already installed.


In [None]:
# If you are in Colab, uncomment the next line
# !pip install -q transformers huggingface_hub accelerate safetensors pillow torch torchvision



: 

## 2. Imports

In [None]:
import os
import torch
import pandas as pd
from PIL import Image

from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    AutoTokenizer,
    AutoModelForCausalLM,
)
from huggingface_hub import login

## 3. Device setup and Hugging Face login

You need a Hugging Face token with access to Gemma 2.

You can either:

- Set `HUGGINGFACE_TOKEN` as an environment variable, or  
- Pass the token directly into `login()` in this cell.


In [None]:
# Choose device (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Hugging Face login
# Option 1: read from environment variable
hf_token = os.environ.get("HUGGINGFACE_TOKEN")

# Option 2: uncomment and paste directly (not recommended in shared notebooks)
# hf_token = "hf_..."

if not hf_token:
    print("WARNING: No HUGGINGFACE_TOKEN found. Set it in the environment or here.")
else:
    login(hf_token)
    print("Logged into Hugging Face.")

## 4. Define the shared system prompt

This is the same prompt we use across models. It explains the rewriting task.


In [None]:
SYSTEM_PROMPT = (
    "You are a content safety editor for internet memes.\n"
    "Your job is to rewrite short meme text so that it is safe and non offensive "
    "while keeping the original meaning, target, and joke structure as much as possible.\n\n"
    "Rules:\n"
    "1. Remove or soften slurs, insults, and explicit hate toward any group or person.\n"
    "2. Keep the same basic situation, characters, and point of view.\n"
    "3. Keep the text short, punchy, and meme like.\n"
    "4. Do not add new events or new facts. Small filler words are fine.\n"
    "5. If the input text is already safe and non offensive, return it unchanged.\n"
    "6. Reply with the rewritten meme text only."
)

## 5. Load BLIP image captioning model

BLIP will turn each meme image into a short text description that we feed into Gemma.


In [None]:
print("Loading BLIP captioning model...")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base"
).to(device)

print("BLIP loaded.")

### Helper: caption one image

In [None]:
def caption_image(image_path: str) -> str:
    """Generate a short caption for the image using BLIP."""
    image = Image.open(image_path).convert("RGB")
    inputs = blip_processor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        out = blip_model.generate(**inputs, max_new_tokens=40)
    caption = blip_processor.decode(out[0], skip_special_tokens=True).strip()
    return caption

# Example (update path to a real file before running)
# test_caption = caption_image("img/example.png")
# print("BLIP caption:", test_caption)

## 6. Load Gemma 2 model

We use the instruction tuned Gemma 2 model (`google/gemma-2-2b-it`).

In [None]:
gemma_model_id = "google/gemma-2-2b-it"

print("Loading Gemma 2 model...")
tokenizer = AutoTokenizer.from_pretrained(gemma_model_id)
model = AutoModelForCausalLM.from_pretrained(
    gemma_model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto" if device == "cuda" else None,
).to(device)

print("Gemma 2 loaded.")

## 7. Function to rewrite a single meme

This function:

1. Captions the image with BLIP.  
2. Builds a prompt combining the system prompt, image description, and original text.  
3. Calls Gemma 2 to generate a safer rewrite.  


In [None]:
def rewrite_meme_with_gemma(image_path: str, original_text: str) -> str:
    # 1) Caption image
    caption = caption_image(image_path)

    # 2) Build full prompt
    prompt = (
        f"{SYSTEM_PROMPT}\n\n"
        f"Image description: {caption}\n"
        f"Original meme text: {original_text}\n\n"
        "Task:\nRewrite this text so that it is safe and non offensive, "
        "but still funny and as close as possible to the original meaning.\n"
    )

    # 3) Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=True,
            top_p=0.9,
            temperature=0.7,
        )

    full_text = tokenizer.decode(out[0], skip_special_tokens=True)

    # Try to remove the prompt from the front and keep only Gemma's answer
    rewritten = full_text[len(prompt):].strip()
    if not rewritten:
        rewritten = full_text.strip()

    return rewritten

# Quick example (fill in real values before running)
# sample_image = "img/example.png"
# sample_text = "mississippi wind chime"
# print("Rewritten:", rewrite_meme_with_gemma(sample_image, sample_text))

## 8. Run over a CSV of memes

Assumes a CSV like:

id,image_path,text  
1,img/42953.png,"its their character not their color that matters"


In [None]:
input_csv = "memes.csv"          # change to your path
output_csv = "memes_gemma2_blip.csv"

df = pd.read_csv(input_csv)
print("Loaded", len(df), "rows.")

rewrites = []
for idx, row in df.iterrows():
    image_path = row["image_path"]
    text = row["text"]
    print(f"[Row {idx}] {image_path}")
    try:
        new_text = rewrite_meme_with_gemma(image_path, text)
    except Exception as e:
        print("Error:", e)
        new_text = ""
    rewrites.append(new_text)

df["gemma2_blip_rewrite"] = rewrites
df.to_csv(output_csv, index=False)
print("Saved rewritten memes to", output_csv)