In [3]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from basin_volume import VolumeConfig, VolumeEstimator

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
tokenizer.pad_token_id = 1  # pythia-specific
tokenizer.eos_token_id = 0  # pythia-specific
dataset = load_dataset("EleutherAI/lambada_openai", name="en", split="test", trust_remote_code=True)

cfg = VolumeConfig(model=model, 
                   tokenizer=tokenizer, 
                   dataset=dataset, 
                   text_key="text",  # must match dataset field
                   n_samples=10,  # number of MC samples
                   cutoff=1e-2,  # KL-divergence cutoff (nats)
                   max_seq_len=2048,  # sequence length for chunking dataset
                   val_size=10,  # number of sequences (chunks) to use in estimation
                   )
estimator = VolumeEstimator.from_config(cfg)

In [5]:
result = estimator.run()

In [6]:
result

VolumeResult(estimates=Array([-1.1057281e+08, -1.0925299e+08, -1.0863679e+08, -1.1000776e+08,
       -1.0962531e+08, -1.0930558e+08, -1.0951793e+08, -1.0957152e+08,
       -1.1011897e+08, -1.0973353e+08], dtype=float32), props=Array([1.       , 1.       , 1.       , 1.       , 1.0000001, 1.0000001,
       1.0000001, 1.0000001, 1.       , 1.       ], dtype=float32), mults=Array([0.4765625 , 0.5234375 , 0.546875  , 0.49609375, 0.5097656 ,
       0.5214844 , 0.5136719 , 0.51171875, 0.4921875 , 0.5058594 ],      dtype=float32), deltas=Array([0.01001176, 0.00992114, 0.00997438, 0.00995031, 0.01002949,
       0.00997159, 0.00996048, 0.00993844, 0.00995879, 0.00998593],      dtype=float32), logabsint=Array([-17460246., -16140436., -15524230., -16895202., -16512752.,
       -16193021., -16405366., -16458958., -17006410., -16620968.],      dtype=float32))