In [1]:
import wandb
import pandas as pd

import torch
import torch.optim as optim
import numpy as np

In [2]:
entity="miki-and-tml"
project="scaling-tests"
api = wandb.Api()

runs = api.runs(f"{entity}/{project}")

dtypes = {
    'run_id': str, # or 'object'
    'val_loss': float, # or 'float64'
    'val_loss_ci_lower': float,
    'val_loss_ci_upper': float,
    'tokens_seen': int, # or 'int64'
    'compute': float,
    'params': int
}

# 2. Create the empty DataFrame, using keys from dtypes as columns
runs_df = pd.DataFrame(columns=dtypes.keys())

for run in runs:
    summary = run.summary._json_dict
    
    # Extract final logged values, defaulting to None if not available
    row = {
        'run_id': run.id,
        'run_name': run.name,
        'val_loss': summary.get('val_loss'),
        'val_loss_ci_lower': summary.get('val_loss_ci_lower'),
        'val_loss_ci_upper': summary.get('val_loss_ci_upper'),
        'tokens_seen': summary.get('tokens_seen'),
        'compute': summary.get('compute'),
        'params': summary.get('params'),
    }
    # Insert at the beginning - index 0
    runs_df = pd.concat([pd.DataFrame([row]), runs_df]).reset_index(drop=True)

runs_df


[34m[1mwandb[0m: Currently logged in as: [33mvanousekmikulas[0m ([33mvanousekmikulas-epfl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  runs_df = pd.concat([pd.DataFrame([row]), runs_df]).reset_index(drop=True)


Unnamed: 0,run_id,run_name,val_loss,val_loss_ci_lower,val_loss_ci_upper,tokens_seen,compute,params
0,xl2usnof,model_6.8M_tokens_120.0M,0.032641,0.032532,0.03276,119996416,979201473642496,6840320
1,vwci9mym,model_6.8M_tokens_60.0M,0.044834,0.044665,0.044998,59998208,489600736821248,6840320
2,agrcb37f,model_6.8M_tokens_30.0M,0.252477,0.251721,0.253218,29999104,244800368410624,6840320
3,2zf1ghwn,model_6.8M_tokens_240.0M,0.028772,0.02866,0.028878,239992832,1958402947284992,6840320
4,h09i8eaz,model_6.8M_tokens_480.0M,0.025872,0.025769,0.02598,479993856,3916872743387136,6840320
5,buol3dn1,model_13.8M_tokens_960.0M,0.021329,0.021235,0.021425,959995904,14015632999710720,13797376
6,bh0nkk09,model_13.8M_tokens_480.0M,0.023663,0.023562,0.023761,479993856,7007756699566080,13797376
7,i8ikrs6s,model_13.8M_tokens_240.0M,0.026442,0.026337,0.026546,239992832,3503818549493760,13797376
8,g7yc8q47,model_13.8M_tokens_120.0M,0.029614,0.029507,0.02973,119996416,1751909274746880,13797376
9,27l1qho0,model_13.8M_tokens_60.0M,0.034794,0.034669,0.034935,59998208,875954637373440,13797376


In [3]:
print("There are ", len(runs_df), " runs")
runs_df_finished = runs_df[runs_df['params'].notnull()]
print("There are ", len(runs_df_finished), " finished runs")
runs_df = runs_df_finished
runs_df = runs_df.astype(dtypes)

There are  15  runs
There are  15  finished runs


In [9]:
import torch.nn as nn
class ScalingLaw(torch.nn.Module):
    """
    We model the loss as a function of number of parameters N and dataset size D as:
    $$
    \hat{L}(N, D) \triangleq E+\frac{A}{N^\alpha}+\frac{B}{D^\beta}
    $$
    """
    def __init__(self, a=0.0, b=0.0, e=0.0, alpha=0.5, beta=0.5):
        super().__init__()
        self.a = nn.Parameter(torch.tensor(a, dtype=torch.float32))
        self.b = nn.Parameter(torch.tensor(b, dtype=torch.float32))
        self.e = nn.Parameter(torch.tensor(e, dtype=torch.float32))
        self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float32))
        self.beta = nn.Parameter(torch.tensor(beta, dtype=torch.float32))

    def forward(self, N, D):
        """
        N: tensor of model sizes (number of parameters)
        D: tensor of dataset sizes
        returns: predicted loss (same shape as N and D)
        """
        # logN = torch.log(N)
        # logD = torch.log(D)

        # # Compute the three components
        # x1 = self.a - self.alpha * logN
        # x2 = self.b - self.beta * logD
        # x3 = self.e.expand_as(x1)

        # # Numerically stable log-sum-exp over the 3 components
        # stacked = torch.stack([x1, x2, x3], dim=0)
        # log_pred = torch.logsumexp(stacked, dim=0)

        # return torch.exp(log_pred)
        A = torch.exp(self.a)
        B = torch.exp(self.b)
        E = torch.exp(self.e)
        return E + A / (N ** self.alpha) + B / (D ** self.beta)

model = ScalingLaw()

total_epochs = 50000
optimizer = optim.AdamW(model.parameters(), lr=1e-1, weight_decay=1e-9)
# Set the number of epochs for a full cosine cycle (T_max)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epochs)

# Convert to numeric types before creating tensors
N = torch.tensor(pd.to_numeric(runs_df['params']).values, dtype=torch.float32)
D = torch.tensor(pd.to_numeric(runs_df['tokens_seen']).values, dtype=torch.float32)
L = torch.tensor(runs_df['val_loss'].values, dtype=torch.float32)
criterion = torch.nn.HuberLoss(delta=1e-3)
for epoch in range(total_epochs):
    model.train()
    optimizer.zero_grad()
    L_pred = model(N, D)
    loss = criterion(torch.log(L_pred), torch.log(L))
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

  \hat{L}(N, D) \triangleq E+\frac{A}{N^\alpha}+\frac{B}{D^\beta}


Epoch 0, Loss: 0.002642521169036627
Epoch 100, Loss: 0.0012824843870475888
Epoch 200, Loss: 0.0011908862506970763
Epoch 300, Loss: 0.0011416436173021793
Epoch 400, Loss: 0.001081848400644958
Epoch 500, Loss: 0.0011192939709872007
Epoch 600, Loss: 0.0009848072659224272
Epoch 700, Loss: 0.0009640994830988348
Epoch 800, Loss: 0.0009265423868782818
Epoch 900, Loss: 0.0008907653973437846
Epoch 1000, Loss: 0.0008774850284680724
Epoch 1100, Loss: 0.000845086935441941
Epoch 1200, Loss: 0.0008061127737164497
Epoch 1300, Loss: 0.0007821861072443426
Epoch 1400, Loss: 0.0007560413214378059
Epoch 1500, Loss: 0.0007248680922202766
Epoch 1600, Loss: 0.0007083348464220762
Epoch 1700, Loss: 0.0006841178983449936
Epoch 1800, Loss: 0.0006829246995039284
Epoch 1900, Loss: 0.0006677730707451701
Epoch 2000, Loss: 0.0006668615387752652
Epoch 2100, Loss: 0.0006651955191046
Epoch 2200, Loss: 0.0006743497797288001
Epoch 2300, Loss: 0.0006788269383832812
Epoch 2400, Loss: 0.0006664912798441947
Epoch 2500, Loss: 

In [None]:
preds = model(N, D)
runs_df['pred'] = preds.detach().numpy()
runs_df['error_abs'] = (runs_df.pred - runs_df.val_loss).abs()
runs_df['error_rel'] = runs_df.error_abs / runs_df.val_loss
runs_df[['run_name', 'val_loss', 'pred', 'error_abs', 'error_rel']].sort_values(by='error_abs', ascending=False)
# runs_df[['run_name', 'val_loss', 'pred', 'error_abs', 'error_rel']].sort_values(by='error_rel', ascending=False)

Unnamed: 0,run_name,val_loss,pred,error_abs,error_rel
12,model_1.8M_tokens_27.0M,3.854408,0.310839,3.54357,0.919355
13,model_1.8M_tokens_13.5M,4.683319,1.295973,3.387346,0.723279
11,model_1.8M_tokens_54.0M,1.129691,0.088154,1.041536,0.921966
14,model_1.8M_tokens_6.8M,5.147617,5.657372,0.509755,0.099027
10,model_1.8M_tokens_108.0M,0.168284,0.037799,0.130485,0.775385
9,model_13.8M_tokens_60.0M,0.034794,0.074985,0.040191,1.155118
1,model_6.8M_tokens_60.0M,0.044834,0.074985,0.030151,0.672495
8,model_13.8M_tokens_120.0M,0.029614,0.034823,0.005209,0.175897
3,model_6.8M_tokens_240.0M,0.028772,0.025737,0.003035,0.105482
4,model_6.8M_tokens_480.0M,0.025872,0.023682,0.00219,0.084652


In [11]:
A = torch.exp(model.a).item()
B = torch.exp(model.b).item()
E = torch.exp(model.e).item()
alpha = model.alpha.item()
beta = model.beta.item()
print(f"A={A}, B={B}, E={E}, alpha={alpha}, beta={beta}")

A=0.032006070017814636, B=2474230801235968.0, E=0.02308114990592003, alpha=3.169862985610962, beta=2.1442453861236572


In [14]:
# jax_fit.py
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from scipy.optimize import minimize

N_arr = jnp.array(runs_df['params'])
D_arr = jnp.array(runs_df['tokens_seen'])
L_arr = jnp.array(runs_df['val_loss'])

def huber(z, delta=1e-3):
    a = jnp.abs(z)
    return jnp.where(a <= delta, 0.5 * z * z / delta, a - 0.5 * delta)

def obj_and_grad(x, N_np, D_np, L_np, delta=1e-3):
    # x = [a, b, e, alpha, beta]
    a, b, e, alpha, beta = x
    # convert inputs to jnp arrays
    N = jnp.asarray(N_np)
    D = jnp.asarray(D_np)
    L = jnp.asarray(L_np)

    t1 = a - alpha * jnp.log(N)
    t2 = b - beta  * jnp.log(D)
    t3 = e
    # log-sum-exp across the three terms (per-run)
    # produce scalar per run then difference with log L
    lse = logsumexp(jnp.stack([t1, t2, jnp.full_like(t1, t3)]), axis=0)
    residual = lse - jnp.log(L)
    losses = huber(residual, delta=delta)
    total = jnp.sum(losses)
    return total

# jax wrapper returning value and gradient for scipy
def value_and_grad_jax(x, N_np, D_np, L_np, delta=1e-3):
    val = obj_and_grad(x, N_np, D_np, L_np, delta)
    grad = jax.grad(lambda xx: obj_and_grad(xx, N_np, D_np, L_np, delta))(x)
    return val, grad

def fit_jax(N_arr, D_arr, L_arr, x0=None, delta=1e-3):
    if x0 is None:
        # reasonable default starting point (paper uses a grid; pick middle)
        x0 = np.array([10.0, 10.0, 0.0, 0.5, 0.5], dtype=float)

    def fun_and_grad_np(x):
        val, grad = value_and_grad_jax(x, N_arr, D_arr, L_arr, delta)
        return np.asarray(val, dtype=float), np.asarray(grad, dtype=float)

    res = minimize(fun_and_grad_np, x0, method="L-BFGS-B", jac=True)
    a, b, e, alpha, beta = res.x
    A, B, E = np.exp(a), np.exp(b), np.exp(e)
    return {"a": a, "b": b, "e": e, "alpha": alpha, "beta": beta,
            "A": A, "B": B, "E": E, "res": res}

# Example usage:
out = fit_jax(N_arr, D_arr, L_arr)
print(out)
def pred(N, D, params=out):
    A = params['A']
    B = params['B']
    E = params['E']
    alpha = params['alpha']
    beta = params['beta']
    return E + A / (N ** alpha) + B / (D ** beta)

{'a': np.float64(9.881793208823115), 'b': np.float64(21.47996702245739), 'e': np.float64(-40.44384477437448), 'alpha': np.float64(2.336459634245038), 'beta': np.float64(1.261894611433897), 'A': np.float64(19570.7857206063), 'B': np.float64(2131234065.6864407), 'E': np.float64(2.7255953996100603e-18), 'res':   message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
  success: True
   status: 0
      fun: 14.033634185791016
        x: [ 9.882e+00  2.148e+01 -4.044e+01  2.336e+00  1.262e+00]
      nit: 32
      jac: [-5.497e-10  1.265e-03 -3.299e-16  8.046e-09 -2.119e-01]
     nfev: 88
     njev: 88
 hess_inv: <5x5 LbfgsInvHessProduct with dtype=float64>}


In [18]:
runs_df['pred2'] = runs_df.apply(lambda row: pred(row['params'], row['tokens_seen']), axis=1)
runs_df['better'] = (runs_df['error_abs'] > (runs_df['pred2'] - runs_df['val_loss']).abs())
runs_df[['run_name', 'val_loss', 'pred', 'pred2', 'better']]

Unnamed: 0,run_name,val_loss,pred,pred2,better
0,model_6.8M_tokens_120.0M,0.032641,0.034823,0.13601,False
1,model_6.8M_tokens_60.0M,0.044834,0.074985,0.326167,False
2,model_6.8M_tokens_30.0M,0.252477,0.25253,0.782183,False
3,model_6.8M_tokens_240.0M,0.028772,0.025737,0.056716,False
4,model_6.8M_tokens_480.0M,0.025872,0.023682,0.02365,False
5,model_13.8M_tokens_960.0M,0.021329,0.023217,0.009862,False
6,model_13.8M_tokens_480.0M,0.023663,0.023682,0.02365,True
7,model_13.8M_tokens_240.0M,0.026442,0.025737,0.056716,False
8,model_13.8M_tokens_120.0M,0.029614,0.034823,0.13601,False
9,model_13.8M_tokens_60.0M,0.034794,0.074985,0.326167,False
