Skip to content

Commit

Permalink
Support full SD3 loras.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 19, 2024
1 parent 55f0dc1 commit 3914d5a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 10 deletions.
19 changes: 9 additions & 10 deletions comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,14 @@ def model_lora_keys_unet(model, key_map={}):
key_map[diffusers_lora_key] = unet_key

if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
for i in range(model.model_config.unet_config.get("depth", 0)):
k = "transformer.transformer_blocks.{}.attn.".format(i)
qkv = "diffusion_model.joint_blocks.{}.x_block.attn.qkv.weight".format(i)
proj = "diffusion_model.joint_blocks.{}.x_block.attn.proj.weight".format(i)
if qkv in sd:
offset = sd[qkv].shape[0] // 3
key_map["{}to_q".format(k)] = (qkv, (0, 0, offset))
key_map["{}to_k".format(k)] = (qkv, (0, offset, offset))
key_map["{}to_v".format(k)] = (qkv, (0, offset * 2, offset))
key_map["{}to_out.0".format(k)] = proj
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
key_map[key_lora] = to

key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
key_map[key_lora] = to

return key_map
70 changes: 70 additions & 0 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,76 @@ def unet_to_diffusers(unet_config):

return diffusers_unet_map

MMDIT_MAP_BASIC = {
("context_embedder.bias", "context_embedder.bias"),
("context_embedder.weight", "context_embedder.weight"),
("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
("x_embedder.proj.bias", "pos_embed.proj.bias"),
("x_embedder.proj.weight", "pos_embed.proj.weight"),
("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"),
("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"),
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
("pos_embed", "pos_embed.pos_embed"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias"),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight"),
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.linear.weight", "proj_out.weight"),
}

MMDIT_MAP_BLOCK = {
("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"),
("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"),
("context_block.attn.proj.bias", "attn.to_add_out.bias"),
("context_block.attn.proj.weight", "attn.to_add_out.weight"),
("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"),
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
("x_block.mlp.fc2.weight", "ff.net.2.weight"),
("", ""),
}

def mmdit_to_diffusers(mmdit_config, output_prefix=""):
key_map = {}

depth = mmdit_config.get("depth", 0)
for i in range(depth):
block_from = "transformer_blocks.{}".format(i)
block_to = "{}joint_blocks.{}".format(output_prefix, i)

offset = depth * 64

for end in ("weight", "bias"):
k = "{}.attn.".format(block_from)
qkv = "{}.x_block.attn.qkv.{}".format(block_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))

qkv = "{}.context_block.attn.qkv.{}".format(block_to, end)
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset))
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))

for k in MMDIT_MAP_BLOCK:
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])

for k in MMDIT_MAP_BASIC:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])

return key_map

def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)
Expand Down

0 comments on commit 3914d5a

Please sign in to comment.