In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse 
import pandas as pd 
import os 
import numpy as np
import torch 
from collections import defaultdict


from law import ScalingLaw

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt 
from sklearn.metrics import r2_score

import pickle 
from scipy.stats import spearmanr

from sklearn.metrics.pairwise import cosine_similarity



# $sim(\tilde{A}^t, A^{t \star})$

In [3]:
def parse_weights(weight_str):
    if len(weight_str) == 2:
        a = int(weight_str[0])
        b = int(weight_str[1])
    elif len(weight_str) == 3:
        if weight_str[0] == "0":
            a = 0
            b = int(weight_str[1:])
        elif weight_str[-1] == "0":
            b = 0 
            a = int(weight_str[:2])
    else:
        idx = [i for i, ltr in enumerate(weight_str) if ltr == "."][1] - 1
        a = float(weight_str[:idx])
        b = float(weight_str[idx:])
    return a, b


In [4]:
def load_params(slice_list, file_name):
    slice_str = "_".join(slice_list)
    params_file = f"./law_results/{slice_str}/{file_name}"

    with open(params_file, 'rb') as f:
        params = pickle.load(f)

    return params

In [6]:
def get_A_matrix_dynamic_3(params, slice_list, step, n_seeds):
    A = np.zeros((n_seeds, 3, 3))

    for i in range(n_seeds):

        l1 = params[slice_list[0]][step][i]
        l2 = params[slice_list[1]][step][i]
        l3 = params[slice_list[2]][step][i]

        A[i] = np.array([l1, l2, l3])


    return A

In [7]:
def load_break_losses(dirs, slice_list, task_name, filter_word=None):
    files_to_exclude = ['trace', 'val_inputs', 'labels', 'proportions', 'emb', 'rouge', 'generations', 'gradient', 'acc', 'matrices']    
    df_all = pd.DataFrame() 

    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files]
        files.sort(key=lambda x: os.path.getmtime(x), reverse=True)

        for file in files:
            if ".log" in file or task_name not in file:
                continue 

            if any(skill not in file for skill in slice_list):
                continue 

            if "break" not in file:
                continue

            if filter_word is not None:
                if filter_word not in file:
                    continue

            break_steps = int(file.split("break_")[-1].split("_")[0])
           
            method = file.split("/")[-1]


            runs = os.listdir(file)
            for run in runs:

                if "test_" in run:
                    continue

                if any([exclude_file in run for exclude_file in files_to_exclude]):
                    continue 

                seed = int(run.split("seed_")[-1].split("_")[0])
                checkpoint = int(run.split("-")[-1].split(".")[0])


                path = os.path.join(file, run)

                df = pd.read_pickle(path)

                df = df.rename(columns={"task_idx": "skill", "task_loss": "loss"})

                df["method"] = method
                df["seed"] = seed
                df["checkpoint"] = checkpoint
                df["break_steps"] = break_steps


                df.set_index("checkpoint", inplace=True)


                df_all = pd.concat([df_all, df])


    df_all = df_all.sort_values(by=["checkpoint", "seed"])
    return df_all


In [11]:
def load_doge_matrices(dirs, slice_list, break_steps):
    matrices = {}

    slice_str = "_".join(slice_list)
    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files]

        for file in files:
            
            if "break" in file and f"stratified_{slice_str}_doge_trainer" in file:
                s = int(file.split("_break_")[-1].split("_")[0])
                if s != break_steps:
                    continue
                runs = os.listdir(file)
                for run in runs:
                    if "avg" not in run:
                        continue
                    if "doge_matrices.npy" not in run:
                        continue 
                    path = os.path.join(file, run)
                    print(path)
                    seed = int(run.split("seed_")[-1].split("_")[0])

                    

                    A = np.load(path)
                    matrices[seed] = A.mean(axis=0)

    return matrices

In [12]:
def load_doremi_matrices(dirs, slice_list, break_steps):
    matrices = {}

    slice_str = "_".join(slice_list)
    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files]

        for file in files:
            
            if "break" in file and f"stratified_{slice_str}_doremi" in file:
                s = int(file.split("_break_")[-1].split("_")[0])
                if s != break_steps:
                    continue
                runs = os.listdir(file)
                for run in runs:
                    if "avg" not in run:
                        continue
                    if "drm_matrices.npy" not in run:
                        continue 
                    path = os.path.join(file, run)
                    print(path)
                    seed = int(run.split("seed_")[-1].split("_")[0])

                    

                    A = np.load(path)
                    matrices[seed] = A.mean(axis=0)

    return matrices

In [96]:
def load_aioli_matrices(dirs, slice_list, break_steps):
    matrices = {}

    slice_str = "_".join(slice_list)
    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files]

        for file in files:
            
            if "break" in file and f"aioli" in file and slice_str in file:
                s = int(file.split("_break_")[-1].split("_")[0])
                if s != break_steps:
                    continue
                runs = os.listdir(file)
                for run in runs:
                    if "aioli_matrices.npy" not in run:
                        continue 
                    path = os.path.join(file, run)
                    print(path)
                    seed = int(run.split("seed_")[-1].split("_")[0])

                    

                    A = np.load(path)
                    matrices[seed] = A[-4:].mean(axis=0)

    return matrices

In [14]:
def get_A_matrix_doge_3(slice_list, params_doge, steps, dirs):
    A_doge = load_doge_matrices(dirs, slice_list, steps)
    A_doge = np.array(list(dict(sorted(A_doge.items())).values()))
    
    b_t = np.array([p[0] for _, p in params_doge[steps].items()])

    A_doge = A_doge * b_t[:, np.newaxis, np.newaxis]

    return A_doge

In [15]:
def get_A_matrix_drm_3(slice_list, params_doremi, steps, dirs):
    A_doremi = load_doremi_matrices(dirs, slice_list, steps)
    A_doremi = np.array(list(dict(sorted(A_doremi.items())).values()))
    
    b_t = np.array([p[0] for _, p in params_doremi[steps].items()])

    A_doremi = A_doremi * b_t[:, np.newaxis, np.newaxis]

    return A_doremi

In [16]:
def get_A_matrix_aioli_3(slice_list, params_doremi, steps, dirs):
    A_doremi = load_aioli_matrices(dirs, slice_list, steps)
    A_doremi = np.array(list(dict(sorted(A_doremi.items())).values()))
    
    b_t = np.array([p[0] for _, p in params_doremi[steps].items()])

    A_doremi = A_doremi * b_t[:, np.newaxis, np.newaxis]

    return A_doremi

In [17]:
def load_skillit_matrices(slice_list, seed, task_name):
    A = np.load(f"../skillit_graphs/{task_name}_{'_'.join(slice_list)}_normalized_seed_{seed}.npy")
    return A

In [18]:
def get_A_matrix_skillit_3(slice_list, df_break, task_name, params_skillit, steps, n_seeds):
    all_matrices = []
    for seed in range(n_seeds):
        A_skillit = load_skillit_matrices(slice_list, seed, task_name).T

        df_break_subset = df_break[(df_break.seed == seed) & (df_break.break_steps == steps)]
        losses = df_break_subset.loc[df_break_subset.index.max()]['loss'].values.reshape(-1, 3)
        b_t = params_skillit[steps][seed][0]

        losses = losses.reshape(-1, 1)
        all_matrices.append(A_skillit * losses * b_t)

    return np.array(all_matrices)


In [19]:
from scipy.stats import spearmanr

def compute_metric(A_star, A_hat):
    metric = []
    spearman = []
    l2 = []
    for a, b in zip(A_star, A_hat):
        a_sum = a.sum(axis=0)
        b_sum = b.sum(axis=0)

        a_sum = a_sum / np.linalg.norm(a_sum)
        b_sum = b_sum / np.linalg.norm(b_sum)
        
        cosim = cosine_similarity(a_sum.reshape(1, -1), b_sum.reshape(1, -1))[0]

        rank_corr, _ = spearmanr(a_sum, b_sum)

        l2.append(cosim)
        spearman.append(rank_corr)
        metric.append((cosim)* 0.5 + (rank_corr)*0.5)

    metric = np.array(metric)
    spearman = np.array(spearman)
    print(spearman)
    l2 = np.array(l2)
    print(l2)
    print(f"Spearman: {spearman.mean()}")
    print(f"L2: {l2.mean()}")
    print(f"Metric: {metric.mean()}")

## Arxiv / Book / Stackexchange

In [None]:
slice_list = ['arxiv', 'books', 'stackexchange']


params_doge_opt = load_params(slice_list, "params_doge_trajectory_opt_1000.pkl")

slice_list = ['arxiv', 'book', 'stackexchange']

A_doge_opt = get_A_matrix_dynamic_3(params_doge_opt, slice_list, 1000, 5)

dirs = ["../output/09242024/", "../output/09252024/"]

slice_list = ['arxiv', 'books', 'stackexchange']

params_doge = load_params(slice_list, "params_doge_trajectory_doge_matrix_1000.pkl")

slice_list = ['arxiv', 'book', 'stackexchange']

A_doge = get_A_matrix_doge_3(slice_list, params_doge, 1000, dirs)


In [None]:
compute_metric(A_doge_opt, A_doge)

In [None]:
slice_list = ['arxiv', 'books', 'stackexchange']


params_doremi_opt = load_params(slice_list, "params_doremi_trajectory_opt_1000_steps_500.pkl")

slice_list = ['arxiv', 'book', 'stackexchange']

A_doremi_opt = get_A_matrix_dynamic_3(params_doremi_opt, slice_list, 1000, 5)

dirs = ["../output/09262024/", "../output/09252024/"]

slice_list = ['arxiv', 'books', 'stackexchange']

params_doremi = load_params(slice_list, "params_doremi_trajectory_doremi_matrix_1000_steps_500.pkl")

slice_list = ['arxiv', 'book', 'stackexchange']

A_doremi = get_A_matrix_drm_3(slice_list, params_doremi, 1000, dirs)


In [None]:
compute_metric(A_doremi_opt, A_doremi)

In [28]:
slice_list = ['arxiv', 'books', 'stackexchange']


params_skillit_opt = load_params(slice_list, "params_skillit_trajectory_opt_1000.pkl")

slice_list = ['arxiv', 'book', 'stackexchange']

A_skillit_opt = get_A_matrix_dynamic_3(params_skillit_opt, slice_list, 1000, 5)

dirs = ["../output/09262024/", "../output/09272024/"]
df_break = load_break_losses(dirs, slice_list, 'slimpj', filter_word="greedy")

slice_list = ['arxiv', 'books', 'stackexchange']

params_skillit = load_params(slice_list, "params_skillit_trajectory_skillit_matrix_1000.pkl")

slice_list = ['arxiv', 'book', 'stackexchange']

A_skillit = get_A_matrix_skillit_3(slice_list, df_break, 'slimpj', params_skillit, 1000, 5)


In [None]:
compute_metric(A_skillit_opt, A_skillit)

In [None]:
slice_list = ['arxiv', 'books', 'stackexchange']


params_aioli_opt = load_params(slice_list, "params_aioli_trajectory_opt_1500.pkl")

slice_list = ['arxiv', 'book', 'stackexchange']

A_aioli_opt = get_A_matrix_dynamic_3(params_aioli_opt, slice_list, 1500, 5)

dirs = ["../output/09272024/", "../output/09282024/", "../output/09292024/"]

slice_list = ['arxiv', 'books', 'stackexchange']

params_aioli = load_params(slice_list, "params_aioli_trajectory_aioli_matrix_1500.pkl")

slice_list = ['arxiv', 'book', 'stackexchange']

A_aioli = get_A_matrix_aioli_3(slice_list, params_aioli, 1500, dirs)


In [None]:
compute_metric(A_aioli_opt, A_aioli)