In [1]:
from utils import *

In [None]:
model.W_V.shape

In [2]:
act_names = [utils.get_act_name("resid_pre", layer) for layer in range(8)]
fake_cache = get_activation(act_names, 10000, start=0)

In [3]:
fake_cache[utils.get_act_name("resid_pre", 0)].shape

torch.Size([10000, 59, 512])

In [6]:
# Make list of all placed tiles, then list of list of all adjecent tiles
def get_placed_yours_acc(layer_yours, layer_placed):
    resid_layer_yours = fake_cache[utils.get_act_name("resid_pre", layer_yours)].to(device)
    resid_layer_placed = fake_cache[utils.get_act_name("resid_pre", layer_placed)].to(device)
    placed_probe = get_probe(layer_placed, "placed", "post")[0].to(device)
    linear_probe = get_probe(layer_yours, "linear", "post")[0].to(device)
    placed_probe_result = einops.einsum(resid_layer_placed, placed_probe, 'b p d, d r c o -> b p r c o').argmax(axis=-1)
    emb_linear_result = einops.einsum(resid_layer_yours, linear_probe, 'b p d, d r c o -> b p r c o').argmax(axis=-1)
    num_correct = ((placed_probe_result == PLACED) * (emb_linear_result == YOURS)).sum()
    num_total = (placed_probe_result == PLACED).sum()
    return num_correct, num_total, num_correct / num_total

num_correct, num_total, acc = get_placed_yours_acc(1, 1)
num_correct, num_total, acc

(tensor(589351, device='cuda:0'),
 tensor(592642, device='cuda:0'),
 tensor(0.9944, device='cuda:0'))

In [7]:
import torch.nn.functional as F
ADJACENT = 0
NOT_ADJACENT = 1

def mark_adjacent_positions(layer_placed):
    placed_probe = get_probe(layer_placed, "placed", "post")[0].to(device)
    resid_layer_placed = fake_cache[utils.get_act_name("resid_pre", layer_placed)].to(device)
    placed_probe_result = einops.einsum(resid_layer_placed, placed_probe, 'b p d, d r c o -> b p r c o').argmax(axis=-1)
    # Create kernel for all 8 directions (including diagonals)
    kernel = t.tensor([
        [1, 1, 1],
        [1, 0, 1],
        [1, 1, 1]
    ], device=placed_probe_result.device).to(t.float32)
    
    # Reshape kernel for conv2d: (out_channels, in_channels, height, width)
    kernel = kernel.view(1, 1, 3, 3)
    
    # Process each batch and position
    batch_size, num_positions, height, width = placed_probe_result.shape
    result = t.zeros_like(placed_probe_result)
    
    # Reshape to combine batch and position dimensions
    placed = (placed_probe_result == PLACED).view(-1, 1, height, width).float()
    
    # Use convolution to mark adjacent positions
    # Padding=1 to handle edges correctly
    adjacent = F.conv2d(placed, kernel, padding=1) > 0
    
    # Reshape back to original dimensions and convert to original dtype
    result = adjacent.view(batch_size, num_positions, height, width)
    
    return result

placed_adjacent_result = mark_adjacent_positions(1)

In [12]:
placed_adjacent_result[0, 2]

tensor([[False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False,  True,  True,  True, False],
        [False, False, False, False,  True, False,  True, False],
        [False, False, False, False,  True,  True,  True, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False]],
       device='cuda:0')

In [23]:
def get_flipped_scores(layer_flipped_pred, layer_flipped_correct, placed_adjacent_result):
    resid_layer_flipped_pred = fake_cache[utils.get_act_name("resid_pre", layer_flipped_pred)].to(device)
    resid_layer_flipped_correct = fake_cache[utils.get_act_name("resid_pre", layer_flipped_correct)].to(device)
    flipped_probe_pred = get_probe(layer_flipped_pred, "flipped", "post")[0].to(device)
    flipped_probe_correct = get_probe(layer_flipped_correct, "flipped", "post")[0].to(device)
    flipped_probe_result_pred = einops.einsum(resid_layer_flipped_pred, flipped_probe_pred, 'b p d, d r c o -> b p r c o').argmax(axis=-1)
    flipped_probe_result_correct = einops.einsum(resid_layer_flipped_correct, flipped_probe_correct, 'b p d, d r c o -> b p r c o').argmax(axis=-1)
    correct = ((flipped_probe_result_correct == FLIPPED) * placed_adjacent_result).to(t.int32)
    pred = ((flipped_probe_result_pred == FLIPPED) * placed_adjacent_result).to(t.int32)
    tp = (correct * pred).sum()
    fp = ((1 - correct) * pred).sum()
    fn = (correct * (1 - pred)).sum()
    tn = ((1 - correct) * (1 - pred)).sum()
    recall = tp / (tp + fn)
    precision = tp / (tp + fp)
    f1 = 2 * recall * precision / (recall + precision)
    return recall, precision, f1

recall, precision, f1 = get_flipped_scores(1, 6, placed_adjacent_result)
recall, precision, f1

(tensor(0.6266, device='cuda:0'),
 tensor(0.8550, device='cuda:0'),
 tensor(0.7232, device='cuda:0'))

In [None]:
# Important Scores:
# Embedding Accuracy at Placed : 0.9269
# Layer 0 Accuracy at Placed : 0.9986

# Flipped Scores:
# Recall: 0.6266
# Precision: 0.8550
# F1: 0.7232