You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I used scripts/convert_vae_pt_to_diffusers.py to convert my trained VAE to Diffusers, I found that the current script cannot handle downsampling and upsampling blocks with attention layers. Therefore, I had to manually convert it using another method.
Reproduction
import argparse
import io
import requests
import torch
import yaml
from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
assign_to_checkpoint,
conv_attn_to_linear,
create_vae_diffusers_config,
renew_vae_attention_paths,
renew_vae_resnet_paths,
)
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
vae_state_dict = checkpoint
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
def vae_pt_to_vae_diffuser(
checkpoint_path: str,
output_path: str,
):
# Only support V1
r = requests.get(
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
io_obj = io.BytesIO(r.content)
# original_config = yaml.safe_load(io_obj)
original_config = yaml.safe_load(open("configs/convert.yaml", "r"))
image_size = 256
device = "cuda" if torch.cuda.is_available() else "cpu"
if checkpoint_path.endswith("safetensors"):
from safetensors import safe_open
checkpoint = {}
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
else:
checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]
# Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
vae_config["down_block_types"] = ('DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'AttnDownEncoderBlock2D')
vae_config["up_block_types"] = ('AttnUpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D')
print(vae_config)
converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config)
new_checkpoint = {}
for key in converted_vae_checkpoint.keys():
if "attn" in key:
new_key = key.replace("attn", "attentions").replace("down", "down_blocks").replace("up", "up_blocks").replace("norm", "group_norm").replace(".q.", ".to_q.").replace(".v.", ".to_v.").replace(".k.", ".to_k.").replace(".proj_out.", ".to_out.0.").replace("up_blocks.4", "up_blocks.0")
if converted_vae_checkpoint[key].shape[-2:] == (1, 1):
new_checkpoint[new_key] = converted_vae_checkpoint[key].squeeze(-1).squeeze(-1)
else:
new_checkpoint[new_key] = converted_vae_checkpoint[key]
else:
new_checkpoint[key] = converted_vae_checkpoint[key]
vae = AutoencoderKL(**vae_config)
print(vae)
vae.load_state_dict(new_checkpoint)
vae.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
args = parser.parse_args()
vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path)
Describe the bug
When I used
scripts/convert_vae_pt_to_diffusers.py
to convert my trained VAE to Diffusers, I found that the current script cannot handle downsampling and upsampling blocks with attention layers. Therefore, I had to manually convert it using another method.Reproduction
Logs
System Info
NVIDIA A100 80GB PCIe, 81920 MiB
NVIDIA A100 80GB PCIe, 81920 MiB
NVIDIA A100 80GB PCIe, 81920 MiB
Who can help?
@sayakpaul @DN6 @yiyixuxu
The text was updated successfully, but these errors were encountered: