In [15]:
import ipywidgets as widgets
import numpy as np
import torch
from IPython.display import clear_output, display
from transformers import AutoTokenizer

import importlib

import eli.encoder

importlib.reload(eli.encoder)
from eli.config import cfg, encoder_cfg
from eli.encoder import (
    PROMPT_PREFIX,
    PROMPT_SUFFIX,
    Encoder,
    EncoderDecoder,
    EncoderTrainer,
    calculate_dinalar_loss,
    get_embeddings_from_decoder,
    kl_div,
    calculate_target_prediction_loss,
)
from einops import einsum

In [17]:
# Load encoder and decoder

cfg.buffer_size_samples = cfg.target_model_batch_size_samples = cfg.train_batch_size_samples

tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name)

encoder_decoder = EncoderDecoder(cfg, encoder_cfg, tokenizer).to(cfg.device)

encoder = Encoder(cfg, encoder_cfg).to(cfg.device)

encoder_path = "encoder-dec-16-enc-4-pythia-31m.pt"
encoder.load_state_dict(torch.load(encoder_path))

encoder_decoder.encoder = encoder

encoder_decoder = encoder_decoder.eval()

In [3]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(cfg.target_model_name)

In [4]:
# Load eval data

from eli.data import DataCollector

cfg.use_data_collector_workers = False

print("Initializing data collector")
data_collector = DataCollector(cfg)

print("Collecting data")
data_collector.collect_data()

data = data_collector.data

Initializing data collector


INFO:root:target_generated_tokens size: 16384 bytes (0.02 MB)
INFO:root:target_acts size: 262144 bytes (0.25 MB)
INFO:root:input_tokens size: 65536 bytes (0.06 MB)
INFO:root:Total shared memory size: 0.00 GB


Collecting data


Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

INFO:root:Tokenize and concatenate called
INFO:root:Full text length: 43727350
INFO:root:Num tokens: 9672977
  return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
INFO:root:Processing data directly on cuda without workers
INFO:root:Processing chunk 0:256 on cuda
INFO:root:Processing batch 0:256


Loaded pretrained model EleutherAI/pythia-31m into HookedTransformer
Moving model to device:  cuda


INFO:root:CHUNK 0:256 COMPLETED, Max GPU memory allocated: 2.24 GB, Max GPU memory reserved: 2.52 GB
INFO:root:Direct data processing completed


In [9]:
# Print loss statistics
target_generated_tokens = data["target_generated_tokens"]
target_acts = data["target_acts"]

buffer_size = target_acts.shape[0]
batch_size = cfg.train_batch_size_samples
num_batches = buffer_size // batch_size

target_prediction_losses = []
dinalar_losses = []

for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = start_idx + batch_size

    # Extract batch data and move to device
    batch_tokens = target_generated_tokens[start_idx:end_idx].to(cfg.device)
    batch_acts = target_acts[start_idx:end_idx].to(cfg.device)

    loss, target_prediction_loss, dinalar_loss, encoder_output_logits = EncoderTrainer.loss(
        cfg, encoder_decoder, batch_tokens, batch_acts, -1
    )

    target_prediction_losses.append(target_prediction_loss.item())
    dinalar_losses.append(dinalar_loss.item())

print(f"Target prediction loss: {np.mean(target_prediction_losses)}")
print(f"Dinalar loss: {np.mean(dinalar_losses)}")

Target prediction loss: 3.610379219055176
Dinalar loss: 5.70062255859375


In [31]:
# Create output widgets for displaying sample information
sample_output = widgets.Output()
button_output = widgets.Output()

# Create a counter and button
current_sample = 0


def on_button_click(b):
    global current_sample
    if current_sample < batch_size:
        display_sample(current_sample)
        current_sample += 1
    else:
        with sample_output:
            print("End of batch reached!")


def create_table(title, headers, rows, col_widths=None):
    """Helper function to create formatted tables

    Args:
        title: Table title string
        headers: List of header strings
        rows: List of rows, where each row is a list of values
        col_widths: List of column widths (defaults to 15 for all columns)

    Returns:
        Formatted table string
    """
    if col_widths is None:
        col_widths = [15] * len(headers)

    # Ensure first column width accommodates row labels
    col_widths[0] = max(col_widths[0], 8)

    # Create table string
    table = f"{title}\n"

    # Create header
    header_row = headers[0].ljust(col_widths[0])
    for i, header in enumerate(headers[1:], 1):
        header_row += header.ljust(col_widths[i])
    table += header_row + "\n"

    # Add separator
    table += "-" * len(header_row) + "\n"

    # Add rows
    for row in rows:
        row_str = str(row[0]).ljust(col_widths[0])
        for i, cell in enumerate(row[1:], 1):
            row_str += str(cell).ljust(col_widths[i])
        table += row_str + "\n"

    return table


# Function to display a single sample
def display_sample(sample_idx):
    with torch.no_grad():
        with torch.autocast(device_type=cfg.device.type, dtype=cfg.dtype):
            # Extract single sample as a "batch" of size 1
            sample_tokens = target_generated_tokens[sample_idx : sample_idx + 1].to(
                cfg.device
            )
            sample_acts = target_acts[sample_idx : sample_idx + 1].to(cfg.device)

            encoder = encoder_decoder.encoder
            encoder_output_logits = encoder(sample_acts) # [batch tok vocab]

            # Convert logits to one-hot-like by making the max value very large and others small
            # max_values, max_indices = torch.max(encoder_output_logits, dim=-1, keepdim=True)
            # one_hot_logits = torch.ones_like(encoder_output_logits) * -100.0  # Set all values to a small number
            # one_hot_logits.scatter_(dim=-1, index=max_indices, value=100.0)   # Set max values to a large number
            # encoder_output_logits = one_hot_logits

            # Get model outputs for this single sample
            (decoder_logits_target, decoder_logits_encoding, virtual_embs, encoder_output_logits) = (
                encoder_decoder(sample_acts, sample_tokens, encoder_output_logits=encoder_output_logits, train_iter=-1)
            )

            # Calculate losses using existing functions
            pred_loss = calculate_target_prediction_loss(decoder_logits_target, sample_tokens).item()
            din_loss = calculate_dinalar_loss(
                decoder_logits_encoding,
                encoder_output_logits,
            ).item()


            print("encoder_output_logits.shape", encoder_output_logits.shape)

            encoder_output_probs = torch.nn.functional.softmax(encoder_output_logits, dim=-1)

            # Get top 3 tokens by encoder output logits
            top_k = 3
            top_values, top_indices = torch.topk(
                encoder_output_probs[0], k=top_k, dim=-1
            )  # [batch tok]

            print(top_indices)
            print(top_indices.shape)

            print("top_values.shape", top_values.shape)

            # Decode tokens for display
            sample_decoded = tokenizer.decode(sample_tokens[0])

            # Display results
            with sample_output:
                sample_output.clear_output(wait=True)
                print(f"Sample {sample_idx+1}/{batch_size}")
                print(f"Target prediction loss: {pred_loss:.6f}")
                print(f"Dinalar loss: {din_loss:.6f}")
                print("\nTarget tokens:")
                print(sample_decoded, "\n")

                # Also decode and display the prefix and suffix tokens
                prefix_tokens = tokenizer(PROMPT_PREFIX, return_tensors="pt").input_ids[
                    0
                ]
                suffix_tokens = tokenizer(PROMPT_SUFFIX, return_tensors="pt").input_ids[
                    0
                ]

                print(tokenizer.decode(prefix_tokens))

                # Prepare data for the tokens table
                col_width = 15
                headers = ["Token"] + [f"Emb {i}" for i in range(virtual_embs.shape[1])]

                token_rows = []
                for k in range(top_k):
                    row = [f"Top {k+1}:"]
                    for j in range(virtual_embs.shape[1]):
                        token_id = top_indices[j, k].item()
                        token_text = tokenizer.decode([token_id])
                        # Replace newlines and tabs for cleaner display
                        token_text = token_text.replace("\n", "\\n").replace(
                            "\t", "\\t"
                        )
                        # Truncate to fit in column
                        token_display = token_text[: col_width - 2]
                        row.append(token_display)
                    token_rows.append(row)

                # Create and display the token table
                token_table = create_table(
                    "", headers, token_rows, [8] + [col_width] * virtual_embs.shape[1]
                )
                print(token_table)

                # Create table for logit values
                logit_rows = []
                for k in range(top_k):
                    row = [f"Top {k+1}:"]
                    for j in range(virtual_embs.shape[1]):
                        # Format logit value with 5 decimal places
                        prob = top_values[j, k].item()
                        row.append(f"{prob:.5f}")
                    logit_rows.append(row)

                # Create and display the logits table
                logit_table = create_table(
                    "Logit Values:",
                    headers,
                    logit_rows,
                    [8] + [col_width] * virtual_embs.shape[1],
                )
                print(logit_table)

                print(tokenizer.decode(suffix_tokens))


# Interactive sample investigation
next_button = widgets.Button(description="Next Sample")
next_button.on_click(on_button_click)

# Display the button and sample output in separate areas
with button_output:
    display(next_button)
display(button_output)
display(sample_output)

# Show the first sample
if batch_size > 0:
    display_sample(current_sample)
    current_sample += 1
else:
    with sample_output:
        print("No samples in batch!")

Output()

Output()

encoder_output_logits.shape torch.Size([1, 4, 50304])
tensor([[  187,  5453, 43472],
        [41768,  6123,     3],
        [   13, 40125, 43327],
        [41972,   247,    15]], device='cuda:0')
torch.Size([4, 3])
top_values.shape torch.Size([4, 3])


encoder_output_logits.shape torch.Size([1, 4, 50304])
tensor([[  187, 24239,  1947],
        [15436,   253, 30696],
        [   13,   669, 13858],
        [  253, 30611, 48763]], device='cuda:0')
torch.Size([4, 3])
top_values.shape torch.Size([4, 3])
encoder_output_logits.shape torch.Size([1, 4, 50304])
tensor([[  187, 14457,  5453],
        [  253, 22778, 30081],
        [15032,  1058,  9436],
        [13826,   253, 36634]], device='cuda:0')
torch.Size([4, 3])
top_values.shape torch.Size([4, 3])
encoder_output_logits.shape torch.Size([1, 4, 50304])
tensor([[  187,  3182,  4144],
        [  253, 24645,  6123],
        [16372, 15739,    13],
        [35135, 28076,  8772]], device='cuda:0')
torch.Size([4, 3])
top_values.shape torch.Size([4, 3])
encoder_output_logits.shape torch.Size([1, 4, 50304])
tensor([[  187, 48278,  5453],
        [40394,   273, 20585],
        [ 9646, 43648, 29626],
        [ 8772, 33721, 15699]], device='cuda:0')
torch.Size([4, 3])
top_values.shape torch.Size([4, 