In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "facebook/opt-125m"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset
from torch.utils.data import DataLoader

# Load the wikitext-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
subset = dataset.shuffle(seed=42).select(range(1000))  # Only has 8 sentences


In [25]:
print(subset[6])  # Print the first sample to verify loading

{'text': ' " Peace process or peace panic ? - The scourge of Palestinian moderation " , Middle East Report , 19 ( 1989 ) 3 / 158 , pp. 25 – 26 @,@ 28 @-@ 30 @,@ 42 \n'}


In [None]:
def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=128)

tokenized = subset.map(tokenize, batched=True, remove_columns=["text"])
# Set the format for PyTorch, so we can use it in a DataLoader
# Each element has 'input_ids' and 'attention_mask'
# 'Input_ids' are the tokenized input sequences
# 'Attention_mask' indicates which tokens are actual input and which are padding
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"])

# Wrap in DataLoader
loader = DataLoader(tokenized, batch_size=16)

Map: 100%|██████████| 1000/1000 [00:00<00:00, 5845.31 examples/s]


In [3]:
from tqdm import tqdm
import torch

In [None]:
model_name = "EleutherAI/pythia‑70m"  # replace size as needed
revision_list = [0,1,2,4,8,16,32,64,128,256,512]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
from devinterp.slt.sampler import estimate_learning_coeff_with_summary

llc_values = []
for revision in tqdm(revision_list):
    revision = f"step{revision}"
    model = AutoModelForCausalLM.from_pretrained(model_name, revision=revision).to(device)
    llc = estimate_learning_coeff_with_summary(
    model=model,
    loader=DataLoader(train_data, batch_size=params.batch_size, shuffle=True),
    evaluate=evaluate_ce,
    sampling_method=SGLD,
    optimizer_kwargs=dict(lr=0.003, nbeta=2.0, localization=5.0),
    num_chains=1,
    num_draws=500,
    device=device,
    online=False,
    )['llc/mean']
    llc_values.append((revision,llc))
