# N-Average merger #

As known as "one click Uniform Merge".

## Abstract ##

- Self explained. Using `sd-mecha` as main library. **Thank you [@ljleb](https://github.com/ljleb/) for the codebase.**
- **No need to waste 1TB+ of disk space for pariwise merging.** However you should know the "model pool", otherwise it is likely result in a worse model.
- I intentionally make it into Python notebook because I can keep explaining stuffs inplace, like most AI / ML articles. [Base code is available here.](https://github.com/ljleb/sd-mecha/blob/main/examples/n_average.py) ~~I know this is also a nice testing script / example for a library.~~

## Required libraries ##

- `torch>=2.0.1`
- `tensordict`
- `sd-mecha` (I prefer [clone](https://github.com/ljleb/sd-mecha/tree/main) the source code inplace)
- [safetensors](https://huggingface.co/docs/safetensors/index)
- [diffusers](https://huggingface.co/docs/diffusers/installation)
- [pytorch](https://pytorch.org/get-started/locally/#windows-python)

## WTF why and will it work? ##

- Yes. [It is part of my research](./README_XL.md).
- Image comparasion will be listed there.

## Importing libraries ##

In [1]:
# Built-in
import glob 
import os

# Is dependency fufilled?
import torch
import tensordict

In [2]:
torch.__version__

'2.2.0+cu121'

In [3]:
# Import the main module.
import sd_mecha

In [4]:
# Fix for OMP: Error #15
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

I'll disable pruning to let [toolkit](https://github.com/arenasys/stable-diffusion-webui-model-toolkit) support the merged model.

In [5]:
g_device = "cuda:0" #"cpu"
g_prune = False
g_merged_model = "x45a-e2e" #.safetensors
g_precision = 16 #fp16, forwarded from sd-meh

In [6]:
model_folder = "../stable-diffusion-webui/tmp/astolfo_mix/sdxl/_x01/"
model_type = torch.float16 if "cuda" in g_device else torch.float # CPU doesn't support FP16 / FP8

Exploring model inside a folder.

In [7]:
model_list = os.listdir(model_folder)
# Exclude yaml.
model_list = list(filter(lambda p: p.endswith(".ckpt") or p.endswith(".safetensors") or p.endswith(".bin"), model_list))
if len(model_list) < 2:
    raise Exception("Need at least 2 models for merge.")

#model_list = list(map(lambda p: os.path.splitext(os.path.basename(p))[0], model_list))

In [8]:
model_list

['_x01-deepDarkHentaiMixNSFW_v12.safetensors',
 '_x02-animeAntifreezingSolutionXL_v10.safetensors',
 '_x03-hsxl_base_1.0.f16.safetensors',
 '_x04-OpenDalleV1.1.safetensors',
 '_x05-copaxTimelessxlSDXL1_v8.safetensors',
 '_x06-juggernautXL_v8Rundiffusion.safetensors',
 '_x07-kohakuXLBeta_beta7.safetensors',
 '_x08-animagineXLV3_v30.safetensors',
 '_x09-animeboysxl_v10.safetensors',
 '_x10-dreamshaperXL_alpha2Xl10.safetensors',
 '_x11-SDXLRonghua_v40.safetensors',
 '_x12-bluePencilXL_v310.safetensors',
 '_x13-leosamsHelloworldSDXL_helloworldSDXL32DPO.safetensors',
 '_x14-ponyDiffusionV6XL_v6.safetensors',
 '_x15-animagineXL_v20.safetensors',
 '_x16-wdxl-aesthetic-0.9.safetensors',
 '_x17-leosamsHelloworldSDXLModel_helloworldSDXL10.safetensors',
 '_x18-nekoray-xl-1.5m-fp16mixed_e02.safetensors',
 '_x19-nekoray-xl-1.5m-pdg32_e02.safetensors',
 '_x20-explicitFreedomNSFW_beta.safetensors',
 '_x21-nd-run8-weighted-3.safetensors',
 '_x22-kohakuXL_alpha7.safetensors',
 '_x23-nekorayxl_v06W3.saf

Setting up merge receipe and merge scheduler.

In [9]:
models = model_list

merge = models[0]
for i, model in enumerate(models[1:], start=2):
    merge = sd_mecha.weighted_sum(merge, model, alpha=1/i)

scheduler = sd_mecha.MergeScheduler(
    base_dir=model_folder,
    device=g_device,
    prune=g_prune,
    precision=g_precision,
)

Time for action.

In [10]:
scheduler.merge_and_save(merge, output_path=g_merged_model)

stage 1: 100%|██████████| 2515/2515 [00:00<00:00, 3607.55it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 420282.66it/s]
stage 1: 100%|██████████| 2515/2515 [00:00<00:00, 5757.19it/s]
stage 2: 100%|██████████| 2516/2516 [00:00<00:00, 504559.83it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 4879.55it/s]
stage 2: 100%|██████████| 2516/2516 [00:00<00:00, 420416.27it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 4414.17it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 504214.64it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 4349.53it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 420249.18it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 4553.66it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 504311.07it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 5096.42it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 420249.18it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 4052.53it/s]
stage 2: 100%|██████████| 2515/2515 [00:0