In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from baukit import TraceDict
from einops import rearrange, einsum
from tqdm import tqdm
from functools import partial
from datasets import Dataset
from model_aligner_script import load_data
from counterfactual_datasets.entity_tracking import object_alignment_example_sampler

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f147f5ecd10>

In [2]:
path = "./llama_7b"
tokenizer = AutoTokenizer.from_pretrained(path)
# model = AutoModelForCausalLM.from_pretrained(path).to(DEVICE)
tokenizer.pad_token_id = tokenizer.eos_token_id

# model.eval()
# for param in model.parameters():
#     param.requires_grad_(False)

loading file tokenizer.model
loading file added_tokens.json
loading file special_tokens_map.json
loading file tokenizer_config.json


In [24]:
NUM_HEADS = model.config.num_attention_heads
HEAD_SIZE = model.config.hidden_size // NUM_HEADS

## Desiderata

In [6]:
data_file_path = "./box_datasets/no_instructions/original/3/train.jsonl"
object_file_path = "./box_datasets/objects_with_bnc_frequency.csv"

In [7]:
objValueFetcher_train, objValueFetcher_eval, objValueFetcher_test = load_data(
    tokenizer=tokenizer,
    data_size=500,
    aligner_func=object_alignment_example_sampler,
    data_file=data_file_path,
    num_ents_or_ops=3,
    batch_size=30,
    architecture="",
    object_file=object_file_path,
    alt_examples=False,
)

Train size:  400
Eval size:  50
Test size:  50


In [8]:
desiderata_train = [objValueFetcher_train]
desiderata_eval = [objValueFetcher_eval]
desiderata_valid = [objValueFetcher_test]

In [9]:
data = next(enumerate(desiderata_train[0]))[1]
print(tokenizer.decode(data["base_input_ids"][0]))
print(tokenizer.decode(data["source_input_ids"][0]))
print(tokenizer.decode(data["labels"][0]))

  Box 0 contains boot, Box 1 contains lunchbox, Box 2 contains bell. Box 0 contains</s></s></s></s></s></s></s>
  Box 0 contains stone, Box 1 contains lunchbox, Box 2 contains bell. Box 0 contains</s></s></s></s></s></s></s>
 stone


## Training Binary Mask

In [23]:
modules = [f"model.layers.{i}.self_attn.o_proj" for i in range(model.config.num_hidden_layers)]

In [25]:
source_activations_valid = {}

for di, desid_train in enumerate(desiderata_train):
    source_activations_valid[di] = {}

    for bi, inputs in enumerate(desid_train):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(DEVICE)

        source_activations_valid[di][bi] = {}
        with torch.no_grad():
            with TraceDict(model, modules, retain_input=True) as trace:
                _ = model(inputs["source_input_ids"])

                for module in modules:
                    if "self_attn" in module:
                        source_activations_valid[di][bi][module] = trace[module].input.detach().cpu()
                    else:
                        source_activations_valid[di][bi][module] = trace[module].output.detach().cpu()

        del trace
        torch.cuda.empty_cache()
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to("cpu")

In [28]:
modules_w_heads = []
for module in modules:
    if "self_attn" in module:
        for head in range(model.config.num_attention_heads):
            modules_w_heads.append(f"{module}.{head}")
    else:
        modules_w_heads.append(module)

mask_dict = {module: i for i, module in enumerate(modules_w_heads)}

In [43]:
def edit_output(inputs, output, layer, mask, from_activations, to_last_token_pos, from_last_token_pos):
    if "self_attn" in layer:
        inp = inputs[0]
        from_activations[layer] = from_activations[layer].to(DEVICE)

        # Computing the output of each head in this layer after the intervention
        heads_out_post_intervention = []
        for head_idx in range(NUM_HEADS):
            head_start = head_idx * HEAD_SIZE
            head_end = (head_idx + 1) * HEAD_SIZE
            abl_amt = mask[mask_dict[f"{layer}.{head_idx}"]]

            head_out = []
            for bi in range(inp.shape[0]):
                intervened_head_output = abl_amt * inp[bi, to_last_token_pos[bi], head_start:head_end] + (1 - abl_amt) * from_activations[layer][bi, from_last_token_pos[bi], head_start:head_end]
                inp[bi, to_last_token_pos[bi], head_start:head_end] = intervened_head_output

        from_activations[layer] = from_activations[layer].to("cpu")

        weights = model.state_dict()[f"{layer}.weight"]
        mod_output = torch.einsum("bsh,oh->bso", inp, weights)

        del weights
        torch.cuda.empty_cache()
        return mod_output

    else:
        assert False, "shouldn't be here"

In [61]:
# mask.data.clamp_(0, 1)

## Evaluating Learned Mask

In [62]:
# rounded = [torch.round(mask) for mask in masks.values()]
# (rounded[0] == 0).nonzero().shape

In [10]:
modules = [f"model.layers.{i}.self_attn.o_proj" for i in range(model.config.num_hidden_layers)]

In [44]:
source_activations_valid = {}

for di, desid_train in enumerate(desiderata_train):
    source_activations_valid[di] = {}

    for bi, inputs in enumerate(desid_train):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(DEVICE)

        source_activations_valid[di][bi] = {}
        with torch.no_grad():
            with TraceDict(model, modules, retain_input=True) as trace:
                _ = model(inputs["base_input_ids"])

                for module in modules:
                    if "self_attn" in module:
                        source_activations_valid[di][bi][module] = trace[module].input.detach().cpu()
                    else:
                        source_activations_valid[di][bi][module] = trace[module].output.detach().cpu()

        del trace
        torch.cuda.empty_cache()
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to("cpu")

In [46]:
with torch.no_grad():
    for round_mask in [torch.ones(model.config.num_hidden_layers * model.config.num_attention_heads)]:
        count, total = 0, 0
        for di, desid_valid in enumerate(desiderata_train):
            accuracy = []
            for bi, inputs in enumerate(desid_train):
                with TraceDict(
                    model,
                    modules,
                    edit_output=partial(
                        dummy_edit,
                        mask=round_mask,
                        from_activations=source_activations_valid[di][bi],
                        to_last_token_pos=inputs["source_input_last_pos"],
                        from_last_token_pos=inputs["base_input_last_pos"],
                    ),
                ) as _:
                    outputs = model(inputs['source_input_ids'].to(DEVICE))

                for i in range(inputs['source_input_ids'].size(0)):
                    logits = outputs.logits[i, inputs['source_input_last_pos'][i]]
                    pred = torch.argmax(logits, dim=-1)
                    label = inputs['labels'][i]

                    if pred == label:
                        count += 1

                total += inputs['source_input_ids'].size(0)
            
        print(f'Accuracy: {count / total}')

Accuracy: 0.0


: 

In [116]:
prompt = "Box 0 contains jacket, Box 1 contains nothing, Box 2 contains lantern. Box 2 contains"
tokens = tokenizer(prompt, return_tensors='pt').input_ids.to(DEVICE)
output = model(tokens)
pred = torch.argmax(output.logits[0, -1], dim=-1)
print(tokenizer.decode(pred))

 nothing


In [38]:
tokenizer.decode(inputs['source_input_ids'][0])

'  Box 0 contains cash, Box 1 contains contrabass, Box 2 contains nametag. Box 0 contains</s></s></s>'