In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load Models
informative_model_name = "gradientai/Llama-3-8B-Instruct-262k" 
base_model_name = "kuotient/Meta-Llama-3-8B-Instruct"
target_model_name = "beomi/Llama-3-Open-Ko-8B-Instruct-preview"  # target model.

informative_model = AutoModelForCausalLM.from_pretrained(long_context_model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
target_model = AutoModelForCausalLM.from_pretrained(target_model_name)

In [None]:
def calculate_weight_diff(a, b):
    return a - b

def calculate_model_diffs(model_a, model_b):
    model_a_dict = model_a.state_dict()
    model_b_dict = model_b.state_dict()
    model_diffs = {}
    for key in model_a_dict.keys():
        if key in model_b_dict:
            model_diffs[key] = calculate_weight_diff(model_a_dict[key], model_b_dict[key])
            print(f"Diff calculated for {key}")
    return model_diffs

def calculate_sigmoid_ratios(base_model, target_model, epsilon=1e-6):
    sigmoid_ratios = {}
    target_diff = calculate_model_diffs(target_model, base_model) # Order doesn't matter #
    for key in target_diff.keys():
        diff_tensor = abs(target_diff[key]) # abs
        diff_min = diff_tensor.min().item()
        diff_max = diff_tensor.max().item()
        print(f"Key: {key}")
        print(f"  Diff Min: {diff_min}")
        print(f"  Diff Max: {diff_max}")

        # # Displaying histogram for the tensor distribution
        # plt.hist(diff_tensor.cpu().numpy().flatten(), bins=50)
        # plt.title(f"Histogram of differences for {key}")
        # plt.xlabel("Difference")
        # plt.ylabel("Frequency")
        # plt.show()
        
        if abs(diff_max - diff_min) < epsilon:
            print(f"  All values are the same. Setting sigmoid_diff to 0.")
            sigmoid_diff = torch.zeros_like(diff_tensor)
        else:
            normalized_diff = (diff_tensor - diff_min) / (diff_max - diff_min)
            sigmoid_diff = torch.sigmoid(normalized_diff * 12 - 6)
        sigmoid_ratios[key] = sigmoid_diff
        print(f"  Sigmoid Diff Min: {sigmoid_diff.min().item()}")
        print(f"  Sigmoid Diff Max: {sigmoid_diff.max().item()}")
    return sigmoid_ratios

def apply_model_diffs(target_model, model_diffs, sigmoid_ratios):
    target_state_dict = target_model.state_dict()
    for key in model_diffs.keys():
        print(key)
        print(model_diffs[key])
        ratio = sigmoid_ratios[key]
        print(ratio)
        scaled_diff = model_diffs[key] * (1 - ratio)
        target_state_dict[key] += scaled_diff
        print(f"Diff applied for {key}")
    target_model.load_state_dict(target_state_dict)

print("Calculating model diffs...")
model_diffs = calculate_model_diffs(informative_model, base_model) # informative - base
print("Model diffs calculated.")

print("Calculating sigmoid ratios...")
sigmoid_ratios = calculate_sigmoid_ratios(base_model, target_model) # Order doesn't matter since we're gonna do abs
print("Sigmoid ratios calculated.")

print("Applying model diffs...")
apply_model_diffs(target_model, model_diffs, sigmoid_ratios)
print("Model diffs applied.")

print("Saving target model...")
target_model.save_pretrained('./done')
print("Target model saved.")