# openNanoBanana -- Free Colab Edition

The full openNanoBanana pipeline running on a free Google Colab T4 GPU.

**Pipeline:**
1. Gemini 3 Flash extracts the real-world subject from your prompt
2. Serper.dev searches Google Images for reference photos
3. Gemini 3 Flash verifies the images match the subject
4. FLUX.2-klein-4B-FP8 generates the final image locally on your T4 GPU

**Before you start:**
- Runtime > Change runtime type > **T4 GPU**
- Paste your API keys in Cell 2

**Get API keys (free):**
- Gemini: [aistudio.google.com/app/apikey](https://aistudio.google.com/app/apikey)
- Serper: [serper.dev](https://serper.dev) (2,500 free queries/month, no credit card)

In [None]:
# Cell 1: Check GPU + Install Dependencies
!nvidia-smi

import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"Compute capability: {torch.cuda.get_device_capability()}")

print("\nInstalling dependencies...")
!pip install -q git+https://github.com/huggingface/diffusers.git
!pip install -q transformers accelerate safetensors
!pip install -q bitsandbytes
!pip install -q gradio Pillow requests
print("\nDone!")

In [None]:
# Cell 2: Configure API Keys
# Paste your API keys below. Get them for free:
#   Gemini: https://aistudio.google.com/app/apikey
#   Serper: https://serper.dev (no credit card needed)

GEMINI_API_KEY = ""  # <-- paste your Gemini API key here
SERPER_API_KEY = ""  # <-- paste your Serper API key here

# Optional: HuggingFace token for faster model downloads
HF_TOKEN = ""  # <-- paste your HF token here (or leave empty)

if HF_TOKEN:
    from huggingface_hub import login
    login(token=HF_TOKEN, add_to_git_credential=False)
    print("Logged in to HuggingFace")

assert GEMINI_API_KEY, "Please paste your Gemini API key above"
assert SERPER_API_KEY, "Please paste your Serper API key above"
print("API keys configured!")

GEMINI_MODEL = "gemini-3-flash-preview"

In [None]:
# Cell 3: Load FLUX.2-klein-4B with 4-bit quantization (BOTH transformer AND text encoder)
#
# Why previous approaches crashed:
#   - Full fp16: transformer (8GB) + text encoder (8GB) = 16GB CPU RAM > 12.7GB → OOM
#   - FP8 conversion: FP8→fp16 cast = ~12GB CPU RAM peak → OOM
#   - Transformer-only NF4: transformer on GPU (2GB) BUT text encoder fp16 on CPU (8GB)
#     → System RAM 12.1/12.7 GB → crashes during inference
#
# Solution: Quantize BOTH to 4-bit NF4. Both stay on GPU permanently.
#   - Transformer NF4: ~2 GB GPU
#   - Text encoder NF4: ~2 GB GPU
#   - VAE fp16: ~0.5 GB (offloaded to CPU, moved to GPU only during decode)
#   - Total: ~4 GB GPU, <1 GB CPU → massive headroom on both

import torch, gc

dtype = torch.float16

from diffusers import Flux2KleinPipeline, Flux2Transformer2DModel, BitsAndBytesConfig
from transformers import Qwen3ForCausalLM, BitsAndBytesConfig as TfBnBConfig

# NF4 configs
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
)
tf_nf4_config = TfBnBConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
)

# ── Step 1: Transformer → NF4 on GPU (~2 GB) ─────────────────────
print("Step 1/3: Loading transformer with 4-bit NF4 quantization...")
transformer = Flux2Transformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.2-klein-4B",
    subfolder="transformer",
    quantization_config=nf4_config,
    torch_dtype=dtype,
)
gc.collect()
print(f"  Transformer on GPU: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

# ── Step 2: Text encoder → NF4 on GPU (~2 GB) ────────────────────
# This was the crash cause: at fp16 it took ~8 GB of CPU RAM.
# NF4 puts it on GPU at ~2 GB instead.
print("\nStep 2/3: Loading text encoder with 4-bit NF4 quantization...")
text_encoder = Qwen3ForCausalLM.from_pretrained(
    "black-forest-labs/FLUX.2-klein-4B",
    subfolder="text_encoder",
    quantization_config=tf_nf4_config,
    torch_dtype=dtype,
)
gc.collect()
print(f"  Transformer + text encoder on GPU: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

# ── Step 3: Assemble pipeline (only downloads VAE, tokenizer, scheduler) ──
print("\nStep 3/3: Loading VAE, tokenizer, scheduler...")
pipe = Flux2KleinPipeline.from_pretrained(
    "black-forest-labs/FLUX.2-klein-4B",
    transformer=transformer,
    text_encoder=text_encoder,
    torch_dtype=dtype,
)

pipe.enable_model_cpu_offload()
gc.collect()
torch.cuda.empty_cache()

print(f"\nModel loaded!")
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU memory reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
print(f"System RAM is now free — both big models live on GPU at 4-bit.")

In [None]:
# Cell 4: Pipeline Functions
# Python port of the openNanoBanana TypeScript pipeline

import requests
import json
import base64
import time
import re
import torch
from io import BytesIO
from PIL import Image

GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"
SERPER_URL = "https://google.serper.dev/images"

RESOLUTION_MAP = {
    "1k": (512, 512),
    "2k": (768, 768),
}

EXTRACT_PROMPT = """You are a search query extractor. Given an image generation prompt, extract TWO things:

1. **searchQuery**: The real-world subject to search for as a reference image. Remove artistic style descriptors, effects, transformations, or hypothetical modifiers (e.g. "as a baby", "in cyberpunk style", "watercolor painting").
2. **subjectType**: A SHORT, generic visual description of what the subject IS -- the kind of thing a person could identify by looking at a photo (e.g. "a person", "a bridge", "a building", "a cat", "a street crossing"). Do NOT use the specific name. This is used to verify search results visually.

Return ONLY valid JSON, no markdown fences, no explanation.

Examples:
- "hkust entrance piazza in cyberpunk future" -> {"searchQuery":"hkust entrance piazza","subjectType":"an outdoor plaza at a university"}
- "golden gate bridge at sunset watercolor painting" -> {"searchQuery":"golden gate bridge","subjectType":"a suspension bridge"}
- "dr ct abraham as a baby" -> {"searchQuery":"dr ct abraham","subjectType":"a person"}
- "my cat wearing a top hat in van gogh style" -> {"searchQuery":"cat wearing top hat","subjectType":"a cat"}
- "tokyo shibuya crossing in anime style" -> {"searchQuery":"tokyo shibuya crossing","subjectType":"a busy street crossing"}
- "labrador puppy in a spacesuit" -> {"searchQuery":"labrador puppy","subjectType":"a dog"}"""


# ===== Gemini Client =====

def call_gemini(api_key, model, contents):
    """Call Gemini API and return the text response."""
    url = f"{GEMINI_BASE_URL}/{model}:generateContent"
    res = requests.post(
        url,
        headers={"x-goog-api-key": api_key, "Content-Type": "application/json"},
        json={"contents": contents},
        timeout=30,
    )
    if res.status_code == 429:
        raise Exception("Gemini rate limit hit. Wait a moment and try again.")
    if res.status_code in (401, 403):
        raise Exception("Invalid Gemini API key.")
    if not res.ok:
        raise Exception(f"Gemini API error ({res.status_code}): {res.text[:200]}")

    data = res.json()
    candidate = (data.get("candidates") or [None])[0]

    if not candidate:
        block = data.get("promptFeedback", {}).get("blockReason")
        if block:
            raise Exception(f"Content blocked by Gemini safety filters: {block}")
        raise Exception("Gemini returned no candidates.")

    if candidate.get("finishReason") == "SAFETY":
        raise Exception("Content blocked by Gemini safety filters.")

    text = (candidate.get("content", {}).get("parts") or [{}])[0].get("text")
    if not isinstance(text, str):
        raise Exception("Gemini returned no text in response.")
    return text


# ===== Step 1: Extract Search Query =====

def extract_search_query(api_key, model, user_prompt):
    """Extract search query and subject type from user prompt. Returns (query, subject_type)."""
    text = call_gemini(api_key, model, [
        {"parts": [{"text": EXTRACT_PROMPT}, {"text": f'User prompt: "{user_prompt}"'}]}
    ])
    cleaned = re.sub(r"^```json\s*|```\s*$", "", text.strip()).strip()
    try:
        parsed = json.loads(cleaned)
    except json.JSONDecodeError:
        raise Exception(f"Could not parse extraction result: {cleaned[:200]}")

    query = (parsed.get("searchQuery") or "").strip().strip('"\'')
    subject_type = (parsed.get("subjectType") or "a subject").strip()

    if not query or len(query) < 2:
        raise Exception("Could not extract a search query from the prompt.")
    return query, subject_type


# ===== Step 2: Search Images =====

def search_images(api_key, query, num=5):
    """Search Google Images via Serper. Returns list of {title, imageUrl, link}."""
    res = requests.post(
        SERPER_URL,
        headers={"X-API-KEY": api_key, "Content-Type": "application/json"},
        json={"q": query, "num": num},
        timeout=15,
    )
    if res.status_code == 429:
        raise Exception("Serper rate limit hit.")
    if res.status_code in (401, 403):
        raise Exception("Invalid Serper API key.")
    if not res.ok:
        raise Exception(f"Serper API error ({res.status_code}): {res.text[:200]}")

    data = res.json()
    images = [
        {"title": img.get("title", ""), "imageUrl": img.get("imageUrl", ""), "link": img.get("link", "")}
        for img in data.get("images", [])
    ]
    if not images:
        raise Exception(f'No images found for "{query}". Try a more specific prompt.')
    return images


# ===== Image Download =====

MIME_MAP = {".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", ".webp": "image/webp", ".gif": "image/gif"}

def infer_mime(url, content_type=None):
    if content_type and content_type.startswith("image/"):
        return content_type.split(";")[0].strip()
    ext_match = re.search(r'\.\w+$', url.split("?")[0])
    if ext_match:
        return MIME_MAP.get(ext_match.group(0).lower(), "image/jpeg")
    return "image/jpeg"

def fetch_image_as_base64(url, timeout=8):
    """Download image and return (base64_str, mime_type) or None."""
    try:
        res = requests.get(url, timeout=timeout, headers={"User-Agent": "openNanoBanana/1.0"})
        if not res.ok:
            return None
        ct = res.headers.get("Content-Type", "")
        if not ct.startswith("image/"):
            return None
        if len(res.content) < 100:
            return None
        b64 = base64.b64encode(res.content).decode("utf-8")
        mime = infer_mime(url, ct)
        return b64, mime
    except Exception:
        return None

def download_pil_image(url, timeout=10):
    """Download image and return PIL Image or None."""
    try:
        res = requests.get(url, timeout=timeout, headers={"User-Agent": "openNanoBanana/1.0"})
        res.raise_for_status()
        return Image.open(BytesIO(res.content)).convert("RGB")
    except Exception:
        return None


# ===== Step 3: Verify Images =====

def verify_images(api_key, model, subject_type, image_results):
    """Check each image with Gemini. Returns (imageUrl, base64, mime) of first match, or None."""
    for i, img in enumerate(image_results):
        result = fetch_image_as_base64(img["imageUrl"])
        if result is None:
            continue
        b64, mime = result
        try:
            answer = call_gemini(api_key, model, [{
                "parts": [
                    {"inline_data": {"mime_type": mime, "data": b64}},
                    {"text": f'Does this image clearly contain {subject_type}? Answer with ONLY "yes" or "no".'},
                ]
            }])
            if answer.strip().lower().startswith("yes"):
                return {"imageUrl": img["imageUrl"], "base64": b64, "mimeType": mime, "index": i}
        except Exception:
            continue
    return None


# ===== Step 4: Generate Image Locally =====

def generate_image(flux_pipe, prompt, reference_image_url, resolution="1k"):
    """Generate image using local FLUX.2 pipeline. Returns PIL Image."""
    h, w = RESOLUTION_MAP.get(resolution, (512, 512))

    ref_img = download_pil_image(reference_image_url)
    if ref_img is None:
        raise Exception("Failed to download reference image for generation.")
    ref_img = ref_img.resize((w, h))

    with torch.inference_mode():
        result = flux_pipe(
            prompt=prompt,
            image=[ref_img],
            height=h,
            width=w,
            guidance_scale=1.0,
            num_inference_steps=4,
            generator=torch.Generator(device="cpu").manual_seed(int(time.time()) % 2**32),
        )
    torch.cuda.empty_cache()
    return result.images[0]


# ===== Full Pipeline Orchestrator =====

def run_pipeline(prompt, gemini_key, serper_key, flux_pipe, model=GEMINI_MODEL, resolution="1k"):
    """
    Run the full openNanoBanana pipeline.
    Returns (final_image, gallery_images, verified_index, log_text).
    """
    log = []

    # Step 1: Extract search query
    log.append("[1/4] Extracting search query...")
    query, subject_type = extract_search_query(gemini_key, model, prompt)
    log.append(f'  Search query: "{query}"')
    log.append(f'  Subject type: "{subject_type}"')

    # Step 2: Search images
    log.append(f'\n[2/4] Searching Google Images for "{query}"...')
    results = search_images(serper_key, query, num=5)
    log.append(f"  Found {len(results)} candidate images")

    # Download thumbnails for gallery
    gallery_images = []
    for img in results:
        pil = download_pil_image(img["imageUrl"])
        if pil:
            gallery_images.append((pil, img["title"][:60]))
        else:
            gallery_images.append(None)

    # Step 3: Verify images
    log.append("\n[3/4] Verifying images match the subject...")
    verified = verify_images(gemini_key, model, subject_type, results)
    if not verified:
        raise Exception("None of the found images match the subject. Try rephrasing your prompt.")
    log.append(f'  Verified image #{verified["index"] + 1}: {results[verified["index"]]["title"][:60]}')

    # Step 4: Generate
    log.append("\n[4/4] Generating image with FLUX.2-klein-4B (this may take 10-30s)...")
    final_image = generate_image(flux_pipe, prompt, verified["imageUrl"], resolution)
    log.append("  Done!")

    # Filter out None entries from gallery
    gallery_clean = [g for g in gallery_images if g is not None]

    return final_image, gallery_clean, verified["index"], "\n".join(log)


print("Pipeline functions loaded.")

In [None]:
# Cell 5: Gradio UI
import gradio as gr

def generate_handler(prompt, resolution):
    """Gradio handler for the generate button."""
    if not GEMINI_API_KEY:
        raise gr.Error("GEMINI_API_KEY not set. Add it to Colab Secrets.")
    if not SERPER_API_KEY:
        raise gr.Error("SERPER_API_KEY not set. Add it to Colab Secrets.")
    if not prompt or len(prompt.strip()) < 3:
        raise gr.Error("Please enter a prompt (at least 3 characters).")

    try:
        final_image, gallery, verified_idx, log = run_pipeline(
            prompt=prompt.strip(),
            gemini_key=GEMINI_API_KEY,
            serper_key=SERPER_API_KEY,
            flux_pipe=pipe,
            resolution=resolution,
        )
        return final_image, gallery, log
    except Exception as e:
        raise gr.Error(str(e))


with gr.Blocks(title="openNanoBanana", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# openNanoBanana")
    gr.Markdown("Real-time grounded image generation. Type a prompt with a real-world subject -- the pipeline searches for reference images, verifies them, and generates a new image grounded in reality.")

    with gr.Row():
        with gr.Column(scale=3):
            prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="andrej karpathy in a gta v poster",
                lines=2,
            )
        with gr.Column(scale=1):
            resolution_input = gr.Dropdown(
                choices=["1k", "2k"],
                value="1k",
                label="Resolution",
            )
            generate_btn = gr.Button("Generate", variant="primary")

    with gr.Row():
        with gr.Column():
            output_image = gr.Image(label="Generated Image", type="pil")
        with gr.Column():
            gallery_output = gr.Gallery(label="Reference Images Found", columns=3, height=300)
            log_output = gr.Textbox(label="Pipeline Log", lines=12, interactive=False)

    generate_btn.click(
        fn=generate_handler,
        inputs=[prompt_input, resolution_input],
        outputs=[output_image, gallery_output, log_output],
    )

    gr.Examples(
        examples=[
            ["hkust entrance piazza in cyberpunk future", "1k"],
            ["andrej karpathy in a gta v poster", "1k"],
            ["golden gate bridge at sunset watercolor painting", "1k"],
            ["tokyo shibuya crossing in anime style", "1k"],
        ],
        inputs=[prompt_input, resolution_input],
    )

demo.launch(share=True)