In [3]:
import torch
import collections

# --- Configuration ---
# Path to the original downloaded model checkpoint
ORIGINAL_MODEL_PATH = "ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
# Path where the new, refactored model will be saved
NEW_MODEL_PATH = "ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_refactored.pt"
NUM_DOUBLE_BLOCKS = 20
# We will need to do the same for single blocks later
NUM_SINGLE_BLOCKS = 40 

def convert_weights():
    """
    Loads the original HunyuanVideo weights, splits the fused QKV layers,
    and saves a new state dictionary compatible with the refactored model.
    """
    print("Loading original model state dictionary...")
    # We load the 'module' part of the checkpoint
    original_state_dict = torch.load(ORIGINAL_MODEL_PATH, map_location='cpu')['module']
    new_state_dict = original_state_dict.copy()

    print("Converting DoubleStreamBlocks...")
    for i in range(NUM_DOUBLE_BLOCKS):
        # --- Process Image Attention Weights ---
        img_qkv_weight_key = f"double_blocks.{i}.img_attn_qkv.weight"
        img_qkv_bias_key = f"double_blocks.{i}.img_attn_qkv.bias"

        if img_qkv_weight_key in original_state_dict:
            # Get the combined QKV weight and bias
            qkv_weight = original_state_dict[img_qkv_weight_key]
            qkv_bias = original_state_dict[img_qkv_bias_key]

            # Split them into three equal parts
            q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3, dim=0)
            q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)

            # Add the new weights to our new state dict with new names
            new_state_dict[f"double_blocks.{i}.img_attn_q.weight"] = q_weight
            new_state_dict[f"double_blocks.{i}.img_attn_k.weight"] = k_weight
            new_state_dict[f"double_blocks.{i}.img_attn_v.weight"] = v_weight
            new_state_dict[f"double_blocks.{i}.img_attn_q.bias"] = q_bias
            new_state_dict[f"double_blocks.{i}.img_attn_k.bias"] = k_bias
            new_state_dict[f"double_blocks.{i}.img_attn_v.bias"] = v_bias

            # Remove the old, combined keys
            del new_state_dict[img_qkv_weight_key]
            del new_state_dict[img_qkv_bias_key]

        # --- Process Text Attention Weights (Identical Logic) ---
        txt_qkv_weight_key = f"double_blocks.{i}.txt_attn_qkv.weight"
        txt_qkv_bias_key = f"double_blocks.{i}.txt_attn_qkv.bias"
        
        if txt_qkv_weight_key in original_state_dict:
            qkv_weight = original_state_dict[txt_qkv_weight_key]
            qkv_bias = original_state_dict[txt_qkv_bias_key]

            q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3, dim=0)
            q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)

            new_state_dict[f"double_blocks.{i}.txt_attn_q.weight"] = q_weight
            new_state_dict[f"double_blocks.{i}.txt_attn_k.weight"] = k_weight
            new_state_dict[f"double_blocks.{i}.txt_attn_v.weight"] = v_weight
            new_state_dict[f"double_blocks.{i}.txt_attn_q.bias"] = q_bias
            new_state_dict[f"double_blocks.{i}.txt_attn_k.bias"] = k_bias
            new_state_dict[f"double_blocks.{i}.txt_attn_v.bias"] = v_bias

            del new_state_dict[txt_qkv_weight_key]
            del new_state_dict[txt_qkv_bias_key]

    # (We would add a similar loop for single_blocks here)
    # In your convert_weights.py script...

    # (Add this loop after the DoubleStreamBlocks loop)

    print("Converting SingleStreamBlocks...")
    for i in range(NUM_SINGLE_BLOCKS): # Make sure NUM_SINGLE_BLOCKS is defined, e.g., 40
        key_prefix = f"single_blocks.{i}."
        linear1_weight_key = f"{key_prefix}linear1.weight"
        linear1_bias_key = f"{key_prefix}linear1.bias"

        if linear1_weight_key in original_state_dict:
            # Get the combined QKV+MLP weight and bias
            linear1_weight = original_state_dict[linear1_weight_key]
            linear1_bias = original_state_dict[linear1_bias_key]

            # The first part is for QKV, the second is for the MLP
            hidden_size = 3072 # As defined in the model
            mlp_hidden_dim = 12288 # hidden_size * 4

            # Split the weight tensor into QKV and MLP parts
            qkv_weight, mlp_weight = torch.split(linear1_weight, [hidden_size * 3, mlp_hidden_dim], dim=0)
            qkv_bias, mlp_bias = torch.split(linear1_bias, [hidden_size * 3, mlp_hidden_dim], dim=0)
            
            # Further split the QKV part into Q, K, and V
            q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3, dim=0)
            q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)

            # Add the four new weights and biases to our new state dict
            new_state_dict[f"{key_prefix}q_proj.weight"] = q_weight
            new_state_dict[f"{key_prefix}k_proj.weight"] = k_weight
            new_state_dict[f"{key_prefix}v_proj.weight"] = v_weight
            new_state_dict[f"{key_prefix}mlp_proj.weight"] = mlp_weight
            
            new_state_dict[f"{key_prefix}q_proj.bias"] = q_bias
            new_state_dict[f"{key_prefix}k_proj.bias"] = k_bias
            new_state_dict[f"{key_prefix}v_proj.bias"] = v_bias
            new_state_dict[f"{key_prefix}mlp_proj.bias"] = mlp_bias

            # Remove the old, combined key
            del new_state_dict[linear1_weight_key]
            del new_state_dict[linear1_bias_key]

    print(f"Saving new refactored state dictionary to {NEW_MODEL_PATH}...")
    torch.save({'module': new_state_dict}, NEW_MODEL_PATH)
    print("Conversion complete!")

if __name__ == '__main__':
    convert_weights()

Loading original model state dictionary...


  original_state_dict = torch.load(ORIGINAL_MODEL_PATH, map_location='cpu')['module']


Converting DoubleStreamBlocks...
Converting SingleStreamBlocks...
Saving new refactored state dictionary to ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_refactored.pt...
Conversion complete!


In [4]:
import torch

recompute_schedule_k = {'double_0': [0, 6, 12],
 'double_1': [0, 4, 11, 14],
 'double_2': [0, 1, 5, 11],
 'double_3': [0, 1, 2, 4, 7, 10, 13, 14],
 'double_4': [0, 1, 2, 4, 8, 12, 14],
 'double_5': [0, 1, 2, 3, 5, 8, 11, 13, 14],
 'double_6': [0, 1, 2, 3, 5, 8, 11, 13, 14],
 'double_7': [0, 1, 2, 3, 5, 8, 11, 13, 14],
 'double_8': [0, 1, 2, 3, 4, 6, 9, 12, 13, 14],
 'double_9': [0, 1, 2, 3, 4, 6, 9, 12, 13, 14],
 'double_10': [0, 1, 2, 3, 4, 6, 9, 12, 13, 14],
 'double_11': [0, 1, 2, 3, 4, 5, 7, 9, 11, 12, 13, 14],
 'double_12': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'double_13': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'double_14': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'double_15': [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 12, 13, 14],
 'double_16': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'double_17': [0, 1, 2, 3, 4, 5, 7, 9, 11, 13, 14],
 'double_18': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'double_19': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'single_0': [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 12, 13, 14],
 'single_1': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_2': [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 12, 13, 14],
 'single_3': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'single_4': [0, 1, 2, 3, 4, 5, 7, 10, 12, 13, 14],
 'single_5': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'single_6': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'single_7': [0, 1, 2, 3, 4, 5, 7, 10, 12, 14],
 'single_8': [0, 1, 2, 3, 4, 5, 7, 10, 12, 13, 14],
 'single_9': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'single_10': [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 12, 13, 14],
 'single_11': [0, 1, 2, 3, 4, 5, 7, 10, 12, 13, 14],
 'single_12': [0, 1, 2, 3, 4, 5, 7, 10, 12, 14],
 'single_13': [0, 1, 2, 3, 4, 5, 7, 10, 13, 14],
 'single_14': [0, 1, 2, 3, 4, 5, 7, 10, 12, 14],
 'single_15': [0, 1, 2, 3, 4, 6, 9, 12, 14],
 'single_16': [0, 1, 2, 3, 4, 5, 7, 10, 13, 14],
 'single_17': [0, 1, 2, 3, 4, 5, 7, 10, 13, 14],
 'single_18': [0, 1, 2, 3, 5, 9, 13],
 'single_19': [0, 1, 2, 3, 4, 5, 7, 10, 13, 14],
 'single_20': [0, 1, 2, 3, 4, 6, 10, 13],
 'single_21': [0, 1, 2, 3, 4, 6, 10, 13],
 'single_22': [0, 1, 2, 3, 4, 6, 10, 13],
 'single_23': [0, 1, 2, 3, 5, 9, 13],
 'single_24': [0, 1, 2, 3, 4, 6, 10, 14],
 'single_25': [0, 1, 3, 10],
 'single_26': [0, 2, 14],
 'single_27': [0, 12],
 'single_28': [0, 1, 3, 11],
 'single_29': [0, 1, 2, 3, 5, 10, 14],
 'single_30': [0, 1, 2, 4, 9, 14],
 'single_31': [0, 1, 2, 3, 4, 5, 7, 11, 13, 14],
 'single_32': [0, 1, 2, 3, 6, 12],
 'single_33': [0, 1, 2, 3, 5, 10, 14],
 'single_34': [0, 1, 2, 6, 14],
 'single_35': [0, 1, 3, 11],
 'single_36': [0, 1, 3, 12],
 'single_37': [0, 1, 2, 7],
 'single_38': [0, 1, 6, 14],
 'single_39': [0, 1, 11]}

recompute_schedule_v = {'double_0': [0, 2, 7, 11, 13, 14],
 'double_1': [0, 1, 4, 8, 12, 14],
 'double_2': [0, 1, 2, 4, 7, 11, 13, 14],
 'double_3': [0, 1, 2, 3, 5, 7, 10, 12, 14],
 'double_4': [0, 1, 2, 3, 5, 7, 10, 12, 13, 14],
 'double_5': [0, 1, 2, 3, 4, 5, 7, 9, 11, 12, 13, 14],
 'double_6': [0, 1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14],
 'double_7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14],
 'double_8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_9': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_10': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_11': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_12': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_13': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_14': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_15': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_16': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_17': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_18': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'double_19': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_0': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_1': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_2': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_9': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_10': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_11': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_12': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_13': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_14': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_15': [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
 'single_16': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_17': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_18': [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 14],
 'single_19': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
 'single_20': [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14],
 'single_21': [0, 1, 2, 3, 4, 5, 7, 10, 13, 14],
 'single_22': [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 14],
 'single_23': [0, 1, 2, 3, 4, 5, 6, 8, 11, 13, 14],
 'single_24': [0, 1, 2, 3, 4, 5, 7, 11, 14],
 'single_25': [0, 1, 2, 4, 9, 14],
 'single_26': [0, 1, 2, 4, 9, 14],
 'single_27': [0, 5],
 'single_28': [0, 1, 2, 3, 5, 10, 14],
 'single_29': [0, 1, 2, 3, 4, 5, 8, 12, 14],
 'single_30': [0, 1, 2, 3, 4, 5, 8, 12, 14],
 'single_31': [0, 1, 2, 3, 4, 5, 6, 9, 12, 14],
 'single_32': [0, 1, 2, 3, 4, 5, 8, 12, 14],
 'single_33': [0, 1, 2, 3, 4, 5, 7, 11, 14],
 'single_34': [0, 1, 2, 3, 7, 12, 14],
 'single_35': [0, 1, 2, 3, 8, 14],
 'single_36': [0, 1, 2, 3, 4, 5, 9, 14],
 'single_37': [0, 1, 2, 3, 4, 5, 8, 13],
 'single_38': [0, 1, 6],
 'single_39': [0, 1, 2, 5, 13]}


# Save both dictionaries as .pt files
torch.save(recompute_schedule_k, "recompute_schedule_k.pt")
torch.save(recompute_schedule_v, "recompute_schedule_v.pt")

print("✅ Saved recompute_schedule_k.pt and recompute_schedule_v.pt")


✅ Saved recompute_schedule_k.pt and recompute_schedule_v.pt


In [5]:
import torch
state = torch.load("ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", map_location='cpu')
print(state.keys())

  state = torch.load("ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", map_location='cpu')


dict_keys(['module'])
