In [9]:
from transformer_lens.cautils.notebook import *
from transformer_lens.rs.callum.keys_fixed import get_effective_embedding_2

from transformer_lens import FactoredMatrix

clear_output()

In [10]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

clear_output()

In [39]:
def get_neg_copying_score(
    model: HookedTransformer,
    W_EE: Float[Tensor, "d_vocab d_model"],
    head: Tuple[int, int],
    n_batches: int = 10,
    sample_size: Optional[int] = None,
    return_prob: bool = False,
    return_frac_on_topk: bool = False,
    truncate_negative: bool = False,
):
    '''
    Gets neg copying scores (either as a sample, or as avg diff).
    '''
    assert not(return_prob and return_frac_on_topk)

    layer, head_idx = head
    W_V = model.W_V[layer, head_idx]
    W_O = model.W_O[layer, head_idx]

    W_U = model.W_U

    full_OV_circuit = FactoredMatrix(W_EE @ W_V, W_O @ W_U)

    assert isinstance(sample_size, int)

    results = []

    for batch in range(n_batches):

        random_sample = t.randint(low=0, high=model.cfg.d_vocab, size=(sample_size,))

        sample_OV_circuit_negated = - full_OV_circuit.A[random_sample, :] @ full_OV_circuit.B[:, random_sample]

        if return_prob:
            probs = sample_OV_circuit_negated.softmax(dim=-1)
            results.append(probs.diag().mean().item())
        elif return_frac_on_topk:
            topk = sample_OV_circuit_negated.topk(k=sample_size, dim=-1).indices
            results.append((topk == t.arange(sample_size).unsqueeze(0)).float().mean().item())
        else:
            diag_sum = sample_OV_circuit_negated.trace()
            offdiag_sum = sample_OV_circuit_negated.sum() - diag_sum
            diag_avg = diag_sum / sample_size
            offdiag_avg = offdiag_sum / (sample_size * (sample_size - 1))
            diff = diag_avg - offdiag_avg
            results.append(diff.item())
    
    results = t.tensor(results).mean()
    if truncate_negative:
        return results * (results > 0)
    else:
        return results

In [40]:
W_EE = get_effective_embedding_2(model)["W_E (including MLPs)"]
W_EE_subE = get_effective_embedding_2(model)["W_E (only MLPs)"]
W_E = get_effective_embedding_2(model)["W_E (no MLPs)"]

results = t.zeros(3, 12, 12).float()

for layer, head_idx in itertools.product(range(12), range(12)):
    results[0, layer, head_idx] = get_neg_copying_score(model, W_EE, head=(layer, head_idx), n_batches=10, sample_size=1000, truncate_negative=True)
    results[1, layer, head_idx] = get_neg_copying_score(model, W_EE_subE, head=(layer, head_idx), n_batches=10, sample_size=1000, truncate_negative=True)
    results[2, layer, head_idx] = get_neg_copying_score(model, W_E, head=(layer, head_idx), n_batches=10, sample_size=1000, truncate_negative=True)

imshow(
    results,
    title="Negative copying scores (only negative heads shown)",
    width=1200,
    facet_col=0,
    facet_labels=["W_E (including MLPs)", "W_E (only MLPs)", "W_E (no MLPs)"],
)

In [45]:
W_EE = get_effective_embedding_2(model)["W_E (including MLPs)"]
W_EE_subE = get_effective_embedding_2(model)["W_E (only MLPs)"]
W_E = get_effective_embedding_2(model)["W_E (no MLPs)"]

results = t.zeros(3, 12, 12).float()

for layer, head_idx in itertools.product(range(12), range(12)):
    results[0, layer, head_idx] = get_neg_copying_score(model, W_EE, head=(layer, head_idx), n_batches=10, sample_size=5000, return_prob=True)
    results[1, layer, head_idx] = get_neg_copying_score(model, W_EE_subE, head=(layer, head_idx), n_batches=10, sample_size=5000, return_prob=True)
    results[2, layer, head_idx] = get_neg_copying_score(model, W_E, head=(layer, head_idx), n_batches=10, sample_size=5000, return_prob=True)

imshow(
    results,
    title="Negative copying scores (only negative heads shown)",
    width=1200,
    facet_col=0,
    facet_labels=["W_E (including MLPs)", "W_E (only MLPs)", "W_E (no MLPs)"],
    zmax=1.0, 
    zmin=-1.0
)