In [None]:
def generate_with_refusal_edit(
    prompts,
    ablate=True,  # True = subtract vector, False = add vector
    scale=1.0,
    max_new_tokens=150,
    tokenizer_kwargs={},
    layer = "layer_21", # [layer_21]
    step = 0
):
    hooks = []

    def project_orthogonal(hidden, vec, scale):
        # vec should be 1D, normalize explicitly
        vec = vec / vec.norm()
        coeff = (hidden * vec).sum(dim=-1, keepdim=True)
        return hidden - scale * coeff * vec


    def edit_fn(module, input, output, layer_key):
        hidden = output

        
        # print(scale)
        if layer_key in refusal_vectors:
            vec = refusal_vectors[layer][step].to(hidden.device)   ## cahnged from layer_key
            if ablate: 
                hidden = project_orthogonal(hidden, vec, scale)
            else:
                if layer_key != layer: #lol check this in paper
                    return hidden
                hidden = hidden + scale * vec
        return hidden
    
    # Register hooks for each layer
    for i, block in enumerate(model.model.layers):
        key = f"layer_{i}"
        if key in refusal_vectors:
            h = block.register_forward_hook(partial(edit_fn, layer_key=key))
            hooks.append(h)

    try:
        convs = [
            tokenizer.apply_chat_template(
                [{"role": "user", "content": p}],
                add_generation_prompt=True,
                tokenize=False,
                enable_thinking=False,
                **tokenizer_kwargs,
            )
            for p in prompts
        ]

        inputs = tokenizer(
            convs,
            return_tensors="pt",
            padding=True,
        )
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.no_grad():
            # Forward once to get logits at the last *input* token for each example
            input_outputs = model(**inputs)  # logits: [B, T, V]
            logits = input_outputs.logits

            # Find last non-pad input token index per example
            # attention_mask: [B, T], sum-1 gives last index
            last_indices = inputs["attention_mask"].sum(dim=1) - 1  # [B]

            # Gather logits at those indices => shape [B, V]
            # Build indices for gather
            bsz, _, vocab_size = logits.shape
            gather_idx = last_indices.view(bsz, 1, 1).expand(bsz, 1, vocab_size)  # [B,1,V]
            last_input_logits = torch.gather(logits, dim=1, index=gather_idx).squeeze(1)  # [B,V]

            # Generate as usual (batched)
            output = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                return_dict_in_generate=True,
                output_logits=True,
            )

        # Attach the batched last-input logits
        output.last_input_logits = last_input_logits  # [batch_size, vocab_size]
        return output
        
    finally:
        # Ensure cleanup even if generation fails
        for h in hooks:
            h.remove()

    # Decode
    return output_ids