In [2]:
!pip install git+https://github.com/huggingface/diffusers.git transformers accelerate xformers==0.0.16 datasets==2.21.0

Collecting git+https://github.com/huggingface/diffusers.git
  Cloning https://github.com/huggingface/diffusers.git to /tmp/pip-req-build-dockgl0t
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/diffusers.git /tmp/pip-req-build-dockgl0t
  Resolved https://github.com/huggingface/diffusers.git to commit 5e48f466b9c0d257f2650e8feec378a0022f2402
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting xformers==0.0.16
  Downloading xformers-0.0.16.tar.gz (7.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.3/7.3 MB[0m [31m32.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting datasets==2.21.0
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting pyre-extensions==0.0.23 (from xformers==0.0.16)
  Downloading pyre_extensions-0.0.23-py3-none

In [50]:
import gradio as gr
import torch
import torchvision.transforms as T
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from PIL import Image
import numpy as np

# -------- CONFIG --------
CONTROLNET_REPO = "swetha3456/thermal-rgb-controlnet-v2"
SUBFOLDER = "checkpoint-5000-contrast-x2/controlnet"
BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5"
RESOLUTION = 384
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# ------------------------

# ControlNet
controlnet = ControlNetModel.from_pretrained(
    CONTROLNET_REPO,
    subfolder=SUBFOLDER,
    torch_dtype=DTYPE
)

# Pipeline
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    BASE_MODEL,
    controlnet=controlnet,
    torch_dtype=DTYPE,
    safety_checker=None
).to(DEVICE)

pipe.enable_xformers_memory_efficient_attention()

# EXACT training-time conditioning transform
cond_transform = T.Compose([
    T.Resize(RESOLUTION),
    T.CenterCrop(RESOLUTION),
    T.ToTensor(),   # [0,1], no normalization
])

def thermal_to_rgb(thermal_img, prompt):
    if thermal_img is None:
        return None

    if isinstance(thermal_img, np.ndarray):
        thermal_img = Image.fromarray(thermal_img)

    thermal_img = thermal_img.convert("RGB")

    control = cond_transform(thermal_img).unsqueeze(0).to(
        device=DEVICE,
        dtype=DTYPE
    )

    if not prompt or prompt.strip() == "":
        prompt = "a realistic RGB photo"

    result = pipe(
        prompt=prompt,
        image=control,
        num_inference_steps=20,
        guidance_scale=7.5,
        controlnet_conditioning_scale=1.0
    )

    return result.images[0]

# -------- GRADIO UI --------
with gr.Blocks() as demo:
    gr.Markdown("## Thermal → RGB Translation")

    with gr.Row():
        inp = gr.Image(
            label="Thermal Image",
            type="numpy",
            value="FLIR_02280.jpeg"  # must exist
        )
        out = gr.Image(label="Generated RGB Image")

    prompt_inp = gr.Textbox(
        label="Prompt",
        value="road scene with trees electric poles and cables"
    )

    btn = gr.Button("Generate RGB")
    btn.click(
        fn=thermal_to_rgb,
        inputs=[inp, prompt_inp],
        outputs=out
    )

demo.launch()

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://a45ec3fd1f8353674a.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


