<a href="https://colab.research.google.com/github/Loki-33/Mergex/blob/main/model_merging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification

In [None]:
#SLERP METHOD

In [None]:
model1 = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
tokenizer1 = AutoTokenizer.from_pretrained("bert-base-uncased")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model2 = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
tokenizer2 = AutoTokenizer.from_pretrained("distilbert-base-uncased")

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
parameters1= sum([p.numel() for p in model1.parameters()])

In [None]:
parameters2 = sum([p.numel() for p in model2.parameters()])

In [None]:
#LERP METHOD

In [None]:
merged_state_dict = {}

In [None]:
alpha=0.5

In [None]:
for key in model2.distilbert.embeddings.state_dict():
  A=model1.bert.embeddings.state_dict()[key]
  B = model2.distilbert.embeddings.state_dict()[key]
  merged_state_dict[f"distilbert.embeddings.{key}"] = alpha * A + (1 - alpha) * B

In [None]:
model1_layers = list(model1.bert.encoder.layer)
model2_layers = list(model2.distilbert.transformer.layer)

In [None]:
for i,j in zip(model1_layers, model2_layers):
  if i == j:
    print(i, j)

In [None]:
# for i, distil_layer in enumerate(model2_layers):
#   bert_idx1, bert_idx2 = 2*i, 2*i+1
#   bert_layer_avg = {
#       k: 0.5 * v1 + 0.5 * v2 for (k, v1), (_, v2) in \
#       zip(model1_layers[bert_idx1].state_dict().items(), \
#           model2_layers[i].state_dict().items())
#   }
#   for k,v in distil_layer.state_dict().items():
#     A = bert_layer_avg[k]
#     B = v
#     merged_state_dict[f"distilbert.transformer.layer.{i}.{k}"] = alpha * A + (1 - alpha) * B


In [None]:
model1.bert.pooler

BertPooler(
  (dense): Linear(in_features=768, out_features=768, bias=True)
  (activation): Tanh()
)

In [None]:
def lerp(x1, x2, alpha=0.5):
  return (1-alpha)*x1 + alpha*x2

In [None]:
def merge(bert_model, distil_model, alpha=0.5):
  new_state = distil_model.state_dict().copy()
  #EMBEDDINGS
  for key in ['embeddings.word_embeddings.weight', "embeddings.position_embeddings.weight"]:
    new_state[key] = lerp(
        bert_model.state_dict()[f"bert.{key}"],
        distil_model.state_dict()[f"distilbert.{key}"],
        alpha=alpha
    )

  #trransformer leayer
  bert_layers = bert_model.bert.encoder.layer
  distil_layers = distil_model.distilbert.transformer.layer

  for distil_idx in range(6):
    bert_idx1, bert_idx2 = 2*distil_idx, 2*distil_idx+1
    bert_layer1 = bert_layers[bert_idx1].state_dict()
    bert_layer2 = bert_layers[bert_idx2].state_dict()
    distil_layer = distil_layers[distil_idx].state_dict()

    merged={}
    for k in distil_layer.keys():
      if 'q_lin' in k:
        k1, k2 = "attention.self.query." + k.split(".")[-1], "attention.self.query." + k.split(".")[-1]
      elif "k_lin" in k:
        k1, k2 = "attention.self.key." + k.split(".")[-1], "attention.self.key." + k.split(".")[-1]
      elif 'v_lin' in k:
        k1, k2 = "attention.self.value." + k.split(".")[-1], "attention.self.value." + k.split(".")[-1]
      elif 'out_lin' in k:
        k1, k2 = "attention.output.dense." + k.split(".")[-1], "attention.output.dense." + k.split(".")[-1]
      elif "ffn.lin1" in k:
        k1, k2 = "intermediate.dense." + k.split(".")[-1], "intermediate.dense." + k.split(".")[-1]
      elif 'ffn.lin2' in k:
        k1, k2 = "output.dense." + k.split(".")[-1], "output.dense." + k.split(".")[-1]
      elif "sa_layer_norm" in k:
        k1, k2 = "attention.output.LayerNorm." + k.split(".")[-1], "attention.output.LayerNorm." + k.split(".")[-1]
      elif "output_layer_norm" in k:
        k1, k2 = "output.LayerNorm." + k.split(".")[-1], "output.LayerNorm." + k.split(".")[-1]
      else:
        continue

      bert_avg = 0.5 * bert_layer1[k1] + 0.5 * bert_layer2[k2]
      merged[k] = lerp(bert_avg, distil_layer[k], alpha=alpha)

    for k,v in merged.items():
      new_state[f"distilbert.transformer.layer.{distil_idx}.{k}"] = v
  if 'pre_classifier.weight' in new_state:
    new_state['pre_classifier.weight'] = distil_model.state_dict()["pre_classifier.weight"]
    new_state["pre_classifier.bias"]   = distil_model.state_dict()["pre_classifier.bias"]

  distil_model.load_state_dict(new_state, strict=False)
  return distil_model

In [None]:
merged_model = merge(model1, model2, alpha=0.5)

In [None]:
model2

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [None]:
tokenizer1.pad_token = tokenizer1.eos_token

In [None]:
sample_texts = [
    "I love this movie, it was fantastic!",   # positive
    "This product is terrible and I hate it", # negative
    "The book was okay, nothing special",     # neutral
    "Absolutely wonderful experience",        # positive
    "Worst service Iâ€™ve ever had",            # negative
]

sample_labels = torch.tensor([1, 0, 2, 1, 0],dtype=torch.float32)

In [None]:
tokenizer1.pad_token = tokenizer1.eos_token

In [None]:
if tokenizer2.pad_token is None:
    tokenizer2.add_special_tokens({'pad_token': '[PAD]'})
tokenizer2.pad_token = '[PAD]'
input_ids = tokenizer2(sample_texts, padding=True, truncation=True, return_tensors="pt")

In [None]:
with torch.no_grad():
  outputs = merged_model(**input_ids)
  preds = torch.argmax(outputs.logits, dim=-1)

correct = (preds == sample_labels).sum().item()
accuracy = correct / len(sample_labels)

print(f"SImple Accuracy TEST: {accuracy * 100:.2f}%")

SImple Accuracy TEST: 40.00%


In [None]:
#SLERP

In [None]:
def slerp(A,B, alpha=0.5):
    A_flat = A.view(-1)
    B_flat = B.view(-1)
    dot = torch.dot(A_flat, B_flat) / (A_flat.norm() * B_flat.norm())
    dot = torch.clamp(dot, -1.0, 1.0)
    theta = torch.acos(dot) * alpha
    rel = (B_flat - A_flat * dot).div((B_flat - A_flat * dot).norm() + 1e-8)
    res = (A_flat * torch.cos(theta) + rel * A_flat.norm() * torch.sin(theta))
    return res.view_as(A)   # ðŸ”‘ reshape back

In [None]:
def merge(bert_model, distil_model, alpha=0.5):
  new_state = distil_model.state_dict().copy()
  #EMBEDDINGS
  for key in ['embeddings.word_embeddings.weight', "embeddings.position_embeddings.weight"]:
    new_state[key] = slerp(
        bert_model.state_dict()[f"bert.{key}"],
        distil_model.state_dict()[f"distilbert.{key}"],
        alpha=alpha
    )

  #trransformer leayer
  bert_layers = bert_model.bert.encoder.layer
  distil_layers = distil_model.distilbert.transformer.layer

  for distil_idx in range(6):
    bert_idx1, bert_idx2 = 2*distil_idx, 2*distil_idx+1
    bert_layer1 = bert_layers[bert_idx1].state_dict()
    bert_layer2 = bert_layers[bert_idx2].state_dict()
    distil_layer = distil_layers[distil_idx].state_dict()

    merged={}
    for k in distil_layer.keys():
      if 'q_lin' in k:
        k_bert = "attention.self.query." + k.split(".")[-1]

      elif "k_lin" in k:
        k_bert = "attention.self.key." + k.split(".")[-1]
      elif 'v_lin' in k:
        k_bert = "attention.self.value." + k.split(".")[-1]
      elif 'out_lin' in k:
        k_bert = "attention.output.dense." + k.split(".")[-1]
      elif "ffn.lin1" in k:
        k_bert = "intermediate.dense." + k.split(".")[-1]

      elif 'ffn.lin2' in k:
        k_bert = "output.dense." + k.split(".")[-1]

      elif "sa_layer_norm" in k:
        k_bert = "attention.output.LayerNorm." + k.split(".")[-1]
      elif "output_layer_norm" in k:
        k_bert = "output.LayerNorm." + k.split(".")[-1]
      else:
        continue

      bert_avg = 0.5 * bert_layer1[k_bert] + 0.5 * bert_layer2[k_bert]
      merged[k] = slerp(bert_avg, distil_layer[k], alpha=alpha)

    for k,v in merged.items():
      new_state[f"distilbert.transformer.layer.{distil_idx}.{k}"] = v
  if 'pre_classifier.weight' in new_state:
    new_state['pre_classifier.weight'] = distil_model.state_dict()["pre_classifier.weight"]
    new_state["pre_classifier.bias"]   = distil_model.state_dict()["pre_classifier.bias"]

  distil_model.load_state_dict(new_state, strict=False) #strict prevents crashing when some leys dont match
  return distil_model

In [None]:
slerp_model = merge(model1, model2, alpha=0.5)

In [None]:
def testy(model):
  with torch.no_grad():
    outputs = model(**input_ids)
    preds = torch.argmax(outputs.logits, dim=-1)

  correct = (preds == sample_labels).sum().item()
  accuracy = correct / len(sample_labels)

  acc = f"SImple Accuracy TEST: {accuracy * 100:.2f}%"
  return acc

In [None]:
#TIES METHOD

In [None]:
def ties_merge(param1, param2, trim_frac=0.2, alpha=None):
    # Flatten
    p1, p2 = param1.view(-1), param2.view(-1)

    # ---- TRIM ----
    k1 = int(len(p1) * trim_frac)
    k2 = int(len(p2) * trim_frac)
    thresh1 = p1.abs().kthvalue(len(p1)-k1).values
    thresh2 = p2.abs().kthvalue(len(p2)-k2).values
    p1_trim = torch.where(p1.abs() >= thresh1, p1, torch.zeros_like(p1))
    p2_trim = torch.where(p2.abs() >= thresh2, p2, torch.zeros_like(p2))

    # ---- ELECT SIGN ----
    dominant = torch.where(p1_trim.abs() >= p2_trim.abs(), p1_trim, p2_trim)
    sign_vector = dominant.sign()

    # ---- DISJOINT MERGE ----
    aligned_p1 = torch.where(p1_trim.sign() == sign_vector, p1_trim, torch.zeros_like(p1_trim))
    aligned_p2 = torch.where(p2_trim.sign() == sign_vector, p2_trim, torch.zeros_like(p2_trim))

    count = (aligned_p1.abs() > 0).int() + (aligned_p2.abs() > 0).int()
    merged = (aligned_p1 + aligned_p2) / count.clamp(min=1)

    return merged.view_as(param1)

In [None]:
def merge_loop(merge_fn, bert_model, distil_model, alpha=None):
  new_state = distil_model.state_dict().copy()
  #EMBEDDINGS
  for key in ['embeddings.word_embeddings.weight', "embeddings.position_embeddings.weight"]:
    new_state[key] = merge_fn(
        bert_model.state_dict()[f"bert.{key}"],
        distil_model.state_dict()[f"distilbert.{key}"],
        alpha=alpha
    )

  #trransformer leayer
  bert_layers = bert_model.bert.encoder.layer
  distil_layers = distil_model.distilbert.transformer.layer

  for distil_idx in range(6):
    bert_idx1, bert_idx2 = 2*distil_idx, 2*distil_idx+1
    bert_layer1 = bert_layers[bert_idx1].state_dict()
    bert_layer2 = bert_layers[bert_idx2].state_dict()
    distil_layer = distil_layers[distil_idx].state_dict()

    merged={}
    for k in distil_layer.keys():
      if 'q_lin' in k:
        k_bert = "attention.self.query." + k.split(".")[-1]

      elif "k_lin" in k:
        k_bert = "attention.self.key." + k.split(".")[-1]
      elif 'v_lin' in k:
        k_bert = "attention.self.value." + k.split(".")[-1]
      elif 'out_lin' in k:
        k_bert = "attention.output.dense." + k.split(".")[-1]
      elif "ffn.lin1" in k:
        k_bert = "intermediate.dense." + k.split(".")[-1]

      elif 'ffn.lin2' in k:
        k_bert = "output.dense." + k.split(".")[-1]

      elif "sa_layer_norm" in k:
        k_bert = "attention.output.LayerNorm." + k.split(".")[-1]
      elif "output_layer_norm" in k:
        k_bert = "output.LayerNorm." + k.split(".")[-1]
      else:
        continue

      bert_avg = 0.5 * bert_layer1[k_bert] + 0.5 * bert_layer2[k_bert]
      merged[k] = merge_fn(bert_avg, distil_layer[k], alpha=alpha)

    for k,v in merged.items():
      new_state[f"distilbert.transformer.layer.{distil_idx}.{k}"] = v

  if 'pre_classifier.weight' in new_state:
    new_state['pre_classifier.weight'] = distil_model.state_dict()["pre_classifier.weight"]
    new_state["pre_classifier.bias"]   = distil_model.state_dict()["pre_classifier.bias"]

  distil_model.load_state_dict(new_state, strict=False) #strict prevents crashing when some leys dont match
  return distil_model

In [None]:
ties_model = merge_loop(ties_merge, model1, model2)

In [None]:
bb = testy(ties_model)
bb

'SImple Accuracy TEST: 40.00%'

In [None]:
#DARE METHOD

In [None]:
from copy import deepcopy

In [None]:
def dare_merge(bert_model, distil_model, drop_prob=0.2):
  new_state = distil_model.state_dict().copy()
  for key in ['embeddings.word_embeddings.weight', "embeddings.position_embeddings.weight"]:
    new_state[key] = 0.5 * (bert_model.state_dict()[f"bert.{key}"] + \
                            distil_model.state_dict()[f"distilbert.{key}"]
    )
  bert_layers = bert_model.bert.encoder.layer
  distil_layers = distil_model.distilbert.transformer.layer
  keep_prob = 1 - drop_prob
  scale = 1.0/ keep_prob

  for distil_idx in range(6):

    bert_idx1, bert_idx2 = 2*distil_idx, 2*distil_idx+1
    bert_layer1 = bert_layers[bert_idx1].state_dict()
    bert_layer2 = bert_layers[bert_idx2].state_dict()
    distil_layer = distil_layers[distil_idx].state_dict()

    merged={}

    for k in distil_layer.keys():
      if 'q_lin' in k:
        k_bert = "attention.self.query." + k.split(".")[-1]
      elif "k_lin" in k:
        k_bert = "attention.self.key." + k.split(".")[-1]
      elif 'v_lin' in k:
        k_bert = "attention.self.value." + k.split(".")[-1]
      elif 'out_lin' in k:
        k_bert = "attention.output.dense." + k.split(".")[-1]
      elif "ffn.lin1" in k:
        k_bert = "intermediate.dense." + k.split(".")[-1]

      elif 'ffn.lin2' in k:
        k_bert = "output.dense." + k.split(".")[-1]
      elif "sa_layer_norm" in k:
        k_bert = "attention.output.LayerNorm." + k.split(".")[-1]
      elif "output_layer_norm" in k:
        k_bert = "output.LayerNorm." + k.split(".")[-1]
      else:
        continue
      bert_avg = 0.5 * bert_layer1[k_bert] + 0.5 * bert_layer2[k_bert]
      mask = (torch.rand_like(bert_avg) < keep_prob).float()
      bert_avg = bert_avg * mask * scale

      merged_param = (bert_avg + distil_layer[k])/ 2.0
      merged[k] = merged_param

    for k,v in merged.items():
      new_state[f"distilbert.transformer.layer.{distil_idx}.{k}"] = v

  if 'pre_classifier.weight' in new_state:
    new_state['pre_classifier.weight'] = distil_model.state_dict()["pre_classifier.weight"]
    new_state["pre_classifier.bias"]   = distil_model.state_dict()["pre_classifier.bias"]
  distil_model.load_state_dict(new_state, strict=False)
  return distil_model



In [None]:
dare_model = dare_merge(model1, model2)

In [None]:
cc = testy(dare_model)
cc

'SImple Accuracy TEST: 40.00%'