## Pearson correlation calculation between fisher info and model perplexity

# base Function

In [17]:
import json
import re
import matplotlib.pyplot as plt
import numpy as np


def draw_fisher_bar(file_path,percentage):
    feature = re.search(r"(\d+_\d+)", file_path).group(1)
    print(feature)
    with open(file_path, "r") as file:
        data = json.load(file)
    
    # Extract keys for plotting
    keys = [
        "self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight",
        "self_attn.o_proj.weight", "mlp.gate_proj.weight", "mlp.up_proj.weight",
        "mlp.down_proj.weight"
    ]
    
    # Create subplots for each key
    fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(15, 20))
    axes = axes.flatten()
    
    global_min = min(min(data[str(idx)][key] for idx in range(len(data))) for key in keys)
    global_max = min(max(max(data[str(idx)][key] for idx in range(len(data))) for key in keys),50)
    
    
    for i, key in enumerate(keys):
        values = [data[str(idx)][key] for idx in range(len(data))]
        axes[i].bar(range(len(values)), values, color="skyblue", edgecolor="black")
        axes[i].set_title(key)
        axes[i].set_xlabel("Transformer block ID")
        axes[i].set_ylabel("Fisher Info")
        axes[i].set_ylim(global_min, global_max)
    
    # Remove the extra subplot if necessary
    if len(keys) < len(axes):
        for j in range(len(keys), len(axes)):
            fig.delaxes(axes[j])
            
    plt.tight_layout()
    # plt.show()
    plt.savefig(f"/root/autodl-tmp/methods/mix_quantize/visualization/llama2-7b/Total_Fisher_bar_chart_{feature}_{percentage}.png")
    plt.close()

In [7]:
def draw_perplexity(file_path,percentage):
    feature = re.search(r"(\d+_\d+)", file_path).group(1)
    original_perplexity = float(re.search(r"(\d+\.\d+)", file_path).group(1))
    print(feature,original_perplexity)
    with open(file_path, "r") as file:
        data = json.load(file)
    
    # Extract keys for plotting
    keys = [
        "self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight",
        "self_attn.o_proj.weight", "mlp.gate_proj.weight", "mlp.up_proj.weight",
        "mlp.down_proj.weight"
    ]
    
    # Create subplots for each key
    fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(15, 20))
    axes = axes.flatten()
    
    global_min = min(min(data[str(idx)][key] - original_perplexity for idx in range(len(data))) for key in keys)
    global_max = max(max(data[str(idx)][key] - original_perplexity for idx in range(len(data))) for key in keys)
    
    
    for i, key in enumerate(keys):
        values = [data[str(idx)][key] - original_perplexity for idx in range(len(data))]
        axes[i].bar(range(len(values)), values, color="skyblue", edgecolor="black")
        axes[i].set_title(key)
        axes[i].set_xlabel("Transformer block ID")
        axes[i].set_ylabel("Perplexity")
        axes[i].set_ylim(global_min, global_max)
    
    # Remove the extra subplot if necessary
    if len(keys) < len(axes):
        for j in range(len(keys), len(axes)):
            fig.delaxes(axes[j])
            
    plt.tight_layout()
    # plt.show()
    plt.savefig(f"/root/autodl-tmp/methods/mix_quantize/visualization/llama2-7b/Total_Perplrxity_bar_chart_{feature}_{percentage}.png")
    plt.close()

root_path = "/root/autodl-tmp/methods/mix_quantize/model_info/llama2-7b/"

# 4096 2 0.2

In [23]:
filename = "fisher_data_4096_2_2024-12-20-22-26-03.json"
draw_fisher_bar(root_path+filename,0.2)

4096_2


In [24]:
filename = "modified_perplexitys_4096_2_2024-12-21-11-15-11_2.556640625.json"
draw_perplexity(root_path+filename,0.2)

4096_2 2.556640625


# 4096 2 0.5

In [25]:
filename = "fisher_data_4096_2_2024-12-21-11-35-11.json"
draw_fisher_bar(root_path+filename,0.5)

4096_2


In [26]:
filename = "modified_perplexitys_4096_2_2024-12-21-11-35-11_2.556640625.json"
draw_perplexity(root_path+filename,0.5)

4096_2 2.556640625


# 4096 20 0.5

In [27]:
filename = "fisher_data_4096_20_2024-12-21-15-40-48.json"
draw_fisher_bar(root_path+filename,0.5)

4096_20


In [28]:
filename = "modified_perplexitys_4096_20_2024-12-21-15-40-48_2.44140625.json"
draw_perplexity(root_path+filename,0.5)

4096_20 2.44140625


# 2048 2 0.2

In [19]:
filename = "fisher_data_2048_2024-12-20-21-59-30.json"
draw_fisher_bar(root_path+filename)

2048_2024


In [20]:
filename = "modified_perplexitys_2048_2024-12-20-21-59-30_2.49609375.json"
draw_perplexity(root_path+filename)

2048_2024 2.49609375


# 1024 2 0.2

In [5]:
filename = "fisher_data_1024_2_2024-12-21-15-55-27.json"
draw_fisher_bar(root_path+filename,0.2)

1024_2


In [6]:
filename = "modified_perplexitys_1024_2_2024-12-21-15-55-27_2.7421875.json"
draw_perplexity(root_path+filename,0.2)

1024_2 2.7421875


# 1024 2 0.5

In [9]:
filename = "fisher_data_1024_2_2024-12-21-16-24-25.json"
draw_fisher_bar(root_path+filename,0.5)

1024_2


In [10]:
filename = "modified_perplexitys_1024_2_2024-12-21-16-24-25_2.7421875.json"
draw_perplexity(root_path+filename,0.5)

1024_2 2.7421875


# 1024 2 0.75

In [14]:
filename = "fisher_data_1024_2_2024-12-21-17-22-47.json"
draw_fisher_bar(root_path+filename,0.75)

1024_2


In [13]:
filename = "modified_perplexitys_1024_2_2024-12-21-17-22-47_2.7421875.json"
draw_perplexity(root_path+filename,0.75)

1024_2 2.7421875


# 1024 10 0.2

In [18]:
filename = "fisher_data_1024_10_2024-12-21-16-16-37.json"
draw_fisher_bar(root_path+filename,0.2)

1024_10


In [16]:
filename = "modified_perplexitys_1024_10_2024-12-21-16-16-37_2.60546875.json"
draw_perplexity(root_path+filename,0.2)

1024_10 2.60546875


# 1024 10 0.5

In [11]:
filename = "fisher_data_1024_10_2024-12-21-16-44-29.json"
draw_fisher_bar(root_path+filename,0.5)

1024_10


In [12]:
filename = "modified_perplexitys_1024_10_2024-12-21-16-44-29_2.60546875.json"
draw_perplexity(root_path+filename,0.5)

1024_10 2.60546875
