In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse 
import pandas as pd 
import os 
import numpy as np
import torch 
from collections import defaultdict


from law import ScalingLaw

# Functions

In [3]:
def parse_weights(weight_str):
    if len(weight_str) == 2:
        a = int(weight_str[0])
        b = int(weight_str[1])
    elif len(weight_str) == 3:
        if weight_str[0] == "0":
            a = 0
            b = int(weight_str[1:])
        elif weight_str[-1] == "0":
            b = 0 
            a = int(weight_str[:2])
    else:
        idx = [i for i, ltr in enumerate(weight_str) if ltr == "."][1] - 1
        a = float(weight_str[:idx])
        b = float(weight_str[idx:])
    return a, b


In [4]:
def load_break_losses(dirs, slice_list, task_name, filter_word=None, include_extrema=False):
    files_to_exclude = ['trace', 'val_inputs', 'labels', 'proportions', 'emb', 'rouge', 'generations', 'gradient', 'acc']    
    df_all = pd.DataFrame() 

    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files] 
        files.sort(key=lambda x: os.path.getmtime(x), reverse=True)

        for file in files:
            if ".log" in file or task_name not in file:
                continue 

            if any(skill not in file for skill in slice_list):
                continue 

            if "break" not in file:
                continue

            if "weights_" not in file:
                continue

            if filter_word is not None and filter_word not in file:
                continue

            break_steps = int(file.split("break_")[-1].split("_")[0])
           
            method = file.split("/")[-1]
            weight_str = method.split("weights_")[-1].split("_")[0]
            a, b = parse_weights(weight_str)
            if a + b != 10:
                continue 

            if not include_extrema:
                if a == 0 or b == 0:
                    continue

            print(a, b)

            runs = os.listdir(file)
            for run in runs:

                if "test_" in run:
                    continue

                if any([exclude_file in run for exclude_file in files_to_exclude]):
                    continue 

                seed = int(run.split("seed_")[-1].split("_")[0])
                checkpoint = int(run.split("-")[-1].split(".")[0])


                path = os.path.join(file, run)

                df = pd.read_pickle(path)

                df = df.rename(columns={"task_idx": "skill", "task_loss": "loss"})

                df["method"] = method
                df["seed"] = seed
                df["checkpoint"] = checkpoint
                df["break_steps"] = break_steps
                df["p1"] = a 
                df["p2"] = b


                df.set_index("checkpoint", inplace=True)


                df_all = pd.concat([df_all, df])


    df_all = df_all.sort_values(by=["checkpoint", "p1", "seed"])
    return df_all


In [5]:
def load_resume_losses(dirs, slice_list, task_name, filter_word=None, include_extrema=False):
    files_to_exclude = ['trace', 'val_inputs', 'labels', 'proportions', 'emb', 'rouge', 'generations', 'gradient', 'acc']    
    df_all = pd.DataFrame() 

    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files] # add path to each file
        files.sort(key=lambda x: os.path.getmtime(x), reverse=True)

        for file in files:
            if ".log" in file or task_name not in file:
                continue 

            if any(skill not in file for skill in slice_list):
                continue 

            if "resume" not in file:
                continue
            if "_weights" not in file:
                continue

            if filter_word is not None and filter_word not in file:
                continue

            new_weight_str = file.split("resume_")[-1].split("_")[0]
            new_a, new_b = parse_weights(new_weight_str)

            if new_a + new_b != 10:
                continue 
            print(new_a, new_b)

            if not include_extrema:
                if new_a == 0 or new_b == 0:
                    continue

            method = file.split("/")[-1]

            old_weight_str = method.split("weights_")[-1].split("_")[0]
            a, b = parse_weights(old_weight_str)
            if a + b != 10:
                continue 

            if not include_extrema:
                if a == 0 or b == 0:
                    continue

            break_steps = int(file.split(f"resume_{new_weight_str}_")[-1].split("_")[0])


            print(a, b)

            if "remaining" in file:
                remaining_steps = int(file.split(f"remaining_")[-1].split("_")[0])
            else:
                remaining_steps = 100

            runs = os.listdir(file)
            for run in runs:

                if "test_" in run:
                    continue

                if any([exclude_file in run for exclude_file in files_to_exclude]):
                    continue 

                seed = int(run.split("seed_")[-1].split("_")[0])
                checkpoint = int(run.split("-")[-1].split(".")[0])


                path = os.path.join(file, run)

                df = pd.read_pickle(path)

                df = df.rename(columns={"task_idx": "skill", "task_loss": "loss"})

                df["method"] = method
                df["seed"] = seed
                df["checkpoint"] = checkpoint
                df["break_steps"] = break_steps
                df["remaining_steps"] = remaining_steps
                df["new_p1"] = a
                df["new_p2"] = b
                df["p1"] = new_a 
                df["p2"] = new_b


                df.set_index("checkpoint", inplace=True)


                df_all = pd.concat([df_all, df])


    df_all = df_all.sort_values(by=["checkpoint", "p1", "new_p1", "seed"])
    return df_all


In [7]:
def inter_law(x, param):
    k, b = param
    # y = c + exp(kx+b)
    return k*x + b

def param_generator():
    for k in np.linspace(-2.4, -1.6, 11):
        for b in np.linspace(-1.0, -0.1, 11):
            yield [k, b]

In [9]:
def make_individual_xy(df_break, df_resume, skill, skills, break_steps, p1):
    x = []
    y = []
    df_break_subset = df_break[(df_break.break_steps == break_steps) & (df_break.p1_normalized == p1) & (df_break.skill == skill)]
    df_break_subset = df_break_subset.loc[df_break_subset.index.max()]

    df_resume_subset = df_resume[(df_resume.break_steps == break_steps) & (df_resume.p1_normalized == p1) & (df_resume.skill == skill)]
    df_resume_subset = df_resume_subset.loc[df_resume_subset.index.max()]
    
    if skill == skills[0]:
        p_col = 'new_p1_normalized'
    else:
        p_col = 'new_p2_normalized'
        
    x = df_resume_subset[p_col].values
    y = df_resume_subset['loss'].values

    return x, y

# Arxiv, Stackexchange

In [None]:
dirs = ["../output/09012024/", "../output/09022024/"] # REPLACE WITH YOUR RUN OUTPUT DIRECTORIES
task_name = "slimpj"
slice_list = ['arxiv', 'stackexchange']


df_break = load_break_losses(dirs, slice_list, task_name)
df_break['p1_normalized'] = df_break.apply(lambda x: x['p1']/(x['p1'] + x['p2']), axis=1)
df_break['p2_normalized'] = df_break.apply(lambda x: x['p2']/(x['p1'] + x['p2']), axis=1)

In [14]:
skills = sorted(df_break.skill.unique())

In [None]:
df_resume = load_resume_losses(dirs, skills, task_name, include_extrema=False)
df_resume['p1_normalized'] = df_resume.apply(lambda x: x['p1']/(x['p1'] + x['p2']), axis=1)
df_resume['p2_normalized'] = df_resume.apply(lambda x: x['p2']/(x['p1'] + x['p2']), axis=1)
df_resume['new_p1_normalized'] = df_resume.apply(lambda x: x['new_p1']/(x['new_p1'] + x['new_p2']), axis=1)
df_resume['new_p2_normalized'] = df_resume.apply(lambda x: x['new_p2']/(x['new_p1'] + x['new_p2']), axis=1)

In [16]:
break_steps = sorted(df_resume.break_steps.unique())
probs = sorted(df_break.p1_normalized.unique())

In [22]:
params = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }


x_per_skill = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }
y_per_skill = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }


for skill in skills:
    for bs in break_steps:
        for p1 in probs:
            x, y = make_individual_xy(df_break, df_resume, skill, skills, bs, p1)
            
            x_per_skill[skill][bs][p1] = x
            y_per_skill[skill][bs][p1] = y
                        
            law = ScalingLaw(inter_law)
            p = law.fit(x, y, param_generator(), max_step=100, delta=0.02)
            params[skill][bs][p1] = p


In [21]:
import pickle 

with open("./law_results/arxiv_stackexchange/params_dynamic.pkl", "wb") as f:
    pickle.dump(params, f)

In [None]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt 
from sklearn.metrics import r2_score
import matplotlib

matplotlib.rcParams.update({'font.size': 14})
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "DeJavu Serif"


c = np.array([
    [31, 119, 180], 
    [253, 191, 96],   # Golden yellow
    [255, 137, 17],   # Orange
    [214, 39, 40],    # Red
    [255, 152, 150],  # Light pink
    [227, 119, 194],  # Pink
    [148, 103, 189],  # Purple  # Blue
    [44, 160, 44],    # Green
    [140, 86, 75],    # Brown
    [127, 127, 127],   # Gray,
    [23, 190, 207] # Teal
]) / 255.0

fig, axes = plt.subplots(1, 2, figsize=(11, 4))
fig.subplots_adjust(bottom=0.2, top=1)

xs = np.linspace(0, 1, 11)
test_break_steps = 2000
test_probs = [0.7, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 0.9]

mses = []
r2s = []

for k, skill in enumerate(skills):
    ax = axes[k]

    for j, p1 in enumerate(test_probs):

        print(f"Skill = {skill}")

        x = x_per_skill[skill][test_break_steps][p1]
        y = y_per_skill[skill][test_break_steps][p1]

        p = params[skill][test_break_steps][p1]

        if j != 0:
            ax.scatter(x, y, color=c[j], s=300, alpha=0.2)
        else:
            ax.scatter(x, y, color=c[j], s=300)
        plot_preds = inter_law(torch.tensor(xs), torch.tensor(p))

        if j != 0:
            ax.plot(xs, plot_preds, color=c[j], lw=3, alpha=0.2)
        else:
            ax.plot(xs, plot_preds, color=c[j], lw=3)

        eval_preds = inter_law(torch.tensor(x), torch.tensor(p))
        mse = torch.nn.functional.mse_loss(eval_preds, torch.tensor(y)).item()
        r2 = r2_score(eval_preds, torch.tensor(y))
        print(f"MSE for {skill}: {mse}")
        print(f"R2 score for {skill}: {r2}\n")

        mses.append(mse)
        r2s.append(r2)
        
        #ax.set_yscale("log")
        #ax.legend()
        ax.set_xlabel(f"Proportion of {skill}", fontsize=18)
        ax.set_ylabel(f"Next-step Loss on {skill}", fontsize=14)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

        ax.grid()
legend_elements = []


fig.subplots_adjust(bottom=0.1, top=0.85)  
fig.suptitle('Linear dynamic mixing law on Arxiv/StackExchange', fontsize=20)

#plt.savefig('../figs/arxiv_stackexchange_dynamic.pdf', bbox_inches="tight")

mses = np.array(mses)
r2s = np.array(r2s)
print(f"MSE: {mses.mean()}, {mses.std()}")
print(f"R2: {r2s.mean()}, {r2s.std()}")

