In [1]:
import ipywidgets as widgets
import numpy as np
import torch
from IPython.display import clear_output, display
from transformers import AutoTokenizer

from eli.config import cfg, encoder_cfg
from eli.encoder import (
    PROMPT_PREFIX,
    PROMPT_SUFFIX,
    Encoder,
    EncoderDecoder,
    EncoderTrainer,
    calculate_dinalar_loss,
    get_embeddings_from_decoder,
    kl_div,
)

cfg.buffer_size_samples = cfg.target_model_batch_size_samples = cfg.train_batch_size_samples

ImportError: cannot import name 'PROMPT_PREFIX' from 'eli.encoder' (/root/eli/src/eli/encoder.py)

In [2]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(cfg.target_model_name).to(cfg.device)
tokenizer = AutoTokenizer.from_pretrained(cfg.target_model_name)

In [3]:
# Load eval data

from eli.data import DataCollector

cfg.use_data_collector_workers = False

print("Initializing data collector")
data_collector = DataCollector(cfg)

print("Collecting data")
data_collector.collect_data()

data = data_collector.data

target_generated_tokens = data["target_generated_tokens"]

batch_size = target_generated_tokens.shape[0]

Initializing data collector


INFO:root:target_generated_tokens size: 16384 bytes (0.02 MB)
INFO:root:target_acts size: 262144 bytes (0.25 MB)
INFO:root:input_tokens size: 65536 bytes (0.06 MB)
INFO:root:Total shared memory size: 0.00 GB


Collecting data


Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

INFO:root:Tokenize and concatenate called
INFO:root:Full text length: 43727350
INFO:root:Num tokens: 9672977
  return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
INFO:root:Processing data directly on cuda without workers
INFO:root:Processing chunk 0:256 on cuda
INFO:root:Processing batch 0:256


Loaded pretrained model EleutherAI/pythia-31m into HookedTransformer
Moving model to device:  cuda


INFO:root:CHUNK 0:256 COMPLETED, Max GPU memory allocated: 2.14 GB, Max GPU memory reserved: 2.39 GB
INFO:root:Direct data processing completed


In [11]:
# target_generated_tokens = tokenizer(". My style was great. How did it happen?", return_tensors="pt").input_ids.to(cfg.device)
# print(target_generated_tokens)
target_generated_tokens = torch.tensor([[  15, 2752, 3740,  369, 1270,   15, 1359,  858,  352, 5108,   32,  187,
           42,  812, 1611,  562]])

In [15]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Create output widgets
context_output = widgets.Output()
button_area = widgets.Output()

# Create sample counter
current_sample = 0

from eli.encoder import PROMPT_PREFIX, PROMPT_SUFFIX, prepend_bos_token

def assemble_decoder_context_without_virtual(text, sample_idx=0):
    """Assemble decoder context using only user-provided text (no virtual embeddings)"""
    with torch.no_grad():
        with torch.autocast(device_type=cfg.device.type, dtype=cfg.dtype):
            # Get a sample from the data
            sample_tokens = target_generated_tokens[sample_idx:sample_idx+1].to(cfg.device)
            print("Sample tokens shape:", sample_tokens.shape)
            
            # Tokenize the user text
            # user_text_tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(cfg.device)
            user_text_tokens = torch.tensor([[3, 2, 3, 1]], device=cfg.device)
            print("User text tokens shape:", user_text_tokens.shape)
            
            # Generate tokens for prompt components
            prefix_tokens = tokenizer(PROMPT_PREFIX, return_tensors="pt", add_special_tokens=False).input_ids.to(cfg.device)
            prefix_tokens = prepend_bos_token(prefix_tokens, tokenizer)
            suffix_tokens = tokenizer(PROMPT_SUFFIX, return_tensors="pt", add_special_tokens=False).input_ids.to(cfg.device)
            
            # Repeat for batch size
            batch_size = sample_tokens.shape[0]
            prefix_tokens = prefix_tokens.repeat(batch_size, 1)
            suffix_tokens = suffix_tokens.repeat(batch_size, 1)
            user_text_tokens = user_text_tokens.repeat(batch_size, 1)
            
            # Concatenate all tokens to create the full context
            # Format: [prefix] + [user_text] + [suffix] + [target_tokens]
            input_tokens = torch.cat([
                prefix_tokens, 
                user_text_tokens, 
                suffix_tokens, 
                sample_tokens
            ], dim=1)
            
            # Get embeddings from decoder
            embeddings = model.get_input_embeddings()
            input_embeds = embeddings(input_tokens)
            
            # Create attention mask (all 1s since we're not using padding)
            attention_mask = torch.ones(
                input_embeds.shape[0],
                input_embeds.shape[1],
                device=input_embeds.device,
            )
            
            # Get decoder output
            decoder_output = model(
                input_ids=input_tokens, 
                # attention_mask=attention_mask
            )
            
            # Extract target logits for prediction loss calculation
            decoder_logits_target = decoder_output.logits[:, -sample_tokens.shape[1] - 1: -1, :]
            print(decoder_logits_target.shape)
            
            # Calculate per-token cross entropy loss
            per_token_ce = []
            overall_ce_loss = 0.0
            valid_token_count = 0
            
            for i in range(sample_tokens.shape[1]):
                # Calculate cross entropy for this position
                ce_loss = torch.nn.functional.cross_entropy(
                    decoder_logits_target[:, i], 
                    sample_tokens[:, i].long(),
                    reduction='mean'
                )
                per_token_ce.append(ce_loss.item())
                overall_ce_loss += ce_loss.item()
                valid_token_count += 1
            
            # Calculate average cross entropy loss
            overall_ce_loss = overall_ce_loss / valid_token_count if valid_token_count > 0 else float('nan')
            
            # Decode tokens for display
            context_decoded = tokenizer.decode(input_tokens[0])
            target_decoded = tokenizer.decode(sample_tokens[0])
            
            # Return results for display
            return {
                "input_tokens": input_tokens,
                "context_decoded": context_decoded,
                "target_decoded": target_decoded,
                "per_token_ce": per_token_ce,  # Add per-token cross entropy
                "overall_ce_loss": overall_ce_loss,  # Add overall cross entropy loss
                "sample_tokens": sample_tokens,
                "prefix_len": prefix_tokens.shape[1],
                "user_text_len": user_text_tokens.shape[1],
                "suffix_len": suffix_tokens.shape[1],
                "decoder_logits": decoder_logits_target[0],
            }

def display_context_and_results(results):
    with context_output:
        context_output.clear_output(wait=True)
        
        print(f"Sample {current_sample+1}/{batch_size}")
        print(f"User text: \"{text_input.value}\"")
        print(f"Overall cross entropy loss: {results['overall_ce_loss']:.6f}")
        
        # Display token lengths
        print("\nContext Structure:")
        print(f"Prefix tokens: {results['prefix_len']} tokens")
        print(f"User text tokens: {results['user_text_len']} tokens")
        print(f"Suffix tokens: {results['suffix_len']} tokens")
        print(f"Target tokens: {results['sample_tokens'].shape[1]} tokens")
        print(f"Total tokens: {results['input_tokens'].shape[1]} tokens")
        
        # Display the entire assembled context
        print("\nAssembled Context (decoded):")
        print("-" * 80)
        print(results["context_decoded"])
        print("-" * 80)
        
        # Display the target part with both KL divergence and cross entropy
        print("\nToken-by-token metrics:")
        print("(Lower values indicate better predictions)")
        
        # Create color representation
        target_tokens = results['sample_tokens'][0]
        per_token_ce = results['per_token_ce']
        
        # Create a color scale from blue (good) to red (bad)
        def get_color_for_value(value, max_value=5.0):
            import math
            if math.isnan(value):
                return "\033[90m"  # gray for NaN
            
            # Transform value to a 0-1 scale with a logarithmic mapping
            # Values > max_value will be bright red, values near 0 will be bright blue
            normalized = min(1.0, max(0.0, value / max_value))
            
            # RGB interpolation from blue (0,0,255) to red (255,0,0)
            if normalized < 0.5:
                # Blue to purple
                r = int(255 * (normalized * 2))
                g = 0
                b = 255
            else:
                # Purple to red
                r = 255
                g = 0
                b = int(255 * (1 - (normalized - 0.5) * 2))
            
            return f"\033[38;2;{r};{g};{b}m"
        
        # Reset color code
        reset_color = "\033[0m"
        
        # Display tokens with their KL and CE values, colorized
        print("       Token         |  KL Div  |  CE Loss ")
        print("-" * 45)
        
        for i, token_id in enumerate(target_tokens):
            token_str = tokenizer.decode([token_id.item()])
            ce_value = per_token_ce[i] if i < len(per_token_ce) else float('nan')
            
            # Format token string for display (handle whitespace)
            token_repr = repr(token_str)
            # Pad for alignment
            padded_token = token_repr.ljust(16)
            
            ce_color = get_color_for_value(ce_value)
            
            print(f"{i:2d}:  {padded_token} | {ce_color}{ce_value:8.4f}{reset_color}")
        
        # Also show overall prediction
        print("\nTarget tokens (what the model should predict):")
        print(results["target_decoded"])

# Create widgets
text_input = widgets.Textarea(
    value="The model is thinking about the relationship between cause and effect.",
    placeholder="Enter text to use as context...",
    description="Input Text:",
    disabled=False,
    layout=widgets.Layout(width="100%", height="100px")
)

run_button = widgets.Button(
    description='Run Test',
    disabled=False,
    button_style='primary', 
    tooltip='Run the test with the provided text'
)

next_button = widgets.Button(
    description='Next Sample',
    disabled=False,
    button_style='info',
    tooltip='Move to the next sample'
)

def on_run_button_clicked(b):
    with context_output:
        print("Running test...")
    results = assemble_decoder_context_without_virtual(text_input.value, current_sample)
    display_context_and_results(results)

def on_next_button_clicked(b):
    global current_sample
    if current_sample < batch_size - 1:
        current_sample += 1
        results = assemble_decoder_context_without_virtual(text_input.value, current_sample)
        display_context_and_results(results)
    else:
        with context_output:
            context_output.clear_output(wait=True)
            print("End of batch reached!")

# Connect the buttons to handlers
run_button.on_click(on_run_button_clicked)
next_button.on_click(on_next_button_clicked)

# Layout
controls = widgets.VBox([text_input, widgets.HBox([run_button, next_button])])
with button_area:
    display(controls)
    
display(button_area)
display(context_output)


# Run initial test with default text
if batch_size > 0:
    results = assemble_decoder_context_without_virtual(text_input.value, current_sample)
    display_context_and_results(results)


else:
    with context_output:
        print("No samples in batch!")


Output()

Output()

Sample tokens shape: torch.Size([1, 16])
User text tokens shape: torch.Size([1, 4])
torch.Size([1, 16, 50304])


In [26]:
def analyze_logits(logits, tokens, tokenizer, top_k=3, title="Analyzing logits"):
    """
    Analyze logits for a given sequence of tokens.
    
    Args:
        logits: Tensor of logits to analyze (shape: [seq_len, vocab_size])
        tokens: Tensor of token IDs corresponding to the actual tokens (shape: [seq_len])
        tokenizer: Tokenizer to decode tokens
        top_k: Number of top tokens to show for each position
        title: Title for the analysis output
    """
    # Get the true tokens for reference
    true_tokens = tokens.tolist() if isinstance(tokens, torch.Tensor) else tokens
    
    print(f"\n{title}:")
    print("=" * 80)
    print(f"Full token sequence: {tokenizer.decode(tokens)}")
    print("=" * 80)
    print(f"Sequence length: {len(true_tokens)} tokens")
    print("-" * 80)
    
    # For each position, find the top-k tokens by logit value
    for pos in range(min(len(true_tokens), logits.shape[0])):
        logits_at_pos = logits[pos]
        
        # Convert logits to probabilities using softmax
        probs = torch.nn.functional.softmax(logits_at_pos, dim=0)
        
        # Get top k tokens
        top_values, top_indices = torch.topk(probs, top_k)
        
        # Get the actual token at this position
        actual_token_id = true_tokens[pos]
        actual_token_str = tokenizer.decode([actual_token_id])
        actual_prob = probs[actual_token_id].item()
        
        # Display results
        print(f"\nPosition {pos} - Actual token: '{actual_token_str}' (ID: {actual_token_id}, Prob: {actual_prob:.4f})")
        print("-" * 50)
        
        for i, (token_id, prob) in enumerate(zip(top_indices.tolist(), top_values.tolist())):
            token_str = tokenizer.decode([token_id])
            print(f"  {i+1}. '{token_str}' (ID: {token_id}, Prob: {prob:.4f})")

def analyze_target_logits(sample_idx=0, top_k=3):
    """Analyze the target logits for a specific sample."""
    # Get target tokens and logits for the specified sample
    target_tokens_sample = target_generated_tokens[sample_idx].to(cfg.device)
    target_logits_sample = target_logits[sample_idx].to(cfg.device)
    
    analyze_logits(
        target_logits_sample, 
        target_tokens_sample, 
        tokenizer, 
        top_k=top_k, 
        title=f"Target logits for sample {sample_idx}"
    )

def analyze_decoder_logits(text, sample_idx=0, top_k=3):
    """Analyze the decoder logits for a specific input text and sample."""
    with torch.no_grad():
        with torch.autocast(device_type=cfg.device.type, dtype=cfg.dtype):
            # Get current target tokens
            target_tokens = target_generated_tokens[sample_idx:sample_idx+1].to(cfg.device)
            
            # Build the full context
            prefix_tokens = tokenizer(PROMPT_PREFIX, return_tensors="pt", add_special_tokens=False).input_ids.to(cfg.device)
            user_text_tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(cfg.device)
            suffix_tokens = tokenizer(PROMPT_SUFFIX, return_tensors="pt", add_special_tokens=False).input_ids.to(cfg.device)
            
            # Concatenate all parts except the target tokens (we'll predict them)
            input_tokens = torch.cat([
                prefix_tokens, 
                user_text_tokens, 
                suffix_tokens
            ], dim=1)
            
            # Get the starting position for prediction
            start_pos = input_tokens.shape[1]
            
            # Run the model to get decoder output
            full_tokens = torch.cat([input_tokens, target_tokens], dim=1)
            decoder_output = model(input_ids=full_tokens)
            
            # Extract logits for the positions where target tokens should be predicted
            decoder_logits = decoder_output.logits[0, start_pos-1:start_pos-1+len(target_tokens[0])]
            
            analyze_logits(
                decoder_logits,
                target_tokens[0],
                tokenizer,
                top_k=top_k,
                title=f"Decoder logits for sample {sample_idx} with text: '{text}'"
            )

# Run both analyses
print("Analyzing target logits...")
analyze_target_logits(current_sample)

print("\n\nAnalyzing decoder logits...")
print(text_input.value)
analyze_decoder_logits(text_input.value, current_sample)

Analyzing target logits...

Target logits for sample 0:
Full token sequence:  poet, Robert Frost, wrote this poem in 1936. It is
Sequence length: 15 tokens
--------------------------------------------------------------------------------

Position 0 - Actual token: ' poet' (ID: 40360, Prob: 0.3603)
--------------------------------------------------
  1. ' poet' (ID: 40360, Prob: 0.3603)
  2. ' D' (ID: 423, Prob: 0.0804)
  3. ' Po' (ID: 14128, Prob: 0.0357)

Position 1 - Actual token: ',' (ID: 11, Prob: 0.2068)
--------------------------------------------------
  1. ',' (ID: 11, Prob: 0.2068)
  2. ' and' (ID: 323, Prob: 0.1513)
  3. ' Robert' (ID: 8563, Prob: 0.0862)

Position 2 - Actual token: ' Robert' (ID: 8563, Prob: 0.1292)
--------------------------------------------------
  1. ' Robert' (ID: 8563, Prob: 0.1292)
  2. ' Sylvia' (ID: 89406, Prob: 0.0506)
  3. ' Walt' (ID: 36367, Prob: 0.0506)

Position 3 - Actual token: ' Frost' (ID: 42320, Prob: 0.8940)
-----------------------------

In [14]:
import pandas as pd

# Print the token information as a plain text table
def print_token_table(sample_idx=0):
    # Get components
    prefix_tokens = tokenizer(PROMPT_PREFIX, return_tensors="pt", add_special_tokens=False).input_ids[0]
    user_text_tokens = tokenizer(text_input.value, return_tensors="pt", add_special_tokens=False).input_ids[0]
    suffix_tokens = tokenizer(PROMPT_SUFFIX, return_tensors="pt", add_special_tokens=False).input_ids[0]
    target_tokens = target_generated_tokens[sample_idx]
    
    # Combine all tokens
    all_tokens = torch.cat([prefix_tokens, user_text_tokens, suffix_tokens, target_tokens])
    
    # Create rows for the table
    rows = []
    for i, id in enumerate(all_tokens):
        token_str = tokenizer.decode([id.item()])
        section = "PREFIX" if i < len(prefix_tokens) else \
                 "USER_TEXT" if i < len(prefix_tokens) + len(user_text_tokens) else \
                 "SUFFIX" if i < len(prefix_tokens) + len(user_text_tokens) + len(suffix_tokens) else \
                 "TARGET"
        rows.append({
            "position": i,
            "token_id": id.item(),
            "token_str": repr(token_str),  # Use repr to show whitespace
            "section": section,
            "is_special": id.item() in tokenizer.all_special_ids
        })
    
    # Convert to dataframe and print as string
    df = pd.DataFrame(rows)
    print(df.to_string())
    
    # Also print section counts for quick reference
    print("\nToken count by section:")
    print(f"PREFIX: {len(prefix_tokens)} tokens")
    print(f"USER_TEXT: {len(user_text_tokens)} tokens")
    print(f"SUFFIX: {len(suffix_tokens)} tokens")
    print(f"TARGET: {len(target_tokens)} tokens")
    print(f"TOTAL: {len(all_tokens)} tokens")

# Execute the function to print the table
print_token_table(current_sample)

     position  token_id              token_str    section  is_special
0           0    128000    '<|begin_of_text|>'     PREFIX        True
1           1       271                 '\n\n'     PREFIX       False
2           2    128006  '<|start_header_id|>'     PREFIX       False
3           3      9125               'system'     PREFIX       False
4           4    128007    '<|end_header_id|>'     PREFIX       False
5           5       271                 '\n\n'     PREFIX       False
6           6      2675                  'You'     PREFIX       False
7           7       990                ' work'     PREFIX       False
8           8       439                  ' as'     PREFIX       False
9           9       961                ' part'     PREFIX       False
10         10       315                  ' of'     PREFIX       False
11         11       459                  ' an'     PREFIX       False
12         12       445                   ' L'     PREFIX       False
13         13     11