###### **© 2025** | Licensed under the [MIT License](LICENSE)

# Interactive Gemini Image Generator

This notebook is a hands-on tool to generate images with Gemini 2.5 Flash Image (aka Nano Banana) using different prompting methods.

## The Image Reference Approach

This notebook is designed to let you use the "Color Image Reference" method, which our separate evaluation study found to be the most accurate way to achieve color adherence. The idea is the following:

1.  Start from a color picker or any color code value (HEX, RGB, or HSL).
2.  Generate a plain color image from this color code, i.e., a tiny square filled entirely with the selected color. This is done in the background with Pillow.
3.  The square is then converted into a PNG and encoded so it can be sent to the model.
4.  Pass the plain color image as an input reference to Nano Banana: along with the user text prompt (and any other uploaded images), the plain color image is included in the request.
5.  Because the model can literally "see" the color instead of just reading it as text, it follows the reference much more accurately.

### How to Use This Notebook:

* **To use the Color image reference method:**
    1.  Pick a color in **"2. Choose Color"**.
    2.  Check the **"Use color as a reference image"** box.
    3.  Write your prompt and click **Generate**.

* **To test the Text-only method:**
    1.  Leave the "Use color as a reference image" box **unchecked**.
    2.  Make sure to write the color (e.g., "a red car") in your text prompt.

**Install Dependencies**

In [None]:
# 1. Install correct libraries
!pip install -q google-genai pillow ipywidgets

**Import Libraries**

In [None]:
import google.genai as genai
from PIL import Image, ImageDraw
import io
import os
import getpass
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML # <-- Added HTML
import base64
import time

**Helper Functions**

In [None]:
# These are the helper functions to create the color swatch
def hex_to_rgb(h):
    """Converts a hex color string to an (R, G, B) tuple."""
    h = h.lstrip("#")
    return tuple(int(h[i : i + 2], 16) for i in (0, 2, 4))

def create_color_swatch_bytes(hex_color, size=(256, 256)):
    """Creates a PNG image of a solid color in memory."""
    img = Image.new("RGB", size, color=hex_to_rgb(hex_color))
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return buf.getvalue()

**Define the UI Widgets**

In [None]:
# --- Create all the interactive UI elements ---

# 1. API Key
api_key_input = widgets.Text(
    description='1. Gemini API Key:',
    placeholder='Enter your API key here',
    layout=widgets.Layout(width='400px'),
    style={'description_width': 'initial'}
)

validate_button = widgets.Button(
    description="Validate Key",
    button_style='info',
    tooltip='Click to check if your API key is valid'
)

validation_output = widgets.Output(layout=widgets.Layout(margin='5px 0 0 110px'))

# 2. Color
color_picker = widgets.ColorPicker(
    description='2. Choose Color:',
    value='#3498DB',
    concise=False,
    style={'description_width': 'initial'}
)

use_color_ref = widgets.Checkbox(
    value=False,
    description='Use color as a reference image',
    style={'description_width': 'initial'}
)

# 3. Image Upload
image_uploader = widgets.FileUpload(
    accept='image/*',
    multiple=True,
    description='3. Upload Images',
    style={'description_width': 'initial'}
)

# 4. Prompt
prompt_input = widgets.Textarea(
    placeholder='A high-tech sphere levitating in a futuristic lab',
    description='4. Your Prompt:',
    layout=widgets.Layout(width='500px', height='100px'),
    style={'description_width': 'initial'}
)

# 5. Generate Button
generate_button = widgets.Button(
    description='Generate Image',
    button_style='primary',
    tooltip='Click to generate the image'
)

# 6. Output Area
output_area = widgets.Output(layout=widgets.Layout(
    border='1px solid #ccc',
    height='512px',
    width='512px',
    margin='10px 0 0 0',
    display='flex',
    align_items='center',
    justify_content='center'
))

with output_area:
    print("Your generated image will appear here.")

# 7. Download Button
download_button = widgets.Button(
    description='Download Last Image',
    button_style='success',
    tooltip='Click to generate a download link'
)

# 8. Download Link Output Area
download_link_area = widgets.Output(layout=widgets.Layout(margin='10px 0 0 0'))

**Define the Generation Logic**

In [None]:
# --- App State ---
# Simple dictionary to hold the last generated image bytes
app_state = {
    'last_image_bytes': None
}

# --- Validation Function ---
def on_validate_button_clicked(b):
    api_key = api_key_input.value
    with validation_output:
        clear_output()
    if not api_key:
        with validation_output:
            print("❌ Key is empty.")
        return
    with validation_output:
        print("Validating...")
    try:
        client = genai.Client(api_key=api_key)
        list(client.models.list())
        with validation_output:
            clear_output()
            print("✅ Valid!")
    except Exception as e:
        with validation_output:
            clear_output()
            print(f"❌ Invalid Key or API Error.")
            print(f"\nDetails: {e}")

# --- Generation Function ---
def on_generate_button_clicked(b):
    with output_area:
        clear_output()
        print("Generating, please wait...")

    # Clear any old download links
    with download_link_area:
        clear_output()

    api_key = api_key_input.value
    color = color_picker.value
    use_color = use_color_ref.value
    uploads = image_uploader.value
    prompt = prompt_input.value

    if not api_key:
        with output_area:
            clear_output()
            print("Error: Please enter your Gemini API Key.")
        return

    try:
        client = genai.Client(api_key=api_key)
        model_name = "gemini-2.5-flash-image-preview"

        parts = []
        if prompt:
            parts.append(prompt)
        if uploads:
            for file_info in uploads:
                content = file_info['content']
                parts.append(Image.open(io.BytesIO(content)))
        if use_color:
            parts.append("Recolor or apply the color from the following color swatch.")
            color_swatch = Image.open(io.BytesIO(create_color_swatch_bytes(color)))
            parts.append(color_swatch)

        response = client.models.generate_content(
            model=model_name,
            contents=parts
        )

        generated_image_bytes = None
        if response.candidates:
            for part in response.candidates[0].content.parts:
                if part.inline_data:
                    generated_image_bytes = part.inline_data.data
                    break

        if generated_image_bytes:
            # Save the bytes for download
            app_state['last_image_bytes'] = generated_image_bytes

            with output_area:
                clear_output()
                display(Image.open(io.BytesIO(generated_image_bytes)))
        else:
            with output_area:
                clear_output()
                error_text = response.text if (hasattr(response, 'text') and response.text) else "No image data found."
                print(f"Generation failed. Response: {error_text}")

    except google.api_core.exceptions.PermissionDenied as e:
        with output_area:
            clear_output()
            print(f"Authentication Error: Your API Key is invalid or has permissions issues.")
            print(f"\nDetails: {e}")
    except Exception as e:
        with output_area:
            clear_output()
            print(f"An error occurred: {e}")

def on_download_button_clicked(b):
    """Generates a download link for the last image."""
    with download_link_area:
        clear_output()
        image_bytes = app_state.get('last_image_bytes')

        if image_bytes:
            # Encode the image bytes as Base64
            b64 = base64.b64encode(image_bytes).decode()

            # Create a unique filename
            filename = f"generated_image_{int(time.time())}.png"

            # Create an HTML <a> tag with the download link
            href = f'<a href="data:image/png;base64,{b64}" download="{filename}" target="_blank" style="color: #3498DB; font-size: 16px;">Click Here to Download: {filename}</a>'

            # Display the link
            display(HTML(href))
        else:
            print("You must generate an image before you can download it.")

# Link the button click events to our functions
generate_button.on_click(on_generate_button_clicked)
validate_button.on_click(on_validate_button_clicked)
download_button.on_click(on_download_button_clicked) # <-- ADD THIS

**Display the UI**

In [None]:
# Group the API key elements horizontally
api_key_box = widgets.HBox([
    api_key_input,
    validate_button
])

# Display all the widgets in a vertical box
ui = widgets.VBox([
    api_key_box,
    validation_output,
    color_picker,
    use_color_ref,
    image_uploader,
    prompt_input,
    generate_button,
    output_area,
    download_button,      # <-- ADD THIS
    download_link_area  # <-- ADD THIS
])

display(ui)