In [None]:
import os
import re
from pathlib import Path
import warnings

warnings.filterwarnings("ignore")

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel

In [None]:
# matplotlib.rcParams.update(matplotlib.rcParamsDefault)
# plt.rcParams['text.usetex'] = True
# plt.rc('text.latex', preamble=r'\usepackage{ulem}')
# matplotlib.rcParams.update({'font.size': 14})
# plt.rc('font', family='Times New Roman')
# sns.set_style(rc={'text.usetex' : True})

In [None]:
os.chdir('..')
# os.chdir(os.path.expanduser("~/clm"))

In [None]:
from args.model_args import ModelArguments, SoftMaxScaleType
from models.llama_nope import monkey_patch_before

monkey_patch_before(ModelArguments(use_flash_attention=True, softmax_scale_type=SoftMaxScaleType.HS))
# monkey_patch_before(ModelArguments(softmax_scale_type=SoftMaxScaleType.HS))

In [None]:
path = Path("path/to/data")
config = AutoConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, use_flash_attention_2=True)
print(model)

In [None]:
path = Path("path/to/data")
entropy = np.load(os.path.join(path, "entropy.npy"))  # [n_layers, n_heads, seq_len]
# entropy = entropy[:, :, :2048]
# entropy_mean = entropy.mean(axis=2)  # [n_layers, n_heads]

In [None]:
data = []
for i in range(config.num_hidden_layers):
    scale = model.model.layers[i].self_attn.scale_param.detach().tolist()
    for j in range(config.num_attention_heads):
        data.append(
            {
                "layer": i,
                "head": j,
                "scale": scale[j],
                "entropy": entropy[i, j, -1],
            }
        )
df = pd.DataFrame(data)
print(df)

In [None]:
plot_df = df
layers = list(range(1, 22, 3))
plot_df = plot_df[plot_df["layer"].isin(layers)]
plot_df.rename(columns={"layer": "Layers 0-21"}, inplace=True)

In [None]:
# fig, ax = plt.subplots(figsize=(10, 10))

g = sns.lmplot(plot_df, x="scale", y="entropy", hue="Layers 0-21", palette="tab10", legend=False)
# # set legend in one row at the bottom of the figure
g.ax.legend(bbox_to_anchor=(0.8, 1.2), ncol=7, title="Layers (0~21)", handletextpad=0, columnspacing=0)
plt.ylim(0, 13)
plt.ylabel(r"Entropy $\bar{\mathcal{H}}_i$")
labels = [item.get_text() for item in g.ax.get_xticklabels()]
labels = [r'$\frac{' + label + r'}{\sqrt{d}}$' for label in labels]
g.ax.set_xticklabels(labels)
plt.xlabel(r"$\lambda$")
# g.ax.set_ylabel(r'\textbf{\underline{Entropy}} $\bar{\mathcal{H}}_i$', fontweight='bold')
g.tight_layout()
# plt.title("HS8k ent vs scale")

In [None]:
g.savefig(os.path.join("logs", "fig.pdf"),bbox_inches='tight', pad_inches=0.0, dpi=1000)