In [1]:
import ipywidgets as widgets
import numpy as np
import torch
from IPython.display import clear_output, display
from transformers import AutoTokenizer

from eli.config import cfg, encoder_cfg
from eli.encoder import (
    PROMPT_PREFIX,
    PROMPT_SUFFIX,
    Encoder,
    EncoderDecoder,
    EncoderTrainer,
    calculate_dinalar_loss,
    get_embeddings_from_decoder,
    kl_div,
)

cfg.buffer_size_samples = cfg.target_model_batch_size_samples = cfg.train_batch_size_samples

In [5]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(cfg.target_model_name).to(cfg.device)
tokenizer = AutoTokenizer.from_pretrained(cfg.target_model_name)

In [3]:
# Load eval data

from eli.data import DataCollector

print("Initializing data collector")
data_collector = DataCollector(use_workers=False)

print("Collecting data")
data_collector.collect_data()

data = data_collector.data

target_generated_tokens = data["target_generated_tokens"]
target_logits = data["target_logits"]

batch_size = target_generated_tokens.shape[0]

Initializing data collector


INFO:root:target_generated_tokens size: 480 bytes (0.00 MB)
INFO:root:target_logits size: 65667072 bytes (62.62 MB)
INFO:root:target_acts size: 65536 bytes (0.06 MB)
INFO:root:input_tokens size: 2048 bytes (0.00 MB)
INFO:root:Total shared memory size: 0.06 GB


Collecting data


Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

INFO:root:Tokenize and concatenate called
INFO:root:Full text length: 43667353
Token indices sequence length is longer than the specified maximum sequence length for this model (472110 > 131072). Running this sequence through the model will result in indexing errors
INFO:root:Num tokens: 9321501
  return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
INFO:root:Processing data directly on cuda without workers
INFO:root:Processing chunk 0:8 on cuda
INFO:root:Processing batch 0:8


Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer
Moving model to device:  cuda


INFO:root:CHUNK 0:8 COMPLETED, Max GPU memory allocated: 13.68 GB, Max GPU memory reserved: 13.78 GB
INFO:root:Direct data processing completed


In [9]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Create output widgets
context_output = widgets.Output()
button_area = widgets.Output()

# Create sample counter
current_sample = 0

PROMPT_PREFIX = """
User: Your task is to predict what another LLM will say, given the following
description of what the LLM is currently thinking: \" 
"""

PROMPT_SUFFIX = """\". Provide your prediction and nothing else.
Assistant:
"""

def assemble_decoder_context_without_virtual(text, sample_idx=0):
    """Assemble decoder context using only user-provided text (no virtual embeddings)"""
    with torch.no_grad():
        with torch.autocast(device_type=cfg.device.type, dtype=cfg.dtype):
            # Get a sample from the data
            sample_tokens = target_generated_tokens[sample_idx:sample_idx+1].to(cfg.device)
            sample_logits = target_logits[sample_idx:sample_idx+1].to(cfg.device, dtype=torch.float32)
            
            # Tokenize the user text
            user_text_tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(cfg.device)
            
            # Generate tokens for prompt components
            prefix_tokens = tokenizer(PROMPT_PREFIX, return_tensors="pt").input_ids.to(cfg.device)
            suffix_tokens = tokenizer(PROMPT_SUFFIX, return_tensors="pt", add_special_tokens=False).input_ids.to(cfg.device)
            
            # Repeat for batch size
            batch_size = sample_tokens.shape[0]
            prefix_tokens = prefix_tokens.repeat(batch_size, 1)
            suffix_tokens = suffix_tokens.repeat(batch_size, 1)
            user_text_tokens = user_text_tokens.repeat(batch_size, 1)
            
            # Concatenate all tokens to create the full context
            # Format: [prefix] + [user_text] + [suffix] + [target_tokens]
            input_tokens = torch.cat([
                prefix_tokens, 
                user_text_tokens, 
                suffix_tokens, 
                sample_tokens
            ], dim=1)
            
            # Get embeddings from decoder
            embeddings = model.get_input_embeddings()
            input_embeds = embeddings(input_tokens)
            
            # Create attention mask (all 1s since we're not using padding)
            attention_mask = torch.ones(
                input_embeds.shape[0],
                input_embeds.shape[1],
                device=input_embeds.device,
            )
            
            # Get decoder output
            decoder_output = model(
                input_ids=input_tokens, 
                # attention_mask=attention_mask
            )
            
            # Extract target logits for prediction loss calculation
            decoder_logits_target = decoder_output.logits[:, -cfg.decoder_pred_len_toks:, :]
            
            # Calculate prediction loss
            prediction_loss = kl_div(decoder_logits_target, sample_logits).item()
            
            # Decode tokens for display
            context_decoded = tokenizer.decode(input_tokens[0])
            target_decoded = tokenizer.decode(sample_tokens[0])
            
            # Return results for display
            return {
                "input_tokens": input_tokens,
                "context_decoded": context_decoded,
                "target_decoded": target_decoded,
                "prediction_loss": prediction_loss,
                "sample_tokens": sample_tokens,
                "prefix_len": prefix_tokens.shape[1],
                "user_text_len": user_text_tokens.shape[1],
                "suffix_len": suffix_tokens.shape[1],
            }

def display_context_and_results(results):
    with context_output:
        context_output.clear_output(wait=True)
        
        print(f"Sample {current_sample+1}/{batch_size}")
        print(f"User text: \"{text_input.value}\"")
        print(f"Prediction loss: {results['prediction_loss']:.6f}")
        
        # Display token lengths
        print("\nContext Structure:")
        print(f"Prefix tokens: {results['prefix_len']} tokens")
        print(f"User text tokens: {results['user_text_len']} tokens")
        print(f"Suffix tokens: {results['suffix_len']} tokens")
        print(f"Target tokens: {results['sample_tokens'].shape[1]} tokens")
        print(f"Total tokens: {results['input_tokens'].shape[1]} tokens")
        
        # Display the entire assembled context
        print("\nAssembled Context (decoded):")
        print("-" * 80)
        print(results["context_decoded"])
        print("-" * 80)
        
        # Also show just the target part
        print("\nTarget tokens (what the model should predict):")
        print(results["target_decoded"])

# Create widgets
text_input = widgets.Textarea(
    value="The model is thinking about the relationship between cause and effect.",
    placeholder="Enter text to use as context...",
    description="Input Text:",
    disabled=False,
    layout=widgets.Layout(width="100%", height="100px")
)

run_button = widgets.Button(
    description='Run Test',
    disabled=False,
    button_style='primary', 
    tooltip='Run the test with the provided text'
)

next_button = widgets.Button(
    description='Next Sample',
    disabled=False,
    button_style='info',
    tooltip='Move to the next sample'
)

def on_run_button_clicked(b):
    with context_output:
        print("Running test...")
    results = assemble_decoder_context_without_virtual(text_input.value, current_sample)
    display_context_and_results(results)

def on_next_button_clicked(b):
    global current_sample
    if current_sample < batch_size - 1:
        current_sample += 1
        results = assemble_decoder_context_without_virtual(text_input.value, current_sample)
        display_context_and_results(results)
    else:
        with context_output:
            context_output.clear_output(wait=True)
            print("End of batch reached!")

# Connect the buttons to handlers
run_button.on_click(on_run_button_clicked)
next_button.on_click(on_next_button_clicked)

# Layout
controls = widgets.VBox([text_input, widgets.HBox([run_button, next_button])])
with button_area:
    display(controls)
    
display(button_area)
display(context_output)

# Run initial test with default text
if batch_size > 0:
    results = assemble_decoder_context_without_virtual(text_input.value, current_sample)
    display_context_and_results(results)
else:
    with context_output:
        print("No samples in batch!")

Output()

Output()