In [43]:
%load_ext autoreload
%autoreload 2

import numpy as np
import yaml
import torch

import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

from basin_volume import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def make_kl_fn(probs_p, apply_fn, x, l2_reg):
    def kl_fn(a, b):
        params_q = a + b
        logits_q = apply_fn(params_q, x)
        logprobs_q = jax.nn.log_softmax(logits_q)
        kl_term = optax.kl_divergence(logprobs_q, probs_p).mean()
        l2_term = 1/2 * l2_reg * jnp.sum(b**2)
        return kl_term + l2_term
    return kl_fn

def make_kl_fn_params(params_p, apply_fn, x, *, l2_reg):
    logits_p = apply_fn(params_p, x)
    probs_p = jax.nn.softmax(logits_p)
    return make_kl_fn(probs_p, apply_fn, x, l2_reg=l2_reg)

# Tokenizer

In [3]:

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")

In [4]:
len(tokenizer.vocab)

50277

# Model

In [5]:

model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")

The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


In [22]:
# for name, p in model.named_parameters():
#     print(name, p.shape)
# print()

# sort by number of parameters and print
sorted_params = sorted(model.named_parameters(), key=lambda x: np.prod(x[1].shape))
# compute total number of parameters
total_params = sum(np.prod(p.shape) for name, p in sorted_params)
print(f"Total number of parameters: {total_params}")

# compute total embedding parameters
embedding_params = sum(np.prod(p.shape) for name, p in sorted_params if "embed" in name)
print(f"Total number of embedding parameters: {embedding_params}")
print(f"Total number of non-embedding parameters: {total_params - embedding_params}")

for name, p in reversed(sorted_params):
    print(f"{np.prod(p.shape) / total_params:.2%}", p.shape, name)

Total number of parameters: 14067712
Total number of embedding parameters: 12877824
Total number of non-embedding parameters: 1189888
45.77% torch.Size([50304, 128]) embed_out.weight
45.77% torch.Size([50304, 128]) gpt_neox.embed_in.weight
0.47% torch.Size([128, 512]) gpt_neox.layers.5.mlp.dense_4h_to_h.weight
0.47% torch.Size([512, 128]) gpt_neox.layers.5.mlp.dense_h_to_4h.weight
0.47% torch.Size([128, 512]) gpt_neox.layers.4.mlp.dense_4h_to_h.weight
0.47% torch.Size([512, 128]) gpt_neox.layers.4.mlp.dense_h_to_4h.weight
0.47% torch.Size([128, 512]) gpt_neox.layers.3.mlp.dense_4h_to_h.weight
0.47% torch.Size([512, 128]) gpt_neox.layers.3.mlp.dense_h_to_4h.weight
0.47% torch.Size([128, 512]) gpt_neox.layers.2.mlp.dense_4h_to_h.weight
0.47% torch.Size([512, 128]) gpt_neox.layers.2.mlp.dense_h_to_4h.weight
0.47% torch.Size([128, 512]) gpt_neox.layers.1.mlp.dense_4h_to_h.weight
0.47% torch.Size([512, 128]) gpt_neox.layers.1.mlp.dense_h_to_4h.weight
0.47% torch.Size([128, 512]) gpt_neox.la

In [46]:
# functions to convert Torch params to flat array and back
def torch_to_flat(model):
    """Convert PyTorch model parameters to flat array"""
    return np.concatenate([p.detach().numpy().ravel() for p in model.parameters()])

def flat_to_torch(flat_params, model):
    """Convert flat array back to PyTorch model parameters"""
    pointer = 0
    for param in model.parameters():
        num_params = param.numel()
        param.data = torch.from_numpy(
            flat_params[pointer:pointer + num_params].reshape(param.shape)
        )
        pointer += num_params

In [58]:
trained_params = torch_to_flat(model)

# Dataset

In [59]:
print(X_train.shape)

NameError: name 'X_train' is not defined

# Basins

In [19]:
# https://github.com/EleutherAI/pythia/blob/main/models/14M/pythia-14m.yml
with open("../data/pythia-14m.yml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
l2_reg = config['weight-decay']
# TODO: determine if this is the correct train size
train_size = config['train-iters'] * config['train_micro_batch_size_per_gpu']
sigma_epoch = 1/jnp.sqrt(l2_reg * train_size)

In [40]:
sigma_params = jnp.sqrt(jnp.mean(params**2))

In [41]:
print(sigma_epoch)  # way too small!
print(sigma_params)

0.001478281
0.3292918


In [57]:
def apply_fn(params, x):
    # assign params to model
    flat_to_torch(params, model)
    return model(x)

In [None]:
kl_fn = make_kl_fn_params(trained_params, apply_fn, X_train, l2_reg=0.)   

In [7]:
CUTOFF = 1e-2

RESULTS = {} # estimates, props, mults, deltas, logabsint

In [None]:
RESULTS['naive'] = get_estimates_vectorized_gauss(1, 
                                                  sigma=sigma_params,
                                                  fn=kl_fn, 
                                                  params=trained_params, 
                                                  cutoff=CUTOFF,
                                                  tol=0.1,
                                                  debug=False,
                                                 )