In [1]:
import importlib

import ipywidgets as widgets
import numpy as np
import torch
from IPython.display import clear_output, display
from transformers import AutoTokenizer

import eli.encoder

importlib.reload(eli.encoder)
from einops import einsum

from eli.config import cfg, encoder_cfg
from eli.encoder import (
    PROMPT_DECODER,
    Encoder,
    EncoderDecoder,
    EncoderTrainer,
    calculate_target_prediction_loss,
    get_embeddings_from_decoder,
)

In [2]:
# Load encoder and decoder
cfg.buffer_size_samples = cfg.target_model_batch_size_samples = (
    cfg.train_batch_size_samples
)

tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name)
tokenizer.pad_token = tokenizer.eos_token

encoder_decoder = EncoderDecoder(cfg, encoder_cfg, tokenizer).to(cfg.device)

encoder = Encoder(cfg, encoder_cfg).to(cfg.device)

encoder_path = "saved_models/encoder.pt"
encoder.load_state_dict(torch.load(encoder_path))

encoder_decoder.encoder = encoder

encoder_decoder = encoder_decoder.eval()

In [3]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(cfg.target_model_name)

In [4]:
# Load eval data

from eli.data import DataCollector

cfg.use_data_collector_workers = False

print("Initializing data collector")
data_collector = DataCollector(cfg)

print("Collecting data")
data_collector.collect_data()

data = data_collector.data

Initializing data collector


INFO:root:target_generated_tokens size: 1024 bytes (0.00 MB)
INFO:root:target_acts size: 786432 bytes (0.75 MB)
INFO:root:input_tokens size: 65536 bytes (0.06 MB)
INFO:root:Total shared memory size: 0.00 GB


Collecting data


Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

INFO:root:Tokenize and concatenate called
INFO:root:Full text length: 43727350
Token indices sequence length is longer than the specified maximum sequence length for this model (488415 > 1024). Running this sequence through the model will result in indexing errors
INFO:root:Num tokens: 9655191
  return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
INFO:root:Processing data directly on cuda without workers
INFO:root:Processing chunk 0:256 on cuda
INFO:root:Processing batch 0:256


Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda


INFO:root:Direct data processing completed


In [15]:
# Print loss statistics
target_generated_tokens = data["target_generated_tokens"]
target_acts = data["target_acts"]

buffer_size = target_acts.shape[0]
batch_size = cfg.train_batch_size_samples
num_batches = buffer_size // batch_size

target_prediction_losses = []
dinalar_losses = []

def recode_and_strip(tokens, tokenizer):
    decoded = tokenizer.batch_decode(
        sequences=tokens,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )
    decoded = [r.strip() for r in decoded]
    encoded = tokenizer(
        decoded, add_special_tokens=False, return_tensors="pt", padding=True
    )
    output_tokens = encoded.input_ids.long()
    attention_mask = encoded.attention_mask
    return output_tokens, attention_mask

for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = start_idx + batch_size

    # Extract batch data and move to device
    batch_tokens = target_generated_tokens[start_idx:end_idx].to(cfg.device)
    batch_acts = target_acts[start_idx:end_idx].to(cfg.device)

    batch_tokens, attention_mask = recode_and_strip(batch_tokens, tokenizer)
    batch_tokens = batch_tokens.to(cfg.device)

    print(batch_tokens.shape, attention_mask.shape)

    attention_mask = attention_mask.to(cfg.device)

    loss = EncoderTrainer.loss(
        cfg, encoder_decoder, batch_tokens, attention_mask, batch_acts, tokenizer, train_iter=-1
    )
    target_prediction_losses.append(loss.item())

print(f"Target prediction loss: {np.mean(target_prediction_losses)}")

torch.Size([256, 3]) torch.Size([256, 3])
Target prediction loss: 3.6078619956970215


In [17]:
from jaxtyping import Float
from torch import Tensor

# Create output widgets for displaying sample information
sample_output = widgets.Output()
button_output = widgets.Output()

# Create a counter and button
current_sample = 0

def get_distances_to_embeddings(
    embeddings: Float[Tensor, "batch tok d_embed"],
    target_embeddings: Float[Tensor, "vocab d_embed"],
) -> Float[Tensor, "batch tok vocab"]:
    """Computes the L2 distance between each token embedding and each target embedding."""
    # embeddings shape: [batch, tok, d_embed]
    # target_embeddings shape: [vocab, d_embed]

    # Reshape embeddings to [batch*tok, d_embed]
    batch_size, seq_len, d_embed = embeddings.shape
    embeddings_flat = embeddings.reshape(-1, d_embed)

    # Compute pairwise distances between all embeddings and target embeddings
    # Returns tensor of shape [batch*tok, vocab]
    distances = torch.cdist(embeddings_flat, target_embeddings, p=2)

    # Reshape back to [batch, tok, vocab]
    return distances.reshape(batch_size, seq_len, -1)

def on_button_click(b):
    global current_sample
    if current_sample < batch_size:
        display_sample(current_sample)
        current_sample += 1
    else:
        with sample_output:
            print("End of batch reached!")


def create_table(title, headers, rows, col_widths=None):
    """Helper function to create formatted tables

    Args:
        title: Table title string
        headers: List of header strings
        rows: List of rows, where each row is a list of values
        col_widths: List of column widths (defaults to 15 for all columns)

    Returns:
        Formatted table string
    """
    if col_widths is None:
        col_widths = [15] * len(headers)

    # Ensure first column width accommodates row labels
    col_widths[0] = max(col_widths[0], 8)

    # Create table string
    table = f"{title}\n"

    # Create header
    header_row = headers[0].ljust(col_widths[0])
    for i, header in enumerate(headers[1:], 1):
        header_row += header.ljust(col_widths[i])
    table += header_row + "\n"

    # Add separator
    table += "-" * len(header_row) + "\n"

    # Add rows
    for row in rows:
        row_str = str(row[0]).ljust(col_widths[0])
        for i, cell in enumerate(row[1:], 1):
            row_str += str(cell).ljust(col_widths[i])
        table += row_str + "\n"

    return table


def shuffle_data(tokens, acts, seed=None):
    """
    Shuffle the tokens and acts tensors in the same order.

    Args:
        tokens: Tensor of token IDs
        acts: Tensor of activations
        seed: Optional random seed for reproducibility

    Returns:
        Tuple of (shuffled_tokens, shuffled_acts)
    """
    if seed is not None:
        torch.manual_seed(seed)

    # Get the number of samples
    num_samples = tokens.size(0)

    # Generate random permutation indices
    indices = torch.randperm(num_samples)

    # Shuffle both tensors using the same indices
    shuffled_tokens = tokens[indices]
    shuffled_acts = acts[indices]

    return shuffled_tokens, shuffled_acts


# Function to display a single sample
def display_sample(sample_idx):
    with torch.no_grad():
        with torch.autocast(device_type=cfg.device.type, dtype=cfg.dtype):
            # Extract single sample as a "batch" of size 1
            sample_tokens = target_generated_tokens[sample_idx : sample_idx + 1].to(
                cfg.device, dtype=torch.long
            )
            sample_tokens, attention_mask = recode_and_strip(sample_tokens, tokenizer)
            sample_tokens = sample_tokens.to(cfg.device)
            attention_mask = attention_mask.to(cfg.device)
            sample_acts = target_acts[sample_idx : sample_idx + 1].to(cfg.device)

            # encoder = encoder_decoder.encoder
            # encoder_output_logits = encoder(sample_acts) # [batch tok vocab]

            # Convert logits to one-hot-like by making the max value very large and others small
            # max_values, max_indices = torch.max(encoder_output_logits, dim=-1, keepdim=True)
            # one_hot_logits = torch.ones_like(encoder_output_logits) * -100.0  # Set all values to a small number
            # one_hot_logits.scatter_(dim=-1, index=max_indices, value=100.0)   # Set max values to a large number
            # encoder_output_logits = one_hot_logits

            # Get model outputs for this single sample
            (decoder_logits_target, decoder_logits_encoding, virtual_embs) = (
                encoder_decoder(sample_acts, sample_tokens, attention_mask, train_iter=-1)
            )

            # Calculate losses using existing functions
            pred_loss = calculate_target_prediction_loss(
                decoder_logits_target, sample_tokens, tokenizer
            ).item()
            # din_loss = calculate_dinalar_loss(
            #     decoder_logits_encoding,
            #     encoder_output_logits,
            # ).item()

            embeddings = get_embeddings_from_decoder(encoder_decoder.decoder).weight
            distances = get_distances_to_embeddings(
                virtual_embs, embeddings
            )  # [batch tok vocab]

            # print("distances.shape", distances.shape)

            # Get top 3 tokens by encoder output logits
            top_k = 5
            top_values, top_indices = torch.topk(
                distances[0], k=top_k, dim=-1
            )  # [batch tok]

            print(top_indices)
            print(top_indices.shape)

            print("top_values.shape", top_values.shape)

            # Decode tokens for display
            sample_decoded = tokenizer.decode(sample_tokens[0])

            # Display results
            with sample_output:
                sample_output.clear_output(wait=True)
                print(f"Sample {sample_idx+1}/{batch_size}")
                print(f"Target prediction loss: {pred_loss:.6f}")
                print("\nTarget tokens:")
                print(sample_decoded, "\n")

                prompt_prefix, prompt_suffix = PROMPT_DECODER.split("<thought>")
                # Also decode and display the prefix and suffix tokens
                prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt").input_ids[
                    0
                ]
                suffix_tokens = tokenizer(prompt_suffix, return_tensors="pt").input_ids[
                    0
                ]

                print(tokenizer.decode(prefix_tokens))

                # Prepare data for the tokens table
                col_width = 15
                headers = ["Token"] + [f"Emb {i}" for i in range(virtual_embs.shape[1])]

                token_rows = []
                for k in range(top_k):
                    row = [f"Top {k+1}:"]
                    for j in range(virtual_embs.shape[1]):
                        token_id = top_indices[j, k].item()
                        token_text = tokenizer.decode([token_id])
                        # Replace newlines and tabs for cleaner display
                        token_text = token_text.replace("\n", "\\n").replace(
                            "\t", "\\t"
                        )
                        # Truncate to fit in column
                        token_display = token_text[: col_width - 2]
                        row.append(token_display)
                    token_rows.append(row)

                # Create and display the token table
                token_table = create_table(
                    "", headers, token_rows, [8] + [col_width] * virtual_embs.shape[1]
                )
                print(token_table)

                # Create table for logit values
                logit_rows = []
                for k in range(top_k):
                    row = [f"Top {k+1}:"]
                    for j in range(virtual_embs.shape[1]):
                        # Format logit value with 5 decimal places
                        prob = top_values[j, k].item()
                        row.append(f"{prob:.5f}")
                    logit_rows.append(row)

                # Create and display the logits table
                logit_table = create_table(
                    "Logit Values:",
                    headers,
                    logit_rows,
                    [8] + [col_width] * virtual_embs.shape[1],
                )
                print(logit_table)

                print(tokenizer.decode(suffix_tokens))


# Interactive sample investigation
next_button = widgets.Button(description="Next Sample")
next_button.on_click(on_button_click)

# Display the button and sample output in separate areas
with button_output:
    display(next_button)
display(button_output)
display(sample_output)

# Show the first sample
if batch_size > 0:
    display_sample(current_sample)
    current_sample += 1
else:
    with sample_output:
        print("No samples in batch!")

Output()

Output()

tensor([[46136, 45348, 16207, 18658, 17773]], device='cuda:0')
torch.Size([1, 5])
top_values.shape torch.Size([1, 5])


tensor([[36473, 46136, 17576, 21364, 28599]], device='cuda:0')
torch.Size([1, 5])
top_values.shape torch.Size([1, 5])
tensor([[17773, 15838, 39021, 41407, 36622]], device='cuda:0')
torch.Size([1, 5])
top_values.shape torch.Size([1, 5])
tensor([[41407, 36473, 14369, 23785, 17773]], device='cuda:0')
torch.Size([1, 5])
top_values.shape torch.Size([1, 5])
tensor([[15838, 11411, 23785, 16207, 45348]], device='cuda:0')
torch.Size([1, 5])
top_values.shape torch.Size([1, 5])
tensor([[41407, 36622, 30202, 41230, 45999]], device='cuda:0')
torch.Size([1, 5])
top_values.shape torch.Size([1, 5])
tensor([[23785, 15838, 36473, 14369, 36622]], device='cuda:0')
torch.Size([1, 5])
top_values.shape torch.Size([1, 5])
tensor([[29823, 48382,  6909, 45260, 19593]], device='cuda:0')
torch.Size([1, 5])
top_values.shape torch.Size([1, 5])
tensor([[23785, 50179, 11411, 15838, 39021]], device='cuda:0')
torch.Size([1, 5])
top_values.shape torch.Size([1, 5])
tensor([[23785, 46659, 28420, 46136, 17773]], device='cu

In [None]:
# text = "Bob"
# tokens = tokenizer(text, add_special_tokens=False, return_tensors="pt")


# Print each token separately
# print("Tokens for text:", text)



In [47]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

text = "Alice jacobson was a good person. Alice "
tokens = tokenizer(text, add_special_tokens=True, return_tensors="pt")
print(tokens)
print(tokens.input_ids)
tokens = torch.cat([torch.tensor([[tokenizer.bos_token_id]]), tokens.input_ids], dim=1)

# Get model outputs
outputs = model(tokens)
logits = outputs.logits

# Get the logits for the last token
last_token_logits = logits[0, -1, :]
print(last_token_logits.shape)

# Get the top 3 token indices and their corresponding logits
k = 10
top_values, top_indices = torch.topk(last_token_logits, k)

# Print the top 3 tokens and their logits
print(f"Top {k} tokens by logit:")
for i, (token_id, logit_value) in enumerate(zip(top_indices, top_values)):
    token_text = tokenizer.decode([token_id.item()])
    print(
        f"{i+1}. Token: '{token_text}', ID: {token_id.item()}, Logit: {logit_value.item():.4f}"
    )

{'input_ids': tensor([[44484,   474,   330,   672,  1559,   373,   257,   922,  1048,    13,
         14862,   220]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
tensor([[44484,   474,   330,   672,  1559,   373,   257,   922,  1048,    13,
         14862,   220]])
torch.Size([50257])
Top 10 tokens by logit:
1. Token: ' ', ID: 1849, Logit: -46.7644
2. Token: 'iced', ID: 3711, Logit: -47.3270
3. Token: 'ich', ID: 488, Logit: -48.3128
4. Token: 'ive', ID: 425, Logit: -48.6372
5. Token: 'irl', ID: 1901, Logit: -48.7395
6. Token: 'ix', ID: 844, Logit: -49.0847
7. Token: 'iz', ID: 528, Logit: -49.1540
8. Token: 'izzy', ID: 40593, Logit: -49.2239
9. Token: 'ike', ID: 522, Logit: -49.2902
10. Token: '________', ID: 2602, Logit: -49.3509
