In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd 
import os 
import numpy as np
import torch 

from law import ScalingLaw

# Functions

In [3]:
def load_break_losses(dirs, slice_list, task_name):
    files_to_exclude = ['trace', 'val_inputs', 'labels', 'proportions', 'emb', 'rouge', 'generations', 'gradient', 'acc']    
    df_all = pd.DataFrame() 

    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files] 
        files.sort(key=lambda x: os.path.getmtime(x), reverse=True)

        for file in files:
            if ".log" in file or task_name not in file:
                continue 

            if any(skill not in file for skill in slice_list):
                continue 

            if "break" not in file:
                continue


            break_steps = int(file.split("break_")[-1].split("_")[0])
           
            method = file.split("/")[-1]

            weight_str = method.split("weights_")[-1].split("_")[0]
            if len(weight_str) == 9:
                a = int(weight_str[0])
                b = int(weight_str[1])
                c = int(weight_str[2])
                d = int(weight_str[3])
                e = int(weight_str[4])
                f = int(weight_str[5])
                g = int(weight_str[6])
                h = int(weight_str[7])
                i = int(weight_str[8])
            else:
                a, b, c, d, e, f, g, h, i = [float(f"0.{weight}") for weight in weight_str.split("0.")[1:]]

            print(a, b, c, d, e, f, g, h, i)


            runs = os.listdir(file)
            for run in runs:

                if "test_" in run:
                    continue

                if any([exclude_file in run for exclude_file in files_to_exclude]):
                    continue 

                seed = int(run.split("seed_")[-1].split("_")[0])
                checkpoint = int(run.split("-")[-1].split(".")[0])


                path = os.path.join(file, run)

                df = pd.read_pickle(path)

                df = df.rename(columns={"task_idx": "skill", "task_loss": "loss"})

                df["method"] = method
                df["seed"] = seed
                df["checkpoint"] = checkpoint
                df["break_steps"] = break_steps
                df["p1"] = a 
                df["p2"] = b
                df["p3"] = c
                df["p4"] = d
                df["p5"] = e
                df["p6"] = f
                df["p7"] = g
                df["p8"] = h
                df["p9"] = i


                df.set_index("checkpoint", inplace=True)


                df_all = pd.concat([df_all, df])


    df_all = df_all.sort_values(by=["checkpoint", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "seed"])
    return df_all


In [4]:
def load_resume_losses(dirs, slice_list, task_name):
    files_to_exclude = ['trace', 'val_inputs', 'labels', 'proportions', 'emb', 'rouge', 'generations', 'gradient', 'acc']    
    df_all = pd.DataFrame() 

    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files] # add path to each file
        files.sort(key=lambda x: os.path.getmtime(x), reverse=True)

        for file in files:
            if ".log" in file or task_name not in file:
                continue 

            if any(skill not in file for skill in slice_list):
                continue 

            if "resume" not in file:
                continue


            old_weight_str = file.split("resume_")[-1].split("_")[0]
            print(file)
            if len(old_weight_str) == 9:
                a = int(old_weight_str[0])
                b = int(old_weight_str[1])
                c = int(old_weight_str[2])
                d = int(old_weight_str[3])
                e = int(old_weight_str[4])
                f = int(old_weight_str[5])
                g = int(old_weight_str[6])
                h = int(old_weight_str[7])
                i = int(old_weight_str[8])
            else:
                if len(old_weight_str.split("0.")) == 2:
                    continue 
                a, b, c, d, e, f, g, h, i = [float(f"0.{weight}") for weight in old_weight_str.split("0.")[1:]]

            print(a, b, c, d, e, f, g, h, i)


            method = file.split("/")[-1]

            new_weight_str = method.split("weights_")[-1].split("_")[0]
            if len(new_weight_str) == 9:
                new_a = int(new_weight_str[0])
                new_b = int(new_weight_str[1])
                new_c = int(new_weight_str[2])
                new_d = int(new_weight_str[3])
                new_e = int(new_weight_str[4])
                new_f = int(new_weight_str[5])
                new_g = int(new_weight_str[6])
                new_h = int(new_weight_str[7])
                new_i = int(new_weight_str[8])
            else:
                new_a, new_b, new_c, new_d, new_e, new_f, new_g, new_h, new_i = [float(f"0.{weight}") for weight in new_weight_str.split("0.")[1:]]


            break_steps = int(file.split(f"resume_{old_weight_str}_")[-1].split("_")[0])


            runs = os.listdir(file)
            for run in runs:

                if "test_" in run:
                    continue

                if any([exclude_file in run for exclude_file in files_to_exclude]):
                    continue 

                seed = int(run.split("seed_")[-1].split("_")[0])
                checkpoint = int(run.split("-")[-1].split(".")[0])


                path = os.path.join(file, run)

                df = pd.read_pickle(path)

                df = df.rename(columns={"task_idx": "skill", "task_loss": "loss"})

                df["method"] = method
                df["seed"] = seed
                df["checkpoint"] = checkpoint
                df["break_steps"] = break_steps
                df["new_p1"] = new_a 
                df["new_p2"] = new_b 
                df["new_p3"] = new_c 
                df["new_p4"] = new_d 
                df["new_p5"] = new_e 
                df["new_p6"] = new_f 
                df["new_p7"] = new_g 
                df["new_p8"] = new_h 
                df["new_p9"] = new_i 

                df["p1"] = a 
                df["p2"] = b
                df["p3"] = c
                df["p4"] = d
                df["p5"] = e
                df["p6"] = f
                df["p7"] = g
                df["p8"] = h
                df["p9"] = i


                df.set_index("checkpoint", inplace=True)


                df_all = pd.concat([df_all, df])


    df_all = df_all.sort_values(by=["checkpoint", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "new_p1", "new_p2", "new_p3", "new_p4", "new_p5", "new_p6", "new_p7", "new_p8", "new_p9", "seed"])
    return df_all


In [5]:
def calculate_r_squared(actuals, predictions):
    actuals, predictions = actuals.numpy(), predictions.numpy()
    # Calculate the total sum of squares
    total_sum_of_squares = np.sum((actuals - np.mean(actuals)) ** 2)
    # Calculate the residual sum of squares
    residual_sum_of_squares = np.sum((actuals - predictions) ** 2)
    # Calculate R-squared
    r_squared = 1 - (residual_sum_of_squares / total_sum_of_squares)
    return r_squared


In [6]:
def mixing_law(x, param):
    # one set of params per skill
    #print(param)
    c_i = param[0]
    t_i = param[1:]
    result = c_i + torch.matmul(x[:, :8], t_i)
    return result

def init_params_law(idx, num_domains=9):
    for c_i in np.linspace(0.5, 5, 10):
        for _ in range(30):
            ts = [-np.random.rand() if i == idx else np.random.rand() * 0.1 for i in range(num_domains-1)]
            yield [c_i] + ts

In [8]:
def make_individual_xy(df_break, df_resume, skill, break_steps, p1):
    x = []
    y = []

    df_break_subset = df_break[(df_break.break_steps == break_steps) & (df_break.p1 == p1) & (df_break.skill == skill)]
    df_break_subset = df_break_subset.loc[df_break_subset.index.max()]

    df_resume_subset = df_resume[(df_resume.break_steps == break_steps) & (df_resume.p1 == p1) & (df_resume.skill == skill)]
    df_resume_subset = df_resume_subset.loc[df_resume_subset.index.max()]
    
        
    x = df_resume_subset[['new_p1', 'new_p2', 'new_p3', 'new_p4', 'new_p5', 'new_p6', 'new_p7', 'new_p8', 'new_p9']].values
    y = df_resume_subset['loss'].values

    return x, y

In [48]:
def make_individual_xy_log(df_break, df_resume, skill, break_steps, p1):
    x = []
    y = []

    df_break_subset = df_break[(df_break.break_steps == break_steps) & (df_break.p1 == p1) & (df_break.skill == skill)]
    df_break_subset = df_break_subset.loc[df_break_subset.index.max()]

    df_resume_subset = df_resume[(df_resume.break_steps == break_steps) & (df_resume.p1 == p1) & (df_resume.skill == skill)]
    df_resume_subset = df_resume_subset.loc[df_resume_subset.index.max()]
    
        
    x = df_resume_subset[['new_p1', 'new_p2', 'new_p3', 'new_p4', 'new_p5', 'new_p6', 'new_p7', 'new_p8', 'new_p9']].values

    L0 = df_break_subset['loss']

    y = np.log(df_resume_subset['loss'].values / L0)

    return x, y, L0

# Instruction tuning

In [None]:
dirs = ["../output/11062024/", "../output/10292024/", "../output/10302024/", "../output/10312024/"] # REPLACE WITH YOUR RUN OUTPUT DIRECTORIES
task_name = "instruction"
slice_list = [""]
df_break = load_break_losses(dirs, slice_list, task_name)
df_break = df_break[df_break.seed==0]

In [11]:
skills = sorted(df_break.skill.unique())

In [None]:
df_resume = load_resume_losses(dirs, slice_list, task_name)

In [13]:
break_steps = sorted(df_resume.break_steps.unique())
probs = sorted(df_break.p1.unique())

In [None]:
params = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }

x_per_skill = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }
y_per_skill = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }

mses = []
r2s = []
for i, skill in enumerate(skills):
    for bs in break_steps:
        for p1 in probs:
            print(skill, bs, p1)
            x, y = make_individual_xy(df_break, df_resume, skill, bs, p1)
            
            x_per_skill[skill][bs][p1] = x
            y_per_skill[skill][bs][p1] = y
                        
            law = ScalingLaw(mixing_law)
            p = law.fit(x, y, init_params_law(i, num_domains=len(skills)), max_step=100, delta=0.02)
            params[skill][bs][p1] = p # param

            prediction_train = mixing_law(torch.tensor(x, dtype=torch.float), torch.tensor(p, dtype=torch.float))
            mse_train = torch.nn.functional.mse_loss(prediction_train, torch.tensor(y, dtype=torch.float)).item()
            r2_train = calculate_r_squared(torch.tensor(y), torch.tensor(prediction_train))

            mses.append(mse_train)
            r2s.append(r2_train)


            print(f"MSE: {mse_train}, R2: {r2_train}")


mses = np.array(mses)
r2s = np.array(r2s)

print(mses.mean(), mses.std())
print(r2s.mean(), r2s.std())


In [None]:
from collections import defaultdict 
x_per_skill = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }
y_per_skill = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }

mses = []
r2s = []

mse_per_skill = defaultdict(list)
r2_per_skill = defaultdict(list)

for i, skill in enumerate(skills):
    for bs in break_steps:
        for p1 in probs:
            x, y = make_individual_xy(df_break, df_resume, skill, bs, p1)
            
            x_per_skill[skill][bs][p1] = x
            y_per_skill[skill][bs][p1] = y
                        
            law = ScalingLaw(mixing_law)
            p = params[skill][bs][p1]

            prediction_train = mixing_law(torch.tensor(x, dtype=torch.float), torch.tensor(p, dtype=torch.float))
            mse_train = torch.nn.functional.mse_loss(prediction_train, torch.tensor(y, dtype=torch.float)).item()
            r2_train = calculate_r_squared(torch.tensor(y), torch.tensor(prediction_train))

            mses.append(mse_train)
            r2s.append(r2_train)


            mse_per_skill[skill].append(mse_train)
            r2_per_skill[skill].append(r2_train)


mses = np.array(mses)
r2s = np.array(r2s)

print(mses.mean(), mses.std())
print(r2s.mean(), r2s.std())
print("\n")
for skill in skills:
    print(skill)
    mses = np.array(mse_per_skill[skill])
    r2s = np.array(r2_per_skill[skill])

    print(mses.mean(), mses.std())
    print(r2s.mean(), r2s.std())

    print("\n")


In [15]:
import pickle 

with open("./law_results/instruction/params_dynamic.pkl", "wb") as f:
    pickle.dump(params, f)



### Log fit

In [44]:
df_break = df_break[df_break.seed==0]

In [None]:
params_log = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }

x_per_skill = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }
y_per_skill = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }

mses = []
r2s = []
for i, skill in enumerate(skills):
    for bs in break_steps:
        for p1 in probs:
            print(skill, bs, p1)
            x, y, _ = make_individual_xy_log(df_break, df_resume, skill, bs, p1)
            
            x_per_skill[skill][bs][p1] = x
            y_per_skill[skill][bs][p1] = y
                        
            law = ScalingLaw(mixing_law)
            p = law.fit(x, y, init_params_law(i, num_domains=len(skills)), max_step=100, delta=0.02)
            params_log[skill][bs][p1] = p # param

            prediction_train = mixing_law(torch.tensor(x, dtype=torch.float), torch.tensor(p, dtype=torch.float))
            mse_train = torch.nn.functional.mse_loss(prediction_train, torch.tensor(y, dtype=torch.float)).item()
            r2_train = calculate_r_squared(torch.tensor(y), torch.tensor(prediction_train))

            mses.append(mse_train)
            r2s.append(r2_train)


            print(f"MSE: {mse_train}, R2: {r2_train}")


mses = np.array(mses)
r2s = np.array(r2s)

print(mses.mean(), mses.std())
print(r2s.mean(), r2s.std())


In [46]:
import pickle 

with open("./law_results/instruction/params_log_dynamic.pkl", "wb") as f:
    pickle.dump(params_log, f)



In [None]:

x_per_skill = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }
y_per_skill = {skill : {bs: {p: {} for p in probs} for bs in break_steps} for skill in skills }

mses = []
r2s = []

mse_per_skill = defaultdict(dict)
r2_per_skill = defaultdict(dict)

for i, skill in enumerate(skills):
    for bs in break_steps:
        for p1 in probs:
            x, y, _ = make_individual_xy_log(df_break, df_resume, skill, bs, p1)
            
            x_per_skill[skill][bs][p1] = x
            y_per_skill[skill][bs][p1] = y
                        
            p = params_log[skill][bs][p1]

            prediction_train = mixing_law(torch.tensor(x, dtype=torch.float), torch.tensor(p, dtype=torch.float))
            mse_train = torch.nn.functional.mse_loss(prediction_train, torch.tensor(y, dtype=torch.float)).item()
            r2_train = calculate_r_squared(torch.tensor(y), torch.tensor(prediction_train))

            mses.append(mse_train)
            r2s.append(r2_train)



            mse_per_skill[skill][p1] = mse_train
            r2_per_skill[skill][p1] = r2_train



mses = np.array(mses)
r2s = np.array(r2s)

print(mses.mean(), mses.std())
print(r2s.mean(), r2s.std())

print("\n")
for skill in skills:
    print(skill)
    mses = np.array(list(mse_per_skill[skill].values()))
    r2s = np.array(list(r2_per_skill[skill].values()))

    print(mses.mean(), mses.std())
    print(r2s.mean(), r2s.std())

    print("\n")

