## Meerkat Tutorial: Build Interfaces for Diffusion Models

In this demo, we will build an interface (from scratch) for generating chest X-rays with diffusion models.

We will use [Roentgen](https://arxiv.org/abs/2211.12737), a Stable Diffusion based model fine-tuned to generate chest X-rays. Model weights can be requested [here](https://t.co/uYEY1cO3SU).

This demo also uses the Stanford CheXpert dataset, which can be downloaded [here](https://stanfordmlgroup.github.io/competitions/chexpert/).



In [1]:
from tqdm.auto import tqdm
from functools import partialmethod
# Disable progress bars globally.
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

from typing import List, Union
import os

from datetime import datetime
import warnings

import torch
import torchvision.transforms as tv_tfms
import PIL
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline

import uuid

import meerkat as mk

In [2]:
# Start the server
# If you are on a remote machine, you will need to forward these ports to your local machine.
mk.gui.start(dev=False, skip_build=True, api_port=8886, frontend_port=8887)

(APIInfo(api=<fastapi.applications.FastAPI object at 0x7f152fcfd760>, port=8886, server=<meerkat.interactive.server.Server object at 0x7f16200b2c10>, name='127.0.0.1', shared=False, process=None, _url=None),
 FrontendInfo(package_manager='npm', port=8887, name='localhost', shared=False, process=<Popen: returncode: None args: ['python', '-m', 'http.server', '8887']>, _url=None))

In [3]:
# The path to the model weights.
model_path = os.path.expanduser("~/models/roentgen")

# The path to cache results from this demo.
cache_path = os.path.expanduser("~/.cache/images/roentgen")

# The path to chexpert dataset.
chexpert_path = os.path.expanduser("~/datasets/CheXpert-v1.0-small")

# We recommend using a cuda device.
device = "cuda"  # or mps, cpu...

In [4]:
roentgen_text2img: StableDiffusionPipeline = None
roentgen_img2img: StableDiffusionImg2ImgPipeline

def _load_roentgen():
    """Load the roentgen pipelines if they haven't been loaded yet.
    """
    global roentgen_text2img, roentgen_img2img

    if roentgen_text2img is not None and roentgen_img2img is not None:
        return

    pipe = StableDiffusionPipeline.from_pretrained(model_path).to(torch.float32).to(device)
    pipe.safety_checker = lambda images, clip_input: (images, False)
    
    roentgen_text2img = pipe
    
    roentgen_img2img = StableDiffusionImg2ImgPipeline(
        vae=pipe.vae,
        text_encoder=pipe.text_encoder,
        tokenizer=pipe.tokenizer,
        unet=pipe.unet,
        scheduler=pipe.scheduler,
        safety_checker=pipe.safety_checker,
        feature_extractor=pipe.feature_extractor
    )


In [5]:
@mk.endpoint()
def generate_images(
    df: mk.DataFrame,
    text: str,
    num_images: int,
    num_inference_steps: int,
    selected: mk.Store[List[str]] = [],
    height: int = 512,
    width: int = 512,
    guidance_scale: float = 4.0,
    negative_prompt: str = "",
):
    """Generate images from a prompt.

    This function is an endpoint that will update the dataframe `df` in-place.

    Args:
        df: The dataframe to append the new images to.
        text: The prompt to generate images from.
        num_images: The number of images to generate.
        num_inference_steps: The number of inference steps to run.
        height: The height of the generated images.
        width: The width of the generated images.
    """
    if not text:
        raise ValueError("Please enter a prompt.")
    if len(selected) > 1:
        raise ValueError("Only one image can be selected at a time.")

    _load_roentgen()

    prompt = text
    num_images = int(num_images)
    num_inference_steps = int(num_inference_steps)
    
    if negative_prompt:
        negative_prompt = [negative_prompt]
    else:
        negative_prompt = None

    with warnings.catch_warnings():
        warnings.simplefilter("error")
        if len(selected) > 0:
            output = roentgen_img2img(
                prompt=[prompt],
                negative_prompt=negative_prompt,
                image=df.loc[selected[0]]["img"](),
                num_images_per_prompt=num_images,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
            )
        else:
            output = roentgen_text2img(
                prompt=[prompt],
                negative_prompt=negative_prompt,
                num_images_per_prompt=num_images,
                num_inference_steps=num_inference_steps,
                height=height,
                width=width,
                guidance_scale=guidance_scale,
            )
    
    os.makedirs(cache_path, exist_ok=True)
    paths = []
    for i, image in enumerate(output["images"]):
        path = os.path.join(cache_path, str(uuid.uuid4()) + ".jpg")
        paths.append(path)
        image.save(path)
        
    # Create a new dataframe with the new images.
    time = datetime.now()
    new_df = mk.DataFrame(
        {
            "Path": paths,
            "source": ["generated"] * num_images,
            "prompt": [prompt] * num_images,
            # "num_inference_steps": [num_inference_steps] * num_images,
            # "time": [str(time)] * num_images,
            # "prompt_image_ids": [selected],
        }
    )
    new_df["img"] = mk.files(new_df["Path"], base_dir="")
    
    # Embed these images with CLIP to make them searchable
    new_df["img_clip"] = mk.embed(new_df["img"], encoder="clip", modality="image")
    
    new_df.set_primary_key("Path")
    df.set_primary_key("Path")
    if len(df) > 0:
        new_df = mk.concat([new_df, df])
    new_df.set_primary_key("Path")

    df.set(new_df)
    if isinstance(selected, mk.Store):
        selected.set([])


@mk.endpoint()
def delete_images(df: mk.DataFrame, selected: mk.Store):
    pkeys = [key for key in df.primary_key if key not in selected]
    new_df = df.loc[pkeys]

    selected.set([])
    df.set(new_df)

In [6]:
# Make a dataframe from the CheXpert dataset.
train = mk.DataFrame.from_csv(os.path.join(chexpert_path, "train.csv"))
valid = mk.DataFrame.from_csv(os.path.join(chexpert_path, "valid.csv"))
chexpert = mk.concat([train, valid])

# Only keep the Path and img columns to simplify adding generated images.
chexpert = chexpert[["Path"]]
chexpert["source"] = "chexpert"
chexpert["prompt"] = ""  # chexpert scans don't have prompts

# Get the clip embeddings.
df = mk.read("https://huggingface.co/datasets/meerkat-ml/meerkat-dataframes/resolve/main/embeddings/CheXpert-v1.0-small_clip-embeddings.mk.tar.gz")
chexpert = chexpert.merge(df, on="Path")

# Add the image column
chexpert["Path"] = os.path.dirname(chexpert_path) + "/" + chexpert["Path"]
chexpert["img"] = mk.files(chexpert["Path"], base_dir="")

  for name, column in merged_df.iteritems():


In [7]:
chexpert.mark()
selected = mk.Store([])

In [8]:
# Generate a sample image.
generate_images(
    df=chexpert,
    selected=[],
    text="small left-sided pleural effusion",
    num_images=1,
    num_inference_steps=75,
    height=512,
    width=512,
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 53.15it/s]


(None, [])

In [9]:
df = chexpert

### Components

In [10]:
# Sliders
num_images = mk.gui.Slider(value=1, min=1, max=10, step=1)
num_inference_steps = mk.gui.Slider(value=50, min=1, max=200, step=1)

In [11]:
# Textboxes
prompt = mk.gui.Textbox(
    text="",
    placeholder="Write a prompt...",
    on_keyenter=generate_images.partial(
        df=df,
        selected=selected,
        num_images=num_images.value,
        num_inference_steps=num_inference_steps.value
    ),
    classes="grow h-10 px-3 rounded-md shadow-md my-1 border-gray-400 w-full"
)

negative_prompt = mk.gui.Textbox(
    text="",
    placeholder="Write a negative prompt (optional)...",
    classes="grow h-10 px-3 rounded-md shadow-md my-1 border-gray-400 w-full"
)

In [20]:
# Buttons
generate = mk.gui.Button(
    title="Generate",
    icon="Magic",
    on_click=generate_images.partial(
        df=df,
        selected=selected,
        num_images=num_images.value,
        num_inference_steps=num_inference_steps.value,
        text=prompt.text,
    ),
    classes="bg-slate-100 py-1 rounded-md flex flex-col hover:bg-slate-200 w-full"
)

delete = mk.gui.Button(
    title="Delete",
    icon="TrashFill",
    on_click = delete_images.partial(df=df, selected=selected),
    classes="bg-slate-100 py-1 rounded-md flex flex-col hover:bg-slate-200 w-full"
)

In [21]:
# Gallery
gallery = mk.gui.contrib.GalleryQuery(
    df=df,
    main_column="img",
    against="img_clip",
    allow_selection=True,
)

In [22]:
# Putting it together.

# Overview Panel
overview_panel = mk.gui.html.flexcol(
    [
        mk.gui.Markdown(
            "Generate chest X-rays with [RoentGen](https://arxiv.org/abs/2211.12737)",
            classes="font-bold text-slate-600 text-sm",
        ),
        mk.gui.Markdown(
            "Specify the impressions to generate (*prompt*) and avoid (*negative prompt*)",
            classes="text-slate-600 text-sm",
        ),
        prompt,
        negative_prompt,
        mk.gui.html.gridcols2([generate, delete], classes="gap-x-4"),
        # filter,
    ],
    classes="justify-items-start mx-4 gap-1",
)

# Slider Panel
sliders = mk.gui.html.flexcol(
    [
        mk.gui.html.grid(
            [mk.gui.Markdown("Number of images to generate", classes="text-slate-600 text-sm"), num_images]
        ),
        mk.gui.html.grid(
            [mk.gui.Text("Number of diffusion steps", classes="text-slate-600 text-sm"), num_inference_steps],
            
        ),
    ],
    classes="items-stretch justify-items-start gap-x-4 gap-y-4 justify-content-space-between"
)

view = mk.gui.html.div(
    [
        mk.gui.html.grid(
            [overview_panel, sliders],
            classes="grid grid-cols-[1fr_1fr] space-x-5",
        ),
        gallery
    ],
    classes="gap-4 h-screen grid grid-rows-[auto_1fr]",
)

In [23]:
# Size the window.
view._get_ipython_height = lambda: "900px"

view