In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd 
import os 
import numpy as np

from law import ScalingLaw, MultiObjScalingLaw

# Functions

In [3]:
def parse_weights(weight_str):
    if len(weight_str) == 2:
        a = int(weight_str[0])
        b = int(weight_str[1])
    elif len(weight_str) == 3:
        if weight_str[0] == "0":
            a = 0
            b = int(weight_str[1:])
        elif weight_str[-1] == "0":
            b = 0 
            a = int(weight_str[:2])
    else:
        idx = [i for i, ltr in enumerate(weight_str) if ltr == "."][1] - 1
        a = float(weight_str[:idx])
        b = float(weight_str[idx:])
    return a, b


In [4]:
def load_break_losses(dirs, slice_list, task_name):
    files_to_exclude = ['trace', 'val_inputs', 'labels', 'proportions', 'emb', 'rouge', 'generations', 'gradient', 'acc', 'matrix', 'matrices']
    df_all = pd.DataFrame() 

    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files] 
        files.sort(key=lambda x: os.path.getmtime(x), reverse=True)

        for file in files:
            if ".log" in file or task_name not in file:
                continue 

            if any(skill not in file for skill in slice_list):
                continue 

            if "break" not in file:
                continue

            break_steps = int(file.split("break_")[-1].split("_")[0])
           
            method = file.split("/")[-1]

            if "doremi" not in file:
                continue 

            runs = os.listdir(file)
            for run in runs:

                if "test_" in run:
                    continue

                if any([exclude_file in run for exclude_file in files_to_exclude]):
                    continue 

                seed = int(run.split("seed_")[-1].split("_")[0])
                checkpoint = int(run.split("-")[-1].split(".")[0])

                if len(df_all) != 0 and len(df_all.loc[(df_all.method==method) & (df_all.seed == seed) & (df_all.index==checkpoint)]) != 0:
                    continue 

                path = os.path.join(file, run)

                df = pd.read_pickle(path)

                df = df.rename(columns={"task_idx": "skill", "task_loss": "loss"})

                df["method"] = method
                df["seed"] = seed
                df["checkpoint"] = checkpoint
                df["break_steps"] = break_steps

                df.set_index("checkpoint", inplace=True)


                df_all = pd.concat([df_all, df])


    df_all = df_all.sort_values(by=["checkpoint", "seed"])
    return df_all


In [5]:
def load_resume_losses(dirs, slice_list, task_name):
    files_to_exclude = ['trace', 'val_inputs', 'labels', 'proportions', 'emb', 'rouge', 'generations', 'gradient', 'acc', 'matrix', 'matrices']
    df_all = pd.DataFrame() 

    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files] # add path to each file
        files.sort(key=lambda x: os.path.getmtime(x), reverse=True)

        for file in files:
            if ".log" in file or task_name not in file:
                continue 

            if any(skill not in file for skill in slice_list):
                continue 

            if "resume" not in file:
                continue

            if "doremi" not in file:
                continue 


            method = file.split("/")[-1]

            weight_str = method.split("weights_")[-1].split("_")[0]
            a, b = parse_weights(weight_str)

            if a + b != 10:
                continue 

            if a == 0 or b == 0:
                continue

            print(a, b)
            break_steps = int(file.split(f"resume_doremi_")[-1].split("_")[0])

            runs = os.listdir(file)
            for run in runs:

                if "test_" in run:
                    continue

                if any([exclude_file in run for exclude_file in files_to_exclude]):
                    continue 

                seed = int(run.split("seed_")[-1].split("_")[0])
                checkpoint = int(run.split("-")[-1].split(".")[0])
                if len(df_all) != 0 and len(df_all.loc[(df_all.method==method) & (df_all.seed == seed) & (df_all.index==checkpoint)]) != 0:
                    continue 


                path = os.path.join(file, run)

                df = pd.read_pickle(path)

                df = df.rename(columns={"task_idx": "skill", "task_loss": "loss"})

                df["method"] = method
                df["seed"] = seed
                df["checkpoint"] = checkpoint
                df["break_steps"] = break_steps
                df["new_p1"] = a 
                df["new_p2"] = b 

                df.set_index("checkpoint", inplace=True)


                df_all = pd.concat([df_all, df])


    df_all = df_all.sort_values(by=["checkpoint", "new_p1", "seed"])
    return df_all


In [6]:
def load_doremi_matrices(dirs, slice_list, break_steps):
    matrices = {}

    slice_str = "_".join(slice_list)
    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files]

        for file in files:
            
            if "break" in file and f"stratified_{slice_str}_doremi" in file:

                s = int(file.split("_break_")[-1].split("_")[0])
                if s != break_steps:
                    continue
                runs = os.listdir(file)
                for run in runs:
                    if "avg" not in run:
                        continue
                    if "drm_matrices.npy" not in run:
                        continue 
                    path = os.path.join(file, run)
                    print(path)
                    seed = int(run.split("seed_")[-1].split("_")[0])

                    

                    A = np.load(path)
                    matrices[seed] = A.mean(axis=0)

    return matrices

In [7]:
def inter_law(x, param):
    k, b = param
    return k*x + b

def param_generator():
    for k in np.linspace(-2.4, -1.6, 11):
        for b in np.linspace(-1.0, -0.1, 11):
            yield [k, b]

In [8]:
def make_individual_xy(df_break, df_resume, skill, skills, break_steps, seed):
    x = []
    y = []

    df_break_subset = df_break[(df_break.break_steps == break_steps) & (df_break.seed == seed) & (df_break.skill == skill)]
    df_break_subset = df_break_subset.loc[df_break_subset.index.max()]

    df_resume_subset = df_resume[(df_resume.break_steps == break_steps) & (df_resume.seed == seed) & (df_resume.skill == skill)]
    df_resume_subset = df_resume_subset.loc[df_resume_subset.index.max()]
    
    if skill == skills[0]:
        p_col = 'new_p1_normalized'
    else:
        p_col = 'new_p2_normalized'
        
    x = df_resume_subset[p_col].values
    y = df_resume_subset['loss'].values

    return x, y

In [9]:
def make_xy_joint(df_break, df_resume, matrix, break_steps, seed):
    x = []
    y = []

    df_break_subset = df_break[(df_break.break_steps == break_steps) & (df_break.seed == seed)]
    df_break_subset = df_break_subset.loc[df_break_subset.index.max()]

    df_resume_subset = df_resume[(df_resume.break_steps == break_steps) & (df_resume.seed == seed)]
    df_resume_subset = df_resume_subset.loc[df_resume_subset.index.max()]

    x = df_resume_subset[['new_p1_normalized', 'new_p2_normalized']].drop_duplicates(keep='first').values.T
    x = matrix.dot(x)
    y = df_resume_subset['loss'].values.reshape(-1, 2)

    return x, y

In [10]:
def law_1(x, param):
    b = param[0]
    c1 = param[1]
    return c1 + b*x[0]

def law_2(x, param):
    b = param[0]
    c2 = param[2]
    return c2 + b*x[1]


def param_generator_joint():
    for b in np.linspace(-10, 0, 11):
        for c1 in np.linspace(0.0, 1.0, 11):
            for c2 in np.linspace(0.0, 1.0, 11):
                yield [b, c1, c2]

# Arxiv, Stackexchange

In [None]:
dirs = ["../output/09252024/"] # REPLACE WITH YOUR RUN OUTPUT DIRECTORIES
task_name = "slimpj"
slice_list = ['arxiv', 'stackexchange']
df_break = load_break_losses(dirs, slice_list, task_name)

In [12]:
skills = sorted(df_break.skill.unique())

In [None]:
df_resume = load_resume_losses(dirs, skills, task_name)
df_resume['new_p1_normalized'] = df_resume.apply(lambda x: x['new_p1']/(x['new_p1'] + x['new_p2']), axis=1)
df_resume['new_p2_normalized'] = df_resume.apply(lambda x: x['new_p2']/(x['new_p1'] + x['new_p2']), axis=1)

In [14]:
break_steps = sorted(df_resume.break_steps.unique())
seeds = sorted(df_break.seed.unique())


## Get $A^{t \star}$ 

In [None]:
params = {skill : {bs: {seed: {} for seed in seeds} for bs in break_steps} for skill in skills }
x_per_skill = {skill : {bs: {seed: {} for seed in seeds} for bs in break_steps} for skill in skills }
y_per_skill = {skill : {bs: {seed: {} for seed in seeds} for bs in break_steps} for skill in skills }

for skill in skills:
    for bs in [500]:
        for seed in seeds:
            print(skill, bs, seed)
            x, y = make_individual_xy(df_break, df_resume, skill, skills, bs, seed)
        
            x_per_skill[skill][bs][seed] = x
            y_per_skill[skill][bs][seed] = y
                        
            law = ScalingLaw(inter_law)
            p = law.fit(x, y, param_generator(), max_step=100, delta=0.02)
            params[skill][bs][seed] = [p[0], p[1]]


In [17]:
import pickle 

with open("./law_results/arxiv_stackexchange/params_doremi_trajectory_opt_500.pkl", "wb") as f:
    pickle.dump(params, f)

## Get $\tilde{A}^t = b^t A^t$

In [None]:
doremi_matrices = load_doremi_matrices(dirs, slice_list, 500)

In [None]:
params_doremi = {bs: {seed: {} for seed in seeds} for bs in break_steps}

x_per_skill_doremi = {bs: {seed: {} for seed in seeds} for bs in break_steps}
y_per_skill_doremi = {bs: {seed: {} for seed in seeds} for bs in break_steps}

for bs in [500]:
    for seed in seeds:
        print(bs, seed)
        x, ys = make_xy_joint(df_break, df_resume, doremi_matrices[seed], bs, seed)
        
        x_per_skill_doremi[bs][seed] = x
        y_per_skill_doremi[bs][seed] = ys
                    
        law = MultiObjScalingLaw([law_1, law_2])
        p = law.fit(x, ys.T, param_generator_joint(), max_step=100, delta=0.02)
        params_doremi[bs][seed] = p 

In [21]:
import pickle 

with open("./law_results/arxiv_stackexchange/params_doremi_trajectory_doremi_matrix_500.pkl", "wb") as f:
    pickle.dump(params_doremi, f)