In [None]:
import os
import re
import warnings
from pathlib import Path

warnings.filterwarnings("ignore")
import matplotlib
import matplotlib.pyplot as plt
# matplotlib.font_manager._rebuild()
plt.rcParams['text.usetex'] = True
matplotlib.rcParams.update({'font.size': 12})
plt.rc('font', family='Times New Roman')
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_style(rc={'text.usetex' : True})

In [None]:
# os.chdir('..')
os.chdir(os.path.expanduser("~/clm"))

In [None]:
# fig_vis_head_entro
exps: list[str] = [
    "NoPE50k(4k)",
    "NoPE50k_s1.2(4k)",
    "NoPE50k_HS4k(4k)",
]
name_map: dict[str, str] = {
    "NoPE50k(4k)": r'\textbf{NoPE}',
    "NoPE50k_s1.2(4k)" : r'\textbf{NoPE} $\lambda=\frac{1.2}{\sqrt{d}}$',
    "NoPE50k_HS4k(4k)":  r'\textbf{NoPE} $\lambda^{(h)}$',
}

In [None]:
ent_root = Path("path/to/data")
if not ent_root.exists():
    raise Exception("Path does not exist: {}".format(ent_root))
# get all folders in root_folder
folders = ent_root.iterdir()


# filter using exps
def filter_name(p: Path):
    return any(p.name.startswith(f"{exp}|") for exp in exps)


folders = sorted([f for f in folders if filter_name(f)])
pattern = re.compile(r"^([^|]+).*")
ent_lines: dict[str, Path] = {pattern.match(path.name).group(1): path for path in folders}
print(ent_lines)

In [None]:
def load_ent(path: Path):
    arr: np.ndarray = np.load(path / "entropy.npy").astype(np.float32)
    n_layers, n_heads, seq_len = arr.shape
    data = []
    for i in range(n_layers):
        for j in range(n_heads):
            for k in range(seq_len):
                data.append(
                    {
                        "layer": i,
                        "head": j,
                        "Position": k,
                        "Entropy": arr[i, j, k],
                    }
                )
    df = pd.DataFrame(data)
    return df

In [None]:
layers = [0, 10, 20]
# heads = list(range(10))
heads = list(range(22))
# A4_paper_size_inches = (8.27, 11.69)
# fig_size = list(size / 2 for size in A4_paper_size_inches)
fig, axs = plt.subplots(3, 3, figsize=(9, 5))
plt.subplots_adjust(wspace=-1,hspace=-1)
for i, (name, path) in enumerate(ent_lines.items()):
    df = load_ent(path)
    for j, layer in enumerate(layers):
        plot_df = df[df["layer"] == layer][df["head"].isin(heads)]
        sns.lineplot(data=plot_df, ax=axs[i][j], x="Position", y="Entropy", hue="head", legend=False, palette=sns.color_palette("Paired"))
        if layer == layers[1]:
            axs[i][j].set_title(f"{name_map[name]} - Layer {layer}")
        else:
            axs[i][j].set_title(f"Layer {layer}")
        # x = np.arange(0, 4096)
        # y = np.log2(x + 1)
        # axs[i][j].plot(x, y, color="black", linestyle="dashed")
        axs[i][j].set_xlabel('')
        axs[i][j].set_ylabel('')
fig.tight_layout()  # to ensure that the right y-label is not slightly clipped

In [None]:
fig.savefig(os.path.join("logs", "fig_vis_head_entro.pdf"),bbox_inches='tight', pad_inches=0.0, dpi=1000)
# plt.savefig(os.path.join("logs", "loss.png"))