In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns

long_context_model_name = "gradientai/Llama-3-8B-Instruct-262k" # Long Context Model
base_model_name = "kuotient/Meta-Llama-3-8B-Instruct" # Base Model // Identical weight to official repo(gated)
target_model_name = ""  # Model that want to expand context
long_context_model = AutoModelForCausalLM.from_pretrained(long_context_model_name)
base_model_name = AutoModelForCausalLM.from_pretrained(base_model_name)

def calculate_weight_diff(base_weight, chat_weight):
    return torch.abs(base_weight - chat_weight).mean().item()

def calculate_layer_diffs(long_context_model, base_model_name):
    layer_diffs = []
    for long_context_layer, base_model_layer in zip(long_context_model.model.layers, base_model_name.model.layers):
        layer_diff = {
            'input_layernorm': calculate_weight_diff(long_context_layer.input_layernorm.weight, base_model_layer.input_layernorm.weight),
            'mlp_down_proj': calculate_weight_diff(long_context_layer.mlp.down_proj.weight, base_model_layer.mlp.down_proj.weight),
            'mlp_gate_proj': calculate_weight_diff(long_context_layer.mlp.gate_proj.weight, base_model_layer.mlp.gate_proj.weight),
            'mlp_up_proj': calculate_weight_diff(long_context_layer.mlp.up_proj.weight, base_model_layer.mlp.up_proj.weight),
            'post_attention_layernorm': calculate_weight_diff(long_context_layer.post_attention_layernorm.weight, base_model_layer.post_attention_layernorm.weight),
            'self_attn_q_proj': calculate_weight_diff(long_context_layer.self_attn.q_proj.weight, base_model_layer.self_attn.q_proj.weight),
            'self_attn_k_proj': calculate_weight_diff(long_context_layer.self_attn.k_proj.weight, base_model_layer.self_attn.k_proj.weight),
            'self_attn_v_proj': calculate_weight_diff(long_context_layer.self_attn.v_proj.weight, base_model_layer.self_attn.v_proj.weight),
            'self_attn_o_proj': calculate_weight_diff(long_context_layer.self_attn.o_proj.weight, base_model_layer.self_attn.o_proj.weight)
        }
        layer_diffs.append(layer_diff)
    return layer_diffs

def apply_layer_diffs(target_model, layer_diffs):
    for layer, layer_diff in zip(target_model.model.layers, layer_diffs):
        layer.input_layernorm.weight.data += layer_diff['input_layernorm']
        layer.mlp.down_proj.weight.data += layer_diff['mlp_down_proj']
        layer.mlp.gate_proj.weight.data += layer_diff['mlp_gate_proj']
        layer.mlp.up_proj.weight.data += layer_diff['mlp_up_proj']
        layer.post_attention_layernorm.weight.data += layer_diff['post_attention_layernorm']
        layer.self_attn.q_proj.weight.data += layer_diff['self_attn_q_proj']
        layer.self_attn.k_proj.weight.data += layer_diff['self_attn_k_proj']
        layer.self_attn.v_proj.weight.data += layer_diff['self_attn_v_proj']
        layer.self_attn.o_proj.weight.data += layer_diff['self_attn_o_proj']

def visualize_layer_diffs(layer_diffs):
    num_layers = len(layer_diffs)
    num_components = len(layer_diffs[0])
    
    fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
    fig.suptitle(f"{long_context_model_name} <> {base_model_name}", fontsize=16)
    
    for i, component in enumerate(layer_diffs[0].keys()):
        component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
        sns.heatmap(component_diffs, annot=True, fmt=".6f", cmap="YlGnBu", ax=axs[i], cbar_kws={"shrink": 0.8})
        axs[i].set_title(component)
        axs[i].set_xlabel("Layer")
        axs[i].set_ylabel("Difference")
        axs[i].set_xticks([])
        axs[i].set_yticks(range(num_layers))
        axs[i].set_yticklabels(range(num_layers))
        axs[i].invert_yaxis()
    
    plt.tight_layout()
    plt.show()

# get diff
layer_diffs = calculate_layer_diffs(long_context_model, base_model_name)

# apply diff
target_model = AutoModelForCausalLM.from_pretrained(long_context_model_name)
apply_layer_diffs(target_model, layer_diffs)

# save
target_model.save_pretrained('./expanded_target')

# wanna see diff?
visualize_layer_diffs(layer_diffs)