In [15]:
import ipywidgets as widgets
import numpy as np
import torch
from IPython.display import clear_output, display
from transformers import AutoTokenizer

from eli.config import cfg, encoder_cfg
from eli.encoder import (
    PROMPT_PREFIX,
    PROMPT_SUFFIX,
    Encoder,
    EncoderDecoder,
    EncoderTrainer,
    calculate_dinalar_loss,
    get_embeddings_from_decoder,
    kl_div,
)

# Load encoder and decoder

cfg.buffer_size_samples = cfg.target_model_batch_size_samples = cfg.train_batch_size_samples

tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name)

encoder_decoder = EncoderDecoder(cfg, encoder_cfg, tokenizer).to(cfg.device)

encoder = Encoder(cfg, encoder_cfg).to(cfg.device)

encoder_path = "encoder-dinalar.pt"
encoder.load_state_dict(torch.load(encoder_path))

encoder_decoder.encoder = encoder

encoder_decoder = encoder_decoder.eval()

In [6]:
# Load eval data

from eli.data import DataCollector

print("Initializing data collector")
data_collector = DataCollector(use_workers=False)

print("Collecting data")
data_collector.collect_data()

data = data_collector.data

Initializing data collector


INFO:root:target_generated_tokens size: 30720 bytes (0.03 MB)
INFO:root:target_logits size: 1648361472 bytes (1572.00 MB)
INFO:root:target_acts size: 262144 bytes (0.25 MB)
INFO:root:input_tokens size: 131072 bytes (0.12 MB)
INFO:root:Total shared memory size: 1.54 GB


Collecting data


Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

INFO:root:Tokenize and concatenate called
INFO:root:Full text length: 43727350
INFO:root:Num tokens: 9672977
INFO:root:Processing data directly on cuda without workers
INFO:root:Processing chunk 0:512 on cuda
INFO:root:Processing batch 0:512


Loaded pretrained model EleutherAI/pythia-14m into HookedTransformer
Moving model to device:  cuda


INFO:root:CHUNK 0:512 COMPLETED, Max GPU memory allocated: 13.29 GB, Max GPU memory reserved: 19.27 GB
INFO:root:Direct data processing completed


In [16]:
# Print loss statistics
target_generated_tokens = data["target_generated_tokens"]
target_logits = data["target_logits"]
target_acts = data["target_acts"]

buffer_size = target_acts.shape[0]
batch_size = cfg.train_batch_size_samples
num_batches = buffer_size // batch_size

target_prediction_losses = []
dinalar_losses = []

for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = start_idx + batch_size

    # Extract batch data and move to device
    batch_tokens = target_generated_tokens[start_idx:end_idx].to(cfg.device)
    batch_logits = target_logits[start_idx:end_idx].to(cfg.device, dtype=torch.float32)
    batch_acts = target_acts[start_idx:end_idx].to(cfg.device)

    loss, target_prediction_loss, dinalar_loss = EncoderTrainer.loss(
        cfg, encoder_decoder, batch_tokens, batch_logits, batch_acts, -1
    )

    target_prediction_losses.append(target_prediction_loss.item())
    dinalar_losses.append(dinalar_loss.item())

print(f"Target prediction loss: {np.mean(target_prediction_losses)}")
print(f"Dinalar loss: {np.mean(dinalar_losses)}")

Target prediction loss: 0.4922386109828949
Dinalar loss: 1.1478568315505981


In [17]:
# Create output widgets for displaying sample information
sample_output = widgets.Output()
button_output = widgets.Output()

# Create a counter and button
current_sample = 0


def on_button_click(b):
    global current_sample
    if current_sample < batch_size:
        display_sample(current_sample)
        current_sample += 1
    else:
        with sample_output:
            print("End of batch reached!")


def create_table(title, headers, rows, col_widths=None):
    """Helper function to create formatted tables

    Args:
        title: Table title string
        headers: List of header strings
        rows: List of rows, where each row is a list of values
        col_widths: List of column widths (defaults to 15 for all columns)

    Returns:
        Formatted table string
    """
    if col_widths is None:
        col_widths = [15] * len(headers)

    # Ensure first column width accommodates row labels
    col_widths[0] = max(col_widths[0], 8)

    # Create table string
    table = f"{title}\n"

    # Create header
    header_row = headers[0].ljust(col_widths[0])
    for i, header in enumerate(headers[1:], 1):
        header_row += header.ljust(col_widths[i])
    table += header_row + "\n"

    # Add separator
    table += "-" * len(header_row) + "\n"

    # Add rows
    for row in rows:
        row_str = str(row[0]).ljust(col_widths[0])
        for i, cell in enumerate(row[1:], 1):
            row_str += str(cell).ljust(col_widths[i])
        table += row_str + "\n"

    return table


# Function to display a single sample
def display_sample(sample_idx):
    with torch.no_grad():
        with torch.autocast(device_type=cfg.device.type, dtype=cfg.dtype):
            # Extract single sample as a "batch" of size 1
            sample_tokens = target_generated_tokens[sample_idx : sample_idx + 1].to(
                cfg.device
            )
            sample_logits = target_logits[sample_idx : sample_idx + 1].to(
                cfg.device, dtype=torch.float32
            )
            sample_acts = target_acts[sample_idx : sample_idx + 1].to(cfg.device)

            # Get model outputs for this single sample
            (decoder_logits_target, decoder_logits_encoding, virtual_embs) = (
                encoder_decoder(sample_acts, sample_tokens, -1)
            )

            # Calculate losses using existing functions
            pred_loss = kl_div(decoder_logits_target, sample_logits).item()
            din_loss = calculate_dinalar_loss(
                decoder_logits_encoding,
                virtual_embs,
                encoder_decoder.decoder
                if not isinstance(encoder_decoder, torch.nn.DataParallel)
                else encoder_decoder.module.decoder,
            ).item()

            # Find top 3 closest tokens to each virtual embedding
            decoder = (
                encoder_decoder.decoder
                if not isinstance(encoder_decoder, torch.nn.DataParallel)
                else encoder_decoder.module.decoder
            )
            token_embeddings = get_embeddings_from_decoder(
                decoder
            ).weight  # [vocab_size, d_embed]

            # Calculate L2 distances between virtual embeddings and token embeddings
            # Reshape for broadcasting: [1, encoding_len, d_embed] and [vocab_size, 1, d_embed]
            v_embs = virtual_embs[0].unsqueeze(0)  # [1, encoding_len, d_embed]
            t_embs = token_embeddings.unsqueeze(1)  # [vocab_size, 1, d_embed]

            # Calculate squared distances
            distances = torch.sum(
                (v_embs - t_embs) ** 2, dim=2
            )  # [vocab_size, encoding_len]

            # Get top 3 closest tokens for each virtual embedding
            top_k = 3
            top_values, top_indices = torch.topk(
                distances, k=top_k, dim=0, largest=False
            )

            # Decode tokens for display
            sample_decoded = tokenizer.decode(sample_tokens[0])

            # Display results
            with sample_output:
                sample_output.clear_output(wait=True)
                print(f"Sample {sample_idx+1}/{batch_size}")
                print(f"Target prediction loss: {pred_loss:.6f}")
                print(f"Dinalar loss: {din_loss:.6f}")
                print("\nTarget tokens:")
                print(sample_decoded, "\n")

                # Also decode and display the prefix and suffix tokens
                prefix_tokens = tokenizer(PROMPT_PREFIX, return_tensors="pt").input_ids[
                    0
                ]
                suffix_tokens = tokenizer(PROMPT_SUFFIX, return_tensors="pt").input_ids[
                    0
                ]

                print(tokenizer.decode(prefix_tokens))

                # Prepare data for the tokens table
                col_width = 15
                headers = ["Token"] + [f"Emb {i}" for i in range(virtual_embs.shape[1])]

                token_rows = []
                for k in range(top_k):
                    row = [f"Top {k+1}:"]
                    for j in range(virtual_embs.shape[1]):
                        token_id = top_indices[k, j].item()
                        token_text = tokenizer.decode([token_id])
                        # Replace newlines and tabs for cleaner display
                        token_text = token_text.replace("\n", "\\n").replace(
                            "\t", "\\t"
                        )
                        # Truncate to fit in column
                        token_display = token_text[: col_width - 2]
                        row.append(token_display)
                    token_rows.append(row)

                # Create and display the token table
                token_table = create_table(
                    "", headers, token_rows, [8] + [col_width] * virtual_embs.shape[1]
                )
                print(token_table)

                # Create table for distances
                distance_rows = []
                for k in range(top_k):
                    row = [f"Top {k+1}:"]
                    for j in range(virtual_embs.shape[1]):
                        # Format distance value with 5 decimal places
                        distance = top_values[k, j].item()
                        row.append(f"{distance:.5f}")
                    distance_rows.append(row)

                # Create and display the distances table
                distance_table = create_table(
                    "L2 Distances:",
                    headers,
                    distance_rows,
                    [8] + [col_width] * virtual_embs.shape[1],
                )
                print(distance_table)

                print(tokenizer.decode(suffix_tokens))


# Interactive sample investigation
next_button = widgets.Button(description="Next Sample")
next_button.on_click(on_button_click)

# Display the button and sample output in separate areas
with button_output:
    display(next_button)
display(button_output)
display(sample_output)

# Show the first sample
if batch_size > 0:
    display_sample(current_sample)
    current_sample += 1
else:
    with sample_output:
        print("No samples in batch!")

Output()

Output()