In [1]:
from ioi_utils import Node
from mandala._next.imports import *
from mandala._next.common_imports import *

# A proof-of-concept evaluation for the greater-than circuit

In [4]:
### reproduce the main aspects of the dataset of https://arxiv.org/pdf/2305.00586
with open('data/gt-nouns.txt', 'r') as f:
    NOUNS = f.read().split('\n')

# figure out which years are tokenized the way we want
prompt = "The war lasted from the year 1732 to the year 17"

HEAD_LOCATIONS = [(5, 1), (5, 5), (6, 1), (6, 9), (7, 10), (8, 8), (8, 11)]
NODES = [Node(component_name='z', layer=layer, head=head, seq_pos=-1) for layer, head in HEAD_LOCATIONS]
YYS = ['23', '42'] # TODO

class Prompt:
    def __init__(self,
                 noun: str, 
                 yy: str,
                 xx: str = '17', # for simplicity, everything is set in the 18th century
                 ):
        """
        - xx: the century
        - yy: the last two digits of the year
        """
        self.noun = noun
        self.xx = xx
        self.yy = yy
    
    @property
    def sentence(self) -> str:
        return f"The {self.noun} lasted from the year {self.xx}{self.yy} to the year {self.xx}"
    
    def with_changed_yy(self, yy: str) -> 'Prompt':
        return Prompt(noun=self.noun, xx=self.xx, yy=yy)

class PromptDistribution:
    """
    A class to represent a distribution over prompts.
    """

    def __init__(
        self,
        yys: List[str],
        nouns: List[str],
    ):
        self.yys = yys
        self.nouns = nouns

    def sample_one(self,) -> Prompt:
        """
        Sample a single prompt from the distribution.
        """
        yy = random.choice(self.yys)
        noun = random.choice(self.nouns)
        return Prompt(noun=noun, yy=yy)


FULL_DISTRIBUTION = PromptDistribution(yys=YYS, nouns=NOUNS)

## Circuit utils

In [9]:
import torch
from typing import Sequence
from torch import Tensor
MODEL_ID = 'gpt2-small'


FEATURE_SUBSETS = [
    ('yy',),
]
FEATURE_SUBSETS = [tuple(sorted(x)) for x in FEATURE_SUBSETS]

def get_yy_to_idx(distribution: PromptDistribution) -> Dict[str, int]:
    # return {name: i for i, name in enumerate(distribution.names)}
    return {yy: i for i, yy in distribution.yys}

# setup to work with features
YY_TO_IDX = get_yy_to_idx(FULL_DISTRIBUTION)

FEATURE_SIZES = {
    'bias': 1,
    'yy': len(YY_TO_IDX),
}

# this collects possible ways to parametrize activations of the model
FEATURE_CONFIGURATIONS = {
    'independent': [('yy', ), ],
}
# add bias to each
FEATURE_CONFIGURATIONS = {k: [('bias', )] + v for k, v in FEATURE_CONFIGURATIONS.items()}

@op
def generate_name_samples(n_samples, names, random_seed: int = 0) -> Any:
    np.random.seed(random_seed)
    return np.random.choice(names, n_samples, replace=True)

@op
def get_cf_prompts(
    prompts: List[Prompt],
    features: Tuple[str, ...],
    yy_targets: List[str],
) -> Any:
    assert features == ('yy',)
    return [p.with_changed_yy(yy=yy) for p, yy in zip(prompts, yy_targets)]

@op
def generate_prompts(distribution: PromptDistribution, random_seed: int, n_prompts: int,
                     ) -> Any:
    random.seed(random_seed)
    return [distribution.sample_one() for _ in range(n_prompts)]

################################################################################
### working with features
################################################################################
def get_prompt_representation(p: Prompt) -> Dict[str, int]:
    # extracts feature values from a prompt
    return { 'yy': YY_TO_IDX[p.yy] }

def get_prompt_feature_vals(p: Prompt) -> Dict[str, Any]:
    return { 'yy': p.yy }

def get_feature_shape(feature: Tuple[str,...]) -> Tuple[int, ...]:
    return tuple(FEATURE_SIZES[f] for f in feature)

def get_feature_deep_idx(feature: Tuple[str,...], prompt_rep: Dict[str, int]) -> Tuple[int,...]:
    # given the feature values for a prompt, returns the index we should use to
    # index into a code representing that feature (without the last dimension,
    # which is the dimension of the code vectors and may vary)
    return tuple([prompt_rep[f] for f in feature])

@op
def get_prompt_feature_idxs(prompts: Optional[Sequence[Prompt]],
                            features: List[Tuple[str,...]],
                            prompt_reps: Optional[List[dict]] = None,
                            ) -> Dict[Tuple[str,...], Tensor]:
    """
    Return a dictionary mapping each feature to a batch of indices into the
    code for this feature over the prompts (indices don't take into account the
    last dimension in the codes, which is the dimension of the code vectors and
    may vary).
    """
    if prompt_reps is None:
        assert prompts is not None
        prompt_reps = [get_prompt_representation(p) for p in prompts]
    prompt_feature_idxs = {f: torch.tensor([get_feature_deep_idx(f, prompt_rep)
                                            for prompt_rep in prompt_reps])
                            for f in features}
    return prompt_feature_idxs

def get_reconstructions(
    codes: Any, # Dict[tuple, Tensor],
    prompt_feature_idxs: Optional[Dict[tuple, Tensor]] = None,
    prompts: Optional[Sequence[Prompt]] = None,
    decomposed: bool = False,
    ) -> Union[Tensor, Dict[tuple, Tensor]]:
    """
    Reconstruct prompts according to the given codes. if prompt_feature_idxs is
    not given, it will be computed from prompts.
    """
    prompt_vectors = {}
    if prompt_feature_idxs is None:
        assert prompts is not None
        prompt_feature_idxs = get_prompt_feature_idxs(prompts, codes.keys())
    for f, idx in prompt_feature_idxs.items():
        # add the last dimension
        full_idx = tuple([idx[:, i].cpu().numpy() for i in range(idx.shape[1])] + [slice(None, None, None)])
        prompt_vectors[f] = codes[f][full_idx]
    if not decomposed:
        return sum(prompt_vectors.values())
    else:
        return prompt_vectors

################################################################################
### computing codes
################################################################################
@op
def get_mean_codes(
    features: List[Tuple[str,...]],
    A: Tensor,
    prompts: Any, # List[Prompt]
) -> Tuple[Any, Any]:
    """
    Compute codes using the mean of the (centered) activations for a given
    feature value.
    """
    # get the shape of the code for each feature
    feature_shapes = {f: get_feature_shape(f) for f in features}
    dim = A.shape[1]
    feature_shapes = {f: tuple(list(feature_shapes[f]) + [dim]) for f in features}
    # get the attributes of the prompts
    prompt_feature_idxs = get_prompt_feature_idxs(prompts=prompts, features=features)
    # group the prompt feature indices by feature value:
    prompt_feature_groups = {
        f: {} # will be {value in feature_idxs: [indices where this value appears in feature_idxs]}
        for f in features
    }
    for f, feature_idxs in prompt_feature_idxs.items():
        # populate prompt_feature_groups[f] = {value in feature_idxs: [indices
        # where this value appears in feature_idxs]}
        for idx, feature_idx in enumerate(feature_idxs):
            value = tuple([x.item() for x in feature_idx])
            if value not in prompt_feature_groups[f]:
                prompt_feature_groups[f][value] = []
            prompt_feature_groups[f][value].append(idx)
        # convert to tensors
        for value, indices in prompt_feature_groups[f].items():
            prompt_feature_groups[f][value] = torch.tensor(indices)
    codes = {
        f: torch.zeros(feature_shape).cuda() for f, feature_shape in feature_shapes.items()
    }
    A_mean = A.mean(dim=0)
    if ('bias',) in features:
        codes[('bias',)] = A_mean.unsqueeze(0)
    A_centered = A - A_mean
    for f, groups in prompt_feature_groups.items():
        if f != ('bias',):
            for value, indices in groups.items():
                codes[f][value] = A_centered[indices].mean(dim=0)
    reconstructions = get_reconstructions(
        codes=codes, prompt_feature_idxs=prompt_feature_idxs,
    )
    return codes, reconstructions


################################################################################
### feature editing
################################################################################
def get_edited_act(
    val: Tensor,
    method: str,
    feature_idxs_to_delete: Dict[Tuple[str,...], List[Tuple[int,...]]],
    feature_idxs_to_insert: Dict[Tuple[str,...], List[Tuple[int,...]]],
    codes: Dict[Tuple[str,...], Tensor],
    A_reference: Optional[Tensor] = None,
):
    """
    The core editing function: perform one or several edits on an activation
    using the given codes and method.

    Note that the methods based on subspace ablations are not commutative with
    respect to the order of feature insertion/deletion. 
    """
    val = val.clone()
    if method == 'arithmetic':
        for f, idx_to_delete in feature_idxs_to_delete.items():
            val = val - torch.stack([codes[f][i] for i in idx_to_delete])
        for f, idx_to_insert in feature_idxs_to_insert.items():
            val = val + torch.stack([codes[f][i] for i in idx_to_insert])
    elif method == 'zero_ablate_subspace':
        for f, idx_to_delete in feature_idxs_to_delete.items():
            code_to_delete = torch.stack([codes[f][i] for i in idx_to_delete])
            code_to_delete = code_to_delete / code_to_delete.norm(dim=-1, keepdim=True)
            projections = einsum('batch dim, batch dim -> batch', val, code_to_delete)
            val = val + einsum('batch, batch dim -> batch dim', - projections, code_to_delete)
        for f, idx_to_insert in feature_idxs_to_insert.items():
            val = val + torch.stack([codes[f][i] for i in idx_to_insert])
    elif method == 'mean_ablate_subspace':
        assert A_reference is not None
        for f, idx_to_delete in feature_idxs_to_delete.items():
            code_to_delete = torch.stack([codes[f][i] for i in idx_to_delete])
            code_to_delete = code_to_delete / code_to_delete.norm(dim=-1, keepdim=True)
            # mean_projection = (A_reference @ code_to_delete).mean(dim=0)
            reference_projections = einsum('dim, batch dim -> batch', A_reference.mean(dim=0), code_to_delete)
            projections = einsum('batch dim, batch dim -> batch', (reference_projections.unsqueeze(1) - val), code_to_delete)
            val = val + einsum('batch, batch dim -> batch dim', projections, code_to_delete)
        for f, idx_to_insert in feature_idxs_to_insert.items():
            val = val + torch.stack([codes[f][i] for i in idx_to_insert])
    else:
        raise ValueError(f'unknown method {method}')
    return val


@op
def get_cf_edited_act(
    val: Tensor,
    features_to_edit: Tuple[str,...],
    base_prompts: Any, # List[Prompt],
    cf_prompts: Any, # List[Prompt],
    codes: Any, # Dict[Tuple[str,...], Tensor],
    method: Literal['mean_ablate_subspace', 'zero_ablate_subspace', 'arithmetic'],
    A_ref: Optional[Tensor] = None,
) -> Tensor:
    """
    Edit an activation using the given counterfactual prompts in order to infer
    the new values for the features being edited.
    
    - features_to_edit is a tuple representing the features we want to change,
    e.g. ('s', 'io_pos',). The new values for these features will be inferred
    from the counterfactual prompts. Then, we use the features in the codes to
    figure out how to edit the activation.
    """
    code_features = list(codes.keys())
    base_feature_idxs = get_prompt_feature_idxs(prompts=base_prompts, 
                                                features=code_features, )
    cf_feature_idxs = get_prompt_feature_idxs(prompts=cf_prompts,
                                              features=code_features, )

    def turn_tensor_to_tuples(t: Tensor) -> List[Tuple[int,...]]:
        return [tuple(x.cpu().tolist()) for x in t]
    
    base_feature_idxs = {k: turn_tensor_to_tuples(v) for k, v in base_feature_idxs.items()}
    cf_feature_idxs = {k: turn_tensor_to_tuples(v) for k, v in cf_feature_idxs.items()}

    edited_act = get_edited_act(
        val=val,
        codes=codes,
        feature_idxs_to_delete=base_feature_idxs,
        feature_idxs_to_insert=cf_feature_idxs,
        A_reference=A_ref,
        method=method,
    )
    return edited_act

def get_forced_hook(
    prompts: List[Prompt],
    node: Node, 
    A: Tensor,
) -> Tuple[str, Callable]:
    """
    Get a hook that forces the activation of the given node to be the given value.
    """
    def hook_fn(activation: Tensor, hook: HookPoint) -> Tensor:
        idx = node.idx(prompts=prompts)
        activation[idx] = A
        return activation
    return (node.activation_name, hook_fn)

@op
def run_activation_patch(
    base_prompts: Any, # List[Prompt],
    cf_prompts: Any, # List[Prompt],
    nodes: List[Node],
    activations: List[Tensor],
    batch_size: int,
    model_id: str = MODEL_ID,
) -> Tuple[Tensor, Tensor]:
    """
    Run a standard activation patch in a batched way
    """
    model = get_model_obj(model_id)
    assert all([len(base_prompts) == v.shape[0] for v in activations])
    n = len(base_prompts)
    n_batches = (n + batch_size - 1) // batch_size
    base_logits_list = []
    cf_logits_list = []
    for i in tqdm(range(n_batches)):
        batch_indices = slice(i * batch_size, (i + 1) * batch_size)
        prompts_batch = base_prompts[batch_indices]
        cf_batch = cf_prompts[batch_indices]
        base_dataset = PromptDataset(prompts_batch, model=model)
        cf_dataset = PromptDataset(cf_batch, model=model)
        hooks = [get_forced_hook(prompts=prompts_batch, node=node, A=act[batch_indices]) for node, act in zip(nodes, activations)]
        changed_logits = model.run_with_hooks(base_dataset.tokens, fwd_hooks=hooks)[:, -1, :]
        base_answer_logits = changed_logits.gather(dim=-1, index=base_dataset.answer_tokens.cuda())
        cf_answer_logits = changed_logits.gather(dim=-1, index=cf_dataset.answer_tokens.cuda())
        base_logits_list.append(base_answer_logits)
        cf_logits_list.append(cf_answer_logits)
    base_logits = torch.cat(base_logits_list, dim=0)
    cf_logits = torch.cat(cf_logits_list, dim=0)
    return base_logits, cf_logits

def get_probability_difference(
        logits: Tensor, # of shape (batch, vocab), last token logits
        yy_values: List[str],
        yy_token_idxs: List[int],
        yys: List[str], # of shape (batch, )
        ):
    """
    The metric used to discover the circuit in the paper is the *probability*
    difference, and we adopt it here as well.
    """
    yy_idx_in_list = yy_values.index(yy)
    tokens_gt_yy = Tensor(yy_token_idxs[yy_idx_in_list + 1:]).to(logits.device)
    tokens_lte_yy = Tensor(yy_token_idxs[:yy_idx_in_list + 1]).to(logits.device)
    logits_gt_yy = logits[:, tokens_gt_yy]

# Training supervised dictionaries

# Editing with the supervised dictionaries

# Training SAEs

# Editing with the SAEs