In [1]:
%cd "/gscratch/xlab/alisaliu/superbpe"
import olmo

/mmfs1/gscratch/xlab/alisaliu/superbpe


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import functools
import itertools as it
import os
from copy import deepcopy

import numpy as np
import torch
import warnings
import math

import olmo

os.environ["SCRATCH_DIR"] = "no_exist"

In [3]:
def scale_config(
    config,
    flops_ratio,
    axis_divisor=128,
    ceil=False,
    mode="inference-flops",
    other_updates={},
):
    assert flops_ratio
    config = deepcopy(config)
    config.init_device = "meta"
    head_dim = config.d_model // config.n_heads
    if not config.mlp_hidden_size:
        config.mlp_hidden_size = config.d_model * config.mlp_ratio

    # estimate the flops of a config
    def f(C):
        C.init_device = "meta"
        model = olmo.model.OLMo(C)
        if mode == "inference-flops":
            return model.num_fwd_flops
        elif mode == "params":
            return model.num_params()
        elif mode == "params-non-embedding":
            return model.num_params(include_embedding=False)
        elif mode == "train-flops":
            return model.num_fwd_flops + model.num_bck_flops
        else:
            raise NotImplementedError(f"Unknown mode {mode}")

    def make_config(d_model, n_layers, mlp_hidden_size, do_updates=True):
        C = deepcopy(config)
        C.d_model = d_model
        C.n_heads = C.d_model // head_dim
        C.mlp_hidden_size = mlp_hidden_size
        C.n_layers = n_layers

        if do_updates:
            for key, val in other_updates.items():
                setattr(C, key, val)
        return C

    # reparameterize so only valid configs are reachable
    def param(d, n, h):
        return (
            config.d_model + head_dim * d,
            config.n_layers + n,
            config.mlp_hidden_size + axis_divisor * h,
        )

    def r(d, n, h):
        ratios = np.array(
            [
                1 + head_dim * d / config.d_model,
                1 + n / config.n_layers,
                1 + axis_divisor * h / config.mlp_hidden_size,
            ]
        )
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            return np.log(ratios)

    def g(d, n, h, do_updates=True):
        return f(make_config(*param(d, n, h), do_updates=do_updates))

    base_flops = g(0, 0, 0, do_updates=False)
    target_flops = base_flops * flops_ratio

    # fit a polynomial to g
    Q = np.array(list(it.product(*[[0, 1, 2]] * 3)))
    one = np.ones_like(Q[:, 0])
    QQ = np.vstack([Q[:, 0] ** a * Q[:, 1] ** b * Q[:, 2] ** c for a, b, c in Q]).T
    Qg = np.array([g(*row) / base_flops for row in Q])
    coeff = np.linalg.lstsq(QQ, Qg, rcond=None)[0]

    def g2(d, n, h):
        return (
            np.array([d**a * n**b * h**c for a, b, c in Q]).dot(coeff)
            * base_flops
        )

    # double check the predictions are matching
    assert round(g2(3, 4, 5)) == g(3, 4, 5)
    assert round(g2(5, 4, 6)) == g(5, 4, 6)
    assert round(g2(2, 7, 3)) == g(2, 7, 3)

    # given d and n, solve for h
    def solve_h(d, n):
        f0, f1 = g2(d, n, 0), g2(d, n, 1)
        slope = f1 - f0
        rounder = np.ceil if ceil else np.floor
        return int(rounder((target_flops - f0) / slope))

    # enumerate all viable d and n
    best, best_l = None, float("inf")
    
    d2 = 0
    while True:
        if g2(d2, 0, 0) > target_flops:
            break

        n2 = 0
        while True:
            if g2(d2, n2, 0) > target_flops:
                break

            h2 = solve_h(d2, n2)
            r2 = r(d2, n2, h2)
            l2 = r2.std()
            if l2 < best_l:
                best_l, best = l2, (d2, n2, h2)

            n2 += 1
        d2 += 1
        
    d2 = 0
    while True:
        if g2(d2 - 1, 0, 0) < target_flops:
            break

        n2 = 0
        while True:
            if g2(d2, n2 - 1, 0) < target_flops:
                break
                
            h2 = solve_h(d2, n2)
            r2 = r(d2, n2, h2)
            if not np.isinf(r2).any():
                l2 = r2.std()
                if l2 < best_l:
                    best_l, best = l2, (d2, n2, h2)

            n2 -= 1
        d2 -= 1

    opt_d, opt_n, opt_hsize = param(*best)
    ratios = tuple(r(*best).tolist())
    rel_flops = (g(*best) - target_flops) / target_flops
    return (
        (opt_d, opt_d // head_dim, opt_n, opt_hsize),
        ratios,
        rel_flops,
        make_config(*param(*best)),
    )

In [4]:
def tokens_per_step(config):
    return config.global_train_batch_size * config.model.max_sequence_length

def test_flops_per_step(config):
    model = olmo.model.OLMo(config.model)
    return tokens_per_step(config) * model.num_fwd_flops
    
def train_flops_per_step(config):
    model = olmo.model.OLMo(config.model)
    return tokens_per_step(config) * (model.num_fwd_flops + model.num_bck_flops)
    
def bytes_per_step(config, encoding_efficiency):
    result = config.global_train_batch_size * (
        config.model.max_sequence_length
        * encoding_efficiency
    )
    return float(result)
    
def num_fwd_flops(model):
    # embedding table is just a lookup in the forward pass
    n_params = model.num_params(include_embedding=False)
    # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
    # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
    # this gets us FLOPs / token
    params_flops_per_token = 2 * n_params
    # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
    attn_flops_per_token = (
        model.config.n_layers * 2 * 2 * (model.config.d_model * model.config.max_sequence_length)
    )
    return params_flops_per_token, attn_flops_per_token

def num_bck_flops(model):
    n_params = model.num_params()
    params_flops_per_token = 4 * n_params
    attn_flops_per_token = model.config.n_layers * 8 * (model.config.d_model * model.config.max_sequence_length)
    return params_flops_per_token, attn_flops_per_token
        
def model_num_params(config):
    model = olmo.model.OLMo(config.model)
    return model.num_params()

def max_steps(config):
    if isinstance(config.max_duration, int):
        return config.max_duration
    elif isinstance(config.max_duration, str):
        if config.max_duration.endswith("T"):
            # convert to float *first* to handle scientific notation
            max_tokens = int(float(config.max_duration[:-1].strip()))
            return math.ceil(max_tokens / (config.global_train_batch_size * config.model.max_sequence_length))
        elif config.max_duration.endswith("ep"):
            raise NotImplementedError
        else:
            # convert to float *first* to handle scientific notation
            return int(float(config.max_duration))
    else:
        raise TypeError(f"expected int or str for 'max_duration', found {type(config.max_duration)}")


In [5]:
# Define the "base" config we are going to scale
BASE_CONFIG = olmo.config.TrainConfig.load("configs/OLMo2-7B-generic-200k.yaml")
BASE_TRAIN_FLOPS = max_steps(BASE_CONFIG) * train_flops_per_step(BASE_CONFIG)
BASE_ENCODING_EFFICIENCY = 4.458679110036542

# Specify the desired scaling parameters
TOKENIZER_VOCAB_SIZE = BASE_CONFIG.model.vocab_size
TOKENIZER_ENCODING_EFFICIENCY = 6.0887434010717465

TARGET_MODEL_SCALE = None
TARGET_TRAIN_FLOPS = BASE_TRAIN_FLOPS

In [6]:
# Do the calculations
TARGET_CONFIG = deepcopy(BASE_CONFIG)
TARGET_CONFIG.model.vocab_size = TOKENIZER_VOCAB_SIZE
TARGET_CONFIG.model.embedding_size = (
    TOKENIZER_VOCAB_SIZE + (-TOKENIZER_VOCAB_SIZE) % 128
)
TARGET_CONFIG.model.max_sequence_length = int(
    np.ceil(
        BASE_CONFIG.model.max_sequence_length
        * BASE_ENCODING_EFFICIENCY
        / TOKENIZER_ENCODING_EFFICIENCY
    )
)

if TARGET_MODEL_SCALE:
    params, scales, error, new_model_config = scale_config(
        BASE_CONFIG.model,
        TARGET_MODEL_SCALE,
        mode="train-flops",
        other_updates=dict(max_sequence_length=TARGET_CONFIG.model.max_sequence_length),
    )
    TARGET_CONFIG.model = new_model_config

TARGET_CONFIG.max_duration = int(
    np.floor(TARGET_TRAIN_FLOPS / train_flops_per_step(TARGET_CONFIG))
)
TARGET_NUM_PARAMS = model_num_params(TARGET_CONFIG)
TARGET_TOKENS = TARGET_CONFIG.max_duration * tokens_per_step(TARGET_CONFIG)
TARGET_TOKEN_PARAM_RATIO = TARGET_TOKENS / TARGET_NUM_PARAMS
TARGET_TOTAL_BYTES = TARGET_CONFIG.max_duration * bytes_per_step(TARGET_CONFIG, TOKENIZER_ENCODING_EFFICIENCY)

print("Model config:")
print(f"{TARGET_CONFIG.model.d_model=}")
print(f"{TARGET_CONFIG.model.n_heads=}")
print(f"{TARGET_CONFIG.model.n_layers=}")
if TARGET_CONFIG.model.mlp_hidden_size:
    print(f"{TARGET_CONFIG.model.mlp_hidden_size=}")
else:
    print(f"{TARGET_CONFIG.model.mlp_ratio=}")

if not TARGET_MODEL_SCALE:
    print("[The above should be unchanged from the baseline.]")
print(f"{TARGET_CONFIG.model.weight_tying=}")
print(f"{TARGET_CONFIG.model.max_sequence_length=}")
print(f"{TARGET_CONFIG.model.vocab_size=}")
print(f"{TARGET_CONFIG.model.embedding_size=}")
print(f"{TARGET_CONFIG.max_duration=}")
print()
if TARGET_MODEL_SCALE:
    print(f"Model ratios: {scales}")
print(f"Tokens: {TARGET_TOKENS:,}")
print(f"Params: {TARGET_NUM_PARAMS:,}")
print(
    f"T/P ratio: {TARGET_TOKEN_PARAM_RATIO:.06} ({TARGET_TOKEN_PARAM_RATIO/22:.05}x Chinchilla)"
)

Model config:
TARGET_CONFIG.model.d_model=4096
TARGET_CONFIG.model.n_heads=32
TARGET_CONFIG.model.n_layers=32
TARGET_CONFIG.model.mlp_hidden_size=22016
[The above should be unchanged from the baseline.]
TARGET_CONFIG.model.weight_tying=False
TARGET_CONFIG.model.max_sequence_length=3000
TARGET_CONFIG.model.vocab_size=200005
TARGET_CONFIG.model.embedding_size=200064
TARGET_CONFIG.max_duration=107972

Tokens: 331,689,984,000
Params: 8,115,458,048
T/P ratio: 40.8714 (1.8578x Chinchilla)
