# AstolfoMix E2E merger #

Extensive use of "N-Average" and "Split UNET / TE".

## Abstract ##

- Base code: [n_average.py](https://github.com/ljleb/sd-mecha/blob/main/examples/n_average.py), [split_unet_text_encoder.py](https://github.com/ljleb/sd-mecha/blob/main/examples/split_unet_text_encoder.py) and [serialize_recipe.py](https://github.com/ljleb/sd-mecha/blob/main/examples/serialize_recipe.py)
- Using [sd-mecha](https://github.com/ljleb/sd-mecha) as main library. **Thank you [@ljleb](https://github.com/ljleb/) for the codebase, and [@illyaeater](https://github.com/Enferlain) for the alpha tester.**
- Each generated model will have its own model metadata and `*.mecha` ~~assembly like~~ [recipe](https://github.com/ljleb/sd-mecha/blob/main/examples/recipes/test_split_unet_text_encoder.mecha). Open it with text editor.
- **No need to waste 1TB+ of disk space for pariwise merging and iterlate with WebUI.** However you should know the "model pool", otherwise it is likely result in a worse model. 
- Required time: *More then an hour.*
- CPU usage: **100% with AVX2.**
- RAM usage: *Around 32GB*.
- VRAM usage: *Around 4GB*. 
- Storage usage: $5N+3$ SDXL models, including $N$ raw models. For $N=52$, it will use **1.66TB** for the most efficient approach.
- I intentionally make it into Python notebook because I need to switch mode this time.

## Required libraries ##

- `torch>=2.0.1`
- `tensordict`
- `sd-mecha` (I prefer [clone](https://github.com/ljleb/sd-mecha/tree/main) the source code inplace,current version as on 240222)
- [safetensors](https://huggingface.co/docs/safetensors/index)
- [diffusers](https://huggingface.co/docs/diffusers/installation)
- [pytorch](https://pytorch.org/get-started/locally/#windows-python)

## Model naming schema ##

- `RAW` as `_x01`: Place all raw models. Will generate `x49a` as averaged model regardless components.
- `CLIP` as `_x01te`: Will generate all models as `x49a` replaced with `_x01`'s TE. Will be a set of `te0`, `te1`, `te2`. Use these models for model selection. 
- `UNET` as `x49a-x22te0x31te1`: *Require selected TEs.* Will generate all models as `_01`'s UNET and average of selected `te0` and `te1`. VAE will be `x49a`.
- `FINAL` as `e2e`: Final model as `x43`.

## Recommened directories to make ##

- `raw`: Store the raw $N$ models
- `clip`: Store $3N$ models for CLIP selection
- `unet`: Store $N$ models for UNET selection

## Operation Mode ##

- [`RAW`, `CLIP`, `UNET`, `FINAL`]. Procedure will be *mutually exclusive*. I will keep restarting the whole notebook.

## Limitation ##

- VAE remains unmanaged. It may cause gitched image without specifying VAE in WebUI.
- SDXL models only. I don't need this for SD1 and SD2.
- Safetensors only. 

## WTF why and will it work? ##

- Yes. [It is part of my research](./README_XL.md).
- Image comparasion will be listed there.

## Importing libraries ##

In [1]:
# Built-in
import time
import os

# Is dependency fufilled?
import torch

In [2]:
torch.__version__

'2.2.0+cu121'

In [3]:
# Import the main module.
import sd_mecha

In [4]:
# Fix for OMP: Error #15
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

## User input session starts here. ##

Specify all the paths.

In [5]:
DIR_BASE = "../../stable-diffusion-webui/tmp/astolfo_mix/sdxl/" #To set up merger

DIR_RAW = "raw/" #To load N models
DIR_CLIP = "clip/"  #To write 3N models
DIR_UNET = "unet/" #To write N models
DIR_FINAL = "./" #To write 1 model

Quick check on directory and make the model name prefix.

In [6]:
MODEL_LIST_RAW = os.listdir("{}raw/".format(DIR_BASE))
# Exclude yaml.
MODEL_LIST_RAW = list(filter(lambda p: p.endswith(".safetensors"), MODEL_LIST_RAW)) #p.endswith(".ckpt") or p.endswith(".safetensors") or p.endswith(".bin")
if len(MODEL_LIST_RAW) < 2:
    raise Exception("Need at least 2 models for merge.")

#model_list = list(map(lambda p: os.path.splitext(os.path.basename(p))[0], model_list))

In [7]:
print("{} raw models found.".format(len(MODEL_LIST_RAW)))

52 raw models found.


Define model selection. Index start with 1. Check model list for ordering!
```
te0: --,--,03,04,05,--,08,10,11,--,14,16,17,18,19,20,21,22,23,24,25,26,27,--,32,--,35,36,--,--,--,41,--,43,--,45,46,48,--,--=27
te1: 01,02,03,04,05,06,--,10,--,12,14,16,17,18,19,20,21,22,23,24,25,26,27,30,32,33,35,36,37,38,40,41,42,--,44,45,46,48,49,50=37
te2: 01,--,03,04,05,06,--,10,--,--,--,16,17,18,19,20,21,22,23,24,25,--,27,--,32,33,35,36,37,38,40,41,--,--,--,--,--,--,--,--=25
=sd: --,--,03,--,--,--,--,10,--,12,--,16,--,18,19,--,--,--,--,--,25,--,--,--,--,--,--,--,--,--,--,--,--,--,--,--,--,--,--,--=7

te0: 04,05,08,11,14,17,20,21,22,23,24,25,26,27,32,35,36,41,43,45,46,48=22
te1: 01,02,04,05,06,14,17,20,21,22,23,24,25,26,27,30,32,33,35,36,37,38,40,41,42,44,45,46,48,49,50=31

unet: 01,02,--,04,05,06,07,08,09,10,11,12,13,--,15,--,17,18,19,20,21,22,23,--,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,--,49,--=44

```

In [8]:
MODEL_SELECTION_TE0 = [4,5,8,11,14,17,20,21,22,23,24,25,26,27,32,35,36,41,43,45,46,48]
MODEL_SELECTION_TE1 = [1,2,4,5,6,14,17,20,21,22,23,24,25,26,27,30,32,33,35,36,37,38,40,41,42,44,45,46,48,49,50]
MODEL_SELECTION_UNET = [1,2,4,5,6,7,8,9,10,11,12,13,15,17,18,19,20,21,22,23,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,49]

In [9]:
print("TE0:{},TE1:{},UNET:{}".format(len(MODEL_SELECTION_TE0),len(MODEL_SELECTION_TE1),len(MODEL_SELECTION_UNET)))

TE0:22,TE1:31,UNET:44


Specify all the keywords (I'll avoid hardcode because they will be everywhere)

In [10]:
MODE_RAW = 'MODE_RAW'
MODE_CLIP = 'MODE_CLIP'
MODE_UNET = 'MODE_UNET'
MODE_FINAL = 'MODE_FINAL'

MODE_ACTIVATED = [MODE_RAW] #[MODE_RAW,MODE_CLIP,MODE_UNET,MODE_FINAL]

Insert version number, and the... *"AstolfoMix"*.

In [11]:
MODEL_NAME_SUFFIX = "e2e-240222-60d0764" #e2e-yymmdd-commit
MODEL_NAME_KEYWORD = "AstolfoMix"

Change if your PC is in trouble.

My PC: i9-7960X, X299 Dark, 128GB DDR4, 2x RTX3090, P4510. Overkill for a merger.

In [12]:
g_device = "cuda:0" #"cpu"
g_precision_while_merge = torch.float64 if "cuda" in g_device else torch.float #I have RAM
g_precision_final_model = torch.float16 if "cuda" in g_device else torch.float #fp16
g_total_buffer_size=2**32 #4GB

## User input shuold ends here. ##

Define output model name. I want to keep the format, however I need to manage the name manually.

In [13]:
MODEL_NAME_RAW_PREFIX = "x{}a".format(len(MODEL_LIST_RAW)-1) #x49a

MODEL_NAME_TE = "x{}te0x{}te1".format(len(MODEL_SELECTION_TE0)-1,len(MODEL_SELECTION_TE1)-1) #x22te0x31te1
MODEL_NAME_FINAL_PREFIX = "x{}".format(len(MODEL_SELECTION_UNET)-1) #x43

MODEL_NAME_RAW = "{}-{}-{}".format(MODEL_NAME_RAW_PREFIX,MODEL_NAME_KEYWORD,MODEL_NAME_SUFFIX) #x49a-AstolfoMix-e2e-240222-60d0764
MODEL_NAME_TE_ITR = "{}-{}-{}{}-{}".format(MODEL_NAME_RAW_PREFIX,MODEL_NAME_KEYWORD,"{}","{}",MODEL_NAME_SUFFIX) #x49a-AstolfoMix-_x01te0-e2e-240222-60d0764
MODEL_NAME_TE_FINAL = "{}-{}-{}-{}".format(MODEL_NAME_RAW_PREFIX,MODEL_NAME_KEYWORD,MODEL_NAME_TE,MODEL_NAME_SUFFIX) #x49a-AstolfoMix-x22te0x31te1-e2e-240222-60d0764
MODEL_NAME_UNET_ITR = "{}-{}-{}-{}".format("{}",MODEL_NAME_KEYWORD,MODEL_NAME_TE,MODEL_NAME_SUFFIX) #_x01a-AstolfoMix-x22te0x31te1-e2e-240222-60d0764
MODEL_NAME_FINAL = "{}-{}-{}-{}".format(MODEL_NAME_FINAL_PREFIX,MODEL_NAME_KEYWORD,MODEL_NAME_TE,MODEL_NAME_SUFFIX) #x43-AstolfoMix-x22te0x31te1-e2e-240222-60d0764

In [14]:
print("Naive average model:                     {}".format(MODEL_NAME_RAW))
print("CLIP models to iterlate:                 {}".format(MODEL_NAME_TE_ITR))
print("Naive average model with selected CLIP:  {}".format(MODEL_NAME_TE_FINAL))
print("UNET models to iterlate:                 {}".format(MODEL_NAME_UNET_ITR))
print("Final merged model:                      {}".format(MODEL_NAME_FINAL))

Naive average model:                     x51a-AstolfoMix-e2e-240222-60d0764
CLIP models to iterlate:                 x51a-AstolfoMix-{}{}-e2e-240222-60d0764
Naive average model with selected CLIP:  x51a-AstolfoMix-x21te0x30te1-e2e-240222-60d0764
UNET models to iterlate:                 {}-AstolfoMix-x21te0x30te1-e2e-240222-60d0764
Final merged model:                      x43-AstolfoMix-x21te0x30te1-e2e-240222-60d0764


## Setting up merge receipe and merge scheduler ##

- I'm still a bit panic about hardcoding. Getter / Setter will be fine. ~~No, you won't see OOP in python notebook.~~
- Will always run. `MODE_ACTIVATED` controls actual merge process only.

In [15]:
def rmk_raw():
    return 'RAW'
def rmk_ste():
    return 'SELECTED_TE'
def rmk_f():
    return 'FINAL'
def rmk_te(i,j):
    return 'CLIP{}_TE{}'.format(i,j) #CLIP1TE0
def rmk_unet(i):
    return 'UNET{}'.format(i)  #UNET1

In [16]:
recipe_mapping = {}

def set_rmk(k, v):
    recipe_mapping[k] = v

In [17]:
def reset_rm():
    set_rmk(rmk_raw(), None)
    set_rmk(rmk_ste(), None)
    set_rmk(rmk_f(), None)

    for i in range(len(MODEL_LIST_RAW)):
        set_rmk(rmk_unet(i+1), None)
        for j in range(3):
            set_rmk(rmk_te(i+1,j), None)

In [18]:
reset_rm()

Single merger should be fine.

In [19]:
scheduler = sd_mecha.RecipeMerger(
    models_dir=DIR_BASE,
    default_device=g_device,
    default_dtype=g_precision_while_merge,
)

Define recipe extension.

In [20]:
RECIPE_EXT = ".mecha"

### Naive Average ###

- Pay attention to the `alpha`. It is opposite to A1111: It is "A merge to B" instead of "Merge A with B".
- Also the receipe is set of `RecipeNode` under [tree structure](https://en.wikipedia.org/wiki/Tree_(data_structure)). Therefore you can expect there are quite a lot of recursive stuffs (returning iteself).

In [21]:
def make_recipe_naive_merge():
    models = list(map(lambda p: "{}{}".format(DIR_RAW,p),MODEL_LIST_RAW))

    recipe = models[0]
    for i, model in enumerate(models[1:], start=2):
        recipe = sd_mecha.weighted_sum(model, recipe, alpha=(i-1)/i, dtype=g_precision_while_merge, device=g_device)

    return recipe

In [22]:
set_rmk(rmk_raw(), make_recipe_naive_merge())
sd_mecha.serialize_and_save(recipe_mapping[rmk_raw()], "{}{}{}".format(DIR_BASE,MODEL_NAME_RAW, RECIPE_EXT))

### CLIP Models to test ###

In [23]:
# coming soon

## Time for action ##

In [24]:
ts = time.time()

### Naive Average ###

In [25]:
if MODE_RAW in MODE_ACTIVATED:
    if os.path.isfile("{}{}".format(DIR_BASE,MODEL_NAME_RAW)):
        print("Merged model is present. Skipping.")
    else:
        tss = time.time()
        scheduler.merge_and_save(recipe_mapping[rmk_raw()], output_path="{}{}".format(DIR_FINAL,MODEL_NAME_RAW), save_dtype=g_precision_final_model, total_buffer_size=g_total_buffer_size)
        tse = time.time()
        print("Merge time: {} sec".format(int(tse - tss)))    
else:
    print("This session is not activated. Skipping.")

Merging recipe: 100%|██████████| 2520/2520 [12:39<00:00,  3.32it/s, key=conditioner.embedders.1.model.text_projection.weight, shape=[1280, 1280]]                                  

Merge time: 767 sec





Full operation time.

In [26]:
te = time.time()
print("Total time: {} sec".format(int(te - ts)))

Total time: 767 sec
