# AstolfoMix E2E merger #

Extensive use of "N-Average" and "Split UNET / TE".

## Abstract ##

- Base code: [n_average.py](https://github.com/ljleb/sd-mecha/blob/main/examples/n_average.py), [split_unet_text_encoder.py](https://github.com/ljleb/sd-mecha/blob/main/examples/split_unet_text_encoder.py) and [serialize_recipe.py](https://github.com/ljleb/sd-mecha/blob/main/examples/serialize_recipe.py)
- Using [sd-mecha](https://github.com/ljleb/sd-mecha) as main library. **Thank you [@ljleb](https://github.com/ljleb/) for the codebase, and [@illyaeater](https://github.com/Enferlain) for the alpha tester.**
- Each generated model will have its own model metadata and `*.mecha` ~~assembly like~~ [recipe](https://github.com/ljleb/sd-mecha/blob/main/examples/recipes/test_split_unet_text_encoder.mecha). Open it with text editor.
- **No need to waste 1TB+ of disk space for pariwise merging and iterlate with WebUI.** However you should know the "model pool", otherwise it is likely result in a worse model. 
- Required time: 25 (RAW) + 150 * 4 (TE) + 25 (SELECTED_TE) + 50 * 4 (UNET) + 25 (FINAL) + 45 (E2E) minutes = **almost 16 hours** for 52 models
- CPU usage: **100% with AVX2.**
- RAM usage: *Around 32GB*.
- VRAM usage: *Around 4GB*. 
- Storage usage: $5N+3$ SDXL models, including $N$ raw models. For $N=52$, it will use **1.66TB** for the most efficient approach.
- I intentionally make it into Python notebook because I need to switch mode this time.

## Required libraries ##

- `torch>=2.0.1`
- `tensordict`
- `sd-mecha` (I prefer [clone](https://github.com/ljleb/sd-mecha/tree/main) the source code inplace,current version as on 240222, commit `afdab8b003730f58b9127228ef68b0014a3c487d`)
- [safetensors](https://huggingface.co/docs/safetensors/index)
- [diffusers](https://huggingface.co/docs/diffusers/installation)
- [pytorch](https://pytorch.org/get-started/locally/#windows-python)

## Model naming schema ##

- `RAW` as `_x01`: Place all raw models. Will generate `x51a` as averaged model regardless components.
- `CLIP` as `_x01te`: Will generate all models as `x51a` replaced with `_x01`'s TE. Will be a set of `te0`, `te1`, `te2`. Use these models for model selection. 
- `UNET` as `x51a-x39te0x39te1`: *Require selected TEs.* Will generate all models as `_01`'s UNET and average of selected `te0` and `te1`. VAE will be `x51a`.
- `FINAL` as `e2e`: Final model as `x45`.

## Recommened directories to make ##

- `raw`: Store the raw $N$ models
- `clip`: Store $3N$ models for CLIP selection
- `unet`: Store $N$ models for UNET selection

## Operation Mode ##

- [`RAW`, `CLIP`, `UNET`, `FINAL`]. Procedure will be *mutually exclusive*. I will keep restarting the whole notebook.

## Limitation ##

- ~~VAE remains unmanaged.~~ VAE can be picked from one of the raw models.
- SDXL models only. I don't need this for SD1 and SD2.
- Safetensors only. 

## WTF why and will it work? ##

- Yes. [It is part of my research](./README_XL.md).
- Image comparasion will be listed there.

## Appendix: Pseudocode of sd-mecha ##

- Note that the core concept is different from WebUI or supermerger. It focus on [serialization](https://www.geeksforgeeks.org/serialization-in-java/), along with *multiple merging methods* and *custom applied areas*.

- It will pick `model_b_as_recipe` for every merge key and `model_a` for every passthroguh key

- Sample code: `model_b_as_recipe = sd_mecha.merge_methods(model_a, model_b_as_recipe, alpha, beta, etc) #returns model_a`

- For example, in `n_average`, `alpha` tends to `1`, instead of `0` in WebUI. 

- Also, `pick_vae` will show a special case on "bake VAE": `model_a_instead = sd_mecha.merge_methods(model_a, model_b_as_recipe, alpha=1) #returns model_a`

Algorithm `SD-MECHA`:

------

- Let

$\{model_A, model_B\, model_C\} \in models$ and $arch_{models} \in arch_{SD}$ and is consistant (i.e. $arch_{model_A}=arch_{model_B}$)

$\{SumAverage,AddDiff,Rotate,ReBasin, etc.\} \in merge$

$\{CLIP, UNET, VAE\} \in models$, but $\{CLIP, UNET\} \in \alpha, \{VAE\} \notin \alpha$

$ \alpha = [0,1] , \alpha=0 \implies model_A, \alpha=1 \implies model_B$

- Repeat:

$model_A, model_B, merge, \alpha, \beta, etc. \leftarrow deserialize(recipe)$ or user defined

$model_A \leftarrow merge(model_A, model_B, \alpha, \beta, etc.)$

$model_B \leftarrow model_A$

$recipe \leftarrow serialize(model_B)$

- Until no more $model_B$

- Return $recipe$

------

## Importing libraries ##

In [1]:
# Built-in
import time
import os
import math

# Is dependency fufilled?
import torch

from tqdm import tqdm

In [2]:
torch.__version__

'2.2.0+cu121'

In [3]:
# Import the main module.
import sd_mecha

In [4]:
# Fix for OMP: Error #15
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

## User input session starts here. ##

Specify all the paths.

In [5]:
DIR_BASE = "../../stable-diffusion-webui/tmp/astolfo_mix/sdxl/" #To set up merger

DIR_RAW = "raw/" #To load N models
DIR_CLIP = "clip/"  #To write 3N models
DIR_UNET = "unet/" #To write N models
DIR_FINAL = "./" #To write 1 model

Quick check on directory and make the model name prefix.

In [6]:
MECHA_RECIPE_EXT = ".mecha"
MECHA_MODEL_EXT = ".safetensors"

MODEL_LIST_RAW = os.listdir("{}{}".format(DIR_BASE,DIR_RAW))
# Exclude yaml.
MODEL_LIST_RAW = list(filter(lambda p: p.endswith(MECHA_MODEL_EXT), MODEL_LIST_RAW)) #p.endswith(".ckpt") or p.endswith(".safetensors") or p.endswith(".bin")
if len(MODEL_LIST_RAW) < 2:
    raise Exception("Need at least 2 models for merge.")
#model_list = list(map(lambda p: os.path.splitext(os.path.basename(p))[0], model_list))

In [7]:
print("{} raw models found.".format(len(MODEL_LIST_RAW)))

52 raw models found.


Define model selection. Index start with 1. Check model list for ordering!
```
te0: --,--,07,--,--,--,--,15,--,--,--,--,29,--,--,--,--,40,--,44,48,--=-6
te1: --,--,--,09,--,--,--,15,--,--,--,--,29,--,31,34,--,--,--,--,48,--=-6
te2: --,06,--,09,--,11,12,15,--,--,--,--,29,30,31,34,38,40,42,44,48,49=-15
=sd: 03,--,--,--,10,--,12,--,16,18,19,25=-7

te0: 03,07,10,12,15,16,18,19,29,40,44,48=-12
te1: 03,09,10,12,15,16,18,19,29,31,34,48=-12
unet: 03,09,14,24,48,50=-6
```

In [8]:
MODEL_SELECTION_TE0 = [i+1 for i in range(len(MODEL_LIST_RAW)) if i+1 not in [3,7,10,12,15,16,18,19,29,40,44,48]] 
MODEL_SELECTION_TE1 = [i+1 for i in range(len(MODEL_LIST_RAW)) if i+1 not in [3,9,10,12,15,16,18,19,29,31,34,48]] 
MODEL_SELECTION_UNET = [i+1 for i in range(len(MODEL_LIST_RAW)) if i+1 not in [3,9,14,24,48,50]]

#25 is the Original SDXL
MODEL_SELECTION_VAE = 25

In [9]:
print("TE0:{},TE1:{},UNET:{}".format(len(MODEL_SELECTION_TE0),len(MODEL_SELECTION_TE1),len(MODEL_SELECTION_UNET)))

TE0:40,TE1:40,UNET:46


Specify all the keywords (I'll avoid hardcode because they will be everywhere)

In [10]:
MODE_RAW = 'MODE_RAW'
MODE_CLIP = 'MODE_CLIP'
MODE_UNET = 'MODE_UNET'
MODE_FINAL = 'MODE_FINAL'

MODE_ACTIVATED = [MODE_RAW,MODE_CLIP,MODE_UNET,MODE_FINAL] #[MODE_RAW,MODE_CLIP,MODE_UNET,MODE_FINAL]

Insert version number, and the... *"AstolfoMix"*.

In [11]:
MODEL_NAME_SUFFIX = "240222-60d0764" #yymmdd-commit
MODEL_NAME_KEYWORD = "AstolfoMix"

Change if your PC is in trouble.

My PC: i9-7960X, X299 Dark, 128GB DDR4, 2x RTX3090, P4510. Overkill for a merger.

In [12]:
g_device = "cuda:0" #"cpu"
g_precision_while_merge = torch.float64 if "cuda" in g_device else torch.float #I have RAM
g_precision_final_model = torch.float16 if "cuda" in g_device else torch.float #fp16
g_total_buffer_size=2**32 #4GB

## User input shuold ends here. ##

Define output model name. I want to keep the format, however I need to manage the name manually.

In [13]:
FORMAT_BYPASS = "{}"

In [14]:
# Auto zfill under total model count
def az(n):
    return str(n).zfill(math.ceil(math.log10(len(MODEL_LIST_RAW))))

In [15]:
N_CLIP = ("te0","te1","te2")
N_RAW = "_x{}"
N_ITR = "{}a"
N_PICKED = "x{}"

MODEL_NAME_RAW_PREFIX = (N_ITR.format(N_PICKED)).format(az(len(MODEL_LIST_RAW)-1)) #x49a

MODEL_NAME_TE = ("{}{}{}{}".format(N_PICKED,N_CLIP[0],N_PICKED,N_CLIP[1])).format(az(len(MODEL_SELECTION_TE0)-1),az(len(MODEL_SELECTION_TE1)-1)) #x22te0x31te1
MODEL_NAME_FINAL_PREFIX = N_PICKED.format(az(len(MODEL_SELECTION_UNET)-1)) #x43

MODEL_NAME_RAW = "{}-{}-{}".format(MODEL_NAME_RAW_PREFIX,MODEL_NAME_KEYWORD,MODEL_NAME_SUFFIX) #x49a-AstolfoMix-e2e-240222-60d0764
MODEL_NAME_TE_ITR = "{}-{}-{}{}-{}".format(MODEL_NAME_RAW_PREFIX,MODEL_NAME_KEYWORD,FORMAT_BYPASS,FORMAT_BYPASS,MODEL_NAME_SUFFIX) #x49a-AstolfoMix-_x01te0-e2e-240222-60d0764
MODEL_NAME_SELECTED_TE = "{}-{}-{}-{}".format(MODEL_NAME_RAW_PREFIX,MODEL_NAME_KEYWORD,MODEL_NAME_TE,MODEL_NAME_SUFFIX) #x49a-AstolfoMix-x22te0x31te1-e2e-240222-60d0764
MODEL_NAME_UNET_ITR = "{}-{}-{}-{}".format(FORMAT_BYPASS,MODEL_NAME_KEYWORD,MODEL_NAME_TE,MODEL_NAME_SUFFIX) #_x01a-AstolfoMix-x22te0x31te1-e2e-240222-60d0764
MODEL_NAME_FINAL = "{}-{}-{}-{}".format(MODEL_NAME_FINAL_PREFIX,MODEL_NAME_KEYWORD,MODEL_NAME_TE,MODEL_NAME_SUFFIX) #x43-AstolfoMix-x22te0x31te1-e2e-240222-60d0764
MODEL_NAME_E2E = "{}-{}-{}-e2e-{}".format(MODEL_NAME_FINAL_PREFIX,MODEL_NAME_KEYWORD,MODEL_NAME_TE,MODEL_NAME_SUFFIX) #x43-AstolfoMix-x22te0x31te1-e2e-240222-60d0764

In [16]:
print("Naive average model:                     {}".format(MODEL_NAME_RAW))
print("CLIP models to iterlate:                 {}".format(MODEL_NAME_TE_ITR))
print("Naive average model with selected CLIP:  {}".format(MODEL_NAME_SELECTED_TE))
print("UNET models to iterlate:                 {}".format(MODEL_NAME_UNET_ITR))
print("Final merged model (staged):             {}".format(MODEL_NAME_FINAL))
print("Final merged model (e2e):                {}".format(MODEL_NAME_E2E))

Naive average model:                     x51a-AstolfoMix-240222-60d0764
CLIP models to iterlate:                 x51a-AstolfoMix-{}{}-240222-60d0764
Naive average model with selected CLIP:  x51a-AstolfoMix-x39te0x39te1-240222-60d0764
UNET models to iterlate:                 {}-AstolfoMix-x39te0x39te1-240222-60d0764
Final merged model (staged):             x45-AstolfoMix-x39te0x39te1-240222-60d0764
Final merged model (e2e):                x45-AstolfoMix-x39te0x39te1-e2e-240222-60d0764


## Setting up merge receipe and merge scheduler ##

- I'm still a bit panic about hardcoding. Getter / Setter will be fine. ~~No, you won't see OOP in python notebook.~~
- Will always run. `MODE_ACTIVATED` controls actual merge process only.

In [17]:
def rmk_raw():
    return 'RAW'
def rmk_ste():
    return 'SELECTED_TE'
def rmk_f():
    return 'FINAL'
def rmk_e2e():
    return 'E2E'
def rmk_te(i,j):
    return 'CLIP{}_TE{}'.format(i,j) #CLIP1TE0
def rmk_unet(i):
    return 'UNET{}'.format(i)  #UNET1

In [18]:
recipe_mapping = {}

def set_rmk(k, v):
    recipe_mapping[k] = v

In [19]:
def reset_rm():
    set_rmk(rmk_raw(), None)
    set_rmk(rmk_ste(), None)
    set_rmk(rmk_f(), None)

    for i in range(len(MODEL_LIST_RAW)):
        set_rmk(rmk_unet(i+1), None)
        for j in range(3):
            set_rmk(rmk_te(i+1,j), None)

In [20]:
reset_rm()

Single merger should be fine.

In [21]:
scheduler = sd_mecha.RecipeMerger(
    models_dir=DIR_BASE,
    default_device=g_device,
    default_dtype=g_precision_while_merge,
)

Define recipe extension, and make the model output path (Note that it is still being formatted)

In [22]:
OS_MODEL_PATH_RAW = "{}{}{}".format(DIR_BASE,MODEL_NAME_RAW,MECHA_MODEL_EXT)
RECIPE_PATH_RAW = "{}{}{}".format(DIR_BASE,MODEL_NAME_RAW,MECHA_RECIPE_EXT)
MECHA_MODEL_PATH_RAW = "{}{}".format(DIR_FINAL,MODEL_NAME_RAW)

OS_MODEL_PATH_TE_ITR = "{}{}{}{}".format(DIR_BASE,DIR_CLIP,MODEL_NAME_TE_ITR,MECHA_MODEL_EXT)
RECIPE_PATH_TE_ITR = "{}{}{}{}".format(DIR_BASE,DIR_CLIP,MODEL_NAME_TE_ITR,MECHA_RECIPE_EXT)
MECHA_MODEL_PATH_TE_ITR =  "{}{}{}".format(DIR_FINAL,DIR_CLIP,MODEL_NAME_TE_ITR)

OS_MODEL_PATH_SELECTED_TE = "{}{}{}".format(DIR_BASE,MODEL_NAME_SELECTED_TE,MECHA_MODEL_EXT)
RECIPE_PATH_SELECTED_TE = "{}{}{}".format(DIR_BASE,MODEL_NAME_SELECTED_TE,MECHA_RECIPE_EXT)
MECHA_MODEL_PATH_SELECTED_TE =  "{}{}".format(DIR_FINAL,MODEL_NAME_SELECTED_TE)

OS_MODEL_PATH_UNET_ITR = "{}{}{}{}".format(DIR_BASE,DIR_UNET,MODEL_NAME_UNET_ITR,MECHA_MODEL_EXT)
RECIPE_PATH_UNET_ITR = "{}{}{}{}".format(DIR_BASE,DIR_UNET,MODEL_NAME_UNET_ITR,MECHA_RECIPE_EXT)
MECHA_MODEL_PATH_UNET_ITR =  "{}{}{}".format(DIR_FINAL,DIR_UNET,MODEL_NAME_UNET_ITR)

OS_MODEL_PATH_FINAL = "{}{}{}".format(DIR_BASE,MODEL_NAME_FINAL,MECHA_MODEL_EXT)
RECIPE_PATH_FINAL = "{}{}{}".format(DIR_BASE,MODEL_NAME_FINAL,MECHA_RECIPE_EXT)
MECHA_MODEL_PATH_FINAL = "{}{}".format(DIR_FINAL,MODEL_NAME_FINAL)

OS_MODEL_PATH_E2E = "{}{}{}".format(DIR_BASE,MODEL_NAME_E2E,MECHA_MODEL_EXT)
RECIPE_PATH_E2E = "{}{}{}".format(DIR_BASE,MODEL_NAME_E2E,MECHA_RECIPE_EXT)
MECHA_MODEL_PATH_E2E = "{}{}".format(DIR_FINAL,MODEL_NAME_E2E)

In [23]:
# Better test the ugly full file path
def get_te_itr_path(s: str,i: int,j: int):
    return s.format(N_RAW.format(az(i+1)),N_CLIP[j])

def get_unet_itr_path(s: str,i: int):
    return s.format((N_ITR.format(N_PICKED)).format(az(i+1)))

In [24]:
print("Sample TE_ITR recipe path:   {}".format(get_te_itr_path(RECIPE_PATH_TE_ITR, 1, 1)))
print("Sample TE_ITR model path:    {}".format(get_te_itr_path(OS_MODEL_PATH_TE_ITR, 1, 1)))
print("Sample UNET_ITR recipe path: {}".format(get_unet_itr_path(RECIPE_PATH_UNET_ITR, 1)))
print("Sample UNET_ITR model path:  {}".format(get_unet_itr_path(OS_MODEL_PATH_UNET_ITR, 1)))
print("Does RAW model exists:       {}".format(os.path.isfile(OS_MODEL_PATH_RAW)))
print("Final model path:            {}".format(MECHA_MODEL_PATH_FINAL))

Sample TE_ITR recipe path:   ../../stable-diffusion-webui/tmp/astolfo_mix/sdxl/clip/x51a-AstolfoMix-_x02te1-240222-60d0764.mecha
Sample TE_ITR model path:    ../../stable-diffusion-webui/tmp/astolfo_mix/sdxl/clip/x51a-AstolfoMix-_x02te1-240222-60d0764.safetensors
Sample UNET_ITR recipe path: ../../stable-diffusion-webui/tmp/astolfo_mix/sdxl/unet/x02a-AstolfoMix-x39te0x39te1-240222-60d0764.mecha
Sample UNET_ITR model path:  ../../stable-diffusion-webui/tmp/astolfo_mix/sdxl/unet/x02a-AstolfoMix-x39te0x39te1-240222-60d0764.safetensors
Does RAW model exists:       True
Final model path:            ./x45-AstolfoMix-x39te0x39te1-240222-60d0764


### Pick VAE ###
- It will pick `cur_model` for every merge key and `vae_model` for every passthroguh key

In [25]:
def pick_vae(cur_model):
    vae_model =  "{}{}".format(DIR_RAW,MODEL_LIST_RAW[MODEL_SELECTION_VAE - 1]) 
    return sd_mecha.weighted_sum(vae_model, cur_model, alpha=1.0, dtype=g_precision_while_merge, device=g_device)

### Naive Average ###

- Pay attention to the `alpha`. It is opposite to A1111: It is "A merge to B" instead of "Merge A with B".
- Also the receipe is set of `RecipeNode` under [tree structure](https://en.wikipedia.org/wiki/Tree_(data_structure)). Therefore you can expect there are quite a lot of recursive stuffs (returning iteself).

In [26]:
def make_recipe_naive_merge():
    models = list(map(lambda p: "{}{}".format(DIR_RAW,p),MODEL_LIST_RAW))

    recipe = models[0]
    for i, model in enumerate(models[1:], start=2):
        recipe = sd_mecha.weighted_sum(model, recipe, alpha=(i-1)/i, dtype=g_precision_while_merge, device=g_device)

    return pick_vae(recipe)

In [27]:
set_rmk(rmk_raw(), make_recipe_naive_merge())
sd_mecha.serialize_and_save(recipe_mapping[rmk_raw()], RECIPE_PATH_RAW)

### CLIP Models to test ###
- Note that `alpha` is using a special operator `|` which is "Bitwise OR".
- Also recall "TE0 use ViT-G" and "TE1 use ViT-L" and "TE2 use both"

In [28]:
def make_recipe_te_itr(p):
    clip_model = "{}{}".format(DIR_RAW, p)
    unet_model = "{}{}".format(MECHA_MODEL_PATH_RAW, MECHA_MODEL_EXT)
    recipe_te0 = sd_mecha.weighted_sum(clip_model, unet_model, alpha=(sd_mecha.sdxl_txt_classes(1.0) | sd_mecha.sdxl_txt_g14_classes(0.0) | sd_mecha.sdxl_unet_classes(1.0)), dtype=g_precision_while_merge, device=g_device)
    recipe_te1 = sd_mecha.weighted_sum(clip_model, unet_model, alpha=(sd_mecha.sdxl_txt_classes(0.0) | sd_mecha.sdxl_txt_g14_classes(1.0) | sd_mecha.sdxl_unet_classes(1.0)), dtype=g_precision_while_merge, device=g_device)
    recipe_te2 = sd_mecha.weighted_sum(clip_model, unet_model, alpha=(sd_mecha.sdxl_txt_classes(0.0) | sd_mecha.sdxl_txt_g14_classes(0.0) | sd_mecha.sdxl_unet_classes(1.0)), dtype=g_precision_while_merge, device=g_device)
    return (pick_vae(recipe_te0), pick_vae(recipe_te1), pick_vae(recipe_te2))

In [29]:
# 3N models
for i in range(len(MODEL_LIST_RAW)):
    rte = make_recipe_te_itr(MODEL_LIST_RAW[i])
    for j in range(len(N_CLIP)):    
        set_rmk(rmk_te(i+1,j), rte[j])
        sd_mecha.serialize_and_save(recipe_mapping[rmk_te(i+1,j)], get_te_itr_path(RECIPE_PATH_TE_ITR,i,j))

### Picked CLIP Model ###

In [30]:
def make_recipe_ste():
    raw_models = list(map(lambda p: "{}{}".format(DIR_RAW,p),MODEL_LIST_RAW))
    models_te0 = [raw_models[i-1] for i in MODEL_SELECTION_TE0]
    models_te1 = [raw_models[i-1] for i in MODEL_SELECTION_TE1]

    unet_model = "{}{}".format(MECHA_MODEL_PATH_RAW, MECHA_MODEL_EXT)

    recipe_te0 = models_te0[0]
    for i, model in enumerate(models_te0[1:], start=2):
        recipe_te0 = sd_mecha.weighted_sum(model, recipe_te0, alpha=(i-1)/i, dtype=g_precision_while_merge, device=g_device)
    recipe_te1 = models_te1[0]
    for i, model in enumerate(models_te1[1:], start=2):
        recipe_te1 = sd_mecha.weighted_sum(model, recipe_te1, alpha=(i-1)/i, dtype=g_precision_while_merge, device=g_device)

    unet_model = sd_mecha.weighted_sum(recipe_te0, unet_model, alpha=(sd_mecha.sdxl_txt_classes(1.0) | sd_mecha.sdxl_txt_g14_classes(0.0) | sd_mecha.sdxl_unet_classes(1.0)), dtype=g_precision_while_merge, device=g_device)
    unet_model = sd_mecha.weighted_sum(recipe_te1, unet_model, alpha=(sd_mecha.sdxl_txt_classes(0.0) | sd_mecha.sdxl_txt_g14_classes(1.0) | sd_mecha.sdxl_unet_classes(1.0)), dtype=g_precision_while_merge, device=g_device)

    return pick_vae(unet_model)

In [31]:
set_rmk(rmk_ste(), make_recipe_ste())
sd_mecha.serialize_and_save(recipe_mapping[rmk_ste()], RECIPE_PATH_SELECTED_TE)

### UNET Models to test ###

In [32]:
def make_recipe_unet_itr(p):
    unet_model = "{}{}".format(DIR_RAW, p)
    clip_model = "{}{}".format(MECHA_MODEL_PATH_SELECTED_TE, MECHA_MODEL_EXT)
    recipe_te2 = sd_mecha.weighted_sum(clip_model, unet_model, alpha=(sd_mecha.sdxl_txt_classes(0.0) | sd_mecha.sdxl_txt_g14_classes(0.0) | sd_mecha.sdxl_unet_classes(1.0)), dtype=g_precision_while_merge, device=g_device)
    return pick_vae(recipe_te2)

In [33]:
# N models
for i in range(len(MODEL_LIST_RAW)):
    set_rmk(rmk_unet(i+1), make_recipe_unet_itr(MODEL_LIST_RAW[i]))
    sd_mecha.serialize_and_save(recipe_mapping[rmk_unet(i+1)], get_unet_itr_path(RECIPE_PATH_UNET_ITR,i))

### Final model ###

- 2 models will be produced for validation. The real e2e and staged merging should yield same model within floating errors.

In [34]:
def make_recipe_final():
    raw_models = list(map(lambda p: "{}{}".format(DIR_RAW,p),MODEL_LIST_RAW))
    models_unet = [raw_models[i-1] for i in MODEL_SELECTION_UNET]

    clip_model = "{}{}".format(MECHA_MODEL_PATH_SELECTED_TE, MECHA_MODEL_EXT)

    recipe_unet = models_unet[0]
    for i, model in enumerate(models_unet[1:], start=2):
        recipe_unet = sd_mecha.weighted_sum(model, recipe_unet, alpha=(i-1)/i, dtype=g_precision_while_merge, device=g_device)

    final_model = sd_mecha.weighted_sum(clip_model, recipe_unet, alpha=(sd_mecha.sdxl_txt_classes(0.0) | sd_mecha.sdxl_txt_g14_classes(0.0) | sd_mecha.sdxl_unet_classes(1.0)), dtype=g_precision_while_merge, device=g_device)

    return pick_vae(final_model)

def make_recipe_e2e():
    raw_models = list(map(lambda p: "{}{}".format(DIR_RAW,p),MODEL_LIST_RAW))
    models_te0 = [raw_models[i-1] for i in MODEL_SELECTION_TE0]
    models_te1 = [raw_models[i-1] for i in MODEL_SELECTION_TE1]
    models_unet = [raw_models[i-1] for i in MODEL_SELECTION_UNET]

    recipe_te0 = models_te0[0]
    for i, model in enumerate(models_te0[1:], start=2):
        recipe_te0 = sd_mecha.weighted_sum(model, recipe_te0, alpha=(i-1)/i, dtype=g_precision_while_merge, device=g_device)
    recipe_te1 = models_te1[0]
    for i, model in enumerate(models_te1[1:], start=2):
        recipe_te1 = sd_mecha.weighted_sum(model, recipe_te1, alpha=(i-1)/i, dtype=g_precision_while_merge, device=g_device)
    recipe_unet = models_unet[0]
    for i, model in enumerate(models_unet[1:], start=2):
        recipe_unet = sd_mecha.weighted_sum(model, recipe_unet, alpha=(i-1)/i, dtype=g_precision_while_merge, device=g_device)

    e2e_model = "{}{}".format(DIR_RAW,MODEL_LIST_RAW[MODEL_SELECTION_VAE - 1]) 

    e2e_model = sd_mecha.weighted_sum(e2e_model, recipe_unet, alpha=(sd_mecha.sdxl_txt_classes(0.0) | sd_mecha.sdxl_txt_g14_classes(0.0) | sd_mecha.sdxl_unet_classes(1.0)), dtype=g_precision_while_merge, device=g_device)
    e2e_model = sd_mecha.weighted_sum(e2e_model, recipe_te0, alpha=(sd_mecha.sdxl_txt_classes(0.0) | sd_mecha.sdxl_txt_g14_classes(1.0) | sd_mecha.sdxl_unet_classes(0.0)), dtype=g_precision_while_merge, device=g_device)
    e2e_model = sd_mecha.weighted_sum(e2e_model, recipe_te1, alpha=(sd_mecha.sdxl_txt_classes(1.0) | sd_mecha.sdxl_txt_g14_classes(0.0) | sd_mecha.sdxl_unet_classes(0.0)), dtype=g_precision_while_merge, device=g_device)

    #Note that there is no pick_vae
    return e2e_model

In [35]:
set_rmk(rmk_f(), make_recipe_final())
sd_mecha.serialize_and_save(recipe_mapping[rmk_f()], RECIPE_PATH_FINAL)
set_rmk(rmk_e2e(), make_recipe_e2e())
sd_mecha.serialize_and_save(recipe_mapping[rmk_e2e()], RECIPE_PATH_E2E)

## Time for action ##

In [36]:
ts = time.time()

### Naive Average ###

In [37]:
tss = time.time()

if MODE_RAW in MODE_ACTIVATED:
    if os.path.isfile(OS_MODEL_PATH_RAW):
        print("Merged model is present. Skipping.")
    else:
        scheduler.merge_and_save(recipe_mapping[rmk_raw()], output_path=MECHA_MODEL_PATH_RAW, save_dtype=g_precision_final_model, total_buffer_size=g_total_buffer_size)
else:
    print("This session is not activated. Skipping.")

tse = time.time()
print("Merge time: {} sec".format(int(tse - tss)))    

Merged model is present. Skipping.
Merge time: 0 sec


### CLIP Models to test ###

In [38]:
tss = time.time()

if MODE_CLIP in MODE_ACTIVATED:
    # 3N models
    #for i in tqdm(range(len(MODEL_LIST_RAW)), desc="Iterlating raw model list to swap TEs"):
    #    for j in tqdm(range(len(N_CLIP)), desc="Making models with swapped raw TEs"):
    for i in range(len(MODEL_LIST_RAW)):
        for j in range(len(N_CLIP)):    
            if os.path.isfile(get_te_itr_path(OS_MODEL_PATH_TE_ITR,i,j)):
                print("Merged model is present. Skipping.")
            else:
                scheduler.merge_and_save(recipe_mapping[rmk_te(i+1,j)], output_path=get_te_itr_path(MECHA_MODEL_PATH_TE_ITR,i,j), save_dtype=g_precision_final_model, total_buffer_size=g_total_buffer_size)
else:
    print("This session is not activated. Skipping.")

tse = time.time()
print("Merge time: {} sec".format(int(tse - tss)))    

Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is pres

### Picked CLIP Model ###

In [39]:
tss = time.time()

if MODE_CLIP in MODE_ACTIVATED:
    if os.path.isfile(OS_MODEL_PATH_SELECTED_TE):
        print("Merged model is present. Skipping.")
    else:
        scheduler.merge_and_save(recipe_mapping[rmk_ste()], output_path=MECHA_MODEL_PATH_SELECTED_TE, save_dtype=g_precision_final_model, total_buffer_size=g_total_buffer_size)
else:
    print("This session is not activated. Skipping.")

tse = time.time()
print("Merge time: {} sec".format(int(tse - tss)))    

Merged model is present. Skipping.
Merge time: 0 sec


### UNET Models to test ###

In [40]:
tss = time.time()

if MODE_UNET in MODE_ACTIVATED:
    # N models
    #for i in tqdm(range(len(MODEL_LIST_RAW)), desc="Iterlating raw model list to swap UNETs"):
    for i in range(len(MODEL_LIST_RAW)):
        if os.path.isfile(get_unet_itr_path(OS_MODEL_PATH_UNET_ITR,i)):
            print("Merged model is present. Skipping.")
        else:
            scheduler.merge_and_save(recipe_mapping[rmk_unet(i+1)], output_path=get_unet_itr_path(MECHA_MODEL_PATH_UNET_ITR,i), save_dtype=g_precision_final_model, total_buffer_size=g_total_buffer_size)
else:
    print("This session is not activated. Skipping.")

tse = time.time()
print("Merge time: {} sec".format(int(tse - tss)))    

Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is present. Skipping.
Merged model is pres

### Final model ###

In [41]:
print (MECHA_MODEL_PATH_SELECTED_TE, MECHA_MODEL_PATH_FINAL)

./x51a-AstolfoMix-x39te0x39te1-240222-60d0764 ./x45-AstolfoMix-x39te0x39te1-240222-60d0764


In [42]:
tss = time.time()

if MODE_FINAL in MODE_ACTIVATED:
    if os.path.isfile(OS_MODEL_PATH_FINAL):
        print("Merged model is present. Skipping.")
    else:
        scheduler.merge_and_save(recipe_mapping[rmk_f()], output_path=MECHA_MODEL_PATH_FINAL, save_dtype=g_precision_final_model, total_buffer_size=g_total_buffer_size)
else:
    print("This session is not activated. Skipping.")

tse = time.time()
print("Merge time: {} sec".format(int(tse - tss)))    

Merged model is present. Skipping.
Merge time: 0 sec


In [43]:
tss = time.time()

if MODE_FINAL in MODE_ACTIVATED:
    if os.path.isfile(OS_MODEL_PATH_E2E):
        print("Merged model is present. Skipping.")
    else:
        scheduler.merge_and_save(recipe_mapping[rmk_e2e()], output_path=MECHA_MODEL_PATH_E2E, save_dtype=g_precision_final_model, total_buffer_size=g_total_buffer_size)
else:
    print("This session is not activated. Skipping.")

tse = time.time()
print("Merge time: {} sec".format(int(tse - tss)))    

Merged model is present. Skipping.
Merge time: 0 sec


Full operation time.

In [44]:
te = time.time()
print("Total time: {} sec".format(int(te - ts)))

Total time: 0 sec
