<a href="https://colab.research.google.com/github/MrDys/stable-diffusion-pytorch-to-flax/blob/main/Stable_Diffusion_PyTorch_to_Flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Stable Diffusion PyTorch to Flax
This notebook will convert a PyTorch-formatted Stable Diffusion model to a Flax model, optionally in `bfloat16` format, for use with TPUs.

This notebook should be run with a GPU runtime. Check your runtime type by going to **Runtime ⮕ Change Runtime Type**. If, during conversion, you encounter memory errors (likely, to be honest), change your **Runtime Shape** (from the same menu) to be **High-RAM**.

Notebook by [Sean Hannan](https://www.seaphantasm.com). The key to figuring a lot of this out was buried in [a GitHub issue comment](https://github.com/huggingface/diffusers/issues/1015#issuecomment-1297504201).

In [None]:
#@title Connect to Google Drive (optional)
#@markdown If your model is on Google Drive or you wish the save the output to Google Drive, run this cell to connect.
from google.colab import drive # type: ignore
try:
    drive_path = "/content/drive" #@param {type:"string"}
    drive.mount(drive_path,force_remount=True)
except:
    print("Error mounting drive.\n")

In [None]:
#@title Connect to Hugging Face (optional)
#@markdown If you want to download a model from [Hugging Face](https://huggingface.co) as part of the conversion, run this cell. It will prompt you for an API token. Make sure you you have agreed to any license agreement on the model card, or you may run into errors.
from IPython.display import clear_output, display
!pip install huggingface_hub==0.10.0 gradio
clear_output()

from huggingface_hub import notebook_login
!git config --global credential.helper store

notebook_login()


In [None]:
#@title Convert Model
#@markdown Set the location where the model will be saved.
output_path = "/content/drive/MyDrive/output/path" #@param {type:"string"}
import os
os.makedirs(output_path, exist_ok=True)

#@markdown Select the format of your model. `bfloat16` (or `bf16`) is a half-precision format ideal for running on TPUs. 
format = "bfloat16" #@param ["bfloat16", "float32"]

#@markdown The model to convert. It can be a path to a model on Google Drive or it can be on Hugging Face in the form of "&lt;account&gt;/&lt;model name&gt;".
model = "" #@param {type:"string"}

#@markdown If the model above is a path to a checkpoint (.ckpt) file, we first need to convert it to diffusers format.
is_checkpoint = False #@param {type:"boolean"}

!pip install --upgrade jax jaxlib
!pip install -U jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install flax diffusers transformers ftfy

if is_checkpoint:
    !pip install OmegaConf
    ![ ! -d "/content/diffusers" ] && git clone https://github.com/huggingface/diffusers.git /content/diffusers
    !python /content/diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path=$model --dump_path=/content/diffconversion
    model = "/content/diffconversion"

# Adapted from huggingface's transformers library
# https://github.com/huggingface/transformers/blob/b9a0ede6ab2558197d919e7a77a96dfd1c466b3f/src/transformers/modeling_flax_utils.py#L294-L355
import jax
import jax.numpy as jnp
from typing import Dict, Union
from flax.core.frozen_dict import FrozenDict
from diffusers import FlaxStableDiffusionPipeline

def to_bf16(params: Union[Dict, FrozenDict]):
    def conditional_cast(param):
        if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
            param = param.astype(jnp.bfloat16)
        return param

    return jax.tree_util.tree_map(conditional_cast, params)

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model, from_pt=True)
if format == "bfloat16":
  params = to_bf16(params=params)

pipeline.save_pretrained(output_path, params=params)