**Merging SDXL checkpoint via Layer block weight**

**Tips When merging models:**

Lower IN blocks (00-03): Affect basic features, textures, and patterns
Middle IN blocks (04-06): Affect object parts and compositions
Higher IN blocks (07-08): Affect scene-level features
Middle block (M00): Affects overall style and composition
Higher OUT blocks (06-08): Affect major compositional elements
Middle OUT blocks (03-05): Affect object details and relationships
Lower OUT blocks (00-02): Affect final details and refinements

Tips for adjustment:

For style transfer: Focus on middle blocks and higher input blocks
For detail preservation: Adjust lower output blocks
For composition: Focus on middle block and higher output blocks
For base features: Adjust lower input blocks**

In [None]:
import os
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import gc

# Configuration for block weights
# Modify these values to adjust the merge weights for different components
BLOCK_WEIGHTS = {
    # Input blocks (early layers)
    'input_blocks.0': 0.0, # Initial layer, processes basic image features like edges and colors
    'input_blocks.1': 0.0, # Low-level feature extraction
    'input_blocks.2': 0.0, # Basic shapes and patterns
    'input_blocks.3': 0.0, # Texture details
    'input_blocks.4': 0.0, # Simple object parts
    'input_blocks.5': 0.0, # More complex object parts
    'input_blocks.6': 0.0, # Basic object compositions
    'input_blocks.7': 0.0, # Object relationships
    'input_blocks.8': 0.0, # Higher-level scene features
    'input_blocks.9': 0.0,
    'input_blocks.10': 0.0,
    'input_blocks.11': 0.0,

    # Middle blocks
    'middle_block': 0.5, #Middle Block (M00) This is the bottleneck layer that processes the most abstract representations Handles global context and relationships between elements Very important for overall image composition and style

    # Output blocks (later layers)
    'output_blocks.0': 0.0, #Final image details and cleanup
    'output_blocks.1': 0.0, # Color refinement
    'output_blocks.2': 0.0, # Fine details
    'output_blocks.3': 0.0, # Texture refinement
    'output_blocks.4': 0.0, # Object refinement
    'output_blocks.5': 0.0, # Complex object details
    'output_blocks.6': 0.0, # Object relationships and positioning
    'output_blocks.7': 0.0, # Highest level abstract features
    'output_blocks.8': 0.0, # Complex scene composition
    'output_blocks.9': 0.0,
    'output_blocks.10': 0.0,
    'output_blocks.11': 0.0,

    # Other components
    'time_embed': 0.5,
    'label_emb': 0.5,
    'model.diffusion': 0.5,

    # Keep these from base model
    'first_stage_model': 0.0,  # VAE
    'transformer_': 0.0,       # Text Encoder
}

def get_component_type(key):
    """Determine which component a key belongs to"""
    if key.startswith('model.diffusion_model'):
        return 'UNET'
    elif key.startswith('first_stage_model'):
        return 'VAE'
    elif key.startswith('transformer_'):
        return 'TEXT_ENCODER'
    else:
        return 'OTHER'

def get_block_weight(key):
    """Get the weight for a specific key based on its block"""
    # Default weight if no specific match
    default_weight = 0.5

    # Check for exact matches first
    for block_key, weight in BLOCK_WEIGHTS.items():
        if block_key in key:
            return weight

    # Component-based fallbacks
    component = get_component_type(key)
    if component == 'VAE':
        return BLOCK_WEIGHTS['first_stage_model']
    elif component == 'TEXT_ENCODER':
        return BLOCK_WEIGHTS['transformer_']

    return default_weight

def merge_tensors(tensor_a, tensor_b, tensor_c, key):
    """Merge tensors with block-specific weights"""
    weight = get_block_weight(key)

    # If weight is 0, keep tensor_a unchanged
    if weight == 0:
        return tensor_a

    # Perform weighted merge
    merged = (1 - weight) * tensor_a + (weight * 0.6) * tensor_b + (weight * 0.4) * tensor_c

    # Ensure the weights don't deviate too far from the original scale
    scale_factor = torch.mean(torch.abs(tensor_a)) / torch.mean(torch.abs(merged))
    merged = merged * scale_factor

    return merged

def save_incrementally(merged_model, path):
    """Save merged tensors incrementally to avoid memory overload"""
    if os.path.exists(path):
        existing_model = load_file(path)
        merged_model.update(existing_model)
    save_file(merged_model, path)
    merged_model.clear()
    gc.collect()
    torch.cuda.empty_cache()

def merge_models(model_a_path, model_b_path, model_c_path, output_path):
    """Main function to merge three models"""
    # Validate model paths
    print("Validating model paths...")
    for path in [model_a_path, model_b_path, model_c_path]:
        if not os.path.exists(path):
            raise FileNotFoundError(f"Model not found at {path}")

    # Load models
    print("Loading models...")
    torch.cuda.empty_cache()
    model_a = load_file(model_a_path)

    torch.cuda.empty_cache()
    model_b = load_file(model_b_path)

    torch.cuda.empty_cache()
    model_c = load_file(model_c_path)

    print(f"Model statistics:")
    print(f"Model A: {len(model_a.keys())} keys")
    print(f"Model B: {len(model_b.keys())} keys")
    print(f"Model C: {len(model_c.keys())} keys")

    # Component tracking
    component_counts = {'UNET': 0, 'VAE': 0, 'TEXT_ENCODER': 0, 'OTHER': 0}

    # Gather all unique keys
    all_keys = set(model_a.keys())

    # Merge process
    print("Merging models incrementally...")
    merged_model = {}
    for key in tqdm(all_keys, desc="Merging keys"):
        # Skip if key doesn't exist in model A
        if key not in model_a:
            continue

        tensor_a = model_a[key]
        tensor_b = model_b.get(key, torch.zeros_like(tensor_a))
        tensor_c = model_c.get(key, torch.zeros_like(tensor_a))

        # Determine component type for tracking
        component_type = get_component_type(key)
        component_counts[component_type] += 1

        # Merge with block-specific weights
        merged_tensor = merge_tensors(tensor_a, tensor_b, tensor_c, key)
        merged_model[key] = merged_tensor.half()

        # Save incrementally
        if len(merged_model) >= 1000:
            save_incrementally(merged_model, output_path)

    # Save any remaining tensors
    if merged_model:
        save_incrementally(merged_model, output_path)

    print("\nMerge statistics by component:")
    for component, count in component_counts.items():
        print(f"{component}: {count} keys processed")

    print(f"\nMerged model saved at {output_path}")

    # Load and check final model
    print("\nValidating merged model...")
    merged_model = load_file(output_path)
    print(f"Final merged model contains {len(merged_model.keys())} keys")

if __name__ == "__main__":
    # Paths to models
    checkpoint_dir = "/content/stable-diffusion-webui/models/Stable-diffusion"
    model_a_path = os.path.join(checkpoint_dir, "modelA.safetensors")
    model_b_path = os.path.join(checkpoint_dir, "modelB.safetensors")
    model_c_path = os.path.join(checkpoint_dir, "modelC.safetensors")
    output_path = os.path.join(checkpoint_dir, "merged_block_weighted.safetensors")

    merge_models(model_a_path, model_b_path, model_c_path, output_path)