diff --git a/comfy/lora.py b/comfy/lora.py index 082a8b3cba..9e3e7a8bb1 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -252,15 +252,14 @@ def model_lora_keys_unet(model, key_map={}): key_map[diffusers_lora_key] = unet_key if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3 - for i in range(model.model_config.unet_config.get("depth", 0)): - k = "transformer.transformer_blocks.{}.attn.".format(i) - qkv = "diffusion_model.joint_blocks.{}.x_block.attn.qkv.weight".format(i) - proj = "diffusion_model.joint_blocks.{}.x_block.attn.proj.weight".format(i) - if qkv in sd: - offset = sd[qkv].shape[0] // 3 - key_map["{}to_q".format(k)] = (qkv, (0, 0, offset)) - key_map["{}to_k".format(k)] = (qkv, (0, offset, offset)) - key_map["{}to_v".format(k)] = (qkv, (0, offset * 2, offset)) - key_map["{}to_out.0".format(k)] = proj + diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format + key_map[key_lora] = to + + key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others? + key_map[key_lora] = to return key_map diff --git a/comfy/utils.py b/comfy/utils.py index 884404cceb..0d24fe7308 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -249,6 +249,76 @@ def unet_to_diffusers(unet_config): return diffusers_unet_map +MMDIT_MAP_BASIC = { + ("context_embedder.bias", "context_embedder.bias"), + ("context_embedder.weight", "context_embedder.weight"), + ("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"), + ("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"), + ("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"), + ("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"), + ("x_embedder.proj.bias", "pos_embed.proj.bias"), + ("x_embedder.proj.weight", "pos_embed.proj.weight"), + ("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"), + ("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"), + ("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"), + ("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"), + ("pos_embed", "pos_embed.pos_embed"), + ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias"), + ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight"), + ("final_layer.linear.bias", "proj_out.bias"), + ("final_layer.linear.weight", "proj_out.weight"), +} + +MMDIT_MAP_BLOCK = { + ("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"), + ("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"), + ("context_block.attn.proj.bias", "attn.to_add_out.bias"), + ("context_block.attn.proj.weight", "attn.to_add_out.weight"), + ("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"), + ("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"), + ("context_block.mlp.fc2.bias", "ff_context.net.2.bias"), + ("context_block.mlp.fc2.weight", "ff_context.net.2.weight"), + ("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"), + ("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"), + ("x_block.attn.proj.bias", "attn.to_out.0.bias"), + ("x_block.attn.proj.weight", "attn.to_out.0.weight"), + ("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"), + ("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"), + ("x_block.mlp.fc2.bias", "ff.net.2.bias"), + ("x_block.mlp.fc2.weight", "ff.net.2.weight"), + ("", ""), +} + +def mmdit_to_diffusers(mmdit_config, output_prefix=""): + key_map = {} + + depth = mmdit_config.get("depth", 0) + for i in range(depth): + block_from = "transformer_blocks.{}".format(i) + block_to = "{}joint_blocks.{}".format(output_prefix, i) + + offset = depth * 64 + + for end in ("weight", "bias"): + k = "{}.attn.".format(block_from) + qkv = "{}.x_block.attn.qkv.{}".format(block_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) + + qkv = "{}.context_block.attn.qkv.{}".format(block_to, end) + key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset)) + key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset)) + key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) + + for k in MMDIT_MAP_BLOCK: + key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) + + for k in MMDIT_MAP_BASIC: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size)