# openNanoBanana -- Free Colab Backend

Run **FLUX.2-klein-4B** on a free Google Colab T4 GPU as the image generation backend for [openNanoBanana](https://github.com/GeeveGeorge/openNanoBanana).

**What this does:**
1. Loads FLUX.2-klein-4B (4B params, Apache 2.0, 4-step inference)
2. Exposes a RunPod-compatible API via ngrok tunnel
3. Your openNanoBanana app calls this instead of RunPod -- completely free

**Before you start:**
- Runtime > Change runtime type > **T4 GPU**
- Add `NGROK_TOKEN` to Colab Secrets (get one free at [ngrok.com](https://ngrok.com))
- Optionally add `HF_TOKEN` for faster model downloads

In [None]:
# Cell 1: Check GPU + Install Dependencies
!nvidia-smi

import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB")
    print(f"Compute capability: {torch.cuda.get_device_capability()}")

print("\nInstalling dependencies...")
!pip install -q git+https://github.com/huggingface/diffusers.git
!pip install -q transformers accelerate safetensors
!pip install -q flask pyngrok Pillow requests
print("\nDone! All dependencies installed.")

In [None]:
# Cell 2: Configure Secrets
from google.colab import userdata

# ngrok auth token (REQUIRED for public API tunnel)
try:
    NGROK_TOKEN = userdata.get('NGROK_TOKEN')
    print("NGROK_TOKEN found in Colab Secrets")
except Exception:
    NGROK_TOKEN = None
    print("WARNING: NGROK_TOKEN not found in Colab Secrets!")
    print("Go to: https://dashboard.ngrok.com/get-started/your-authtoken")
    print("Then: Click the key icon in Colab's left sidebar > Add NGROK_TOKEN")

# HuggingFace token (optional, for faster downloads)
try:
    HF_TOKEN = userdata.get('HF_TOKEN')
    from huggingface_hub import login
    login(token=HF_TOKEN, add_to_git_credential=False)
    print("Logged in to HuggingFace")
except Exception:
    print("No HF_TOKEN -- downloading anonymously (may be slower)")

In [None]:
# Cell 3: Load FLUX.2-klein-4B
import torch
import gc

# T4 has excellent FP16 tensor cores (65 TFLOPS)
# T4 does NOT have native BF16 cores -- BF16 would run at ~8 TFLOPS via FP32 emulation
dtype = torch.float16

print("Loading FLUX.2-klein-4B (this takes 2-4 minutes on first run)...")
print("The model will be cached for subsequent runs.\n")

from diffusers import Flux2KleinPipeline

pipe = Flux2KleinPipeline.from_pretrained(
    "black-forest-labs/FLUX.2-klein-4B",
    torch_dtype=dtype,
)

# CPU offloading: only the active submodule (transformer/text_encoder/VAE) is on GPU at a time
# Peak GPU usage: ~4 GB (transformer) instead of ~8 GB (everything)
pipe.enable_model_cpu_offload()

gc.collect()
torch.cuda.empty_cache()

print(f"\nModel loaded!")
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU memory reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

In [None]:
# Cell 4: Test Generation
import torch
import requests
from PIL import Image
from io import BytesIO
from IPython.display import display
import time

# --- Test 1: Text-to-image ---
print("Test 1: Text-to-image (512x512, 4 steps)...")
t0 = time.time()

with torch.inference_mode():
    image = pipe(
        prompt="A cat holding a sign that says hello world",
        height=512,
        width=512,
        guidance_scale=4.0,
        num_inference_steps=4,
        generator=torch.Generator(device="cpu").manual_seed(42),
    ).images[0]

elapsed = time.time() - t0
print(f"Generated in {elapsed:.1f}s")
print(f"GPU memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB peak")
display(image)

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# --- Test 2: Image editing with reference ---
print("\nTest 2: Image editing with reference image (512x512, 4 steps)...")
ref_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
resp = requests.get(ref_url, timeout=15)
ref_image = Image.open(BytesIO(resp.content)).convert("RGB").resize((512, 512))

t0 = time.time()
with torch.inference_mode():
    edited = pipe(
        prompt="A cat in a cyberpunk city at night with neon lights",
        image=[ref_image],
        height=512,
        width=512,
        guidance_scale=4.0,
        num_inference_steps=4,
        generator=torch.Generator(device="cpu").manual_seed(42),
    ).images[0]

elapsed = time.time() - t0
print(f"Generated in {elapsed:.1f}s")
print(f"GPU memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB peak")
display(edited)

torch.cuda.empty_cache()
print("\nBoth tests passed! Model is working correctly.")

In [None]:
# Cell 5: Start API Server with ngrok tunnel
import threading
import base64
import uuid
import time
import json
import gc
import torch
import requests as http_requests
from io import BytesIO
from PIL import Image
from flask import Flask, request as flask_request, jsonify
from pyngrok import ngrok

app = Flask(__name__)

# Resolution mapping -- conservative for T4 VRAM
RESOLUTION_MAP = {
    "1k": (512, 512),
    "2k": (768, 768),
}

# In-memory job store for async /run + /status endpoints
jobs = {}
# Lock to prevent concurrent GPU usage (T4 can only run one inference at a time)
gpu_lock = threading.Lock()


def _download_images(urls, width, height):
    """Download and resize reference images."""
    images = []
    for url in urls:
        try:
            resp = http_requests.get(url, timeout=15)
            resp.raise_for_status()
            img = Image.open(BytesIO(resp.content)).convert("RGB")
            img = img.resize((width, height))
            images.append(img)
        except Exception as e:
            print(f"  Failed to download {url[:80]}: {e}")
    return images


def _generate(input_data):
    """Run the FLUX.2 pipeline. Returns dict with 'result' key."""
    prompt = input_data.get("prompt", "")
    image_urls = input_data.get("images", [])
    resolution = input_data.get("resolution", "1k")
    output_format = input_data.get("output_format", "jpeg")

    if not prompt:
        raise ValueError("prompt is required")

    h, w = RESOLUTION_MAP.get(resolution, (512, 512))

    # Download reference images
    ref_images = _download_images(image_urls, w, h)

    # Build pipeline kwargs
    pipe_kwargs = dict(
        prompt=prompt,
        height=h,
        width=w,
        guidance_scale=4.0,
        num_inference_steps=4,
        generator=torch.Generator(device="cpu").manual_seed(int(time.time()) % 2**32),
    )
    if ref_images:
        pipe_kwargs["image"] = ref_images

    # Run inference with GPU lock
    with gpu_lock:
        with torch.inference_mode():
            result = pipe(**pipe_kwargs)
        torch.cuda.empty_cache()

    # Encode output to base64
    output_image = result.images[0]
    buffer = BytesIO()
    fmt = "JPEG" if output_format.lower() in ("jpeg", "jpg") else "PNG"
    output_image.save(buffer, format=fmt, quality=90)
    b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
    data_uri = f"data:image/{fmt.lower()};base64,{b64}"

    return {"result": data_uri}


# ===== Health endpoint =====
@app.route("/health", methods=["GET"])
def health():
    return jsonify({
        "status": "ok",
        "model": "FLUX.2-klein-4B",
        "gpu": torch.cuda.get_device_name() if torch.cuda.is_available() else "none",
        "vram_allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
    })


# ===== Synchronous generate endpoint =====
@app.route("/generate", methods=["POST"])
def generate():
    try:
        data = flask_request.get_json()
        result = _generate(data)
        return jsonify(result)
    except torch.cuda.OutOfMemoryError:
        torch.cuda.empty_cache()
        return jsonify({"error": "GPU out of memory. Try resolution '1k'."}), 503
    except Exception as e:
        return jsonify({"error": str(e)}), 500


# ===== RunPod-compatible async endpoints =====
# These mimic RunPod's /run + /status/{id} API so the openNanoBanana
# Next.js app can use them with minimal changes.

@app.route("/run", methods=["POST"])
def run_async():
    data = flask_request.get_json()
    input_data = data.get("input", data)
    job_id = str(uuid.uuid4())
    jobs[job_id] = {"status": "IN_QUEUE", "output": None, "error": None}

    def process():
        jobs[job_id]["status"] = "IN_PROGRESS"
        try:
            output = _generate(input_data)
            jobs[job_id]["output"] = output
            jobs[job_id]["status"] = "COMPLETED"
        except Exception as e:
            jobs[job_id]["error"] = str(e)
            jobs[job_id]["status"] = "FAILED"

    threading.Thread(target=process, daemon=True).start()
    return jsonify({"id": job_id, "status": "IN_QUEUE"})


@app.route("/status/<job_id>", methods=["GET"])
def status(job_id):
    job = jobs.get(job_id)
    if not job:
        return jsonify({"error": "Job not found"}), 404
    response = {"id": job_id, "status": job["status"]}
    if job["output"]:
        response["output"] = job["output"]
    if job["error"]:
        response["error"] = job["error"]
    return jsonify(response)


# ===== Start server + tunnel =====
if not NGROK_TOKEN:
    print("ERROR: NGROK_TOKEN is not set. Cannot create public tunnel.")
    print("Add it to Colab Secrets (key icon in left sidebar).")
else:
    ngrok.set_auth_token(NGROK_TOKEN)
    public_url = ngrok.connect(5000).public_url

    print("="*60)
    print(f"  API is live!")
    print(f"")
    print(f"  Base URL: {public_url}")
    print(f"")
    print(f"  Endpoints:")
    print(f"    GET  {public_url}/health")
    print(f"    POST {public_url}/generate  (synchronous)")
    print(f"    POST {public_url}/run       (async, RunPod-compatible)")
    print(f"    GET  {public_url}/status/<id>")
    print(f"")
    print(f"  To use with openNanoBanana:")
    print(f"    Set COLAB_URL={public_url} in your .env.local")
    print(f"    Or paste this URL in the BYOK panel")
    print("="*60)

    # Run Flask in background thread so cell doesn't block
    threading.Thread(
        target=lambda: app.run(port=5000, use_reloader=False),
        daemon=True,
    ).start()

    print("\nServer running. Keep this cell alive -- it serves requests in the background.")
    print("The URL will change if you restart the notebook.")

## How to use with openNanoBanana

1. Copy the **Base URL** printed above (e.g. `https://xxxx-xx-xx.ngrok-free.app`)

2. **Option A -- Environment variable:**
   ```bash
   # Add to your .env.local
   COLAB_URL=https://xxxx-xx-xx.ngrok-free.app
   ```

3. **Option B -- BYOK panel:**
   Paste the URL in the "Colab URL" field in the openNanoBanana web UI

### Limitations

| Constraint | Detail |
|---|---|
| Session duration | Free Colab disconnects after ~90 min idle or ~12 hours total |
| URL changes | ngrok URL changes every time you restart. Copy the new one. |
| Resolution | 512x512 recommended. 768x768 may work. 1024x1024 will likely OOM. |
| Concurrency | One request at a time (T4 can only run one inference) |
| Speed | ~8-20 seconds per image (T4 is slower than A100/H100) |

### API Reference

**POST /generate** (synchronous)
```json
{
  "prompt": "A cat in cyberpunk city",
  "images": ["https://example.com/reference.jpg"],
  "resolution": "1k",
  "output_format": "jpeg"
}
// Response: { "result": "data:image/jpeg;base64,..." }
```

**POST /run** (async, RunPod-compatible)
```json
{ "input": { "prompt": "...", "images": ["..."], "resolution": "1k" } }
// Response: { "id": "job-uuid", "status": "IN_QUEUE" }
```

**GET /status/{job_id}**
```json
// Response: { "id": "job-uuid", "status": "COMPLETED", "output": { "result": "data:..." } }
```