In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse 
import pandas as pd 
import os 
import numpy as np
import torch 
from collections import defaultdict
import pickle

from law import ScalingLaw

In [3]:
def load_losses(dirs, slice_list, task_name, filter_word=None, exclude_word=None):
    files_to_exclude = ['trace', 'val_inputs', 'labels', 'proportions', 'emb', 'rouge', 'generations', 'gradient', 'acc']    
    df_all = pd.DataFrame() 

    for dir in dirs:
        files = os.listdir(dir)
        files = [os.path.join(dir, f) for f in files]
        files.sort(key=lambda x: os.path.getmtime(x), reverse=True)

        for file in files:
            if ".log" in file or task_name not in file:
                continue 

            if any(skill not in file for skill in slice_list):
                continue 

            method = file.split("/")[-1]

            if filter_word is not None and filter_word not in method:
                continue

            if exclude_word is not None and exclude_word in method:
                continue
           
            if "stratified" in file or "doge" in file or "skillit" in file or "aioli" in file:
                continue 
            

            weight_str = method.split("weights_")[-1].split("_")[0]

            if len(weight_str) == 3:
                a = int(weight_str[0])
                b = int(weight_str[1])
                c = int(weight_str[2])
            elif len(weight_str) == 4:
                if weight_str == "0010":
                    a = 0
                    b = 0
                    c = 10
                elif weight_str == "0100":
                    a = 0
                    b = 10 
                    c = 0
                elif weight_str == "1000":
                    a = 10 
                    b = 0 
                    c = 0
            else:
                a, b, c = [float(f"0.{weight}") for weight in weight_str.split("0.")[1:]]

            print(a, b, c)


            runs = os.listdir(file)
            for run in runs:
                if "test_" in run:
                    continue

                if any([exclude_file in run for exclude_file in files_to_exclude]):
                    continue 

                seed = int(run.split("seed_")[-1].split("_")[0])
                checkpoint = int(run.split("-")[-1].split(".")[0])


                path = os.path.join(file, run)

                df = pd.read_pickle(path)
                df = df.rename(columns={"task_idx": "skill", "task_loss": "loss"})

                df["method"] = method
                df["seed"] = seed
                df["checkpoint"] = checkpoint
                df["p1"] = a 
                df["p2"] = b
                df["p3"] = c


                df.set_index("checkpoint", inplace=True)


                df_all = pd.concat([df_all, df])


    df_all = df_all.sort_values(by=["checkpoint", "p1", "p2", "seed"])
    return df_all


### The math of this mixing law

The generic form reported in the paper is $L_i(p) = c_i + k_i \exp(\sum_{j=1}^m A_{ij} p_j)$. For $m=3$, we only need two parameters per row in $A$, so we can rewrite as
\begin{align}
L_i(p) &= c_i + \exp(\log k_i + A_{i1} p_1 + A_{i2} p_2 + A_{i3} (1 - p_1 - p_2))  \\
&= c_i + \exp(\log k_i + A_{i3} + (A_{i1} - A_{i3}) p_1 + (A_{i2} - A_{i3}) p_2) \\
&= c_i + \exp(b_i + t_{i1}p_1 + t_{i2}p_2)
\end{align}

where $b_i = \log k_i + A_{i3}$, $t_{i1} = A_{i1} - A_{i3}, t_{i2} = A_{i2} - A_{i3}$.

That is, each regression involves 4 parameters.

In [4]:
def mixing_law(x, param):
    log_c_i, b_i = param[0], param[1]
    t_i = param[2:]
    result = torch.exp(log_c_i) + torch.exp(b_i + torch.matmul(x[:, :2], t_i))
    return result

def init_params_law(idx, num_domains=3):
    for log_c_i in np.linspace(-2, 1.5, 10):
        for b_i in np.linspace(-10, 1, 20):
            for _ in range(30):
                ts = [-np.random.rand() if i == idx else np.random.rand() * 0.1 for i in range(num_domains-1)]
                yield [log_c_i, b_i] + ts

In [6]:
def calculate_r_squared(actuals, predictions):
    actuals, predictions = actuals.numpy(), predictions.numpy()
    # Calculate the total sum of squares
    total_sum_of_squares = np.sum((actuals - np.mean(actuals)) ** 2)
    # Calculate the residual sum of squares
    residual_sum_of_squares = np.sum((actuals - predictions) ** 2)
    # Calculate R-squared
    r_squared = 1 - (residual_sum_of_squares / total_sum_of_squares)
    return r_squared


# Arxiv, Books, Stackexchange

In [None]:
dirs = ["../output/09032024/", 
        "../output/09042024/",
        "../output/09052024/",
        "../output/09062024/"] # REPLACE WITH YOUR RUN OUTPUT DIRECTORIES
task_name = "slimpj"
slice_list = ["arxiv", "book", "stackexchange"]
df = load_losses(dirs, slice_list, task_name)

In [8]:
# Filter down the results to be only from the sweep
grid = 9
all_weights = []
for seed in range(5):
    with open(f"../dirichlet_weights/k_3_n_{grid}_seed_{seed}.txt", "r") as f:
        all_weights.extend(f.readlines())

all_weights = [weight_line.strip().replace(",", "") for weight_line in all_weights]
all_weights = [weight_line.rstrip("0") for weight_line in all_weights]
while "00." in "".join(all_weights):
    all_weights = [weight_line.replace("00.", "0.") for weight_line in all_weights]

df = df[df.method.str.contains("|".join(all_weights))]



In [10]:
skills = sorted(df.skill.unique())
indices = [5000] # end of training

seeds = sorted(df.seed.unique())

In [None]:
params = {skill : {idx: {seed: {} for seed in seeds} for idx in indices} for skill in skills }

mses = []
r2s = []

for i, skill in enumerate(skills):
    for index in indices:
        for seed in seeds:
            print(f"Skill = {skill}, index = {index}, seed = {seed}")

            df_subset = df.loc[df.index == index]
            df_subset = df_subset[(df_subset.skill == skill) & (df_subset.seed == seed)]

            x = df_subset[['p1', 'p2', 'p3']].values 

            y = df_subset['loss'].values
                
            law = ScalingLaw(mixing_law)

            p = law.fit(x, y, init_params_law(i, num_domains=len(skills)), max_step=100, delta=0.02)
            params[skill][index][seed] = p

            prediction_train = mixing_law(torch.tensor(x, dtype=torch.float), torch.tensor(p, dtype=torch.float))
            rmse_train = (torch.mean((prediction_train - y)**2)**0.5).item()
            mae_train = torch.mean(torch.abs(prediction_train - y)).item()
            mse_train = torch.nn.functional.mse_loss(prediction_train, torch.tensor(y, dtype=torch.float)).item()
            r2_train = calculate_r_squared(torch.tensor(y), torch.tensor(prediction_train))

            mses.append(mse_train)
            r2s.append(r2_train)


            print(f"RMSE: {rmse_train}, MAE: {mae_train}, MSE: {mse_train}, R2: {r2_train}")

mses = np.array(mses)
r2s = np.array(r2s)

print(mses.mean(), mses.std())
print(r2s.mean(), r2s.std())

In [13]:
import pickle 

with open("./law_results/arxiv_books_stackexchange/params_static.pkl", "wb") as f:
    pickle.dump(params, f)

## Get optimal mixture according to model

In [None]:
import torch.optim as optim

all_ps = []
step = 500

for seed in range(5):
    # Set random seed for reproducibility
    torch.manual_seed(0)

    b = torch.tensor([values[5000][seed][1] for _, values in params.items()])
    t1 = torch.tensor([values[5000][seed][2] for _, values in params.items()])
    t2 = torch.tensor([values[5000][seed][3] for _, values in params.items()])

    # Initialize p1 and p2 with random values in [0, 1]
    p1 = torch.rand(1, requires_grad=True)
    p2 = torch.rand(1, requires_grad=True)

    # Define the objective function
    def objective(p1, p2):
        return torch.sum(torch.exp(b + t1 * p1 + t2 * p2))

    # Define a function to project onto the constraint set
    def project(p1, p2):
        p1.data.clamp_(0, 1)
        p2.data.clamp_(0, 1)
        sum_p = p1 + p2
        if sum_p > 1:
            factor = 1 / sum_p
            p1.data.mul_(factor)
            p2.data.mul_(factor)
        return p1, p2

    # Optimization
    optimizer = optim.Adam([p1, p2], lr=0.001)
    n_iterations = 10000

    for i in range(n_iterations):
        optimizer.zero_grad()
        loss = objective(p1, p2)
        loss.backward()
        optimizer.step()
        
        # Project onto the constraint set
        with torch.no_grad():
            p1, p2 = project(p1, p2)
        
        if (i + 1) % 100 == 0:
            print(f'Iteration {i+1}/{n_iterations}, Loss: {loss.item():.4f}, p1: {p1.item():.4f}, p2: {p2.item():.4f}')

    print(f'Final result: p = {p1.item()}, {p2.item()}, {1-p1.item()-p2.item()}')

    all_ps.append([p1.item(), p2.item(), 1 - p1.item() - p2.item()])
    print(f'Final objective value: {objective(p1, p2).item():.4f}')

In [None]:
print(all_ps) 