In [None]:
import os
import re
import warnings
from pathlib import Path

warnings.filterwarnings("ignore")
import matplotlib
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import seaborn as sns

from scipy.optimize import curve_fit

In [None]:
os.chdir("..")

In [None]:
job = "gs"
# job = "gs_ent"
# file = "fig_fit10k.pdf"
# root_folder = Path("path/to/data")
file = "fig_fit50k.pdf"
root_folder = Path("path/to/data")

if not root_folder.exists():
    raise Exception("Path does not exist: {}".format(root_folder))
folders = sorted(list(root_folder.iterdir()))
print(folders)

In [None]:
pattern = re.compile(r"^([^_]+).*")
def get_loss(path: Path):
    assert path.is_dir(), path
    loss: np.ndarray = np.load(os.path.join(path, "loss.npy"))
    # if len(loss) == 16384:
    #     loss = np.concatenate([loss, np.full_like(loss, float('inf'))])
    # print(len(loss))
    # if len(loss) > 16384:
    #     loss = loss[:16384]
    scale = pattern.match(path.name).group(1)
    return scale, loss
def get_entropy(path: Path):
    # assert path.is_dir(), path
    if not (path/"entropy.npy").is_file():
        return "0", np.zeros(8192)
    entropy: np.ndarray = np.load(os.path.join(path, "entropy.npy"))  # [n_layers, n_heads, seq_len]
    entropy = entropy.mean(axis=(0, 1))  # [seq_len]
    scale = pattern.match(path.name).group(1)
    return scale, entropy
data = {float(scale): loss for scale, loss in map(get_loss, folders)}
# data = {float(scale): loss for scale, loss in map(get_entropy, folders)}
df = pd.DataFrame.from_dict(data, orient="columns")
# print(df)

In [None]:
argmin = df.idxmin(axis=1) # get argmin(best scale) and valmin(best loss) for each row
valmin = df.min(axis=1) # get min for each column
print(argmin)

In [None]:
length = 2048
d = 64
def func(x, a):
    return (a * np.log(x / length) + 1)

data = argmin[argmin.index >= length]
# data = data[data.index < 8192]
x, y = data.index, data.values
popt, pcov = curve_fit(func, x, y)
for rst in popt:
    print(f"{rst:.4f}")
residuals = y - func(x, *popt)
ss_res = np.sum(residuals**2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
r_squared = 1 - (ss_res / ss_tot)
print(f"{r_squared=:.4f}")

In [None]:
fig, ax1 = plt.subplots()
ax1.set_ylabel(r'$\sigma$', color='tab:blue')
ax1.tick_params(axis='y', labelcolor='tab:blue')
ax2 = ax1.twinx() # create a second set of axes that shares the x-axis
ax2.set_ylabel('loss', color='tab:red')
ax2.tick_params(axis='y', labelcolor='tab:red')
fig.tight_layout()  # to ensure that the right y-label is not slightly clipped
x = np.arange(16384)
y = func(x, *popt)

sns.lineplot(argmin, ax=ax1, color='tab:blue')
sns.lineplot(valmin, ax=ax2, color='tab:red')
sns.lineplot(x=x, y=y, ax=ax1, color='black')
ax1.set_ylim(0.9, 2.0)
labels = [item.get_text() for item in ax1.get_yticklabels()]
labels = [r'$\frac{' + label + r'}{\sqrt{d}}$' for label in labels]
ax1.set_yticklabels(labels)
ax1.set_ylabel(r'$\lambda$')
ax1.set_xlabel(r"Position $i$")
# add a vertical line x=1024
# plt.axvline(x=512, color='k', linestyle='--')
# plt.axvline(x=1024, color='k', linestyle='--')
ax2.set_ylabel("log Perplexity")
plt.axvline(x=2048, color='k', linestyle='--')
# sns.lineplot(x=x, y=y, ax=ax1, label="formula", color="blue", linestyle="--")
# plt.title("len=512, d=32, h=24")
# plt.title("len=2048, d=32, h=12")
plt.title("NoPE 50k steps")
plt.show()

In [None]:
fig.savefig(os.path.join("figs", file),bbox_inches='tight', pad_inches=0.0, dpi=1000)