In [None]:
import torch
import torch.nn.functional as F
from typing import Literal, Optional
import einops
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
import os
import random
import pickle
import itertools
from tqdm import tqdm

import mypkg.whitebox_infra.attribution as attribution
import mypkg.whitebox_infra.dictionaries.batch_topk_sae as batch_topk_sae
import mypkg.whitebox_infra.data_utils as data_utils
import mypkg.whitebox_infra.model_utils as model_utils
import mypkg.whitebox_infra.interp_utils as interp_utils
import mypkg.pipeline.setup.dataset as dataset_setup
import mypkg.pipeline.infra.hiring_bias_prompts as hiring_bias_prompts
from mypkg.eval_config import EvalConfig


: 

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16

model_name = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=dtype, device_map=device
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

batch_size = model_utils.MODEL_CONFIGS[model_name]["batch_size"] * 6
context_length = 256


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s]


In [3]:
chosen_layer = 18
submodules = [model_utils.get_submodule(model, chosen_layer)]


In [4]:
num_layers = len(list(model.model.layers))
print(num_layers)

submodules = [model_utils.get_submodule(model, i) for i in range(num_layers)]

26


In [5]:
dataset_name = "togethercomputer/RedPajama-Data-V2"
num_tokens = 10_000

batched_tokens = interp_utils.get_batched_tokens(
    tokenizer=tokenizer,
    model_name=model_name,
    dataset_name=dataset_name,
    num_tokens=num_tokens,
    batch_size=batch_size,
    device=device,
    context_length=context_length,
    force_rebuild_tokens=False,
)

Loading tokenized dataset from tokens/togethercomputer_RedPajama-Data-V2_10000_google_gemma-2-2b-it.pt


In [6]:
print(model)

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemm

In [17]:
def get_activations_per_layer(
    model: AutoModelForCausalLM,
    submodules: list[torch.nn.Module],
    tokens_batch: dict[str, torch.Tensor],
    get_final_token_only: bool = False,
) -> tuple[dict[torch.nn.Module, torch.Tensor], torch.Tensor]:
    activations_BLD = {}

    def gather_target_act_hook(module, inputs, outputs):
        nonlocal activations_BLD
        assert isinstance(outputs, tuple)
        if get_final_token_only:
            activations_BLD[module] = outputs[0][:, -1:, :]
        else:
            activations_BLD[module] = outputs[0]

    all_handles = []

    try:
        for submodule in submodules:
            handle = submodule.register_forward_hook(gather_target_act_hook)
            all_handles.append(handle)

        logits_BLV = model(**tokens_batch).logits

        if get_final_token_only:
            logits_BLV = logits_BLV[:, -1:, :]

    except Exception as e:
        print(e)
        raise e
    finally:
        for handle in all_handles:
            handle.remove()

    return activations_BLD, logits_BLV


def run_final_block(
    acts_BLD: torch.Tensor,
    model: AutoModelForCausalLM,
) -> torch.Tensor:
    """
    Feeds `acts_BLD` through the last Gemma2DecoderLayer (attention+MLP)
    and returns the updated hidden states.
    """
    B, L, _ = acts_BLD.shape
    device = acts_BLD.device

    pos_ids_BL = torch.arange(L, device=device).unsqueeze(0).expand(B, L)
    position_embeddings = model.model.rotary_emb(acts_BLD, pos_ids_BL)

    hidden_BLD = model.model.layers[-1](
        hidden_states=acts_BLD,
        position_embeddings=position_embeddings,
    )
    return hidden_BLD[0]


def apply_logit_lens(
    acts_BLD: torch.Tensor, model: AutoModelForCausalLM, final_token_only: bool = True
) -> torch.Tensor:
    if final_token_only:
        acts_BLD = acts_BLD[:, -1:, :]

    acts_BLD += run_final_block(acts_BLD, model)

    logits_BLV = model.lm_head(model.model.norm(acts_BLD))
    return logits_BLV


def kl_between_logits(
    logits_p_BLV: torch.Tensor,  # “teacher” / reference
    logits_q_BLV: torch.Tensor,  # “student” / lens
    reduction: Literal["batchmean", "mean", "none"] = "none",
) -> torch.Tensor:
    """
    KL‖(p || q) for two logit tensors of shape [B, L, V].

    * `p` is obtained with softmax on `logits_p_BLV`
    * `q` is obtained with softmax on `logits_q_BLV`
    * Uses `torch.kl_div(log_q, p)` so we pass log-probabilities for q
    """

    B, L, V = logits_p_BLV.shape

    # Convert logits → probabilities / log-probabilities
    p_BLV = torch.softmax(logits_p_BLV, dim=-1)
    log_q_BLV = torch.log_softmax(logits_q_BLV, dim=-1)

    # KL divergence per token; `reduction` handles batching
    kl = F.kl_div(log_q_BLV, p_BLV, reduction=reduction).sum(dim=-1)

    return kl  # scalar if reduction ≠ "none", else [B, L]


@torch.no_grad()
def get_activations(
    model: AutoModelForCausalLM,
    submodules: list[torch.nn.Module],
    batched_tokens: dict[str, torch.Tensor],
    get_final_token_only: bool = False,
) -> dict[int, float]:
    mean_kl_per_layer = {}

    for i, layer in enumerate(submodules):
        mean_kl_per_layer[i] = 0.0

    for batch in tqdm(batched_tokens):
        activations_BLD, logits_BLV = get_activations_per_layer(
            model, submodules, batch, get_final_token_only=get_final_token_only
        )

        for i, layer in enumerate(submodules):
            logit_lens_BLV = apply_logit_lens(
                activations_BLD[layer], model, final_token_only=get_final_token_only
            )
            kl = kl_between_logits(logits_BLV, logit_lens_BLV)
            mean_kl_per_layer[i] += kl.mean().item()

    for i, layer in enumerate(submodules):
        mean_kl_per_layer[i] /= len(batched_tokens)

    return mean_kl_per_layer


model.eval()
torch.set_grad_enabled(False)

# acts_BLD, logits_BLV = get_batch_activations(model, submodules, batched_tokens[0])

kl = get_activations(model, submodules, batched_tokens, get_final_token_only=True)
print(kl)
kl = get_activations(model, submodules, batched_tokens, get_final_token_only=False)
print(kl)


100%|██████████| 7/7 [00:02<00:00,  3.41it/s]


{0: 42.517857142857146, 1: 33.857142857142854, 2: 23.678571428571427, 3: 21.142857142857142, 4: 18.714285714285715, 5: 21.178571428571427, 6: 18.303571428571427, 7: 15.276785714285714, 8: 12.84375, 9: 12.584821428571429, 10: 12.214285714285714, 11: 11.303571428571429, 12: 11.008928571428571, 13: 9.258928571428571, 14: 8.915178571428571, 15: 7.6875, 16: 7.370535714285714, 17: 6.555803571428571, 18: 6.220982142857143, 19: 5.303571428571429, 20: 4.205357142857143, 21: 4.314174107142857, 22: 2.7472098214285716, 23: 1.7064732142857142, 24: 0.6989397321428571, 25: 0.05683244977678571}


100%|██████████| 7/7 [00:09<00:00,  1.34s/it]

{0: 73.64285714285714, 1: 36.964285714285715, 2: 24.125, 3: 19.785714285714285, 4: 19.839285714285715, 5: 22.75, 6: 24.107142857142858, 7: 21.821428571428573, 8: 20.026785714285715, 9: 20.008928571428573, 10: 15.794642857142858, 11: 12.133928571428571, 12: 11.321428571428571, 13: 9.116071428571429, 14: 8.066964285714286, 15: 6.915178571428571, 16: 5.669642857142857, 17: 5.410714285714286, 18: 5.21875, 19: 4.549107142857143, 20: 3.595982142857143, 21: 3.200892857142857, 22: 2.5089285714285716, 23: 1.703125, 24: 0.7806919642857143, 25: 0.04380580357142857}





In [8]:
print(kl)

{0: 73.64285714285714, 1: 36.964285714285715, 2: 24.125, 3: 19.767857142857142, 4: 19.839285714285715, 5: 22.732142857142858, 6: 24.107142857142858, 7: 21.821428571428573, 8: 20.044642857142858, 9: 20.0, 10: 15.803571428571429, 11: 12.142857142857142, 12: 11.3125, 13: 9.107142857142858, 14: 8.0625, 15: 6.90625, 16: 5.669642857142857, 17: 5.40625, 18: 5.214285714285714, 19: 4.549107142857143, 20: 3.595982142857143, 21: 3.200892857142857, 22: 2.5044642857142856, 23: 1.7042410714285714, 24: 0.7801339285714286, 25: 0.04377092633928571}


In [15]:
import torch
import torch.nn.functional as F
from math import log

def uniform_kl_sanity(
    teacher_logits_BLV: torch.Tensor,
    reduction: str = "batchmean",
) -> None:
    """
    Prints KL(teacher ‖ uniform) and KL(uniform ‖ teacher) for a quick scale check.
    """
    B, L, V = teacher_logits_BLV.shape
    device  = teacher_logits_BLV.device
    dtype   = teacher_logits_BLV.dtype

    # 1) build uniform logits (all zeros ⇒ softmax == 1/V)
    uniform_logits_BLV = torch.zeros((B, L, V), device=device, dtype=dtype)

    # 2) helper : per-token KL averaged over tokens
    # def kl_tok_avg(p_logits, q_logits):
    #     p = torch.softmax(p_logits, dim=-1)
    #     log_q = torch.log_softmax(q_logits, dim=-1)
    #     # token-level average: sum over V then mean over B×L
    #     return F.kl_div(log_q, p, reduction="none").sum(-1).mean().item()

    # kl_teacher_uniform   = kl_tok_avg(teacher_logits_BLV, uniform_logits_BLV)
    # kl_uniform_teacher   = kl_tok_avg(uniform_logits_BLV, teacher_logits_BLV)
    # kl_uniform_uniform   = kl_tok_avg(uniform_logits_BLV, uniform_logits_BLV)

    kl_teacher_uniform = kl_between_logits(teacher_logits_BLV, uniform_logits_BLV).mean().item()
    kl_uniform_teacher = kl_between_logits(uniform_logits_BLV, teacher_logits_BLV).mean().item()
    kl_uniform_uniform = kl_between_logits(uniform_logits_BLV, uniform_logits_BLV).mean().item()

    print(f"KL(teacher ‖ uniform)  : {kl_teacher_uniform:8.4f}  nats")
    print(f"KL(uniform ‖ teacher)  : {kl_uniform_teacher:8.4f}  nats")
    print(f"KL(uniform ‖ uniform)  : {kl_uniform_uniform:8.4f}  nats (≈0)")
    print(f"log |V|                : {log(V):8.4f}  nats  (upper bound)")

# --- example call on your first batch ----------------------------------------
with torch.no_grad():
    sample_logits_BLV = model(**batched_tokens[0]).logits
    uniform_kl_sanity(sample_logits_BLV)

KL(teacher ‖ uniform)  :  10.3750  nats
KL(uniform ‖ teacher)  :  11.1250  nats
KL(uniform ‖ uniform)  :   0.0000  nats (≈0)
log |V|                :  12.4529  nats  (upper bound)


{0: 44873.142857142855, 1: 32694.85714285714, 2: 24758.85714285714, 3: 19017.14285714286, 4: 14966.857142857143, 5: 13147.42857142857, 6: 10377.142857142857, 7: 8594.285714285714, 8: 7844.571428571428, 9: 6976.0, 10: 6258.285714285715, 11: 5490.285714285715, 12: 5339.428571428572, 13: 4781.714285714285, 14: 4530.285714285715, 15: 4093.714285714286, 16: 3332.5714285714284, 17: 3140.5714285714284, 18: 2980.5714285714284, 19: 2857.1428571428573, 20: 2253.714285714286, 21: 1982.857142857143, 22: 1657.142857142857, 23: 1313.142857142857, 24: 641.1428571428571, 25: 28.928571428571427}