In [1]:
import importlib

import ipywidgets as widgets
import numpy as np
import torch
from IPython.display import clear_output, display
from transformers import AutoTokenizer
from jaxtyping import Float
from torch import Tensor
from einops import einsum

import eli.train
importlib.reload(eli.train)

from eli.train.config import encoder_cfg, train_cfg
from eli.train.encoder import EncoderDecoder, Encoder, PROMPT_DECODER
from eli.train.train import pull_dataset_config

In [2]:
device = torch.device("cuda:0")

# Load Encoder

In [3]:
import boto3

def load_encoder_from_s3(s3_key: str, target_path):
    s3_client = boto3.client("s3")
    s3_client.download_file(train_cfg.s3_bucket, s3_key, target_path)
    return torch.load(target_path)

In [4]:

dataset_cfg = pull_dataset_config(train_cfg)

decoder_tokenizer = AutoTokenizer.from_pretrained(train_cfg.decoder_model_name)
decoder_tokenizer.pad_token = decoder_tokenizer.eos_token

encoder_decoder = EncoderDecoder(decoder_tokenizer, dataset_cfg, encoder_cfg, train_cfg).to(device)

encoder = Encoder(dataset_cfg, train_cfg, encoder_cfg).to(device)

encoder_path = "encoder.pt"
# state_dict = load_encoder_from_s3("models/EleutherAI-pythia-70m-resid_post-4-5-100000000-pythia-70m-encoder.pt", encoder_path)
state_dict = torch.load(encoder_path)

encoder.load_state_dict(state_dict)

encoder_decoder.encoder = encoder

encoder_decoder = encoder_decoder.eval()

INFO:botocore.credentials:Found credentials in environment variables.


Downloaded dataset configuration from s3://eli-datasets/datasets/EleutherAI-pythia-70m-resid_post-4-5-100000000/config.json


# Load dataset

In [5]:
from eli.train.download import download_dataset
from eli.train.train import preprocess_acts, preprocess_target_generated_tokens

target_tokenizer = AutoTokenizer.from_pretrained(dataset_cfg.target_model_name)
target_tokenizer.pad_token = target_tokenizer.eos_token

data_loader = iter(download_dataset(dataset_cfg, train_cfg))

URL: pipe: aws s3 cp s3://eli-datasets/datasets/EleutherAI-pythia-70m-resid_post-4-5-100000000/{00000000..00000999}.tar -


# Print summary stats for validation

In [6]:
num_batches = 20

from eli.train.encoder import get_loss

losses = []

for batch_idx in range(num_batches):
    target_acts, target_generated_tokens = next(data_loader)
    target_acts = preprocess_acts(target_acts)
    target_acts, target_generated_tokens = (
        target_acts.to(device),
        target_generated_tokens.to(device),
    )
    target_generated_tokens, attention_mask = (
        preprocess_target_generated_tokens(
            target_generated_tokens,
            target_tokenizer,
            decoder_tokenizer,
        )
    )
    losses.append(get_loss(
        train_cfg,
        device,
        encoder_decoder,
        target_generated_tokens,
        attention_mask,
        target_acts,
        decoder_tokenizer,
        -1,
    ).item())

print(np.mean(losses))

3.2158238291740417


# View embeddings

In [7]:
target_acts, target_generated_tokens = next(data_loader)
target_acts = preprocess_acts(target_acts)
target_acts, target_generated_tokens = (
    target_acts.to(device),
    target_generated_tokens.to(device),
)
target_generated_tokens, attention_mask = (
    preprocess_target_generated_tokens(
        target_generated_tokens,
        target_tokenizer,
        decoder_tokenizer,
    )
)

In [14]:
from eli.train.encoder import get_target_prediction_loss

# Create output widgets for displaying sample information
sample_output = widgets.Output()

# Create a counter and button
current_sample = 0

def get_similarities_to_embeddings(
    embeddings_seq: Float[Tensor, "batch tok d_embed"],
    embeddings_vocab: Float[Tensor, "vocab d_embed"],
    cosine: bool = True,
) -> Float[Tensor, "batch tok vocab"]:
    if cosine:
        embeddings_seq_norm = embeddings_seq / embeddings_seq.norm(dim=-1, keepdim=True)
        embeddings_vocab_norm = embeddings_vocab / embeddings_vocab.norm(dim=-1, keepdim=True)

        return einsum(
            embeddings_seq_norm,
            embeddings_vocab_norm,
            "batch tok d_embed, vocab d_embed -> batch tok vocab",
        )
    else:
        embeddings_vocab = embeddings_vocab.unsqueeze(0)
        return torch.cdist(embeddings_seq, embeddings_vocab, p=2)


def on_button_click(b):
    global current_sample
    if current_sample < target_generated_tokens.shape[0]:
        display_sample(current_sample)
        current_sample += 1
    else:
        with sample_output:
            print("End of batch reached!")


def create_table(title, headers, rows, col_widths=None):
    """Helper function to create formatted tables

    Args:
        title: Table title string
        headers: List of header strings
        rows: List of rows, where each row is a list of values
        col_widths: List of column widths (defaults to 15 for all columns)

    Returns:
        Formatted table string
    """
    if col_widths is None:
        col_widths = [15] * len(headers)

    # Ensure first column width accommodates row labels
    col_widths[0] = max(col_widths[0], 8)

    # Create table string
    table = f"{title}\n"

    # Create header
    header_row = headers[0].ljust(col_widths[0])
    for i, header in enumerate(headers[1:], 1):
        header_row += header.ljust(col_widths[i])
    table += header_row + "\n"

    # Add separator
    table += "-" * len(header_row) + "\n"

    # Add rows
    for row in rows:
        row_str = str(row[0]).ljust(col_widths[0])
        for i, cell in enumerate(row[1:], 1):
            row_str += str(cell).ljust(col_widths[i])
        table += row_str + "\n"

    return table


losses = []


# Function to display a single sample
def display_sample(sample_idx):
    with torch.no_grad():
        # Extract single sample as a "batch" of size 1
        sample_tokens = target_generated_tokens[sample_idx : sample_idx + 1]
        sample_acts = target_acts[sample_idx : sample_idx + 1]
        sample_attention_mask = attention_mask[sample_idx : sample_idx + 1]

        # Get model outputs for this single sample
        (decoder_logits_target, decoder_logits_encoding, virtual_embs) = (
            encoder_decoder(
                sample_acts, sample_tokens, sample_attention_mask, train_iter=-1
            )
        )

        # Calculate loss
        pred_loss = get_target_prediction_loss(
            decoder_logits_target, sample_tokens, decoder_tokenizer
        ).item()
        losses.append(pred_loss)

        # Get similarities between virtual embeddings and token embeddings of
        # decoder
        embeddings = encoder_decoder.decoder.get_input_embeddings().weight
        similarities = get_similarities_to_embeddings(
            virtual_embs, embeddings, cosine=True
        )  # [batch tok vocab]

        # Get top k tokens by encoder output logits
        top_k = 10
        top_values, top_indices = torch.topk(
            similarities[0], k=top_k, dim=-1
        )  # [batch tok]

        # Decode tokens for display
        sample_decoded = decoder_tokenizer.decode(sample_tokens[0])


        # Display results
        with sample_output:
            sample_output.clear_output(wait=True)
            print(f"Sample {sample_idx + 1}/{target_generated_tokens.shape[0]}")
            print(f"Target prediction loss: {pred_loss:.6f}")
            print("\nTarget tokens:")
            print(sample_decoded, "\n")

            prompt_prefix, prompt_suffix = PROMPT_DECODER.split("<thought>")

            print(prompt_prefix)

            col_width = 15
            headers = ["Token"]
            for i in range(virtual_embs.shape[1]):
                headers.extend([f"Emb {i}", f"Sim {i}"])

            token_rows = []
            for k in range(top_k):
                row = [f"Top {k + 1}:"]
                for j in range(virtual_embs.shape[1]):
                    token_id = top_indices[j, k].item()
                    token_text = decoder_tokenizer.decode([token_id])
                    # Replace newlines and tabs for cleaner display
                    token_text = token_text.replace("\n", "\\n").replace(
                        "\t", "\\t"
                    )
                    # Truncate to fit in column
                    token_display = token_text[: col_width - 2]
                    # Add token and similarity as separate columns
                    row.append(token_display)
                    row.append(f"{top_values[j, k].item():.3f}")
                token_rows.append(row)

            # Create and display the token table
            token_table = create_table(
                "",
                headers,
                token_rows,
                [8]
                + [col_width, 8] * virtual_embs.shape[1],  # Adjusted column widths
            )
            print(token_table)

            print(prompt_suffix)

            print(sum(losses) / len(losses))


# Interactive sample investigation
next_button = widgets.Button(description="Next Sample")
next_button.on_click(on_button_click)

# Show the first sample
display_sample(current_sample)

display(next_button)
display(sample_output)


Button(description='Next Sample', style=ButtonStyle())

Output()