# Merge as unprune #

## Abstract ##

- Special case of `uniform_merge`. Some pruned SDXL models cannot be merged via A1111, or open by [toolkit](https://github.com/arenasys/stable-diffusion-webui-model-toolkit). We use [OpenDalle](https://huggingface.co/dataautogpt3/OpenDalleV1.1) for example. We use [SDXL Base 1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) as base model, then "merge" the model with "either aplha=0 or 1" (model list is unsorted), a.k.a use foreigner model weight directly.

## Recipe ##

- `_x25-sd_xl_base_1.0`: Complete base model.

- `_x04-OpenDalleV1.1`: Model to be un-pruned.

## Required libraries ##

- `torch>=2.0.1`
- `tensordict`
- `sd-mecha` (Commit `ead8ad7caba900ab0a40e2dfcae04b9d50fae2e6` only! `git clone`, and then `git checkout ead8ad7caba900ab0a40e2dfcae04b9d50fae2e6`, and copy 1 layer upward to `./sd_mecha/`)
- [safetensors](https://huggingface.co/docs/safetensors/index)
- [diffusers](https://huggingface.co/docs/diffusers/installation)
- [pytorch](https://pytorch.org/get-started/locally/#windows-python)

## WTF why and will it work? ##

- Yes. [It is part of my research](./README_XL.md).
- Image comparasion will be listed there.

## Importing libraries ##

In [1]:
# Built-in
import time
import os

# Is dependency fufilled?
import torch
import tensordict

In [2]:
torch.__version__

'2.2.0+cu121'

In [3]:
tensordict.__version__

'0.3.0'

In [4]:
# Import the main module.
import sd_mecha

In [5]:
# Fix for OMP: Error #15
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

I'll disable pruning to let [toolkit](https://github.com/arenasys/stable-diffusion-webui-model-toolkit) support the merged model.

In [6]:
g_device = "cuda:0" #"cpu"
g_prune = False
g_merged_model = "_x04-fixed" #.safetensors
g_precision = 16 #fp16, forwarded from sd-meh

In [7]:
model_folder = "../stable-diffusion-webui/tmp/astolfo_mix/sdxl/_x04/"
model_type = torch.float16 if "cuda" in g_device else torch.float # CPU doesn't support FP16 / FP8

Exploring model inside a folder.

In [8]:
model_list = os.listdir(model_folder)
# Exclude yaml.
model_list = list(filter(lambda p: p.endswith(".ckpt") or p.endswith(".safetensors") or p.endswith(".bin"), model_list))
if len(model_list) < 2:
    #Special case: Model fix
    model_list.append(model_list[0])
    #raise Exception("Need at least 2 models for merge.")

#model_list = list(map(lambda p: os.path.splitext(os.path.basename(p))[0], model_list))

In [9]:
#model_list
print("{} models found.".format(len(model_list)))

2 models found.


Setting up merge receipe and merge scheduler.

In [10]:
models = model_list

merge = models[0]
for i, model in enumerate(models[1:], start=2):
    merge = sd_mecha.weighted_sum(merge, model, alpha=0)

scheduler = sd_mecha.MergeScheduler(
    base_dir=model_folder,
    device=g_device,
    prune=g_prune,
    precision=g_precision,
)

Time for action.

In [11]:
ts = time.time()
scheduler.merge_and_save(merge, output_path=g_merged_model)

stage 1: 100%|██████████| 2516/2516 [00:00<00:00, 5448.66it/s]
stage 2: 100%|██████████| 2515/2515 [00:00<00:00, 420316.16it/s]


In [12]:
te = time.time()
print("time: {} sec".format(int(te - ts)))

time: 18 sec
