Let's assume that we have a random model that we counterfactually force to gate at certain times. What does the advantage of doing so look like? 

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoTokenizer

from clean_code.flexible_bitter_llm import FlexibleBitterLLM, Gemma2RotaryEmbedding, IndependentWrapperGater
from clean_code.bitter_llm import RandomGater
torch.serialization.add_safe_globals([nn.modules.sparse.Embedding])

# Import matplotlib for plotting
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from IPython.display import HTML

import numpy as np

In [2]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
byte5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-large")

In [8]:
load_model_with_filter = False
if load_model_with_filter:
    my_model = torch.load("experiment_28/model_8.pt", weights_only=False)
    filter_vals = [v.item() for v in my_model.down_layer_gate.filter]
else:
    my_model = torch.load("training_random_base_model/random_select_early_output_base_model_42.pt", weights_only=False)
    my_model.down_layer_gate = IndependentWrapperGater(my_model.down_layer_gate)

--- 
Quick question: what filters are learned by the models?


In [5]:
inspect_filter_values = False   
if inspect_filter_values:
    for seed in range(1, 4):
        my_model = torch.load(f"experiment_28/model_{seed}.pt", weights_only=False)
        base_val = my_model.down_layer_gate.base_value.item()
    filter_vals = [v.item() for v in my_model.down_layer_gate.filter]
    print(base_val, filter_vals)

-1.001364827156067 [-0.2573898732662201, -0.0959232747554779, 0.003937778528779745, 0.002313619013875723]
-1.006346583366394 [-0.2470589131116867, -0.11891937255859375, -0.008503127843141556, 0.04420414939522743]
-0.9988336563110352 [-0.26974087953567505, -0.0981111004948616, -0.01271668542176485, 0.009889219887554646]


Indeed, supression of gating after just gating _is_ learned, just a very small amount! This is likely due to the Adam optimizer normalizing the gradients passed to these values, and the . For example, if the average size of the gradient to one of these parameters has first moment $E(X) = \mu$ and second moment $E(X) = v^2$, then adam will normalise this to $\frac{\mu}{v^2}$ on avergate and then scale it by `1e-4`. This means that changing the loss balance won't help, but a more flexible model might

---

In [6]:

def load_old_model(file_name):
    my_model = torch.load(file_name, weights_only=False)
    my_model.__class__ = FlexibleBitterLLM
    my_model.rotary_emb = Gemma2RotaryEmbedding(my_model.byte_layer_config)
    my_model.attn_implementation = "eager" # Use eager attention for higher precision.

    my_model = my_model.to(dtype=torch.float32)

    for l in [*my_model.down_layers, *my_model.mid_layers, *my_model.up_layers]:
        l.self_attn.attn_logit_softcapping = 50.0

    return my_model



In [9]:
my_model

FlexibleBitterLLM(
  (embedding): Embedding(256, 768)
  (down_layers): ModuleList(
    (0-1): 2 x OptimizedModule(
      (_orig_mod): Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (o_proj): Linear(in_features=768, out_features=768, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=768, out_features=768, bias=False)
          (up_proj): Linear(in_features=768, out_features=768, bias=False)
          (down_proj): Linear(in_features=768, out_features=768, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((768,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((768,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((768,), eps=1e-06)


In [10]:
def index_by_next_token(vals, token_ids):
    # vals is a tensor of shape (batch_size, seq_len, vocab_size)
    # token_ids is a tensor of shape (seq_len,)
    # We want to index vals by the next token in token_ids
    # Return a tensor of shape (batch_size, seq_len-1)
    next_token_ids = token_ids[1:]
    current_vals = vals[:, :-1, :]
    # result = current_vals.index_select(dim=-1, index=next_token_ids)
    next_token_ids = next_token_ids.unsqueeze(0).unsqueeze(-1)
    next_token_ids = next_token_ids.expand(current_vals.shape[0], -1, 1)
    result = current_vals.gather(dim=-1, index=next_token_ids)
    
    return result.squeeze(-1)

In [11]:
def logit_diff_stats_batch_independent(a_logits, b_logits):
    """
    a_logits and b_logits are tensors of shape (batch_size, seq_len)
    returns the mean and standard error of the difference between a_logits and b_logits, computed as if there is no relation across batch indices between a and b.
    """
    batch_size_a = a_logits.shape[0]
    batch_size_b = b_logits.shape[0]

    mean_a = a_logits.mean(dim=0)
    mean_b = b_logits.mean(dim=0)
    diff_mean = mean_a - mean_b

    std_a = a_logits.std(dim=0)
    std_b = b_logits.std(dim=0)
    diff_sterr = torch.sqrt((std_a**2)/batch_size_a + (std_b**2)/batch_size_b)

    return diff_mean, diff_sterr

In [12]:
# Refactoring the plotting functions to be more modular:
# Helper function: visualize the hard to predict tokens.
def plot_character_intensities(txt, intensities, intensity_to_value_fn=lambda x: x, colorbar_label="intensity"):
    # Get the token logits (cross-entropy loss for each token)    
    # Normalize the logits to a probability-like scale (higher logits = higher probability of being gated)
    # We use softmax-like normalization to get values between 0 and 1
    # Ensure gate_probs and input_text have the same length
    # The first token is not predicted, so we set its intensity to 0
    #     
    # Create HTML with colored text based on probabilities
    colored_text = ""
    colorbar = ""
    
    # Create a colorbar showing the gradient
    for i in range(11):  # 0.0 to 1.0 in steps of 0.1
        intensity = i / 10
        r = min(1.0, intensity)
        b = max(0.0, 1.0 - intensity)
        color = f"rgb({int(r*255)}, 0, {int(b*255)})"
        white = f"rgb(255, 255, 255)"
        colorbar += f'<span style="color:{white}; background-color:{color}; margin-right:2px; padding:0 5px;">{intensity_to_value_fn(intensity):.2f}</span>'
    
    # Add a legend for the colorbar
    colorbar_html = f'''
    <div style="margin-bottom:10px;">
        <div style="font-family:monospace; font-size:12px; margin-bottom:3px;">{colorbar_label}</div>
        <div style="font-family:monospace; font-size:14px;">{colorbar}</div>
    </div>
    '''
    
    # Process the text with colors
    for char, intensity in zip(txt, intensities):
        # Convert probability to color (blue->red)
        r = min(1.0, intensity)  # Red increases with probability
        b = max(0.0, 1.0 - intensity)  # Blue decreases with probability
        color = f"rgb({int(r*255)}, 0, {int(b*255)})"
        # Add the colored character to the output
        colored_text += f'<span style="color:{white}; background-color:{color};">{char}</span>'
    
    # Display the colorbar and colored text
    display(HTML(f'''
    <div>
        {colorbar_html}
        <div style="font-family:monospace; font-size:14px;">{colored_text}</div>
    </div>
    '''))


In [13]:
# From https://www.bbc.com/news/articles/crrz44d7v08o
test_string = """This BBC interview with Prince Harry will become one of those famous moments when television collides with the world of the royals.

It was like an emotional avalanche. It began with some stones being kicked over with questions about security and then the interview turned into a spectacular release of what seemed to be a rolling mountain of pent-up frustration and a poignant sense of separation.

The starting point was Prince Harry's defeat in the courts as he sought to overturn a downgrading of his security in the UK. He seemed wounded. Had he decided it was time to have his say? And then really say some more?

A conversation about security was suddenly becoming about a whole range of insecurities.
"""
test_batch = byte5_tokenizer.encode(test_string, return_tensors="pt", padding=True).to(device)
token_ids = test_batch[0]


---
Analysis 1: keeping the preceding gates fixed to make it the "true" advantage.

In [14]:
with torch.no_grad():
    base_out = my_model(test_batch.expand(2, -1))



In [15]:
gate_of_interest = 23
base_down_gate_samples = base_out["down_gate_samples"]
gate_samples_mask = torch.cat([torch.ones(gate_of_interest+1, dtype=torch.bool, device=device), torch.zeros(test_batch.shape[1]-gate_of_interest-1, dtype=torch.bool, device=device)], dim=0).unsqueeze(0)

gate_gate_samples = base_down_gate_samples.clone()
gate_gate_samples[:, gate_of_interest] = 1
no_gate_samples = base_down_gate_samples.clone()
no_gate_samples[:, gate_of_interest] = 0

What we want: counter-factual difference between gating at token of interest and not gating at token of interest, itegrated over some large number of samples

In [18]:
forward_batch_size = 512
expanded_gate_samples_mask = gate_samples_mask.expand(forward_batch_size, -1)
expanded_test_batch = test_batch.expand(forward_batch_size, -1)


In [19]:
gate_gate_out = []

for ggs in gate_gate_samples:
    ggs = ggs.expand(forward_batch_size, -1)

    with torch.no_grad():
        ggo = my_model(
            expanded_test_batch, 
            prescribed_down_gate_samples=ggs, 
            down_gate_mask=gate_samples_mask
        )
    gate_gate_out.append(ggo)

BackendCompilerFailed: backend='inductor' raised:
OutOfMemoryError: CUDA out of memory. Tried to allocate 11.60 GiB. GPU 0 has a total capacity of 47.54 GiB of which 5.21 GiB is free. Including non-PyTorch memory, this process has 42.32 GiB memory in use. Of the allocated memory 41.95 GiB is allocated by PyTorch, and 66.53 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


In [29]:
no_gate_out = []
for ngs in no_gate_samples:
    ngs = ngs.expand(forward_batch_size, -1)

    with torch.no_grad():
        ngo = my_model(
            expanded_test_batch, 
            prescribed_down_gate_samples=ngs, 
            down_gate_mask=gate_samples_mask
        )
    no_gate_out.append(ngo)

In [30]:
next_token_logits_gate = [index_by_next_token(ggo["logits"], token_ids) for ggo in gate_gate_out]
next_token_logits_no_gate = [index_by_next_token(ngo["logits"], token_ids) for ngo in no_gate_out]

In [31]:
diff_stats = [logit_diff_stats_batch_independent(a, b) for a, b in zip(next_token_logits_gate, next_token_logits_no_gate)]


In [32]:
# Visualize the pattern of differences with standard error
def visualize_logit_diffs(diff_mean, diff_sterr, **plot_kwargs):
    diff_np = diff_mean.cpu().numpy()
    stderr_np = diff_sterr.cpu().numpy()

    plt.plot(diff_np, **plot_kwargs)

    plt.fill_between(
        range(len(diff_np)), 
        diff_np - stderr_np, 
        diff_np + stderr_np, 
        alpha=0.3, 
        label='Standard Error'
    )


Observation 1: the advantage is very noisy depending on the preceding gating:

In [33]:
plt.figure(figsize=(12, 5))
visualize_logit_diffs(diff_stats[0][0], diff_stats[0][1], label=f"preceding gating {0}")
visualize_logit_diffs(diff_stats[1][0], diff_stats[1][1], label=f"preceding gating {1}")
plt.xlim(0, 100)
plt.axvline(x=gate_of_interest, color='r', linestyle='--', label=f'Gate of interest (position {gate_of_interest})')
plt.title('Next Token Logit Differences (gate - no gate)')
plt.xlabel('Position')
plt.ylabel('Logit Difference')
plt.legend()
plt.grid(True)
plt.show()

In [34]:
test_string[:gate_of_interest+1]

In [35]:
base_down_gate_samples[:, :gate_of_interest+1]

In [20]:
gate_of_interest

23

In [21]:
byte5_tokenizer.decode(token_ids[gate_of_interest])

' '

In [22]:
from tqdm import trange


---
Analysis 2: what does the advantage look like if we marginalise over the preceding gates too?

In [23]:
n_forward_minibatches = 32
forward_minibatch_size = 512
forward_batch_size = n_forward_minibatches * forward_minibatch_size

# Re-sample the gates 1024 times:
full_logits = []
full_gate_samples = []

for i in trange(n_forward_minibatches):
    with torch.no_grad():
        full_out = my_model(test_batch.expand(forward_minibatch_size, -1))
        full_logits.append(full_out["logits"])
        full_gate_samples.append(full_out["down_gate_samples"])

full_logits = torch.cat(full_logits, dim=0)
full_gate_samples = torch.cat(full_gate_samples, dim=0)
full_early_logits = full_out["early_logits"][0, :, :] # the early logits do not depend on the gate samples.
full_early_logits = full_early_logits.unsqueeze(0) # Put the batch dimension back in.

full_next_token_logits = index_by_next_token(full_logits, token_ids)
full_early_next_token_logits = index_by_next_token(full_early_logits, token_ids)

  0%|                                                                                                                                             | 0/32 [00:00<?, ?it/s]


BackendCompilerFailed: backend='inductor' raised:
OutOfMemoryError: CUDA out of memory. Tried to allocate 11.54 GiB. GPU 0 has a total capacity of 47.54 GiB of which 3.99 GiB is free. Including non-PyTorch memory, this process has 43.54 GiB memory in use. Of the allocated memory 43.17 GiB is allocated by PyTorch, and 65.02 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


In [65]:
full_next_token_logits.shape, full_early_next_token_logits.shape

In [66]:

def split_logits_by_gate_activation(full_gate_samples, full_next_token_logits, gate_of_interest):
    """
    full_gate_samples is a tensor of shape (batch_size, seq_len)
    full_next_token_logits is a tensor of shape (batch_size, seq_len, vocab_size)
    gate_of_interest is an integer
    """

    gate_samples_of_interest = full_gate_samples[:, gate_of_interest]

    # Split logits based on gate activation (1 or 0) for the token of interest
    gate_on_mask = gate_samples_of_interest == 1
    gate_off_mask = gate_samples_of_interest == 0

    # Get logits where gate is activated (1)
    gate_on_logits = full_next_token_logits[gate_on_mask]

    # Get logits where gate is not activated (0)
    gate_off_logits = full_next_token_logits[gate_off_mask]

    return gate_on_logits, gate_off_logits

gate_of_interest = 25
gate_on_logits, gate_off_logits = split_logits_by_gate_activation(full_gate_samples, full_next_token_logits, gate_of_interest)
# Print shapes to verify the split
print(f"Gate ON logits shape: {gate_on_logits.shape}")
print(f"Gate OFF logits shape: {gate_off_logits.shape}")


In [67]:
diff_mean, diff_sterr = logit_diff_stats_batch_independent(gate_on_logits, gate_off_logits)

In [68]:
plt.figure(figsize=(12, 5))
visualize_logit_diffs(diff_mean, diff_sterr, label=f"marginalised over preceding gating")
plt.xlim(0, 100)
plt.axvline(x=gate_of_interest, color='r', linestyle='--', label=f'Gate of interest (position {gate_of_interest})')
plt.title('Next Token Logit Differences (gate - no gate)')
plt.xlabel('Position')
plt.ylabel('Logit Difference')
plt.legend()
plt.grid(True)
plt.show()

In [69]:
plt.figure(figsize=(12, 5))
visualize_logit_diffs(diff_mean, diff_sterr, label=f"preceding gating {0}")
plt.xlim(0, 600)
plt.axvline(x=gate_of_interest, color='r', linestyle='--', label=f'Gate of interest (position {gate_of_interest})')
plt.title('Next Token Logit Differences (gate - no gate)')
plt.xlabel('Position')
plt.ylabel('Logit Difference')
plt.legend()
plt.grid(True)
plt.show()

In [70]:
print(test_string[:gate_of_interest+1])

In [71]:
diff_mean.shape, len(test_string)

In [72]:
def plot_logit_diffs_intensities(txt, diff_mean, **kwargs):
    diff_mean = diff_mean.cpu().numpy()
    diff_mean = np.concatenate([np.zeros(1), diff_mean[:-1]]) # Diff means concerns the _next_ token, so we shift by one.
    plot_linear_intensity_scale(txt, diff_mean, colorbar_label="logit difference")


def plot_linear_intensity_scale(txt, unscaled_intensities, **kwargs):
    intensities = (unscaled_intensities - unscaled_intensities.min()) / (unscaled_intensities.max() - unscaled_intensities.min())
    intensity_to_value_fn = lambda x: x * (unscaled_intensities.max() - unscaled_intensities.min()) + unscaled_intensities.min()
    plot_character_intensities(txt, intensities, intensity_to_value_fn, **kwargs)


plot_logit_diffs_intensities(test_string, diff_mean)


In [73]:
diff_mean.min(), diff_mean.max()

In [74]:
window_size = 10

diff_mean[gate_of_interest:gate_of_interest+window_size].sum()

In [75]:
_, sequence_length = full_gate_samples.shape
window_advantages = []

for gate_of_interest_loop in range(1, sequence_length - window_size):

    gonl, gnol = split_logits_by_gate_activation(full_gate_samples, full_next_token_logits, gate_of_interest_loop)
    d_mean, d_sterr = logit_diff_stats_batch_independent(gonl, gnol)

    window_advantage = d_mean[gate_of_interest_loop:gate_of_interest_loop+window_size].sum()
    window_advantages.append(window_advantage.item())

plt.plot(window_advantages)
plt.show()


In [76]:
# Verify that we are indeed computing the correct advantage:
window_advantages[gate_of_interest-1]

In [83]:
def plot_clipped_advantage_intensities(txt, advantages, clip_value=1.5, **kwargs):
    advantages = np.array(advantages)
    advantages = np.concatenate([np.zeros(1), advantages]) # The first token is not predicted, so we set its advantage to 0
    clipped_advantages = np.clip(advantages, -clip_value, clip_value)
    plot_linear_intensity_scale(txt, clipped_advantages, colorbar_label="clipped advantage", **kwargs)

plot_clipped_advantage_intensities(test_string, window_advantages)


Observation: 
- Gating at the second "Prince Harry" is extremely important to be able to use an induction head (but potentially not as important if there's already been a couple rounds of gating)
- Suprisingly (to me) the advantage of gating on sequence delimiters is basically null

We see that the base model indeed 

In [87]:
mean_entropy_gain = full_next_token_logits.mean(dim=0) - full_early_next_token_logits[0]
plot_clipped_advantage_intensities(test_string, mean_entropy_gain.cpu(), clip_value=1.)

In [86]:
mean_entropy_gain.mean().item(), mean_entropy_gain.min().item(), mean_entropy_gain.max().item()

---
Analysis 3: what is the advantage of gating after a given preceding gating pattern?

prediction: 50% that gating with 5 preceding zeros has a more positive advantage and gating with 2 preceding ones has a more negative advantage.

In [88]:
full_gate_samples.shape

In [89]:

def count_preceding_gates(full_gate_samples, n_preceding_ones):
    # Use a conv1d to count the number of preceding ones.

    filter = torch.ones((n_preceding_ones,))

    return filter_preceding_gates(full_gate_samples, filter)

def filter_preceding_gates(full_gate_samples, filter_1d):
    # Use a conv1d to count the number of preceding ones.
    filter_size, = filter_1d.shape
    filter = filter_1d.unsqueeze(0).unsqueeze(0)
    full_gate_samples_c = full_gate_samples.unsqueeze(1)

    full_gate_samples_c = full_gate_samples_c.to(device=device, dtype=torch.float32)
    filter = filter.to(device=device, dtype=full_gate_samples_c.dtype)

    n_preceding_gates = F.conv1d(full_gate_samples_c, filter, padding=filter_size-1)
    n_preceding_gates = n_preceding_gates.squeeze(1)
    n_preceding_gates = n_preceding_gates[:, :-filter_size+1]
    return n_preceding_gates.to(dtype=torch.int32)


sum_preceding_2_gates = count_preceding_gates(full_gate_samples, 2)
sum_preceding_3_gates = count_preceding_gates(full_gate_samples, 3)
sum_preceding_2_gates.shape

In [90]:
full_gate_samples[:5, :20]

In [91]:
sum_preceding_3_gates[:5, :20]

In [92]:
triple_gates = (sum_preceding_3_gates == 3)
triple_gates.sum()

Gather all the examples of 2 preceding gates and 3 preceding gates. What's the advantage of 121 compared to 123?

ie. $$\mathbb{E}(\log p(x_{i+1:i+11}) \vert a_{i-2:i} = 11, a_{i} = 1) - \mathbb{E}(\log p(x_{i+1}) \vert a_{i-2:i} = 11, a_{i} = 0)$$

Where the expectation is taken over all $x$ and $a_i$. 111 can be detected with value 3 on filter 111, 110 can be detected with value 2 on filter 11-1.

In [93]:
counterfactual_pattern = torch.tensor([1, 1, -1])
counterfactual_vals = filter_preceding_gates(full_gate_samples, counterfactual_pattern)
counterfactual_vals[:3, :20]

In [94]:
counterfactual_gates = (counterfactual_vals == 2)
counterfactual_gates.sum()


In [12]:
full_next_token_logits.shape

NameError: name 'full_next_token_logits' is not defined

In [13]:
extender = torch.ones((10,)) # Sliding this window across the sequence allows us to average the logits over the window, so we compute the advantage.

counteractual_gates_window = filter_preceding_gates(counterfactual_gates, extender)
triple_gates_window = filter_preceding_gates(triple_gates, extender)

counteractual_gates_window[:5, :20], counteractual_gates_window.sum()

NameError: name 'filter_preceding_gates' is not defined

In [14]:
def masked_logit_stats(full_next_token_logits, mask):
    """
    full_next_token_logits is a tensor of shape (batch_size, seq_len-1, vocab_size)
    mask is a tensor of shape (batch_size, seq_len)
    """
    
    next_token_mask = mask[:, :-1]
    n_vals = next_token_mask.sum()

    mean_logits = (full_next_token_logits * next_token_mask).sum() / n_vals
    std_logits = torch.sqrt((full_next_token_logits**2 * next_token_mask).sum() / n_vals - mean_logits**2)
    sterr_logits = std_logits / torch.sqrt(n_vals)

    return mean_logits.item(), std_logits.item(), sterr_logits.item()

In [15]:
mean_triple, std_triple, sterr_triple = masked_logit_stats(full_next_token_logits, triple_gates_window)
mean_triple, std_triple, sterr_triple

NameError: name 'full_next_token_logits' is not defined

In [16]:
mean_counterfactual, std_counterfactual, sterr_counterfactual = masked_logit_stats(full_next_token_logits, counteractual_gates_window)
mean_counterfactual, std_counterfactual, sterr_counterfactual

NameError: name 'full_next_token_logits' is not defined

We see that there is basically no mean logit difference between being preceded with 111 and 110. Let's check what it looks like for 000.

In [100]:
triple_zero_gates = (sum_preceding_3_gates == 0)
triple_zero_gates_window = filter_preceding_gates(triple_zero_gates, extender)
mean_triple_zero, std_triple_zero, sterr_triple_zero = masked_logit_stats(full_next_token_logits, triple_zero_gates_window)
mean_triple_zero, std_triple_zero, sterr_triple_zero

Again, no measurable difference (todo: improve and validate standard deviation calculation, update: these aren't perfect but for some reason the triple zero seems to win).

--- 
Let's retry this analysis but at the fine-grained token level. What we'll do is intervene on the preceding two bits and see how this impacts the logit diff.

In [101]:
full_gate_samples[:5, :20]

In [102]:

def split_logits_by_pattern(full_gate_samples, full_next_token_logits, gates_of_interest, patterns):
    """
    Split logits by multiple patterns
    
    full_gate_samples is a tensor of shape (batch_size, seq_len)
    full_next_token_logits is a tensor of shape (batch_size, seq_len-1)
    gates_of_interest is a tensor of shape (pattern_size,)
    patterns is a tensor of shape (n_patterns, pattern_size)
    
    Returns a list of tensors each of shape (split_size, seq_len), one for each pattern which matches the patterns
    """

    n_patterns, _ = patterns.shape

    gate_samples_oi = full_gate_samples.index_select(dim=1, index=gates_of_interest) # shape (batch_size, pattern_size)

    # Match the pattern to the gate samples
    gate_samples_oi = gate_samples_oi.unsqueeze(1) # shape (batch_size, 1, pattern_size)
    patterns = patterns.unsqueeze(0)               # shape (1, n_patterns, pattern_size)
    matches = (patterns == gate_samples_oi)        # shape (batch_size, n_patterns, pattern_size)
    matches = matches.all(dim=2)                   # shape (batch_size, n_patterns)

    pattern_matches = [
        full_next_token_logits[matches[:, i]] for i in range(n_patterns)
    ]

    return pattern_matches

# Just to check that the above works:
test = False
if test:
    gates_of_interest = torch.tensor([3, 4]).to(device=device)

    patterns = torch.tensor([
        [0, 0],
        [0, 1],
        [1, 0],
        [1, 1]
    ]).to(device=device)

    fake_logits = torch.arange(25).reshape(5, 5).to(device=device)
    fake_gate_samples = torch.tensor([
        [0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [1, 1, 1, 0, 1],
        [0, 0, 0, 1, 1],
        [0, 1, 1, 0, 1]
    ]).to(device=device)

    pattern_matches = split_logits_by_pattern(fake_gate_samples, fake_logits, gates_of_interest, patterns)
    print(*pattern_matches, sep="\n")

In [103]:
gates_of_interest = torch.tensor([40, 41]).to(device=device)

pattern_list = [
    [0, 0],
    [0, 1],
    [1, 0],
    [1, 1]
]

patterns = torch.tensor(pattern_list).to(device=device)

pattern_matches = split_logits_by_pattern(full_gate_samples, full_next_token_logits, gates_of_interest, patterns)

In [104]:
# Take means \& standard deviations over the batch dimension.
pattern_means = [p.mean(dim=0) for p in pattern_matches]
pattern_stds = [p.std(dim=0) for p in pattern_matches]
pattern_stderrs = [p.std(dim=0) / np.sqrt(p.shape[0]) for p in pattern_matches]

In [105]:
def pattern_to_str(pattern):
    return "".join(str(p) for p in pattern)

base_pattern_mean = pattern_means[0]
base_pattern_str = pattern_to_str(pattern_list[0])

In [106]:

plt.figure(figsize=(12, 5))

for pattern, pattern_mean, pattern_stderr in zip(pattern_list, pattern_means, pattern_stderrs):
    pattern_str = pattern_to_str(pattern)
    visualize_logit_diffs(pattern_mean - base_pattern_mean, pattern_stderr, label=f"gating pattern {pattern_str}")

min_goi = gates_of_interest.min().item()
max_goi = gates_of_interest.max().item()

plt.xlim(30, 200)
plt.axvline(x=min_goi, color='r', linestyle='--', label=f'Gates of interest (position {min_goi})')
plt.axvline(x=max_goi, color='r', linestyle='--', label=f'Gates of interest (position {max_goi})')
plt.title(f'Next Token Logit Differences (pattern - {base_pattern_str})')
plt.xlabel('Position')
plt.ylabel('Logit Difference')
plt.legend()
plt.grid(True)
plt.show()


In [107]:

test_string[:min_goi], test_string[min_goi:max_goi+1], test_string[max_goi+1:]


---
Okay, but what if we marginalise over all prior gates and bytes $x$, and simply ask about what this value looks like:

$$
\mathbb{E} \log p(x_{i+r} \vert a_{i-3:i})
$$

TODO: continute this analysis

In [113]:
def pattern_matches_mask(full_gate_samples, pattern):
    """
    returns mask, a tensor of shape (batch_size, seq_len-pattern_size+1) where the ith entry is 1 if the pattern matches gates i, i+1, ..., i+pattern_size-1.
    Args:
        full_gate_samples: (batch_size, seq_len)
        pattern: (pattern_size,)
    Returns:
        mask: (batch_size, seq_len-pattern_size+1)
    """
    
    pattern_size, = pattern.shape

    filter_1d = 2*pattern - 1 # This filter can detect the pattern as its max value when slid over the gates.
    filter_val = (filter_1d * pattern).sum()

    mask = filter_preceding_gates(full_gate_samples, filter_1d)
    mask = mask == filter_val
    mask = mask[:, pattern_size-1:]

    return mask

test = True
if test:
    my_pattern = torch.tensor([1, 0, 0, 0, 1])
    my_pattern_mask = pattern_matches_mask(full_gate_samples, my_pattern)
    print(f"{full_gate_samples.shape=}")
    print(full_gate_samples[:5, :20])
    print(f"{my_pattern_mask.shape=}")
    print(my_pattern_mask[:5, :20].int())


In [114]:
boi = my_pattern_mask.unsqueeze(1) 
# bruh = full_next_token_logits * boi
# bruh.shape
boi.shape, full_next_token_logits.shape

In [115]:
def pattern_logits_following_mean(full_next_token_logits, pattern_mask, offset=0):
    """
    full_next_token_logits is a tensor of shape (batch_size, seq_len-1)
    pattern_matches_mask is a tensor of shape (batch_size, seq_len-pattern_size+1)
    Returns the mean of the next token logits at i + offset, where the pattern matches i, i+1, ..., i+pattern_size-1.
    """
    seq_len = full_next_token_logits.shape[1] + 1
    offset_token_logits = full_next_token_logits[:, offset:]
    min_sequence_dim = min(offset_token_logits.shape[1], pattern_mask.shape[1])

    offset_token_logits = offset_token_logits[:, :min_sequence_dim]
    pattern_mask = pattern_mask[:, :min_sequence_dim]

    # TODO: maybe consider comparing to the diff of the mean within a given logit. this could maybe make the output less noisy.

    return( (offset_token_logits * pattern_mask).sum() / pattern_mask.sum()).item()

def pattern_logits_followning_mean_list(full_next_token_logits, pattern_mask, window_of_interest):
    return [
        pattern_logits_following_mean(full_next_token_logits, pattern_mask, i)
        for i in range(window_of_interest)
    ]

window_of_interest = 100

my_pattern = torch.tensor([1, 0, 0, 0, 1])
following_mean_logits = pattern_logits_followning_mean_list(full_next_token_logits, my_pattern_mask, window_of_interest)
plt.plot(following_mean_logits)
plt.axvline(x=0, color='r', linestyle='--', label=f'Gates of interest (position {min_goi})')
plt.axvline(x=my_pattern.shape[0]-1, color='r', linestyle='--', label=f'Gates of interest (position {max_goi})')
plt.show()


TODO: maybe consider comparing to the diff of the mean within a given logit. This could maybe make the output less noisy (?).

This looks somewhat promising, like the first $1.$ boosts the predictive power and then it decays until the second $1$. Super noisy though.

In [112]:
# abcd

--- 
TODO question: has our model learned induction heads?

- Need to analyse the in context learning score to see if there is an improvement

---
Analysis of total number of gates vs log-likelihood.

In [111]:
total_gates = full_gate_samples.sum(dim=1).cpu().numpy()
total_logits = full_next_token_logits.sum(dim=1).cpu().numpy()

# Compute correlation between total gates and total logits
from scipy import stats
correlation, p_value = stats.pearsonr(total_gates, total_logits)
# Get one-sided p-value for positive correlation
one_sided_p_value = p_value / 2 if correlation > 0 else 1 - (p_value / 2)
print(f"Correlation between total gates and total log-likelihood: {correlation:.4f}")
print(f"P-value for positive correlation: {one_sided_p_value:.6f}")

plt.figure(figsize=(12, 5))
plt.scatter(total_gates, total_logits)
plt.xlabel('Total number of gates')
plt.ylabel('Total log-likelihood (nats)')
plt.show()
