In [None]:
import os
import re
import warnings
from pathlib import Path

warnings.filterwarnings("ignore")
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['text.usetex'] = True
plt.rc('text.latex', preamble=r'\usepackage{ulem}')
matplotlib.rcParams.update({'font.size': 14})
plt.rc('font', family='Times New Roman')


import numpy as np
import pandas as pd
import seaborn as sns
sns.set_style(rc={'text.usetex' : True})

max_len = None

In [None]:
# os.chdir('..')
os.chdir(os.path.expanduser("~/clm"))

In [None]:
# fig_vis_entro
# exps: list[str] = [
#     "NoPE50k(4k)",
#     "NoPE50k_s1.2(4k)",
#     "RoPE50k(4k)",
#     "RoPE50k_NTK2(4k)",
# ]
# loss_lines: dict[str, Path] = {
#     "NoPE50k(4k)": Path("path/to/data"),
#     "NoPE50k_s1.2(4k)": Path("path/to/data"),
#     "RoPE50k(4k)": Path("path/to/data"),
#     "RoPE50k_NTK2(4k)": Path("path/to/data"),
# }
# name_map: dict[str, str] = {
#     "NoPE50k(4k)": 'NoPE',
#     "NoPE50k_s1.2(4k)" : r'NoPE $\lambda=\frac{1.2}{\sqrt{d}}$',
#     "RoPE50k(4k)": 'RoPE',
#     "RoPE50k_NTK2(4k)": 'RoPE NTK',
# }

In [None]:
# fig_vis_uni_scale_entro_nope
# exps: list[str] = [
#     "NoPE50k_s0.9(4k)",
#     "NoPE50k(4k)",
#     "NoPE50k_s1.1(4k)",
#     "NoPE50k_s1.2(4k)",
# ]
# loss_lines: dict[str, Path] = {
#     "NoPE50k_s0.9(4k)": Path("path/to/data"),
#     "NoPE50k(4k)": Path("path/to/data"),
#     "NoPE50k_s1.1(4k)": Path("path/to/data"),
#     "NoPE50k_s1.2(4k)": Path("path/to/data"),
# }
# name_map: dict[str, str] = {
#     "NoPE50k_s0.9(4k)" : r'NoPE $\lambda=\frac{0.9}{\sqrt{d}}$',
#     "NoPE50k(4k)" : r'NoPE $\lambda=\frac{1.0}{\sqrt{d}}$',
#     "NoPE50k_s1.1(4k)" : r'NoPE $\lambda=\frac{1.1}{\sqrt{d}}$',
#     "NoPE50k_s1.2(4k)" : r'NoPE $\lambda=\frac{1.2}{\sqrt{d}}$',
# }

In [None]:
# fig_vis_uni_scale_entro_rope
exps: list[str] = [
    "RoPE50k_s0.8(4k)",
    "RoPE50k(4k)",
    "RoPE50k_s1.2(4k)",
    "RoPE50k_s1.4(4k)",
]
loss_lines: dict[str, Path] = {
    "RoPE50k_s0.8(4k)": Path("path/to/data"),
    "RoPE50k(4k)": Path("path/to/data"),
    "RoPE50k_s1.2(4k)": Path("path/to/data"),
    "RoPE50k_s1.4(4k)": Path("path/to/data"),
}

name_map: dict[str, str] = {
    "RoPE50k_s0.8(4k)" : r'\textbf{RoPE} $\lambda=\frac{0.8}{\sqrt{d}}$',
    "RoPE50k(4k)" : r'\textbf{RoPE} $\lambda=\frac{1.0}{\sqrt{d}}$',
    "RoPE50k_s1.2(4k)" : r'\textbf{RoPE} $\lambda=\frac{1.2}{\sqrt{d}}$',
    "RoPE50k_s1.4(4k)" : r'\textbf{RoPE} $\lambda=\frac{1.4}{\sqrt{d}}$',
}

In [None]:
# fig_vis_head_vs_uni_scale
# exps: list[str] = [
#     "NoPE50k_s1.6(8k)",
#     "NoPE50k_HS8k(8k)",
# ]
# loss_lines: dict[str, Path] = {
#     "NoPE50k_s1.6(8k)": Path("path/to/data"),
#     "NoPE50k_HS8k(8k)": Path("path/to/data"),
# }
# max_len = 8192

In [None]:
ent_root = Path("path/to/data")
if not ent_root.exists():
    raise Exception("Path does not exist: {}".format(ent_root))
# get all folders in root_folder
folders = ent_root.iterdir()


# filter using exps
def filter_name(p: Path):
    return any(p.name.startswith(f"{exp}|") for exp in exps)


folders = sorted([f for f in folders if filter_name(f)])
pattern = re.compile(r"^([^|]+).*")
ent_lines: dict[str, Path] = {pattern.match(path.name).group(1): path for path in folders}
print(ent_lines)

In [None]:
def moving_average(data: np.ndarray, w: int):
    assert w % 2 == 0
    # Initialize an empty array for the moving averages
    moving_avg = np.zeros_like(data, dtype=float)
    # Calculate the moving average with a window size that adjusts at the borders
    for i in range(len(data)):
        # Determine the window size
        start = max(0, i - (w // 2))
        end = min(len(data), i + (w // 2) + 1)
        window_size = end - start
        moving_avg[i] = np.sum(data[start:end]) / window_size  # Calculate the average
    return moving_avg


def load_arr(path: Path, seq_len_max: int, ent: bool = False):
    if not os.path.isfile(path):
        print(f"{str(path)} not exists")
        return [0] * seq_len_max
    arr: np.ndarray = np.load(path).astype(np.float32)
    if ent:
        arr = arr.mean(axis=(0, 1))
    seq_len = min(arr.shape[0], seq_len_max)
    arr = arr[:seq_len]
    if not ent:
        arr = moving_average(arr, 100)
    return arr

In [None]:
if max_len is None:
    max_len = 4096
data = []
for name, ent_path in ent_lines.items():
    ent = load_arr(ent_path / "entropy.npy", seq_len_max=max_len, ent=True)
    loss = load_arr(loss_lines[name] / "loss.npy", seq_len_max=max_len)
    for i in range(max_len):
        data.append({"model": name_map[name], "ent": ent[i], "loss": loss[i], "pos": i})
df = pd.DataFrame(data)
print(df)

In [None]:
fig, ax1 = plt.subplots(figsize=(8, 4))
assert isinstance(ax1, plt.Axes)
# ax1.set_xlabel('Position')
ax1.set_xlabel(r'$\textbf{Position}\  i$', fontweight='bold')

ax1.set_ylabel(' ')
ax1.tick_params(axis='y')
# plt.vlines(2400, 0, 8, color=plt.cm.Paired(0), linestyle="dotted", linewidth=2)

ax2 = ax1.twinx() # create a second set of axes that shares the x-axis
# ax2.set_ylabel('log Perplexity')
ax2.set_ylabel(r'$\textbf{\dashuline{log Perplexity}}$')
ax2.tick_params(axis='y')
# ax2.set_ylim(2, 10)
# x = np.arange(0, max_len)
# y = np.log2(x + 1)
# ax1.plot(x, y, color="black", linestyle="dashed")
 
hue_order = [ r'\textbf{RoPE} $\lambda=\frac{0.8}{\sqrt{d}}$',
    r'\textbf{RoPE} $\lambda=\frac{1.0}{\sqrt{d}}$',
    r'\textbf{RoPE} $\lambda=\frac{1.2}{\sqrt{d}}$',
    r'\textbf{RoPE} $\lambda=\frac{1.4}{\sqrt{d}}$',]

sns.lineplot(df, ax=ax1, x="pos", y="ent", hue="model", hue_order=hue_order, palette=sns.color_palette("Reds", 3)[:1]+sns.color_palette("Paired", 5)[-1:]+sns.color_palette("Reds", 3)[-1:]+sns.color_palette("Reds", 12)[-1:], lw=2)
sns.lineplot(df, ax=ax2, x="pos", y="loss", hue="model", hue_order=hue_order,  palette=sns.color_palette("Reds", 3)[:1]+sns.color_palette("Paired", 5)[-1:]+sns.color_palette("Reds", 3)[-1:]+sns.color_palette("Reds", 12)[-1:], linestyle="dashed", legend=False, lw=2)

ax1.legend(loc="upper left")
# ax1.legend(loc="upper left", bbox_to_anchor=(0, -0.1), ncol=4)
# fig.tight_layout()  # to ensure that the right y-label is not slightly clipped 1 4 7 10

In [None]:
fig.savefig(os.path.join("logs", "fig_vis_uni_scale_entro_rope.pdf"),bbox_inches='tight', pad_inches=0.0, dpi=1000)