# Web Stable Diffusion - TVMCon 2023

https://mlc.ai/web-stable-diffusion

This project brings stable diffusion models to web browsers. **Everything runs inside the browser with no server support.** To our knowledge, this is the the world’s first stable diffusion completely running on the browser. Now let’s get started.

![workflow](site/img/fig/workflow.svg)

## Install packages

To import and build the model, we first need to install the on-going development of TVM Unity and other dependencies with the following pip command.  

In [None]:
!python3 -m pip install --pre torch --upgrade --index-url https://download.pytorch.org/whl/nightly/cpu
!python3 -m pip install diffusers transformers accelerate
!python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels

We import necessary packages and set up the artifact directory.

In [1]:
import tvm
from tvm import relax
from tvm.script import relax as R

import torch
from torch import fx

from web_stable_diffusion import trace
from web_stable_diffusion import utils

torch_dev_key = "mps"
target = tvm.target.Target(
    "webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm"
)

In [2]:
!mkdir -p dist

## Import stable diffusion models

With necessary packages imported, the first step is to import the stable diffusion PyTorch models into TVM. Here we also leverage the techniques and interfaces about [TorchDynamo](https://pytorch.org/tutorials/intermediate/dynamo_tutorial.html) and [Torch FX](https://pytorch.org/docs/stable/fx.html) introduced before.

![pipeline](site/img/fig/pipeline.svg)

Compared with building and deploying the stable diffusion model to local CUDA backend, building and deploying the model to web with WebGPU runtime is special, mostly because of the difference of runtime environment:

**Web runtime has no Python (let alone PyTorch, NumPy). We can only leverage JavaScript for model deployment on web.**

Considering this major difference, we will need to _simplify the runtime stable diffusion pipeline (written in JavaScript) as much as possible_, and maximize the number of tasks in the build stage, ahead of the web model deployment. Our workflow demonstrates this principal in the two following aspects:
1. capture more computation to TVM’s IRModule,
2. separate models’ weights from the IRModule.

### 1. Capture more computation to TVM’s IRModule

In the previous CUDA deployment where we only import the CLIP text encoder, the UNet and the VAE decoder to Relax and leave all other steps as PyTorch operations.

However, in web runtime we do not have PyTorch. As result, it is necessary to cover those additional operations in our IRModule as well. There are two different approaches to the same goal. Both are very simple:
1. implement the operations manually with existing Relax infrastructure, or
2. write a wrapper `torch.nn.Module` which both contains the ML model and the appending/prepending operations.

In our practice, we adopt both approaches. For some operations we use wrapper `nn.Module`, while for others we write the operation manually. Let’s walk through each of them.

#### ①. The CLIP text encoder

In the entire pipeline, the text encoder is used twice: one for the prompt and the other one for the negative prompt. Since it is next to the tokenization (which we do not import) and is followed by the concatenation of both prompts’ embeddings, the encoder is a standalone phase.

Therefore, we use an `nn.Module` to wrap the single CLIP forward, and use `dynamo_capture_subgraphs` to import the `nn.Module`.

In [3]:
from tvm.relax.frontend.torch import dynamo_capture_subgraphs
from tvm.relax.frontend.torch import from_fx

def clip_to_text_embeddings(pipe) -> tvm.IRModule:
    # Define the wrapper torch.nn.Module for CLIP.
    class CLIPModelWrapper(torch.nn.Module):
        def __init__(self, clip):
            super().__init__()
            self.clip = clip

        def forward(self, text_input_ids):
            text_embeddings = self.clip(text_input_ids)[0]
            return text_embeddings

    clip = pipe.text_encoder
    clip_to_text_embeddings = CLIPModelWrapper(clip)

    # Create random input (77 is the maximum length).
    text_input_ids = torch.rand((1, 77)).to(torch.int32)
    # Capture CLIP's computational graph.
    mod = dynamo_capture_subgraphs(
        clip_to_text_embeddings.forward,
        text_input_ids,
        keep_params_as_input=True,
    )
    assert len(mod.functions) == 1

    return tvm.IRModule({"clip": mod["subgraph_0"]})

#### ②. The embedding concatenation

This stage concatenates the embeddings of both prompts, and is followed by the huge UNet iteration. It is also standalone, and here we choose to implement the concatenation by hand.

In [4]:
def concat_embeddings() -> tvm.IRModule:
    bb = relax.BlockBuilder()
    cond_embeddings = relax.Var("cond_embeddings", R.Tensor([1, 77, 768], "float32"))
    uncond_embeddings = relax.Var(
        "uncond_embeddings", R.Tensor([1, 77, 768], "float32")
    )
    with bb.function("concat_embeddings", [cond_embeddings, uncond_embeddings]):
        res = bb.emit(
            relax.op.concat([cond_embeddings, uncond_embeddings], axis=0)
        )
        bb.emit_func_output(res)
    return bb.get()

#### ③. Latent concat + UNet + classifier-free guidance

The third stage is the first part of the UNet loop body. It is mostly the UNet forward, while before the UNet forward there is a step of latent concatenation, and after UNet forward, one step of [classifier-free guidance](https://github.com/huggingface/diffusers/blob/79eb3d07d07a2dada172c5958d6fca478c860f16/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L674-L677) will be performed to force the generation to better match the prompt potentially. Since the latent concatenation and the guidance are immediately before/after the UNet forward which we always have to import whatever, we use a wrapper `nn.Module` to import all of them.

In [5]:
def unet_latents_to_noise_pred(pipe, device_str: str) -> tvm.IRModule:
    class UNetModelWrapper(torch.nn.Module):
        def __init__(self, unet):
            super().__init__()
            self.unet = unet
            # Default guidance scale factor in stable diffusion.
            self.guidance_scale = 7.5

        def forward(self, latents, timestep_tensor, text_embeddings):
            # Latent concatenation.
            latent_model_input = torch.cat([latents] * 2, dim=0)
            # UNet forward.
            noise_pred = self.unet(latent_model_input, timestep_tensor, text_embeddings)
            # Classifier-free guidance.
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + self.guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )
            return noise_pred

    unet = utils.get_unet(pipe, device_str)
    unet_to_noise_pred = UNetModelWrapper(unet)
    graph = fx.symbolic_trace(unet_to_noise_pred)
    mod = from_fx(
        graph,
        [((1, 4, 64, 64), "float32"), ((), "int32"), ((2, 77, 768), "float32")],
        keep_params_as_input=True,
    )
    return tvm.IRModule({"unet": mod["main"]})

#### ④. Scheduler step

The scheduler step stage is the other part of the UNet iteration, and is very important for updating the latents towards the denoising direction. There are many kinds of different schedulers, with each having (possibly) largely different implementation. Here we use the multi-step DPM solver scheduler.

One feature of schedulers is that schedulers usually maintain a list of history UNet output, and the scheduler step operation takes the maintained history as input internally. Since the step operation is history dependent, we are not able to combine the scheduler step together with the previous UNet part, and have to implement it separately. Because the scheduler step is also standalone mostly, we implement it by hand.

In [6]:
def dpm_solver_multistep_scheduler_steps() -> tvm.IRModule:
    bb = relax.BlockBuilder()

    # convert_model_output, the first function in multi-step DPM solver.
    sample = relax.Var("sample", R.Tensor((1, 4, 64, 64), "float32"))
    model_output = relax.Var("model_output", R.Tensor((1, 4, 64, 64), "float32"))
    alpha = relax.Var(f"alpha", R.Tensor((), "float32"))
    sigma = relax.Var(f"sigma", R.Tensor((), "float32"))
    with bb.function(
        "dpm_solver_multistep_scheduler_convert_model_output",
        [sample, model_output, alpha, sigma],
    ):
        converted_model_output = bb.emit(
            (sample - sigma * model_output) / alpha, "converted_model_output"
        )
        bb.emit_func_output(converted_model_output)

    # step, the second function.
    sample = relax.Var("sample", R.Tensor((1, 4, 64, 64), "float32"))
    model_output = relax.Var("model_output", R.Tensor((1, 4, 64, 64), "float32"))
    last_model_output = relax.Var(
        "last_model_output", R.Tensor((1, 4, 64, 64), "float32")
    )
    consts = [relax.Var(f"c{i}", R.Tensor((), "float32")) for i in range(3)]

    with bb.function(
        "dpm_solver_multistep_scheduler_step",
        [sample, model_output, last_model_output, *consts],
    ):
        prev_sample = bb.emit(
            consts[0] * sample
            - consts[1] * model_output
            - consts[2] * (model_output - last_model_output),
            "prev_sample",
        )
        bb.emit_func_output(prev_sample)

    return bb.get()

#### ⑤. VAE + image normalization

The last but one stage is the VAE step followed by an image normalization, which normalizes the value range from `[-1, 1]` to integers in `[0, 255]`. For the same reason as ③, we use a wrapper `nn.Module`.

In [7]:
def vae_to_image(pipe) -> tvm.IRModule:
    class VAEModelWrapper(torch.nn.Module):
        def __init__(self, vae):
            super().__init__()
            self.vae = vae

        def forward(self, latents):
            # Scale the latents so that it can be decoded by VAE.
            latents = 1 / 0.18215 * latents
            # VAE decode
            z = self.vae.post_quant_conv(latents)
            image = self.vae.decoder(z)
            # Image normalization
            image = (image / 2 + 0.5).clamp(min=0, max=1)
            image = (image.permute(0, 2, 3, 1) * 255).round()
            return image

    vae = pipe.vae
    vae_to_image = VAEModelWrapper(vae)

    z = torch.rand((1, 4, 64, 64), dtype=torch.float32)
    mod = dynamo_capture_subgraphs(
        vae_to_image.forward,
        z,
        keep_params_as_input=True,
    )
    assert len(mod.functions) == 1

    return tvm.IRModule({"vae": mod["subgraph_0"]})

#### ⑥. Image conversion to RGBA

To display the image, we need to convert the image to RGBA mode that can be directly rendered by the web runtime. This conversion requires dtype `uint32`, which PyTorch doesn’t support. Therefore, we are unable to combine this stage with the previous one, and need to implement it by hand with Relax.

In [8]:
def image_to_rgba() -> tvm.IRModule:
    from tvm import te

    def f_image_to_rgba(A):
        def fcompute(y, x):
            return (
                A[0, y, x, 0].astype("uint32")
                | (A[0, y, x, 1].astype("uint32") << 8)
                | (A[0, y, x, 2].astype("uint32") << 16)
                | tvm.tir.const(255 << 24, "uint32")
            )

        return te.compute((512, 512), fcompute, name="image_to_rgba")

    bb = relax.BlockBuilder()
    x = relax.Var("x", R.Tensor([1, 512, 512, 3], "float32"))
    with bb.function("image_to_rgba", [x]):
        image = bb.emit(
            bb.call_te(f_image_to_rgba, x, primfunc_name_hint="tir_image_to_rgba")
        )
        bb.emit_func_output(image)
    return bb.get()

#### Combine every piece together

We have described how we import every part of the stable diffusion pipeline into Relax. Now we can combine every of them together:

In [9]:
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
clip = clip_to_text_embeddings(pipe)
unet = unet_latents_to_noise_pred(pipe, torch_dev_key)
vae = vae_to_image(pipe)
concat_embeddings = concat_embeddings()
image_to_rgba = image_to_rgba()
schedulers = [dpm_solver_multistep_scheduler_steps()]

mod: tvm.IRModule = utils.merge_irmodules(
    clip,
    unet,
    vae,
    concat_embeddings,
    image_to_rgba,
    *schedulers,
)

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]



### 2. Separate models’ weights from the IRModule

To reduce the size of the built artifact so that it can be universally deployed everywhere (including the web), we separate models’ weights from the IRModule we get. For a weight tensor, we use a placeholder to represent it in the IRModule, instead of letting it reside in the IRModule as a constant tensor. We will save the separated weights to the disk later. At the beginning of the deployment, we will load these weights from disk to memory.

The separation is implemented as function `relax.frontend.detach_params`.

In [10]:
mod, params = relax.frontend.detach_params(mod)

We can try to print out the entire IRModule via 
```python
mod.show()
```
to see the models and other functions we have imported in the IRModule. The output will be thousands of lines long, so we do not run it live here. If you try it out, the printed output should look in the following way:
```python
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def clip(
        inp_0: R.Tensor((1, 77), dtype="int32"),
        self_clip_text_model_embeddings_position_embedding_weight: R.Tensor((77, 768), dtype="float32"),
        self_clip_text_model_embeddings_token_embedding_weight: R.Tensor((49408, 768), dtype="float32"),
        ...
    ) -> R.Tensor((1, 77, 768), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv: R.Tensor((1, 77), dtype="int32") = R.reshape(inp_0, R.shape([1, 77]))
            lv1: R.Tensor((1, 77), dtype="int32") = R.astype(lv, dtype="int32")
            lv2: R.Tensor((77,), dtype="int32") = R.reshape(lv1, R.shape([77]))
            lv3: R.Tensor((77, 768), dtype="float32") = R.take(self_clip_text_model_embeddings_token_embedding_weight, lv2, axis=0)
            lv4: R.Tensor((1, 77, 768), dtype="float32") = R.reshape(lv3, R.shape([1, 77, 768]))
            lv5: R.Tensor((1, 77), dtype="int32") = R.astype(metadata["relax.expr.Constant"][0], dtype="int32")
            lv6: R.Tensor((77,), dtype="int32") = R.reshape(lv5, R.shape([77]))
            lv7: R.Tensor((77, 768), dtype="float32") = R.take(self_clip_text_model_embeddings_position_embedding_weight, lv6, axis=0)
            lv8: R.Tensor((1, 77, 768), dtype="float32") = R.reshape(lv7, R.shape([1, 77, 768]))
            lv9: R.Tensor((1, 77, 768), dtype="float32") = R.add(lv4, lv8)
            ...
```

Instead of printing out the whole IRModule, what we can do is to print out the names of the functions.

In [11]:
def print_relax_funcnames(mod: tvm.IRModule):
    for global_var, func in mod.functions.items():
        if isinstance(func, relax.Function):
            print(global_var.name_hint)
    print()


print_relax_funcnames(mod)

dpm_solver_multistep_scheduler_step
dpm_solver_multistep_scheduler_convert_model_output
image_to_rgba
clip
unet
concat_embeddings
vae



We can also print out one of the weight tensors to see what we have captured for model weights.

In [12]:
# Print the first weight parameter of the CLIP model.
params["clip"][0]

<tvm.nd.NDArray shape=(77, 768), cpu(0)>
array([[ 0.00158362,  0.0020091 ,  0.00020799, ..., -0.00130294,
         0.0007798 ,  0.00150727],
       [ 0.00423452,  0.00287621,  0.00020198, ...,  0.00103357,
         0.0014911 , -0.00119652],
       [ 0.00183514,  0.00073841, -0.00124233, ..., -0.00294402,
        -0.00091987,  0.00255763],
       ...,
       [ 0.02157524,  0.00553936, -0.01014109, ..., -0.00649147,
        -0.00294858,  0.00372774],
       [ 0.0188203 ,  0.00729219, -0.00766407, ..., -0.00251736,
        -0.00087413,  0.00567614],
       [ 0.03300093,  0.02810323,  0.0288674 , ...,  0.01597873,
         0.01021753, -0.03095413]], dtype=float32)

By now, we have went through all steps of importing the stable diffusion model to Relax for web deployment.

## Optimize and bulild the model

This section talks about how we optimize the stable diffusion model in TVM Unity, and how we build it to the WebGPU backend. We will very briefly go through these steps as they are not the focus of this tutorial. If you are interested, we have a walkthrough notebook which focuses more on the optimization and build at https://github.com/mlc-ai/web-stable-diffusion/blob/main/walkthrough.ipynb.

### Optimization

The optimization for the model is mainly kernel fusion and constant folding.

We also apply the pre-tuned MetaSchedule database to the IRModule, so that each operator in the IRModule is well-optimized and ready to be built to GPU backends.

In [13]:
model_names = ["clip", "unet", "vae"]
scheduler_func_names = [
    "dpm_solver_multistep_scheduler_convert_model_output",
    "dpm_solver_multistep_scheduler_step",
]
entry_funcs = (
    model_names + scheduler_func_names + ["image_to_rgba", "concat_embeddings"]
)

# The default Relax pipeline contains kernel fusion and constant folding.
mod = relax.pipeline.get_pipeline()(mod)

# Some other transformations.
mod = relax.transform.RemoveUnusedFunctions(entry_funcs)(mod)
mod = relax.transform.LiftTransformParams()(mod)
mod_transform, mod_deploy = utils.split_transform_deploy_mod(
    mod, model_names, entry_funcs
)



In [14]:
# Apply MetaSchedule database.
from tvm import meta_schedule as ms

db = ms.database.create(work_dir="log_db")
with target, db, tvm.transform.PassContext(opt_level=3):
    mod_deploy = relax.transform.MetaScheduleApplyDatabase()(mod_deploy)

### Preparation for build

As previously mentioned, we need to save the models’ weights to the disk. In addition to the weights, we also store the constants used by the scheduler at each step of iteration to the disk.

In [15]:
trace.compute_save_scheduler_consts(artifact_path="dist")
new_params = utils.transform_params(mod_transform, params)
utils.save_params(new_params, artifact_path="dist")

Start storing to cache dist/params
[1054/1054] saving clip_195
All finished, 61 total shards committed, record saved to dist/params/ndarray-cache.json
Also saved a bf16 record to dist/params/ndarray-cache-b16.json


### Build

We build the model to WebGPU backend by `relax.build`.

In [16]:
ex = relax.build(mod_deploy, target)



And then export the build artifact to the disk, which we can load back in web runtime.

In [17]:
ex.export_library("dist/stable_diffusion_webgpu.wasm")

## Deploy on web

As the last step, let’s try to deploy the stable diffusion to web end with WebGPU runtime.
* We have implemented the stable diffusion pipeline as described before [in JavaScript](https://github.com/mlc-ai/web-stable-diffusion/blob/main/web/stable_diffusion.js) in ahead. It connects everything together and has about 500 lines of code.
* To deploy the stable diffusion to web, we actually need to run a shell script before build. Here we assume we have run the script, and just set up the site and show the final demo.

For detailed instructions, please refer to our GitHub repo https://github.com/mlc-ai/web-stable-diffusion.

### Set up the website

The last thing to do is setting up the site by running the following command in a terminal session:
```shell
./scripts/local_deploy_site.sh
```

Once the website is set up, open `localhost:8888/web-stable-diffusion/` in Chrome Canary to try out the demo on your local machine!

In [None]:
!./scripts/local_deploy_site.sh

+ scripts/build_site.sh web/local-config.json
+ [[ ! -f web/local-config.json ]]
+ rm -rf site/dist
+ mkdir -p site/dist site/_inlcudes
+ echo 'Copy local configurations..'
Copy local configurations..
+ cp web/local-config.json site/stable-diffusion-config.json
+ echo 'Copy files...'
Copy files...
+ cp web/stable_diffusion.html site/_includes
+ cp web/stable_diffusion.js site/dist
+ cp dist/scheduler_pndm_consts.json site/dist
+ cp dist/scheduler_dpm_solver_multistep_consts.json site/dist
+ cp dist/stable_diffusion_webgpu.wasm site/dist
+ cp dist/tvmjs_runtime.wasi.js site/dist
+ cp dist/tvmjs.bundle.js site/dist
+ cp -r dist/tokenizers-wasm site/dist
+ cd site
+ jekyll b
Configuration file: /Users/ruihang-macstudio/Workspace/web-stable-diffusion/site/_config.yml
            Source: /Users/ruihang-macstudio/Workspace/web-stable-diffusion/site
       Destination: /Users/ruihang-macstudio/Workspace/web-stable-diffusion/site/_site
 Incremental build: disabled. Enable with --incremental
  