In [None]:
!pip install git+https://github.com/huggingface/diffusers.git transformers accelerate xformers==0.0.16 datasets==2.21.0

In [None]:
import gradio as gr
import torch
import torchvision.transforms as T
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from PIL import Image
import numpy as np

# -------- CONFIG --------
CONTROLNET_REPO = "swetha3456/thermal-rgb-controlnet-v2"
SUBFOLDER = "sketch2webpage_synthetic_checkpoint-4000/controlnet"
BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5"
RESOLUTION = 384
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# ------------------------

# ControlNet
controlnet = ControlNetModel.from_pretrained(
    CONTROLNET_REPO,
    subfolder=SUBFOLDER,
    torch_dtype=DTYPE
)

# Pipeline
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    BASE_MODEL,
    controlnet=controlnet,
    torch_dtype=DTYPE,
    safety_checker=None
).to(DEVICE)

pipe.enable_xformers_memory_efficient_attention()

# EXACT training-time conditioning transform
cond_transform = T.Compose([
    T.Resize(RESOLUTION),
    T.CenterCrop(RESOLUTION),
    T.ToTensor(),   # [0,1], no normalization
])

def thermal_to_rgb(thermal_img, prompt=""):
    if thermal_img is None:
        return None

    if isinstance(thermal_img, np.ndarray):
        thermal_img = Image.fromarray(thermal_img)

    thermal_img = thermal_img.convert("RGB")

    control = cond_transform(thermal_img).unsqueeze(0).to(
        device=DEVICE,
        dtype=DTYPE
    )

    # if not prompt or prompt.strip() == "":
    #     prompt = "a realistic RGB photo"

    result = pipe(
        prompt=prompt,
        image=control,
        num_inference_steps=20,
        guidance_scale=7.5,
        controlnet_conditioning_scale=1.0
    )

    return result.images[0]

# -------- GRADIO UI --------
with gr.Blocks() as demo:
    gr.Markdown("## Skecth â†’ Webpage Generation")

    with gr.Row():
        inp = gr.Image(
            label="Sketch",
            type="numpy",
            value="../dataset_sketch2webpage/10018_0.jpeg"  # must exist
        )
        out = gr.Image(label="Generated Webpage from Sketch")

    # prompt_inp = gr.Textbox(
    #     label="Prompt",
    #     value="road scene with trees electric poles and cables"
    # )

    btn = gr.Button("Generate Webpage")
    btn.click(
        fn=thermal_to_rgb,
        inputs=[inp, prompt_inp],
        outputs=out
    )

demo.launch()