In [None]:
import random
import sys
import time
import gc

import numpy as np

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

params = {
    'low_cpu_mem_usage': True,
    'trust_remote_code': False,
    'torch_dtype': torch.bfloat16,
    'use_safetensors': True,
    'attn_implementation': "flash_attention_2"
}


def norm_model_weights(model):
    last_q = None
    lqb = None
    lqkm = None
    last_v = None
    lvb = None
    lvom = None
    last_up = None
    for name, param in model.named_parameters():
        if "q_proj" in name:
            if "bias" in name:
                lqb = param
            else:
                last_q = param
        if "k_proj" in name:
            if "bias" in name:
                param.data = param.data * lqkm
            else:
                # print(last_q.data.shape, param.data.shape)

                # safe but not full solution
                # last_q.data = last_q.data.to(torch.float64)
                # param.data = param.data.to(torch.float64)
                mult = torch.sqrt(torch.mean(torch.abs(last_q.data), dim=0, keepdim=True).transpose(0, 1) / 
                                torch.mean(torch.abs(param.data), dim=1, keepdim=True))
                mult = torch.mean(mult)
                last_q.data = last_q.data / mult
                lqb.data = lqb.data / mult
                param.data = param.data * mult
                lqkm = mult
                # last_q.data = last_q.data.to(torch.bfloat16)
                # param.data = param.data.to(torch.bfloat16)
                #

                # THIS ONE STAYS
                # last_q.data = last_q.data.to(torch.float64)
                # param.data = param.data.to(torch.float64)
                # mult = torch.sqrt(torch.mean(torch.abs(last_q.data), dim=1, keepdim=True) / 
                #                   torch.mean(torch.abs(param.data.repeat(8, 1)), dim=1, keepdim=True))
                # # print(mult.shape, mult)
                # last_q.data = last_q.data / mult
                # param.data = param.data * ((mult[:256] + mult[256:512] + mult[512:768] + mult[768:1024] +
                #                            mult[1024:1280] + mult[1280:1536] + mult[1536:1792] + mult[1792:2048]) / 8)
                # last_q.data = last_q.data.to(torch.bfloat16)
                # param.data = param.data.to(torch.bfloat16)

        if "v_proj" in name:
            if "bias" in name:
                lvb = param
            else:
                last_v = param
        if "o_proj" in name:
            if "bias" in name:
                param.data = param.data * lvom
            else:
                # print(last_v.data.shape, param.data.shape)

                # safe but not full solution
                # last_v.data = last_v.data.to(torch.float64)
                # param.data = param.data.to(torch.float64)
                mult = torch.sqrt(torch.mean(torch.abs(last_v.data), dim=0, keepdim=True) / 
                                torch.mean(torch.abs(param.data), dim=0, keepdim=True))
                mult = torch.mean(mult) * -1
                last_v.data = last_v.data / mult
                lvb.data = lvb.data / mult
                param.data = param.data * mult
                lvom = mult
                # last_v.data = last_v.data.to(torch.bfloat16)
                # param.data = param.data.to(torch.bfloat16)
                # 

                # mult = torch.sqrt(torch.mean(torch.abs(last_v.data.repeat(8, 1)), dim=0, keepdim=True) / 
                #                   torch.mean(torch.abs(param.data), dim=0, keepdim=True)) / 2
                # print(mult.shape, mult)
                # last_v.data = last_v.data / mult # ((mult[:256] + mult[256:512] + mult[512:768] + mult[768:1024] +
                #                            #mult[1024:1280] + mult[1280:1536] + mult[1536:1792] + mult[1792:2048]) / 8)
                # param.data = param.data * mult

        if "up_proj" in name:
            last_up = param
        if "down_proj" in name:
            # print(last_up.data.shape, param.data.shape)
            # last_up.data = last_up.data.to(torch.float64)
            # param.data = param.data.to(torch.float64)
            mult = torch.sqrt(torch.mean(torch.abs(last_up.data), dim=1, keepdim=True).transpose(0, 1) / 
                            torch.mean(torch.abs(param.data), dim=0, keepdim=True))
            # print(mult, mult.shape)
            last_up.data = last_up.data / mult.transpose(0, 1)
            param.data = param.data * mult
            # last_up.data = last_up.data.to(torch.bfloat16)
            # param.data = param.data.to(torch.bfloat16)
    return model



def merge(model_name0, model_name1, merge_name="merged", ratio=0.5, 
          embed_ratio=None, norm_ratio=None, fc_ratio=None, norm0=False, norm1=False, return_model=False): # higher ratio means more of model0
    
    if embed_ratio is None:
        embed_ratio = ratio
    if norm_ratio is None:
        norm_ratio = ratio
    if fc_ratio is None:
        fc_ratio = ratio

    tokenizer = None
    if type(model_name0) is str:
        model0 = AutoModelForCausalLM.from_pretrained(model_name0, **params, cache_dir="Models")
        # tokenizer = AutoTokenizer.from_pretrained(model_name0, cache_dir="Models")
    else:
        model0 = model_name0
    if type(model_name1) is str:
        model1 = AutoModelForCausalLM.from_pretrained(model_name1, **params, cache_dir="Models")
        # if tokenizer is None:
        #     tokenizer = AutoTokenizer.from_pretrained(model_name1, cache_dir="Models")
    else:
        model1 = model_name1

    if norm0:
        model0 = norm_model_weights(model0)
    if norm1:
        model1 = norm_model_weights(model1)

    tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-zephyr-1_6b", trust_remote_code=False, use_fast=True, cache_dir="Models")

    params0 = {}
    for name, param in model0.named_parameters():
        params0[name] = param

    for name, param in model1.named_parameters():
        if "embed" in name:
            param.data = ((params0[name].data * embed_ratio) + (param.data * (1 - embed_ratio)))
        elif ("up_proj" not in name 
            and "down_proj" not in name 
            and "gate_proj" not in name 
            and "o_proj" not in name 
            and "k_proj" not in name 
            and "v_proj" not in name 
            and "q_proj" not in name
            and "embed" not in name
            ):
            param.data = ((params0[name].data * norm_ratio) + (param.data * (1 - norm_ratio)))
        elif "up_proj" in name or "down_proj" in name:
            param.data = ((params0[name].data * fc_ratio) + (param.data * (1 - fc_ratio)))
        else:
            param.data = ((params0[name].data * ratio) + (param.data * (1 - ratio)))

    model1.config.bos_token_id = 2
    model1.config.eos_token_id = 1
    model1.generation_config.bos_token_id = 2
    model1.generation_config.eos_token_id = 1

    if return_model:
        del model0; gc.collect()
        return model1, tokenizer
    else:
        model1.save_pretrained("Models/"+merge_name)
        tokenizer.save_pretrained("Models/"+merge_name)
        del model0, model1, tokenizer; gc.collect()

In [None]:
merge("MesozoicMetallurgist/zeta-Anisian", "0x0dad0/beta_s03", 
      merge_name="merged20", norm0=True, norm1=True, ratio=0.5, embed_ratio=0.5, norm_ratio=0.5, fc_ratio=0.5)

In [None]:
merge("MesozoicMetallurgist/nous-Burdigalian", "MesozoicMetallurgist/nous-Langhian", 
      merge_name="merged17", norm0=True, norm1=True, ratio=0.5, embed_ratio=0.5, norm_ratio=0.5, fc_ratio=0.5)