In [21]:
import torch
from transformers import GenerationConfig
from recdep.utils.model import load_model_and_tokenizer

device = torch.device("cuda:0")

In [None]:
model, tokenizer = load_model_and_tokenizer("tomg-group-umd", "huginn-0125")
model = model.to(device)


config = GenerationConfig(max_length=1024, stop_strings=["<|end_text|>", "<|end_turn|>"], 
                          do_sample=False, temperature=None, top_k=None, top_p=None, min_p=None, 
                          return_dict_in_generate=True,
                          eos_token_id=65505,bos_token_id=65504,pad_token_id=65509)

Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.35s/it]


Using Huginn tokenizer settings.


In [3]:
from datasets import load_dataset

ds = load_dataset("openai/gsm8k", "socratic")

Generating train split: 100%|██████████| 7473/7473 [00:00<00:00, 481578.74 examples/s]
Generating test split: 100%|██████████| 1319/1319 [00:00<00:00, 237243.75 examples/s]


In [None]:
def tokenize_dataset(vals):
    return tokenizer.apply_chat_template(vals['question'], tokenize=True, add_generation_prompt=True, return_tensors='pt')

tokenized_dataset = ds.map(
    tokenize_dataset
)

In [None]:
print(tokenized_dataset['test'][1])

{'question': {'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'input_ids': [65504, 65, 933, 1353, 2849, 402, 59416, 286, 4201, 6305, 295, 3434, 337, 2360, 5564, 6305, 46, 19779, 1523, 59416, 291, 1543, 1364, 431, 1972, 63], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}, 'answer': {'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'input_ids': [65504, 2395, 1523, 59416, 286, 5564, 6305, 1364, 431, 1972, 63, 935, 1147, 2849, 402, 47, 50, 61, 5539, 50, 47, 50, 61, 49, 4616, 49, 48351, 286, 5564, 6305, 10, 2395, 1523, 59416, 291, 1543, 1364, 431, 1972, 63, 935, 2127, 264, 1543, 3353, 286, 12026, 305, 402, 43, 49, 61, 5539, 50, 43, 49, 61, 51, 4616, 51, 59416, 286, 12026, 10, 1319, 532], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0

In [None]:
@torch.no_grad()
def compute_latents(model, outputs, num_steps=128):
    # Get initial state and compute trajectory
    embedded_inputs, _,_ = model.embed_inputs(outputs.sequences)
    input_states = model.initialize_state(embedded_inputs, deterministic=False)

    # Initialize storage for normalized latents
    latents = []
    current_latents = input_states
    latents.append(model.transformer.ln_f(current_latents).cpu().float().numpy())

    # Collect all latent states
    for step in range(num_steps):
        current_latents, _,_ = model.iterate_one_step(embedded_inputs, current_latents)
        normalized_latents = model.transformer.ln_f(current_latents)
        latents.append(normalized_latents.cpu().float().numpy())

    # Stack all latents
    latents = np.stack(latents)  # [num_steps+1, batch, seq_len, hidden_dim]
    return latents


latents = compute_latents(model, outputs, num_steps=128)

