In [None]:
from src.hooking.llama_attention import LlamaAttentionPatcher
from src.selection.functional import get_patches_to_verify_independent_enrichment


def cache_q_projections(
    mt,
    input,
    query_locations,
    return_output,
    projection_signature = ".q_proj"
):
    layer_to_hq = {}
    for layer_idx, head_idx, query_idx in query_locations:
        if layer_idx not in layer_to_hq:
            layer_to_hq[layer_idx] = []
        layer_to_hq[layer_idx].append((head_idx, query_idx))

    q_projections = {}
    batch_size = input.input_ids.shape[0]
    seq_len = input.input_ids.shape[1]
    n_heads = mt.config.num_attention_heads
    head_dim = mt.n_embd // n_heads
    group_size = n_heads // mt.config.num_key_value_heads
    q_module_projections_per_layer = {}
    with mt.trace(input) as tracer:
        for layer_idx, query_locs in layer_to_hq.items():
            q_proj_name = (
                mt.attn_module_name_format.format(layer_idx) + projection_signature
            )
            q_proj_module.get_module_nnsight(mt, q_proj_name)
            q_module_projections_per_layer[q_proj_name] = q_proj_module.output.save()

        if return_output:
            output = mt.output.save()

    for layer_idx, query_locs in layer_to_hq.items():
        q_proj_name = (
            mt.attn_module_name_format.format(layer_idx) + projection_signature
        )
        q_proj_out = (
            q_module_projections_per_layer[q_proj_name]
            .view(batch_size, seq_len, -1, head_dim)
            .transpose(1,2)
        )
        if projection_signature in [".k_proj", ".v_proj"] and group_size != 1:
            q_proj_out = repeat_kv(q_proj_out, n_rep=group_size)
        for head_idx, query_idx in query_locs:
            q_projections[(layer_idx, head_idx, query_idx)] = (
                q_proj_out[:, head_idx, query_idx, :].clone().squeeze()
            )

    if return_output:
        return q_projections, output
    return q_projections


def validate_q_proj_ie_on_sample_pair(
    mt,
    clean_sample,
    patch_sample,
    heads,
    query_indices,
    verify_head_behavior_on,
    ablate_possible_ans_info_from_options,
    amplification_scale,
    must_track_tokens,
    patch_args
):
    # Prepare the samples
    clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt)
    patch_tokenized = prepare_input(prompts=patch_sample.prompt(), tokenizer=mt)

    if patch_args.get("batch_size", 1) > 1:
        # Create a list to store patch samples
        patch_samples = []
        # Set the task to the passed task (e.g., "select_task")
        task = patch_args['task']
        logger.debug(f"Sampling {patch_args.get('batch_size', 1)} patch_samples...")

        # While we have less patch samples than the batch size
        while len(patch_samples) < patch_args.get("batch_size", 1):
            # Set the object index to the next position in the batch
            obj_idx = len(patch_samples) % len(patch_samples.options)
            
            if patch_args["distinct_options"] is True:
                # Get a random sample with distinct options
                sample = task.get_random_sample(
                    mt=mt,
                    category=patch_sample.category,
                    prompt_template_idx=patch_args["prompt_template_idx"],
                    options_style=patch_args["option_style"],
                    filter_by_lm_prediction=True,
                    exclude_objs=[clean_sample.obj, patch_sample.obj],
                    n_distractors=patch_args["n_distractors"],
                    obj_idx=obj_idx,
                )
            else:
                # Copy the patch sample
                sample = copy.deepcopy(patch_sample)

                # Cycle the position of the correct answer
                sample.options[obj_idx], sample.options[sample.obj_idx] = (
                    sample.options[sample.obj_idx],
                    sample.options[obj_idx]
                )
                sample.obj_idx = obj_idx
            patch_samples.append(sample)
        patch_tokenized_batch = prepare_input(
            prompts=[sample.prompt() for sample in patch_samples], tokenizer=mt
        )
        logger.debug(f"{patch_tokenized_batch.input_ids.shape}")

    if verify_head_behavior_on is not None:
        logger.info("Verifying head behavior...")

        logger.info(f"Clean Sample >> Ans: {clean_sample.obj}")
        clean_attn_pattern = verify_head_patterns(
            prompt=clean_sample.prompt(),
            tokenized_prompt=clean_tokenized,
            options=[f"{opt}," for opt in clean_sample.options[:-1]]
            + [f"{clean_sample.options[-1]}."],
            pivot=clean_sample.subj,
            mt=mt,
            heads=heads,
            generate_full_answer=True,
            query_index=verify_head_behavior_on,
            ablate_possible_ans_info_from_options=ablate_possible_ans_info_from_options
        )

        logger.info(f"Patch Sample >> Ans: {patch_sample.obj}")
        patch_attn_pattern = verify_head_patterns(  # noqa
            prompt=patch_sample.prompt(),
            tokenized_prompt=patch_tokenized,
            # options=patch_sample.options,
            options=[f"{opt}," for opt in patch_sample.options[:-1]]
            + [f"{patch_sample.options[-1]}."],
            pivot=patch_sample.subj,
            mt=mt,
            heads=heads,
            generate_full_answer=True,
            query_index=verify_head_behavior_on,
            ablate_possible_ans_info_from_options=ablate_possible_ans_info_from_options,
        )

        logger.info(f"Caching the query states for the {len(heads)} heads")

        query_locations = [
            (layer_idx, head_idx, patch_query_idx)
            for layer_idx, head_idx in heads
            for patch_query_idx in query_indices.keys()
        ]

        cached_q_states, patch_output = cache_q_projections(
            mt=mt,
            input=patch_tokenized,
            query_locations=query_locations,
            return_output=True
        )

        if patch_args.get("batch_size", 1) > 1:
            cached_q_states = cache_q_projections(
                mt=mt,
                input=patch_tokenized_batch,
                query_locations=query_locations,
                return_output=False,
            )
            for lok in cached_q_states:
                cached_q_states[lok] = cached_q_states[lok].mean(dim=0)

        q_proj_patches = []
        for (layer_idx, head_idx, patch_query_idx), q_proj in cached_q_states.items():
            q_proj_patches.append(
                PatchSpec(
                    location=(
                        mt.attn_module_name_format.format(layer_idx) + ".q_proj",
                        head_idx,
                        query_indices[patch_query_idx],
                    ),
                    patch=q_proj
                )
            )

        patch_logits = patch_output.logits[:, -1, :].squeeze()
        patch_predictions = interpret_logits(
            tokenizer=mt,
            logits=patch_logits,
        )
        logger.info(f"patch_prediction={[str(pred) for pred in patch_predictions]}")

        interested_tokens = clean_sample.options
        interested_tokens = [
            get_first_token_id(name=opt, tokenizer=mt.tokenizer, prefix=" ")
            for opt in interested_tokens
        ]

        logger.info("clean run")
        clean_output = patch_with_baukit(
            mt=mt,
            inputs=clean_tokenized,
            patches=[]
        )
        clean_logits = clean_output.logits[:, -1, :].squeeze()
        clean_predictions, clean_track = interpret_logits(
            tokenizer=mt,
            logits=clean_logits,
            interested_tokens=interested_tokens + must_track_tokens
        )
        logger.info(f"clean_prediction={[str(pred) for pred in clean_predictions]}")
        logger.info(f"clean_track={clean_track}")

        logger.info("patching the q_proj states")

        if verify_head_behavior_on is not None and amplification_scale == 1.0:
            int_attn_pattern = verify_head_patterns(
                prompt=clean_sample.prompt(),
                tokenized_prompt=clean_tokenized,
                options=[f"{opt}," for opt in clean_sample.options[:-1]]
                + [f"{clean_sample.options[-1]}"],
                pivot=clean_sample.subj,
                mt=mt,
                heads=heads,
                query_patches=q_proj_patches,
                generate_full_answer=False,
                query_index=verify_head_behavior_on,
                ablate_possible_ans_info_from_options=ablate_possible_ans_info_from_options,
            )
            int_logits = int_attn_pattern["logits"].squeeze()

        else:
            default_attn_implementation = mt.config._attn_implementation
            if amplification_scale != 1.0:
                mt.reset_forward()
                mt.set_attn_implementation("sdpa")

                layers_to_heads = {}
                for layer_idx, head_idx in heads:
                    if layer_idx not in layers_to_heads:
                        layers_to_heads[layer_idx] = []
                    layers_to_heads[layer_idx].append(head_idx)

                layers_to_q_patches = {}
                for (
                    layer_idx,
                    head_idx,
                    patch_query_idx,
                ), patch in cached_q_states.items():
                if layer_idx not in layers_to_q_patches:
                    layers_to_q_patches[layer_idx] = []
                layers_to_q_patches[layer_idx].append(
                    (head_idx, query_indices[patch_query_idx], patch)
                )

                attention_patterns = {}
                head_contributions = {}
                for layer_idx, head_indices in layers_to_heads.items():
                    attn_block_name = mt.attn_module_name_format.format(layer_idx)
                    attn_block = baukit.get_module(mt._model, attn_block_name)

                    attention_patterns[layer_idx] = {}
                    head_contributions[layer_idx] = {}

                    attn_block.forward = types.MethodType(
                        LlamaAttentionPatcher(
                            block_name=attn_block_name,
                            save_attn_for=head_indices,
                            store_attn_matrices=attention_patterns[layer_idx],
                            store_head_contributions=head_contributions[layer_idx],
                            query_patches=layers_to_q_patches[layer_idx],
                            amplify_contributions=[
                                (head_idx, q_idx, amplification_scale)
                                for head_idx in head_indices
                                for q_idx in query_indices.values()
                            ],
                        ),
                        attn_block,
                    )
                patches = []

            else:
                patches = q_proj_patches

            if ablate_possible_ans_info_from_options:
                patches.extend(
                    get_patches_to_verify_independent_enrichment(
                        prompt=clean_sample.prompt(),
                        options=clean_sample.options,
                        pivot=clean_sample.subj,
                        mt=mt,
                        tokenized_prompt=clean_tokenized,
                    )
                )

            int_out = patch_with_baukit(
                mt=mt,
                inputs=clean_tokenized,
                patches=patches,
            )
            int_logits = int_out.logits[:, -1, :].squeeze()

            if amplification_scale != 1.0:
                mt.reset_forward()
                mt.set_attn_implementation(default_attn_implementation)

                if verify_head_behavior_on is not None:
                    attn_matrix = []
                    for layer_idx in attention_patterns:
                        for head_idx in attention_patterns[layer_idx]:
                            attn_matrix.append(
                                attention_patterns[layer_idx][head_idx].cpu()
                            )
                    attn_matrix = torch.stack(attn_matrix).squeeze().mean(dim=0)

                    visualize_attn_matrix(
                        attn_matrix=attn_matrix,
                        tokens=[
                            mt.tokenizer.decode(t) for t in clean_tokenized["input_ids"][0]
                        ],
                    )

        int_predictions, int_track = interpret_logits(
            tokenizer=mt,
            logits = int_logits,
            interested_tokens=interested_tokens + must_track_tokens
        )
        logger.info(f"int_prediction={[str(pred) for pred in int_predictions]}")
        logger.info(f"int_track={int_track}")

        return {
            "clean_sample": clean_sample,
            "patch_sample": patch_sample
        }