In [None]:
import os
import torch
from safetensors.torch import load_file, save_file
from collections import defaultdict
from tqdm import tqdm
import gc

def load_lora(lora_path):
    """Load a LoRA file and organize its layers"""
    if lora_path.endswith('.safetensors'):
        lora_state_dict = load_file(lora_path)
    else:
        lora_state_dict = torch.load(lora_path, map_location='cpu')

    # Organize LoRA keys by type
    lora_layers = defaultdict(dict)
    for key in lora_state_dict.keys():
        if 'lora_down' in key:
            base_key = key.replace('lora_down', '')
            lora_layers[base_key]['down'] = lora_state_dict[key]
        elif 'lora_up' in key:
            base_key = key.replace('lora_up', '')
            lora_layers[base_key]['up'] = lora_state_dict[key]
        elif 'alpha' in key:
            base_key = key.replace('_alpha', '')
            lora_layers[base_key]['alpha'] = lora_state_dict[key]

    return lora_layers

def merge_lora_into_model(model_path, lora_path, output_path, lora_weight=1.0):
    """Merge LoRA weights into a checkpoint"""
    print(f"Loading base model from {model_path}")
    model = load_file(model_path)

    print(f"Loading LoRA from {lora_path}")
    lora_layers = load_lora(lora_path)

    print(f"Starting merge process with weight {lora_weight}")
    merged_model = {}
    modified_count = 0

    # Map LoRA keys to model keys
    lora_to_model = {}
    for model_key in model.keys():
        # Remove common prefixes for matching
        clean_key = model_key.replace('model.diffusion_model.', '')
        clean_key = clean_key.replace('first_stage_model.', '')
        clean_key = clean_key.replace('model.', '')

        for lora_key in lora_layers.keys():
            if clean_key in lora_key:
                lora_to_model[lora_key] = model_key

    # Perform merge
    for key in tqdm(model.keys(), desc="Merging LoRA"):
        if key in merged_model:
            continue

        # Copy original tensor by default
        merged_model[key] = model[key].clone()

        # Check if this layer has a corresponding LoRA
        matching_keys = [k for k, v in lora_to_model.items() if v == key]
        if matching_keys:
            lora_key = matching_keys[0]
            if 'down' in lora_layers[lora_key] and 'up' in lora_layers[lora_key]:
                down = lora_layers[lora_key]['down']
                up = lora_layers[lora_key]['up']
                alpha = lora_layers[lora_key].get('alpha', torch.tensor(1.0))

                # Calculate scale factor
                if torch.is_tensor(alpha):
                    scale = alpha / down.shape[0]
                else:
                    scale = 1.0

                # Compute LoRA contribution
                lora_contribution = torch.mm(up.float(), down.float()) * scale * lora_weight

                # Add to base weights
                if lora_contribution.shape == merged_model[key].shape:
                    merged_model[key] = merged_model[key].float() + lora_contribution
                    modified_count += 1

        # Convert back to half precision
        merged_model[key] = merged_model[key].half()

        # Save incrementally
        if len(merged_model) >= 1000:
            save_file(merged_model, output_path)
            merged_model = {}
            gc.collect()
            torch.cuda.empty_cache()

    # Save remaining tensors
    if merged_model:
        save_file(merged_model, output_path)

    print(f"\nMerge completed:")
    print(f"Modified {modified_count} layers with LoRA weights")
    print(f"Output saved to {output_path}")

    # Validate final model
    print("\nValidating merged model...")
    final_model = load_file(output_path)
    print(f"Final model contains {len(final_model.keys())} keys")

if __name__ == "__main__":
    # Updated paths for your directory structure
    checkpoint_dir = "/content/stable-diffusion-webui-reForge/models/Stable-diffusion"
    lora_dir = "/content/stable-diffusion-webui-reForge/models/Lora"

    # Example paths - replace with your actual filenames
    base_model_path = os.path.join(checkpoint_dir, "your_base_model.safetensors")  # Replace with your base model filename
    lora_path = os.path.join(lora_dir, "your_lora.safetensors")  # Replace with your LoRA filename
    output_path = os.path.join(checkpoint_dir, "merged_with_lora.safetensors")

    # Merge settings
    lora_weight = 0.75  # Adjust this value to control LoRA influence (0.0 to 1.0)

    # Perform merge
    merge_lora_into_model(
        model_path=base_model_path,
        lora_path=lora_path,
        output_path=output_path,
        lora_weight=lora_weight
    )