Let's assume that we have a random model that we counterfactually force to gate at certain times. What does the advantage of doing so look like? 

In [1]:
import torch
from torch import nn
from transformers import AutoTokenizer

from clean_code.flexible_bitter_llm import FlexibleBitterLLM, Gemma2RotaryEmbedding
from clean_code.bitter_llm import RandomGater, CausalGemmaMiniBitterLLM
torch.serialization.add_safe_globals([CausalGemmaMiniBitterLLM, nn.modules.sparse.Embedding])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
byte5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-large")

In [4]:
my_model = torch.load("bitter-llm-exp15.pt", weights_only=False)
my_model.__class__ = FlexibleBitterLLM
my_model.attn_implementation = "flash_attention_2"
my_model.rotary_emb = Gemma2RotaryEmbedding(my_model.byte_layer_config)
for l in [*my_model.down_layers, *my_model.mid_layers, *my_model.up_layers]:
    l.self_attn.attn_logit_softcapping = 50.0

my_model

FlexibleBitterLLM(
  (embedding): Embedding(256, 512)
  (down_layers): ModuleList(
    (0-1): 2 x OptimizedModule(
      (_orig_mod): Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=512, out_features=512, bias=False)
          (k_proj): Linear(in_features=512, out_features=512, bias=False)
          (v_proj): Linear(in_features=512, out_features=512, bias=False)
          (o_proj): Linear(in_features=512, out_features=512, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=512, out_features=512, bias=False)
          (up_proj): Linear(in_features=512, out_features=512, bias=False)
          (down_proj): Linear(in_features=512, out_features=512, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((512,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((512,), eps=1e-06)
        (post_feedfor

In [5]:
# From https://www.bbc.com/news/articles/crrz44d7v08o
test_string = """This BBC interview with Prince Harry will become one of those famous moments when television collides with the world of the royals.

It was like an emotional avalanche. It began with some stones being kicked over with questions about security and then the interview turned into a spectacular release of what seemed to be a rolling mountain of pent-up frustration and a poignant sense of separation.

The starting point was Prince Harry's defeat in the courts as he sought to overturn a downgrading of his security in the UK. He seemed wounded. Had he decided it was time to have his say? And then really say some more?

A conversation about security was suddenly becoming about a whole range of insecurities.
"""
test_batch = byte5_tokenizer.encode(test_string, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
    base_out = my_model(test_batch)

In [6]:
base_down_gate_samples = base_out["down_gate_samples"]

In [7]:
gate_of_interest = 100
base_down_gate_samples_mask = torch.cat([torch.ones(gate_of_interest, dtype=torch.bool, device=device), torch.zeros(test_batch.shape[1]-gate_of_interest, dtype=torch.bool, device=device)], dim=0).unsqueeze(0)
base_down_gate_samples.shape, base_down_gate_samples_mask.shape

(torch.Size([1, 710]), torch.Size([1, 710]))

In [18]:
with torch.no_grad():
    test_out = my_model(
        test_batch, 
        prescribed_down_gate_samples=base_down_gate_samples, 
        down_gate_mask=base_down_gate_samples_mask
    )

In [19]:
base_out["down_gate_samples"]

tensor([[1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1,
         0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1,
         0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0,
         1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0,
         0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0,
         0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
         1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0,
         1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0,
         1, 0, 0, 0, 0, 0, 0

In [20]:
test_out["down_gate_samples"]

tensor([[1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1,
         0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1,
         0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0,
         1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0,
         0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0,
         0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
         0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0

In [21]:
logit_diff = base_out["logits"] - test_out["logits"]

In [22]:
logit_diff

tensor([[[ 5.2212e-02, -5.1270e-03,  1.6527e-01,  ...,  3.0318e-01,
           4.0684e-01,  3.9837e-01],
         [-3.9117e-01, -1.2651e-01,  1.4279e-01,  ...,  2.8289e-01,
           3.7215e-01,  4.0109e-01],
         [-2.8314e-01, -2.9463e-01, -9.6945e-02,  ..., -3.4163e-02,
          -2.7750e-02, -2.9251e-02],
         ...,
         [ 3.1815e-03,  5.2868e-02,  4.8299e-01,  ...,  6.6721e-01,
           7.8633e-01,  5.9012e-01],
         [-1.1690e+00, -5.7395e-01, -1.3976e+00,  ..., -4.6978e-01,
          -5.3423e-01, -6.2995e-01],
         [-1.9269e-04,  1.6357e-01,  3.8826e-02,  ...,  8.8838e-01,
           5.7656e-01,  9.8496e-01]]], device='cuda:0')

In [23]:
logit_diff.shape

torch.Size([1, 710, 256])

In [24]:
token_diffs = logit_diff.abs().sum(dim=1)

In [25]:
token_diffs.shape

torch.Size([1, 256])

In [26]:
token_diffs # TODO: figure out why the token diffs are so high.

tensor([[250.0041, 230.8676, 268.8690, 368.0974, 368.0240, 288.3623, 264.1929,
         314.3453, 344.9648, 373.2635, 361.4180, 263.8871, 324.1987, 161.2671,
         339.4470, 345.0281, 335.7305, 302.0316, 278.0880, 386.1330, 299.8571,
         329.6399, 272.7560, 353.1989, 307.4991, 381.5919, 301.2749, 300.9753,
         384.9588, 343.0367, 293.8155, 347.0078, 303.4677, 383.1190, 252.8161,
          89.5313, 211.1520, 109.2413, 178.0293, 240.4632, 172.2382, 151.1092,
         134.0942, 196.8838, 174.0321, 147.6389, 162.0296, 146.1143, 192.8975,
         132.4105, 218.8984, 161.5545, 169.9106, 143.5901, 141.5398, 166.3748,
         178.7725, 174.4449, 187.3694, 186.3413, 158.0880, 155.7456, 155.5697,
         197.6676, 173.6640, 163.6866, 199.7448, 211.9181, 192.4732, 243.1473,
         161.7686, 220.3329, 158.1173, 172.8017, 287.5665, 309.6010, 178.4384,
         163.0322, 207.3238, 198.5522, 169.3605, 244.9174, 159.0432, 162.9296,
         210.6855, 161.7034, 171.1038, 150.6186, 214