In [None]:
# --- Cell 1:importing necessary libraries ---
!pip install -q diffusers transformers accelerate bitsandbytes sentencepiece ftfy gradio torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121


import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from tqdm.auto import tqdm
import json
import random
import numpy as np
from google.colab import drive
from accelerate import Accelerator
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from diffusers.optimization import get_scheduler
from transformers import CLIPTextModel, CLIPTokenizer, BlipProcessor, BlipForConditionalGeneration, GPT2Config, GPT2LMHeadModel, AutoTokenizer
from peft import LoraConfig
import bitsandbytes as bnb
import torch.nn as nn

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
print("Libraries installed and imported successfully.")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.7/188.7 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.9/72.9 MB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m35.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m53.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m44.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# --- Cell 2: mounting google drive to access datasets(for google colab) ---
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# --- Cell 3: Model Parameters and Paths ---
DRIVE_BASE_PATH = "/content/drive/My Drive/"
DATASET_FOLDER_NAME = "painting"
OUTPUT_FOLDER_NAME = "colab_outputs/painting_style_model"

DATASET_PATH = os.path.join(DRIVE_BASE_PATH, "datasets", DATASET_FOLDER_NAME)
OUTPUT_DIR = os.path.join(DRIVE_BASE_PATH, OUTPUT_FOLDER_NAME)

# --- CHECKPOINTING: Define paths to save progress ---
LORA_CHECKPOINT_DIR = os.path.join(OUTPUT_DIR, "checkpoints_lora")
CAPTIONER_CHECKPOINT_DIR = os.path.join(OUTPUT_DIR, "checkpoints_captioner")


# --- Style Configuration ---
STYLE_NAME = "painting"
STYLE_TRIGGER = "style_painting"

# --- Model Configuration ---
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-large"

# --- T4 Optimized Training Parameters ---
IMAGE_RESOLUTION = 512
BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE_LORA = 2e-4
NUM_TRAIN_EPOCHS_LORA = 10
LORA_RANK = 16
MIXED_PRECISION = "fp16"

# --- Output Configuration ---
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(LORA_CHECKPOINT_DIR, exist_ok=True)
os.makedirs(CAPTIONER_CHECKPOINT_DIR, exist_ok=True)
METADATA_FILE = os.path.join(OUTPUT_DIR, "metadata.jsonl")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("--- Configuration ---")
print(f"Dataset path: {DATASET_PATH}")
print(f"Output path: {OUTPUT_DIR}")
print(f"LoRA Checkpoint path: {LORA_CHECKPOINT_DIR}")
print("---------------------")

--- Configuration ---
Dataset path: /content/drive/My Drive/datasets/painting
Output path: /content/drive/My Drive/colab_outputs/painting_style_model
LoRA Checkpoint path: /content/drive/My Drive/colab_outputs/painting_style_model/checkpoints_lora
---------------------


In [None]:
# --- Cell 4: Custom Dataset for LoRA Training ---

class DreamBoothDataset(Dataset):
    def __init__(self, metadata, tokenizer, size=512):
        self.metadata = metadata
        self.tokenizer = tokenizer
        self.size = size
        self.transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, index):
        item = self.metadata[index]
        image = Image.open(item["file_name"]).convert("RGB")
        pixel_values = self.transforms(image)
        input_ids = self.tokenizer(
            item["text"],
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt"
        ).input_ids

        return {"pixel_values": pixel_values, "input_ids": input_ids.squeeze(0)}

# Initialize tokenizer
tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer")

# Create the dataset and dataloader
train_dataset_lora = DreamBoothDataset(metadata, tokenizer, size=IMAGE_RESOLUTION)
train_dataloader_lora = DataLoader(train_dataset_lora, batch_size=BATCH_SIZE, shuffle=True)

print(f"LoRA training dataset created with {len(train_dataset_lora)} samples.")

In [None]:
# --- Cell 5: LoRA Training Loop ---

# Initialize Accelerator
accelerator = Accelerator(
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    mixed_precision=MIXED_PRECISION,
    log_with="tensorboard",
    project_dir=os.path.join(OUTPUT_DIR, "logs")
)

# Load models in standard float32 precision
text_encoder = CLIPTextModel.from_pretrained(MODEL_NAME, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(MODEL_NAME, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(MODEL_NAME, subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer")
noise_scheduler = PNDMScheduler.from_pretrained(MODEL_NAME, subfolder="scheduler")

# Add LoRA adapter
lora_config = LoraConfig(r=LORA_RANK, lora_alpha=LORA_RANK, target_modules=["to_q", "to_k", "to_v", "to_out.0", "proj_in", "proj_out"])
unet.add_adapter(lora_config)

# Memory savings
unet.enable_gradient_checkpointing()
unet.set_attention_slice("auto")
print("Using memory-efficient attention slicing and gradient checkpointing.")


# Create dataset and dataloader
with open(METADATA_FILE, 'r') as f:
    metadata = [json.loads(line) for line in f]
train_dataset = DreamBoothDataset(metadata, tokenizer, size=IMAGE_RESOLUTION)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Optimizer and Scheduler
optimizer = bnb.optim.AdamW8bit(filter(lambda p: p.requires_grad, unet.parameters()), lr=LEARNING_RATE_LORA)
lr_scheduler = get_scheduler("cosine", optimizer=optimizer, num_warmup_steps=0, num_training_steps=(len(train_dataloader) * NUM_TRAIN_EPOCHS_LORA))

# Prepare with accelerator
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
vae.to(accelerator.device, dtype=torch.float16)
text_encoder.to(accelerator.device, dtype=torch.float16)

if os.path.exists(LORA_CHECKPOINT_DIR) and len(os.listdir(LORA_CHECKPOINT_DIR)) > 0:
    print("Resuming from checkpoint...")
    checkpoint_dirs = [d for d in os.listdir(LORA_CHECKPOINT_DIR) if d.startswith('epoch_')]
    if checkpoint_dirs:
        latest_epoch = max([int(d.split('_')[1]) for d in checkpoint_dirs])
        latest_checkpoint_dir = os.path.join(LORA_CHECKPOINT_DIR, f"epoch_{latest_epoch}")

    
        lora_weights_file = os.path.join(latest_checkpoint_dir, "pytorch_lora_weights.safetensors")
        if os.path.exists(lora_weights_file):
            accelerator.load_state(latest_checkpoint_dir)

            unet.load_attn_procs(latest_checkpoint_dir)
            first_epoch = latest_epoch
            print(f"Resumed successfully from epoch {first_epoch}")
        else:
            print(f"Corrupted checkpoint found at epoch {latest_epoch}. Please delete it and restart.")
            first_epoch = NUM_TRAIN_EPOCHS_LORA
    else:
        first_epoch = 0
        print("No valid epoch checkpoints found. Starting new training run.")
else:
    print("Starting new training run.")
    first_epoch = 0


print("Starting LoRA training...")
for epoch in range(first_epoch, NUM_TRAIN_EPOCHS_LORA):
    unet.train()
    progress_bar = tqdm(total=len(train_dataloader), desc=f"Epoch {epoch + 1}")
    if epoch == first_epoch and first_epoch > 0:
        resume_step = accelerator.step % len(train_dataloader)
        progress_bar.update(resume_step)

    for step, batch in enumerate(train_dataloader):
        if epoch == first_epoch and first_epoch > 0 and step < resume_step:
            continue

        with accelerator.accumulate(unet):
            with torch.no_grad():
                latents = vae.encode(batch["pixel_values"].to(dtype=torch.float16)).latent_dist.sample()
                latents = latents * vae.config.scaling_factor
                encoder_hidden_states = text_encoder(batch["input_ids"])[0]

            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        progress_bar.update(1)
        progress_bar.set_postfix(loss=loss.detach().item())

    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        epoch_checkpoint_dir = os.path.join(LORA_CHECKPOINT_DIR, f"epoch_{epoch+1}")
        os.makedirs(epoch_checkpoint_dir, exist_ok=True)

        accelerator.save_state(epoch_checkpoint_dir)

        unwrapped_unet = accelerator.unwrap_model(unet)
        unwrapped_unet.save_attn_procs(epoch_checkpoint_dir, safe_serialization=True, weight_name="pytorch_lora_weights.safetensors")

        print(f"Checkpoint and LoRA weights saved for epoch {epoch + 1} at {epoch_checkpoint_dir}")


# --- SAVING THE FINAL WEIGHTS ---
print("Saving final LoRA weights...")
unwrapped_unet = accelerator.unwrap_model(unet)
unwrapped_unet.save_attn_procs(OUTPUT_DIR, safe_serialization=True, weight_name="pytorch_lora_weights.safetensors")
print(f"Final LoRA weights saved successfully to: {OUTPUT_DIR}")
print("LoRA training finished.")

Using memory-efficient attention slicing and gradient checkpointing.
Resuming from checkpoint...
Corrupted checkpoint found at epoch 5. Please delete it and restart.
Starting LoRA training...
Saving final LoRA weights...
Final LoRA weights saved successfully to: /content/drive/My Drive/colab_outputs/painting_style_model
LoRA training finished.


In [None]:
# --- Cell 6: Captioner Model and Dataset ---

# 1. Define the Captioner Model (a simple GPT-2-like Transformer Decoder)
class LatentCaptioner(nn.Module):
    def __init__(self, latent_dim, vocab_size, embed_dim=768, num_heads=8, num_layers=4):
        super().__init__()
        self.latent_dim = latent_dim
        self.vocab_size = vocab_size

        # Project the flat latent noise into the model's embedding dimension
        self.latent_projection = nn.Linear(latent_dim, embed_dim)

        # Standard GPT-2 components
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
        self.tokenizer.pad_token = self.tokenizer.eos_token

        config = GPT2Config(
            vocab_size=self.tokenizer.vocab_size,
            n_embd=embed_dim,
            n_head=num_heads,
            n_layer=num_layers,
            n_positions=1024, # max sequence length
        )
        self.transformer = GPT2LMHeadModel(config)

        # Adjust the model to accept our projected latent as a prefix
        self.transformer.config.add_cross_attention = True 

    def forward(self, input_ids, latent_noise, attention_mask=None, labels=None):
      
        flat_latent = latent_noise.view(latent_noise.shape[0], -1)

        latent_embedding = self.latent_projection(flat_latent).unsqueeze(1) # (batch, 1, embed_dim)

        token_embeddings = self.transformer.transformer.wte(input_ids) # (batch, seq_len, embed_dim)

        embeddings = torch.cat([latent_embedding, token_embeddings], dim=1)

        # Create a combined attention mask
        if attention_mask is not None:
            prefix_mask = torch.ones(attention_mask.shape[0], 1, device=attention_mask.device)
            combined_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
        else:
            combined_attention_mask = None

        # Pass through the transformer
        outputs = self.transformer(
            inputs_embeds=embeddings,
            attention_mask=combined_attention_mask,
            labels=None 
        )

        # Calculate loss if labels are provided
        loss = None
        if labels is not None:
            logits = outputs.logits[:, :-1, :].contiguous()
            shift_labels = labels.contiguous()

            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), shift_labels.view(-1))

        return {"loss": loss, "logits": outputs.logits}

# 2. Define the Dataset for the Captioner
class CaptionerDataset(Dataset):
    def __init__(self, metadata, tokenizer, latent_shape=(4, 64, 64)):
        self.metadata = metadata
        self.tokenizer = tokenizer
        self.latent_shape = latent_shape
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, index):
        # This is the "seed"
        latent_noise = torch.randn(self.latent_shape)

        # Get the corresponding caption
        caption = self.metadata[index]["text"]
        tokenized = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=77, # Standard SD length
            return_tensors="pt"
        )
        input_ids = tokenized.input_ids.squeeze(0)
        attention_mask = tokenized.attention_mask.squeeze(0)

        # The labels are the same as input_ids for language modeling
        labels = input_ids.clone()

        return {
            "latent_noise": latent_noise,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

# Initialize model and dataset
latent_dim = 4 * (IMAGE_RESOLUTION // 8) * (IMAGE_RESOLUTION // 8)
captioner = LatentCaptioner(latent_dim=latent_dim, vocab_size=50257).to(DEVICE)
captioner_dataset = CaptionerDataset(metadata, captioner.tokenizer)
captioner_dataloader = DataLoader(captioner_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Captioner model and dataset are ready.")
print(f"Latent dimension: {latent_dim}")

In [None]:
# --- Cell 7: Captioner Model training loop ---
NUM_TRAIN_EPOCHS_CAPTIONER = 30
LEARNING_RATE_CAPTIONER = 3e-5
CAPTIONER_MODEL_PATH = os.path.join(OUTPUT_DIR, "captioner_model.pth")

# Use a standard AdamW optimizer for the captioner
optimizer_captioner = torch.optim.AdamW(captioner.parameters(), lr=LEARNING_RATE_CAPTIONER)

# Prepare everything with the same accelerator
captioner, optimizer_captioner, captioner_dataloader = accelerator.prepare(
    captioner, optimizer_captioner, captioner_dataloader
)

# --- Checkpointing for the Captioner ---
if os.path.exists(CAPTIONER_CHECKPOINT_DIR) and len(os.listdir(CAPTIONER_CHECKPOINT_DIR)) > 0:
    print("Resuming captioner training from checkpoint...")
    accelerator.load_state(CAPTIONER_CHECKPOINT_DIR)
else:
    print("Starting new captioner training run.")


print("Starting Captioner training...")
for epoch in range(NUM_TRAIN_EPOCHS_CAPTIONER):
    captioner.train()
    progress_bar = tqdm(total=len(captioner_dataloader), desc=f"Captioner Epoch {epoch + 1}")

    for step, batch in enumerate(captioner_dataloader):
        with accelerator.accumulate(captioner):
            # The batch is already on the correct device thanks to accelerator
            outputs = captioner(
                input_ids=batch["input_ids"],
                latent_noise=batch["latent_noise"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"]
            )

            loss = outputs["loss"]

            accelerator.backward(loss)

            optimizer_captioner.step()
            optimizer_captioner.zero_grad()

        progress_bar.update(1)
        progress_bar.set_postfix(loss=loss.detach().item())

    # --- Save a checkpoint after each epoch ---
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        accelerator.save_state(CAPTIONER_CHECKPOINT_DIR)
        print(f"Captioner checkpoint saved for epoch {epoch + 1}")

# --- SAVING THE FINAL CAPTIONER MODEL ---
print("Saving final Captioner model...")
unwrapped_captioner = accelerator.unwrap_model(captioner)
accelerator.save(unwrapped_captioner.state_dict(), CAPTIONER_MODEL_PATH)

print(f"Captioner model saved successfully to: {CAPTIONER_MODEL_PATH}")
print("Captioner training finished.")


In [None]:
# --- Cell 8: ---

from diffusers import StableDiffusionPipeline
import torch

# --- 1. Load the Base Pipeline and Trained LoRA Weights ---
pipe = StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(accelerator.device)
lora_weights_path = os.path.join(OUTPUT_DIR, "pytorch_lora_weights.safetensors")
pipe.load_lora_weights(OUTPUT_DIR, weight_name="pytorch_lora_weights.safetensors")
print("Image generation pipeline with LoRA weights is ready.")

# --- 2. Load the Pre-trained BLIP Model for Accurate Captioning ---
# --- FIX: We will use the original, powerful BLIP model for captioning ---
# This ensures the generated captions are accurate to the final image.
blip_processor = BlipProcessor.from_pretrained(CAPTIONING_MODEL_NAME)
blip_model = BlipForConditionalGeneration.from_pretrained(CAPTIONING_MODEL_NAME, torch_dtype=torch.float16).to(accelerator.device)
blip_model.eval()
print("Accurate captioning model (BLIP) is ready.")


# --- 3. Define the Generation Function ---
def generate_multimodal_output(prompt, seed):
    # Ensure seed is an integer
    seed = int(seed)

    # Generate the latent noise based on the seed
    generator = torch.Generator(device=accelerator.device).manual_seed(seed)

    # --- Generate Image ---
    full_prompt = f"{prompt}, in the style of {STYLE_TRIGGER}"

    image = pipe(
        prompt=full_prompt,
        generator=generator, # Use the seeded generator for reproducibility
        num_inference_steps=30,
        guidance_scale=7.5
    ).images[0]

    # --- Generate Accurate Caption from the Final Image ---
    with torch.no_grad():
        # Prepare the generated image for the BLIP model
        inputs = blip_processor(images=image, return_tensors="pt").to(accelerator.device, dtype=torch.float16)

        # Generate caption IDs
        output_ids = blip_model.generate(**inputs, max_new_tokens=77)

        # Decode the IDs into a text caption
        caption = blip_processor.decode(output_ids[0], skip_special_tokens=True)

    return {"image": image, "caption": caption}

# --- Test the function with a seed and a prompt ---
result = generate_multimodal_output("A majestic lion", 42)
print("\n--- Test Generation ---")
print(f"Generated Caption: {result['caption']}")
result['image']

In [None]:
# --- Cell 9: Creating a gradio web app demo for GUI ---

import gradio as gr
import random

def gradio_interface(prompt, seed):
    # Gradio inputs are strings or floats, so convert to int
    seed = int(seed)
    # The generation function is already defined in the previous cell
    result = generate_multimodal_output(prompt, seed)
    return result['image'], result['caption']

# --- Function for the new "Random Seed" button ---
def get_random_seed():
    return random.randint(0, 2**32 - 1)


# --- Create the Gradio interface ---
css = """
body {
    background: linear-gradient(120deg, #1f2937, #374151, #4b5563);
    background-size: 200% 200%;
    animation: gradient 15s ease infinite;
}
@keyframes gradient {
    0% {background-position: 0% 50%;}
    50% {background-position: 100% 50%;}
    100% {background-position: 0% 50%;}
}
.gradio-container {
    border-radius: 20px !important;
    background-color: rgba(255, 255, 255, 0.05) !important;
    backdrop-filter: blur(10px);
}
.gr-button {
    background: linear-gradient(90deg, #4f46e5, #a855f7);
    color: white;
    border: none;
    border-radius: 8px;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
    transition: all 0.2s ease-in-out;
}
.gr-button:hover {
    transform: translateY(-2px);
    box-shadow: 0 6px 12px rgba(0, 0, 0, 0.2);
}
.gr-input {
    border-radius: 8px !important;
}
.gr-image {
    border-radius: 12px !important;
    box-shadow: 0 8px 25px rgba(0,0,0,0.2);
}
"""

with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    gr.Markdown(f"<h1 style='text-align: center; color: white;'>🎨 Unified Multimodal {STYLE_NAME.title()} Generator 🎨</h1>")
    gr.Markdown("<p style='text-align: center; color: #d1d5db;'>Enter a prompt and a seed to generate a unique, stylized image and an accurate caption describing it.</p>")

    with gr.Row(elem_id="main-row"):
        with gr.Column(scale=1):
            prompt_input = gr.Textbox(label="Prompt", value="A majestic lion", lines=3, placeholder="Describe the image you want to create...")

            with gr.Row():
                seed_input = gr.Number(label="Seed", value=42, precision=0)
                random_seed_btn = gr.Button("🎲")

            generate_btn = gr.Button("✨ Generate ✨", variant="primary")

        with gr.Column(scale=2):
            output_image = gr.Image(label="Generated Image")
            output_caption = gr.Textbox(label="Generated Caption", interactive=False)

    # --- Define the behavior of the buttons ---
    generate_btn.click(
        fn=gradio_interface,
        inputs=[prompt_input, seed_input],
        outputs=[output_image, output_caption]
    )

    random_seed_btn.click(
        fn=get_random_seed,
        inputs=[],
        outputs=seed_input
    )

# Launch the demo
demo.launch()