In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse 
import pandas as pd 
import os 
import numpy as np
import torch 
from collections import defaultdict


from law import ScalingLaw

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt 
from sklearn.metrics import r2_score



# Setup

In [3]:
def load_losses(dirs, slice_list, task_name, filter_word=None, run_subset=None, include_extrema=False):
    files_to_exclude = ['trace', 'val_inputs', 'labels', 'proportions', 'emb', 'rouge', 'generations', 'gradient', 'acc']    
    df_all = pd.DataFrame() 

    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files]
        files.sort(key=lambda x: os.path.getmtime(x), reverse=True)

        for file in files:
            if ".log" in file or task_name not in file:
                continue 

            if "weights_" not in file:
                continue 

            if any(skill not in file for skill in slice_list):
                continue 

            if filter_word is not None and filter_word not in file:
                continue
           

            method = file.split("/")[-1]

            weight_str = method.split("weights_")[-1].split("_")[0]

            if len(weight_str) == 2:
                a = int(weight_str[0])
                b = int(weight_str[1])
            elif len(weight_str) == 3:
                if weight_str[0] == "0":
                    a = 0
                    b = int(weight_str[1:])
                elif weight_str[-1] == "0":
                    b = 0 
                    a = int(weight_str[:2])
            else:
                idx = [i for i, ltr in enumerate(weight_str) if ltr == "."][1] - 1
                a = float(weight_str[:idx])
                b = float(weight_str[idx:])

            if not include_extrema:
                if a == 0 or b == 0:
                    continue

            if a + b != 10:
                continue 

            if run_subset is not None:
                if a not in run_subset:
                    continue

            print(a, b)


            runs = os.listdir(file)
            for run in runs:

                if "test_" in run:
                    continue

                if any([exclude_file in run for exclude_file in files_to_exclude]):
                    continue 

                seed = int(run.split("seed_")[-1].split("_")[0])
                checkpoint = int(run.split("-")[-1].split(".")[0])


                path = os.path.join(file, run)

                df = pd.read_pickle(path)
                df = df.rename(columns={"task_idx": "skill", "task_loss": "loss"})

                df["method"] = method
                df["seed"] = seed
                df["checkpoint"] = checkpoint
                df["p1"] = a 
                df["p2"] = b


                df.set_index("checkpoint", inplace=True)


                df_all = pd.concat([df_all, df])


    df_all = df_all.sort_values(by=["checkpoint", "p1", "seed"])
    return df_all


# mixing law math

$L_i(p) = c_i + b_i \exp(k_i p_i)$

For $m=2$ domains, this can be derived from the more general $L_i(p) = c_i + b_i \exp(\sum_{j=1}^m A_{ij} p_j)$ (here it appears that there are 4 parameters per domain, but since $p$ is on the simplex, we only need 3 parameters as described in the equation above.)


In [4]:
def inter_law(x, param):
    k, b = param
    # y = c + exp(kx+b)
    return torch.exp(k*x + b)

def param_generator():
    for k in np.linspace(-2.4, -1.6, 11):
        for b in np.linspace(-1.0, -0.1, 11):
            yield [k, b]

In [6]:
def make_individual_xy(df, skill, skills, break_steps):
    x = []
    y = []

    df_resume_subset = df[(df.index == break_steps) & (df.skill == skill)]
    df_resume_subset = df_resume_subset.loc[df_resume_subset.index.max()]
    
    if skill == skills[0]:
        p_col = 'p1_normalized'
    else:
        p_col = 'p2_normalized'
        
    x = df_resume_subset[p_col].values
    y = df_resume_subset['loss'].values

    return x, y

# Arxiv, Stackexchange

In [None]:
ckpt_0_dirs = ["../output/08232024/", "../output/08242024/", "../output/08252024/", "../output/08262024/", 
        "../output/08292024/", "../output/08302024/"] # REPLACE WITH YOUR RUN OUTPUT DIRECTORIES

task_name = "slimpj"
slice_list = ['arxiv', 'stackexchange']
df = load_losses(ckpt_0_dirs, slice_list, task_name, run_subset=[1, 2, 3, 4, 5, 6, 7, 8, 9])

df['p1_normalized'] = df.apply(lambda x: x['p1']/(x['p1'] + x['p2']), axis=1)
df['p2_normalized'] = df.apply(lambda x: x['p2']/(x['p1'] + x['p2']), axis=1)

In [10]:
skills = sorted(df.skill.unique())
indices = [5000] # fit at end of training

seeds = sorted(df.seed.unique())

In [None]:

params = {skill: {index: {seed: {} for seed in seeds} for index in indices} for skill in skills}
grid_size = 1000
for skill in skills:
    for index in indices: 
        for seed in seeds:
            print(f"Skill = {skill}, index = {index}, seed = {seed}")

            df_subset = df.loc[df.index == index]
            df_subset = df_subset[(df_subset.skill == skill) & (df_subset.seed == seed)]
            best_c, max_corr = 0, 0

            if skill == slice_list[0]:
                p_col = 'p1_normalized'
            else:
                p_col = 'p2_normalized'

            x = df_subset[p_col].values 

            y = df_subset['loss'].values

            for i, c in enumerate(np.linspace(0, df_subset.loss.min(), grid_size)):

                if i == grid_size - 1:
                    continue 
                corr = np.abs(np.corrcoef(x, np.log(y-c))[0, 1])
                if corr > max_corr:
                    max_corr = corr
                    best_c = c 
                
            law = ScalingLaw(inter_law)
            p = law.fit(x, y-best_c, param_generator(), max_step=100, delta=0.02)
            params[skill][index][seed] = [best_c, p[0], p[1]] # param


In [10]:
import pickle 

with open("./law_results/arxiv_stackexchange/params_static.pkl", "wb") as f:
    pickle.dump(params, f)

In [None]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt 
from sklearn.metrics import r2_score
import matplotlib

matplotlib.rcParams.update({'font.size': 14})
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "DeJavu Serif"

c = np.array([[31, 119, 180], [255, 137, 17], [214, 39, 40], [255, 152, 150], [227, 119, 194]])/255

fig, axes = plt.subplots(1, 2, figsize=(11, 4))
fig.subplots_adjust(bottom=0.2, top=1)

xs = np.linspace(0, 1, 20)


indices = [5000]
seeds = [0, 1, 2, 3, 4]

mses = []
r2s = [] 

for k, skill in enumerate(slice_list):
    ax = axes[k]

    for j, index in enumerate(indices):
        for j, seed in enumerate(seeds):
            if skill == slice_list[0]:
                p_col = 'p1_normalized'
            else:
                p_col = 'p2_normalized'

            df_subset = df.loc[df.index == index]
            df_subset = df_subset[(df_subset.skill == skill) & (df_subset.seed == seed)]
            x = df_subset[p_col].values 
            y = df_subset['loss'].values

            p = params[skill][index][seed]

            if j != 0:
                ax.scatter(x, y-p[0], color=c[j], s=300, alpha=0.2)
            else:
                ax.scatter(x, y-p[0], color=c[j], s=300)

            plot_preds = inter_law(torch.tensor(xs), torch.tensor(p[1:]))

            if j != 0:
                ax.plot(xs, plot_preds, color=c[j], lw=3, alpha=0.2)
            else:
                ax.plot(xs, plot_preds, color=c[j], lw=3)

            eval_preds = inter_law(torch.tensor(x), torch.tensor(p[1:]))

            mse = torch.nn.functional.mse_loss(eval_preds, torch.tensor(y-p[0])).item()
            r2 = r2_score(eval_preds, torch.tensor(y-p[0]))

            mses.append(mse)
            r2s.append(r2)

            print(f"MSE for {skill}, step {index}: {mse}")
            print(f"R2 score for {skill}, step {index}: {r2}\n")
            
            ax.set_yscale("log")
            ax.set_xlabel(f"Proportion of {skill}", fontsize=18)
            skill_str = skill

            ax.set_ylabel(f"Log (Loss - c) on {skill}", fontsize=15)
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)

            ax.grid()
legend_elements = []


fig.subplots_adjust(bottom=0.1, top=0.85)  
fig.suptitle('Log-linear static mixing law on Arxiv/StackExchange', fontsize=20)

#plt.savefig('../figs/arxiv_stackexchange_static.pdf', bbox_inches="tight")

mses = np.array(mses)
r2s = np.array(r2s)

print(mses.mean(), mses.std())
print(r2s.mean(), r2s.std())

## Get optimal parameters according to fit model

In [None]:
t = 5000
for seed in range(5):
    k1, b1 = params['arxiv'][t][seed][1], params['arxiv'][t][seed][2]
    k2, b2 = params['stackexchange'][t][seed][1], params['stackexchange'][t][seed][2]

    p1 = (k2 + b2 - b1 + np.log(k2 / k1)) / (k1 + k2) # closed form
    print(p1, 1-p1)