In [2]:
import ipywidgets as widgets
import numpy as np
import torch
from IPython.display import clear_output, display
from transformers import AutoTokenizer

from eli.config import cfg, encoder_cfg
from eli.encoder import (
    PROMPT_PREFIX,
    PROMPT_SUFFIX,
    Encoder,
    EncoderDecoder,
    EncoderTrainer,
    calculate_dinalar_loss,
    get_embeddings_from_decoder,
    kl_div,
)

In [None]:
# Load encoder and decoder

cfg.buffer_size_samples = cfg.target_model_batch_size_samples = cfg.train_batch_size_samples

tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name)

encoder_decoder = EncoderDecoder(cfg, encoder_cfg, tokenizer).to(cfg.device)

encoder = Encoder(cfg, encoder_cfg).to(cfg.device)

encoder_path = "saved_models/encoder-dinalar-1e-2.pt"
encoder.load_state_dict(torch.load(encoder_path))

encoder_decoder.encoder = encoder

encoder_decoder = encoder_decoder.eval()

In [2]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(cfg.target_model_name)

In [3]:
# Load eval data

from eli.data import DataCollector

print("Initializing data collector")
data_collector = DataCollector(use_workers=False)

print("Collecting data")
data_collector.collect_data()

data = data_collector.data

Initializing data collector


INFO:root:target_generated_tokens size: 480 bytes (0.00 MB)
INFO:root:target_logits size: 65667072 bytes (62.62 MB)
INFO:root:target_acts size: 65536 bytes (0.06 MB)
INFO:root:input_tokens size: 2048 bytes (0.00 MB)
INFO:root:Total shared memory size: 0.06 GB


Collecting data


Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

INFO:root:Tokenize and concatenate called
INFO:root:Full text length: 43667353
Token indices sequence length is longer than the specified maximum sequence length for this model (472110 > 131072). Running this sequence through the model will result in indexing errors
INFO:root:Num tokens: 9321501
  return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
INFO:root:Processing data directly on cuda without workers
INFO:root:Processing chunk 0:8 on cuda
INFO:root:Processing batch 0:8


Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer
Moving model to device:  cuda


INFO:root:CHUNK 0:8 COMPLETED, Max GPU memory allocated: 18.30 GB, Max GPU memory reserved: 18.42 GB
INFO:root:Direct data processing completed


In [3]:
# Print loss statistics
target_generated_tokens = data["target_generated_tokens"]
target_logits = data["target_logits"]
target_acts = data["target_acts"]

buffer_size = target_acts.shape[0]
batch_size = cfg.train_batch_size_samples
num_batches = buffer_size // batch_size

target_prediction_losses = []
dinalar_losses = []

for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = start_idx + batch_size

    # Extract batch data and move to device
    batch_tokens = target_generated_tokens[start_idx:end_idx].to(cfg.device)
    batch_logits = target_logits[start_idx:end_idx].to(cfg.device, dtype=torch.float32)
    batch_acts = target_acts[start_idx:end_idx].to(cfg.device)

    loss, target_prediction_loss, dinalar_loss = EncoderTrainer.loss(
        cfg, encoder_decoder, batch_tokens, batch_logits, batch_acts, -1
    )

    target_prediction_losses.append(target_prediction_loss.item())
    dinalar_losses.append(dinalar_loss.item())

print(f"Target prediction loss: {np.mean(target_prediction_losses)}")
print(f"Dinalar loss: {np.mean(dinalar_losses)}")

Target prediction loss: 0.499033123254776
Dinalar loss: 0.48165082931518555


In [6]:
# Create output widgets for displaying sample information
sample_output = widgets.Output()
button_output = widgets.Output()

# Create a counter and button
current_sample = 0


def on_button_click(b):
    global current_sample
    if current_sample < batch_size:
        display_sample(current_sample)
        current_sample += 1
    else:
        with sample_output:
            print("End of batch reached!")


def create_table(title, headers, rows, col_widths=None):
    """Helper function to create formatted tables

    Args:
        title: Table title string
        headers: List of header strings
        rows: List of rows, where each row is a list of values
        col_widths: List of column widths (defaults to 15 for all columns)

    Returns:
        Formatted table string
    """
    if col_widths is None:
        col_widths = [15] * len(headers)

    # Ensure first column width accommodates row labels
    col_widths[0] = max(col_widths[0], 8)

    # Create table string
    table = f"{title}\n"

    # Create header
    header_row = headers[0].ljust(col_widths[0])
    for i, header in enumerate(headers[1:], 1):
        header_row += header.ljust(col_widths[i])
    table += header_row + "\n"

    # Add separator
    table += "-" * len(header_row) + "\n"

    # Add rows
    for row in rows:
        row_str = str(row[0]).ljust(col_widths[0])
        for i, cell in enumerate(row[1:], 1):
            row_str += str(cell).ljust(col_widths[i])
        table += row_str + "\n"

    return table


# Function to display a single sample
def display_sample(sample_idx):
    with torch.no_grad():
        with torch.autocast(device_type=cfg.device.type, dtype=cfg.dtype):
            # Extract single sample as a "batch" of size 1
            sample_tokens = target_generated_tokens[sample_idx : sample_idx + 1].to(
                cfg.device
            )
            sample_logits = target_logits[sample_idx : sample_idx + 1].to(
                cfg.device, dtype=torch.float32
            )
            sample_acts = target_acts[sample_idx : sample_idx + 1].to(cfg.device)

            # Get model outputs for this single sample
            (decoder_logits_target, decoder_logits_encoding, virtual_embs) = (
                encoder_decoder(sample_acts, sample_tokens, -1)
            )

            # Calculate losses using existing functions
            pred_loss = kl_div(decoder_logits_target, sample_logits).item()
            din_loss = calculate_dinalar_loss(
                decoder_logits_encoding,
                virtual_embs,
                encoder_decoder.decoder
                if not isinstance(encoder_decoder, torch.nn.DataParallel)
                else encoder_decoder.module.decoder,
            ).item()

            # Find top 3 closest tokens to each virtual embedding
            decoder = (
                encoder_decoder.decoder
                if not isinstance(encoder_decoder, torch.nn.DataParallel)
                else encoder_decoder.module.decoder
            )
            token_embeddings = get_embeddings_from_decoder(
                decoder
            ).weight  # [vocab_size, d_embed]

            # Calculate L2 distances between virtual embeddings and token embeddings
            # Reshape for broadcasting: [1, encoding_len, d_embed] and [vocab_size, 1, d_embed]
            v_embs = virtual_embs[0].unsqueeze(0)  # [1, encoding_len, d_embed]
            t_embs = token_embeddings.unsqueeze(1)  # [vocab_size, 1, d_embed]

            # Calculate squared distances
            distances = torch.sum(
                (v_embs - t_embs) ** 2, dim=2
            )  # [vocab_size, encoding_len]

            # Get top 3 closest tokens for each virtual embedding
            top_k = 3
            top_values, top_indices = torch.topk(
                distances, k=top_k, dim=0, largest=False
            )

            # Decode tokens for display
            sample_decoded = tokenizer.decode(sample_tokens[0])

            # Display results
            with sample_output:
                sample_output.clear_output(wait=True)
                print(f"Sample {sample_idx+1}/{batch_size}")
                print(f"Target prediction loss: {pred_loss:.6f}")
                print(f"Dinalar loss: {din_loss:.6f}")
                print("\nTarget tokens:")
                print(sample_decoded, "\n")

                # Also decode and display the prefix and suffix tokens
                prefix_tokens = tokenizer(PROMPT_PREFIX, return_tensors="pt").input_ids[
                    0
                ]
                suffix_tokens = tokenizer(PROMPT_SUFFIX, return_tensors="pt").input_ids[
                    0
                ]

                print(tokenizer.decode(prefix_tokens))

                # Prepare data for the tokens table
                col_width = 15
                headers = ["Token"] + [f"Emb {i}" for i in range(virtual_embs.shape[1])]

                token_rows = []
                for k in range(top_k):
                    row = [f"Top {k+1}:"]
                    for j in range(virtual_embs.shape[1]):
                        token_id = top_indices[k, j].item()
                        token_text = tokenizer.decode([token_id])
                        # Replace newlines and tabs for cleaner display
                        token_text = token_text.replace("\n", "\\n").replace(
                            "\t", "\\t"
                        )
                        # Truncate to fit in column
                        token_display = token_text[: col_width - 2]
                        row.append(token_display)
                    token_rows.append(row)

                # Create and display the token table
                token_table = create_table(
                    "", headers, token_rows, [8] + [col_width] * virtual_embs.shape[1]
                )
                print(token_table)

                # Create table for distances
                distance_rows = []
                for k in range(top_k):
                    row = [f"Top {k+1}:"]
                    for j in range(virtual_embs.shape[1]):
                        # Format distance value with 5 decimal places
                        distance = top_values[k, j].item()
                        row.append(f"{distance:.5f}")
                    distance_rows.append(row)

                # Create and display the distances table
                distance_table = create_table(
                    "L2 Distances:",
                    headers,
                    distance_rows,
                    [8] + [col_width] * virtual_embs.shape[1],
                )
                print(distance_table)

                print(tokenizer.decode(suffix_tokens))


# Interactive sample investigation
next_button = widgets.Button(description="Next Sample")
next_button.on_click(on_button_click)

# Display the button and sample output in separate areas
with button_output:
    display(next_button)
display(button_output)
display(sample_output)

# Show the first sample
if batch_size > 0:
    display_sample(current_sample)
    current_sample += 1
else:
    with sample_output:
        print("No samples in batch!")

Output()

Output()

In [8]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Create output widgets
test_output = widgets.Output()
button_area = widgets.Output()

def get_embeddings_for_text(text, tokenizer, decoder, required_len):
    """Convert text to token embeddings, matching required length."""
    # Tokenize the text
    tokens = tokenizer(text, return_tensors="pt").input_ids.to(cfg.device)
    
    # Get embeddings from decoder
    embeddings = get_embeddings_from_decoder(decoder)
    text_embeddings = embeddings(tokens)
    
    # Truncate or pad to match required length
    if text_embeddings.shape[1] > required_len:
        # Truncate
        text_embeddings = text_embeddings[:, :required_len, :]
    elif text_embeddings.shape[1] < required_len:
        # Pad with zeros
        padding = torch.zeros(
            (text_embeddings.shape[0], 
             required_len - text_embeddings.shape[1], 
             text_embeddings.shape[2]),
            device=text_embeddings.device
        )
        text_embeddings = torch.cat([text_embeddings, padding], dim=1)
    
    return tokens, text_embeddings

def test_embedding_comparison(user_text, sample_idx=0):
    """Test and compare prediction loss between user text embeddings and virtual embeddings."""
    with torch.no_grad():
        with torch.autocast(device_type=cfg.device.type, dtype=cfg.dtype):
            # Get a sample from the data
            sample_tokens = target_generated_tokens[sample_idx:sample_idx+1].to(cfg.device)
            sample_logits = target_logits[sample_idx:sample_idx+1].to(cfg.device, dtype=torch.float32)
            sample_acts = target_acts[sample_idx:sample_idx+1].to(cfg.device)
            
            # Get embeddings for user text
            user_tokens, user_embeddings = get_embeddings_for_text(
                user_text, tokenizer, 
                encoder_decoder.decoder if not isinstance(encoder_decoder, torch.nn.DataParallel) 
                else encoder_decoder.module.decoder,
                cfg.encoding_len_toks
            )
            
            # Generate virtual embeddings with the encoder for comparison
            virtual_embeddings = encoder_decoder.encoder(sample_acts)
            
            # Run decoder with user embeddings
            user_context_embeddings, user_attention_mask, user_token_lens = (
                encoder_decoder.assemble_decoder_context_embeddings(
                    sample_tokens, user_embeddings, -1
                )
            )
            
            user_decoder_logits = encoder_decoder.decoder(
                inputs_embeds=user_context_embeddings, 
                attention_mask=user_attention_mask
            ).logits
            
            user_decoder_logits_target = user_decoder_logits[:, -cfg.decoder_pred_len_toks:, :]
            user_pred_loss = kl_div(user_decoder_logits_target, sample_logits).item()
            
            # Run decoder with virtual embeddings
            virtual_context_embeddings, virtual_attention_mask, virtual_token_lens = (
                encoder_decoder.assemble_decoder_context_embeddings(
                    sample_tokens, virtual_embeddings, -1
                )
            )
            
            virtual_decoder_logits = encoder_decoder.decoder(
                inputs_embeds=virtual_context_embeddings, 
                attention_mask=virtual_attention_mask
            ).logits
            
            virtual_decoder_logits_target = virtual_decoder_logits[:, -cfg.decoder_pred_len_toks:, :]
            virtual_pred_loss = kl_div(virtual_decoder_logits_target, sample_logits).item()
            
            # Calculate decoder logits for encoding region (for closest token analysis)
            prefix_len = user_token_lens["prefix_tokens_len"]
            user_decoder_logits_encoding = user_decoder_logits[:, 
                prefix_len:(prefix_len + cfg.encoding_len_toks), :]
            
            # Find closest tokens for both sets of embeddings
            decoder = (
                encoder_decoder.decoder
                if not isinstance(encoder_decoder, torch.nn.DataParallel)
                else encoder_decoder.module.decoder
            )
            token_embeddings = get_embeddings_from_decoder(decoder).weight  # [vocab_size, d_embed]
            
            # For user embeddings
            u_embs = user_embeddings[0].unsqueeze(0)  # [1, encoding_len, d_embed]
            t_embs = token_embeddings.unsqueeze(1)  # [vocab_size, 1, d_embed]
            user_distances = torch.sum((u_embs - t_embs) ** 2, dim=2)  # [vocab_size, encoding_len]
            
            # For virtual embeddings
            v_embs = virtual_embeddings[0].unsqueeze(0)  # [1, encoding_len, d_embed]
            virtual_distances = torch.sum((v_embs - t_embs) ** 2, dim=2)  # [vocab_size, encoding_len]
            
            # Get top closest tokens
            top_k = 3
            user_top_values, user_top_indices = torch.topk(
                user_distances, k=top_k, dim=0, largest=False
            )
            virtual_top_values, virtual_top_indices = torch.topk(
                virtual_distances, k=top_k, dim=0, largest=False
            )
            
            return {
                "user_text": user_text,
                "user_tokens": user_tokens,
                "user_embeddings": user_embeddings,
                "virtual_embeddings": virtual_embeddings,
                "user_pred_loss": user_pred_loss,
                "virtual_pred_loss": virtual_pred_loss,
                "sample_tokens": sample_tokens,
                "sample_decoded": tokenizer.decode(sample_tokens[0]),
                "user_top_indices": user_top_indices,
                "user_top_values": user_top_values,
                "virtual_top_indices": virtual_top_indices,
                "virtual_top_values": virtual_top_values,
            }

# Create widgets
text_input = widgets.Textarea(
    value="The model is thinking about the relationship between cause and effect.",
    placeholder="Enter text to embed...",
    description="Input Text:",
    disabled=False,
    layout=widgets.Layout(width="100%", height="100px")
)

# Create a counter for the current sample
current_sample = 0

run_button = widgets.Button(
    description='Run Test',
    disabled=False,
    button_style='primary', 
    tooltip='Run the test with the provided text'
)

next_button = widgets.Button(
    description='Next Sample',
    disabled=False,
    button_style='info',
    tooltip='Move to the next sample'
)

# Function to display results
def display_test_results(results):
    with test_output:
        test_output.clear_output(wait=True)
        print(f"Sample {current_sample+1}/{batch_size}")
        print(f"User text: \"{results['user_text']}\"")
        print(f"User embeddings prediction loss: {results['user_pred_loss']:.6f}")
        print(f"Virtual embeddings prediction loss: {results['virtual_pred_loss']:.6f}")
        print(f"\nTarget sample text:")
        print(results['sample_decoded'])
        
        # Re-use the create_table function from the previous cell
        # Display closest tokens for user embeddings
        col_width = 15
        headers = ["Token"] + [f"Emb {i}" for i in range(results['user_embeddings'].shape[1])]
        
        print("\nClosest tokens to USER embeddings:")
        user_token_rows = []
        for k in range(3):  # Top 3
            row = [f"Top {k+1}:"]
            for j in range(results['user_embeddings'].shape[1]):
                token_id = results['user_top_indices'][k, j].item()
                token_text = tokenizer.decode([token_id])
                token_text = token_text.replace("\n", "\\n").replace("\t", "\\t")
                token_display = token_text[:col_width-2]
                row.append(token_display)
            user_token_rows.append(row)
        
        user_token_table = create_table(
            "", headers, user_token_rows, [8] + [col_width] * results['user_embeddings'].shape[1]
        )
        print(user_token_table)
        
        print("\nClosest tokens to VIRTUAL embeddings:")
        virtual_token_rows = []
        for k in range(3):  # Top 3
            row = [f"Top {k+1}:"]
            for j in range(results['virtual_embeddings'].shape[1]):
                token_id = results['virtual_top_indices'][k, j].item()
                token_text = tokenizer.decode([token_id])
                token_text = token_text.replace("\n", "\\n").replace("\t", "\\t")
                token_display = token_text[:col_width-2]
                row.append(token_display)
            virtual_token_rows.append(row)
        
        virtual_token_table = create_table(
            "", headers, virtual_token_rows, [8] + [col_width] * results['virtual_embeddings'].shape[1]
        )
        print(virtual_token_table)
        
        # Distance tables
        print("\nL2 Distances for USER embeddings:")
        user_distance_rows = []
        for k in range(3):
            row = [f"Top {k+1}:"]
            for j in range(results['user_embeddings'].shape[1]):
                distance = results['user_top_values'][k, j].item()
                row.append(f"{distance:.5f}")
            user_distance_rows.append(row)
        
        user_distance_table = create_table(
            "", headers, user_distance_rows, [8] + [col_width] * results['user_embeddings'].shape[1]
        )
        print(user_distance_table)
        
        print("\nL2 Distances for VIRTUAL embeddings:")
        virtual_distance_rows = []
        for k in range(3):
            row = [f"Top {k+1}:"]
            for j in range(results['virtual_embeddings'].shape[1]):
                distance = results['virtual_top_values'][k, j].item()
                row.append(f"{distance:.5f}")
            virtual_distance_rows.append(row)
        
        virtual_distance_table = create_table(
            "", headers, virtual_distance_rows, [8] + [col_width] * results['virtual_embeddings'].shape[1]
        )
        print(virtual_distance_table)

# Define button click handlers
def on_run_button_clicked(b):
    with test_output:
        print("Running test...")
    results = test_embedding_comparison(text_input.value, current_sample)
    display_test_results(results)

def on_next_button_clicked(b):
    global current_sample
    if current_sample < batch_size - 1:
        current_sample += 1
        results = test_embedding_comparison(text_input.value, current_sample)
        display_test_results(results)
    else:
        with test_output:
            test_output.clear_output(wait=True)
            print("End of batch reached!")

# Connect the buttons to handlers
run_button.on_click(on_run_button_clicked)
next_button.on_click(on_next_button_clicked)

# Layout
controls = widgets.VBox([text_input, widgets.HBox([run_button, next_button])])
with button_area:
    display(controls)
    
display(button_area)
display(test_output)

# Run initial test with default text
if batch_size > 0:
    results = test_embedding_comparison(text_input.value, current_sample)
    display_test_results(results)
else:
    with test_output:
        print("No samples in batch!")

Output()

Output()

In [13]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Create output widgets
context_output = widgets.Output()
button_area = widgets.Output()

# Create sample counter
current_sample = 0

PROMPT_PREFIX = """
User: Your task is to predict what another LLM will say, given the following
description of what the LLM is currently thinking: \" 
"""

PROMPT_SUFFIX = """\". Provide your prediction and nothing else.
Assistant:
"""

def assemble_decoder_context_without_virtual(text, sample_idx=0):
    """Assemble decoder context using only user-provided text (no virtual embeddings)"""
    with torch.no_grad():
        with torch.autocast(device_type=cfg.device.type, dtype=cfg.dtype):
            # Get a sample from the data
            sample_tokens = target_generated_tokens[sample_idx:sample_idx+1].to(cfg.device)
            sample_logits = target_logits[sample_idx:sample_idx+1].to(cfg.device, dtype=torch.float32)
            
            # Tokenize the user text
            user_text_tokens = tokenizer(text, return_tensors="pt").input_ids.to(cfg.device)
            
            # Generate tokens for prompt components
            prefix_tokens = tokenizer(PROMPT_PREFIX, return_tensors="pt").input_ids.to(cfg.device)
            suffix_tokens = tokenizer(PROMPT_SUFFIX, return_tensors="pt").input_ids.to(cfg.device)
            
            # Repeat for batch size
            batch_size = sample_tokens.shape[0]
            prefix_tokens = prefix_tokens.repeat(batch_size, 1)
            suffix_tokens = suffix_tokens.repeat(batch_size, 1)
            user_text_tokens = user_text_tokens.repeat(batch_size, 1)
            
            # Concatenate all tokens to create the full context
            # Format: [prefix] + [user_text] + [suffix] + [target_tokens]
            input_tokens = torch.cat([
                prefix_tokens, 
                user_text_tokens, 
                suffix_tokens, 
                sample_tokens
            ], dim=1)
            
            # Get embeddings from decoder
            embeddings = get_embeddings_from_decoder(encoder_decoder.decoder)
            input_embeds = embeddings(input_tokens)
            
            # Create attention mask (all 1s since we're not using padding)
            attention_mask = torch.ones(
                input_embeds.shape[0],
                input_embeds.shape[1],
                device=input_embeds.device,
            )
            
            # Get decoder output
            decoder_output = encoder_decoder.decoder(
                input_ids=input_tokens, 
                # attention_mask=attention_mask
            )
            
            # Extract target logits for prediction loss calculation
            decoder_logits_target = decoder_output.logits[:, -cfg.decoder_pred_len_toks:, :]
            
            # Calculate prediction loss
            prediction_loss = kl_div(decoder_logits_target, sample_logits).item()
            
            # Decode tokens for display
            context_decoded = tokenizer.decode(input_tokens[0])
            target_decoded = tokenizer.decode(sample_tokens[0])
            
            # Return results for display
            return {
                "input_tokens": input_tokens,
                "context_decoded": context_decoded,
                "target_decoded": target_decoded,
                "prediction_loss": prediction_loss,
                "sample_tokens": sample_tokens,
                "prefix_len": prefix_tokens.shape[1],
                "user_text_len": user_text_tokens.shape[1],
                "suffix_len": suffix_tokens.shape[1],
            }

def display_context_and_results(results):
    with context_output:
        context_output.clear_output(wait=True)
        
        print(f"Sample {current_sample+1}/{batch_size}")
        print(f"User text: \"{text_input.value}\"")
        print(f"Prediction loss: {results['prediction_loss']:.6f}")
        
        # Display token lengths
        print("\nContext Structure:")
        print(f"Prefix tokens: {results['prefix_len']} tokens")
        print(f"User text tokens: {results['user_text_len']} tokens")
        print(f"Suffix tokens: {results['suffix_len']} tokens")
        print(f"Target tokens: {results['sample_tokens'].shape[1]} tokens")
        print(f"Total tokens: {results['input_tokens'].shape[1]} tokens")
        
        # Display the entire assembled context
        print("\nAssembled Context (decoded):")
        print("-" * 80)
        print(results["context_decoded"])
        print("-" * 80)
        
        # Also show just the target part
        print("\nTarget tokens (what the model should predict):")
        print(results["target_decoded"])

# Create widgets
text_input = widgets.Textarea(
    value="The model is thinking about the relationship between cause and effect.",
    placeholder="Enter text to use as context...",
    description="Input Text:",
    disabled=False,
    layout=widgets.Layout(width="100%", height="100px")
)

run_button = widgets.Button(
    description='Run Test',
    disabled=False,
    button_style='primary', 
    tooltip='Run the test with the provided text'
)

next_button = widgets.Button(
    description='Next Sample',
    disabled=False,
    button_style='info',
    tooltip='Move to the next sample'
)

def on_run_button_clicked(b):
    with context_output:
        print("Running test...")
    results = assemble_decoder_context_without_virtual(text_input.value, current_sample)
    display_context_and_results(results)

def on_next_button_clicked(b):
    global current_sample
    if current_sample < batch_size - 1:
        current_sample += 1
        results = assemble_decoder_context_without_virtual(text_input.value, current_sample)
        display_context_and_results(results)
    else:
        with context_output:
            context_output.clear_output(wait=True)
            print("End of batch reached!")

# Connect the buttons to handlers
run_button.on_click(on_run_button_clicked)
next_button.on_click(on_next_button_clicked)

# Layout
controls = widgets.VBox([text_input, widgets.HBox([run_button, next_button])])
with button_area:
    display(controls)
    
display(button_area)
display(context_output)

# Run initial test with default text
if batch_size > 0:
    results = assemble_decoder_context_without_virtual(text_input.value, current_sample)
    display_context_and_results(results)
else:
    with context_output:
        print("No samples in batch!")

Output()

Output()