In [None]:
import os
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import gc

def get_component_type(key):
    """Determine which component a key belongs to"""
    if key.startswith('model.diffusion_model'):
        return 'UNET'
    elif key.startswith('first_stage_model'):
        return 'VAE'
    elif key.startswith('transformer_'):
        return 'TEXT_ENCODER'
    else:
        return 'OTHER'

def merge_tensors(tensor_a, tensor_b, tensor_c, alpha, beta, component_type):
    """Merge tensors with component-specific logic"""
    if component_type == 'VAE':
        # Keep VAE from model A unchanged
        return tensor_a
    elif component_type == 'TEXT_ENCODER':
        # Keep text encoder from model A unchanged
        return tensor_a
    else:
        # For UNET and other components, use weighted average
        # This is a more conservative approach that ensures weights stay in a valid range
        weighted_b = (1 - alpha) * tensor_a + alpha * tensor_b
        weighted_c = (1 - beta) * tensor_a + beta * tensor_c

        # Take the weighted average of the results
        merged = 0.5 * (weighted_b + weighted_c)

        # Ensure the weights don't deviate too far from the original scale
        scale_factor = torch.mean(torch.abs(tensor_a)) / torch.mean(torch.abs(merged))
        merged = merged * scale_factor

        return merged

def save_incrementally(merged_model, path):
    """Save merged tensors incrementally to avoid memory overload"""
    if os.path.exists(path):
        existing_model = load_file(path)
        merged_model.update(existing_model)
    save_file(merged_model, path)
    merged_model.clear()
    gc.collect()
    torch.cuda.empty_cache()

# Paths to models
checkpoint_dir = "/content/stable-diffusion-webui-reForge/models/Stable-diffusion"
model_a_path = os.path.join(checkpoint_dir, "waiNSFWIllustrious_v80.safetensors")
model_b_path = os.path.join(checkpoint_dir, "hassakuXLIllustrious_v12Style.safetensors")
model_c_path = os.path.join(checkpoint_dir, "illustriousXL_smoothftSOLID.safetensors")
output_path = os.path.join(checkpoint_dir, "merged_model.safetensors")

# Use more conservative merge parameters
alpha = 0.36
beta = 0.25

# Validate model paths
print("Validating model paths...")
for path in [model_a_path, model_b_path, model_c_path]:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Model not found at {path}")

# Load models
print("Loading models...")
torch.cuda.empty_cache()
model_a = load_file(model_a_path)

torch.cuda.empty_cache()
model_b = load_file(model_b_path)

torch.cuda.empty_cache()
model_c = load_file(model_c_path)

print(f"Model statistics:")
print(f"Model A: {len(model_a.keys())} keys")
print(f"Model B: {len(model_b.keys())} keys")
print(f"Model C: {len(model_c.keys())} keys")

# Component tracking
component_counts = {'UNET': 0, 'VAE': 0, 'TEXT_ENCODER': 0, 'OTHER': 0}

# Gather all unique keys
all_keys = set(model_a.keys())

# Merge process
print("Merging models incrementally...")
merged_model = {}
for key in tqdm(all_keys, desc="Merging keys"):
    # Skip if key doesn't exist in model A
    if key not in model_a:
        continue

    tensor_a = model_a[key]
    tensor_b = model_b.get(key, torch.zeros_like(tensor_a))
    tensor_c = model_c.get(key, torch.zeros_like(tensor_a))

    # Determine component type
    component_type = get_component_type(key)
    component_counts[component_type] += 1

    # Merge with component-specific logic
    merged_tensor = merge_tensors(tensor_a, tensor_b, tensor_c, alpha, beta, component_type)
    merged_model[key] = merged_tensor.half()

    # Save incrementally
    if len(merged_model) >= 1000:
        save_incrementally(merged_model, output_path)

# Save any remaining tensors
if merged_model:
    save_incrementally(merged_model, output_path)

print("\nMerge statistics by component:")
for component, count in component_counts.items():
    print(f"{component}: {count} keys processed")

print(f"\nMerged model saved at {output_path}")

# Load and check final model
print("\nValidating merged model...")
merged_model = load_file(output_path)
print(f"Final merged model contains {len(merged_model.keys())} keys")