# N-Average merger #

As known as "one click Uniform Merge".

## Abstract ##

- Self explained. Using `sd-mecha` as main library. **Thank you [@ljleb](https://github.com/ljleb/) for the codebase.**
- **No need to waste 1TB+ of disk space for pariwise merging.** However you should know the "model pool", otherwise it is likely result in a worse model. It takes around 8 miuntes to merge 40 SDXL models, comparing to 47 minutes on A1111 WebUI.
- VRAM usage: *A lot, will drain up VRAM but no OOM error.*
- I intentionally make it into Python notebook because I can keep explaining stuffs inplace, like most AI / ML articles. [Base code is available here.](https://github.com/ljleb/sd-mecha/blob/main/examples/n_average.py) ~~I know this is also a nice testing script / example for a library.~~

## Required libraries ##

- `torch>=2.0.1`
- `tensordict`
- `sd-mecha` (I prefer [clone](https://github.com/ljleb/sd-mecha/tree/main) the source code inplace)
- [safetensors](https://huggingface.co/docs/safetensors/index)
- [diffusers](https://huggingface.co/docs/diffusers/installation)
- [pytorch](https://pytorch.org/get-started/locally/#windows-python)

## WTF why and will it work? ##

- Yes. [It is part of my research](./README_XL.md).
- Image comparasion will be listed there.

## Importing libraries ##

In [1]:
# Built-in
import time
import os

# Is dependency fufilled?
import torch
import tensordict

In [2]:
torch.__version__

'2.2.0+cu121'

In [3]:
tensordict.__version__

'0.3.0'

In [4]:
# Import the main module.
import sd_mecha

In [5]:
# Fix for OMP: Error #15
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

I'll disable pruning to let [toolkit](https://github.com/arenasys/stable-diffusion-webui-model-toolkit) support the merged model.

In [6]:
g_device = "cuda:0" #"cpu"
g_prune = False
g_merged_model = "x46a-e2e" #.safetensors
g_precision = 16 #fp16, forwarded from sd-meh

In [7]:
model_folder = "../stable-diffusion-webui/tmp/astolfo_mix/sdxl/_x01/"
model_type = torch.float16 if "cuda" in g_device else torch.float # CPU doesn't support FP16 / FP8

Exploring model inside a folder.

In [8]:
model_list = os.listdir(model_folder)
# Exclude yaml.
model_list = list(filter(lambda p: p.endswith(".ckpt") or p.endswith(".safetensors") or p.endswith(".bin"), model_list))
if len(model_list) < 2:
    raise Exception("Need at least 2 models for merge.")

#model_list = list(map(lambda p: os.path.splitext(os.path.basename(p))[0], model_list))

In [9]:
#model_list
print("{} models found.".format(len(model_list)))

47 models found.


Setting up merge receipe and merge scheduler.

In [10]:
models = model_list

merge = models[0]
for i, model in enumerate(models[1:], start=2):
    merge = sd_mecha.weighted_sum(merge, model, alpha=1/i)

scheduler = sd_mecha.MergeScheduler(
    base_dir=model_folder,
    device=g_device,
    prune=g_prune,
    precision=g_precision,
)

Time for action.

In [11]:
ts = time.time()
scheduler.merge_and_save(merge, output_path=g_merged_model)

stage 1: 100%|██████████| 2515/2515 [00:00<00:00, 4009.10it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 280199.61it/s]
stage 1: 100%|██████████| 2515/2515 [00:00<00:00, 5731.19it/s]
stage 2: 100%|██████████| 2516/2516 [00:00<00:00, 360387.57it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 5299.83it/s]
stage 2: 100%|██████████| 2516/2516 [00:00<00:00, 360424.50it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 4108.68it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 360232.03it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 4537.27it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 420282.66it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 4671.72it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 504383.41it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 4724.21it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 504311.07it/s]
stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 3875.15it/s]
stage 2: 100%|██████████| 2515/2515 [00:0

In [12]:
te = time.time()
print("time: {} sec".format(int(te - ts)))

time: 662 sec
