In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from baukit import TraceDict
from einops import rearrange, einsum
from tqdm import tqdm
from functools import partial
from datasets import Dataset
from model_aligner_script import load_data
from counterfactual_datasets.entity_tracking import object_alignment_example_sampler

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f490849ccf0>

In [2]:
path = "./llama_7b"
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path).to(DEVICE)
tokenizer.pad_token_id = tokenizer.eos_token_id

model.eval()
for param in model.parameters():
    param.requires_grad_(False)

loading file tokenizer.model
loading file added_tokens.json
loading file special_tokens_map.json
loading file tokenizer_config.json
loading configuration file ./llama_7b/config.json
Model config LlamaConfig {
  "_name_or_path": "./llama_7b",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.28.0.dev0",
  "use_cache": true,
  "vocab_size": 32000
}

loading weights file ./llama_7b/pytorch_model.bin.index.json
Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "pad_token_id": 0,
  "transformers_version": "4.28.0.dev0"
}

Loading checkpoint shards: 100%|██████████| 3/3 [00:09<

In [3]:
NUM_HEADS = model.config.num_attention_heads
HEAD_SIZE = model.config.hidden_size // NUM_HEADS

## Desiderata

In [4]:
data_file_path = "./box_datasets/no_instructions/original/3/train.jsonl"
object_file_path = "./box_datasets/objects_with_bnc_frequency.csv"

In [5]:
objValueFetcher_train, objValueFetcher_eval, objValueFetcher_test = load_data(
    tokenizer=tokenizer,
    data_size=500,
    aligner_func=object_alignment_example_sampler,
    data_file=data_file_path,
    num_ents_or_ops=3,
    batch_size=40,
    architecture="",
    object_file=object_file_path,
    alt_examples=False,
)

Train size:  400
Eval size:  50
Test size:  50


In [6]:
desiderata_train = [objValueFetcher_train]
desiderata_eval = [objValueFetcher_eval]
desiderata_valid = [objValueFetcher_test]

In [7]:
data = next(enumerate(desiderata_train[0]))[1]
bi = 1
print(tokenizer.decode(data["base_input_ids"][bi][:data["base_input_last_pos"][bi]]))
print(tokenizer.decode(data["source_input_ids"][bi][:data["source_input_last_pos"][bi]]))
print(tokenizer.decode(data["labels"][bi]))

 Box 0 contains the sheet, Box 1 contains the tunic, Box 2 contains the incense. Box 1 contains
 Box 0 contains the sheet, Box 1 contains the fig, Box 2 contains the incense. Box 1 contains
 fig


## Training Binary Mask

In [8]:
modules = [f"model.layers.{i}.self_attn.o_proj" for i in range(model.config.num_hidden_layers)]

In [9]:
from_activations_train = {}

for di, desid_train in enumerate(desiderata_train):
    from_activations_train[di] = {}

    for bi, inputs in enumerate(desid_train):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(DEVICE)

        from_activations_train[di][bi] = {}
        with torch.no_grad():
            with TraceDict(model, modules, retain_input=True) as trace:
                _ = model(inputs["source_input_ids"])

                for module in modules:
                    if "self_attn" in module:
                        from_activations_train[di][bi][module] = trace[module].input.detach().cpu()
                    else:
                        from_activations_train[di][bi][module] = trace[module].output.detach().cpu()

        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to("cpu")
        
        del trace
        torch.cuda.empty_cache()

In [10]:
modules_w_heads = []
for module in modules:
    if "self_attn" in module:
        for head in range(model.config.num_attention_heads):
            modules_w_heads.append(f"{module}.{head}")
    else:
        modules_w_heads.append(module)

mask_dict = {module: i for i, module in enumerate(modules_w_heads)}

In [11]:
def edit_output(inputs, output, layer, mask, from_activations, to_last_token_pos, from_last_token_pos):
    if "self_attn" in layer:
        inp = inputs[0]
        from_activations[layer] = from_activations[layer].to(DEVICE)

        # Computing the output of each head in this layer after the intervention
        for head_idx in range(NUM_HEADS):
            head_start = head_idx * HEAD_SIZE
            head_end = (head_idx + 1) * HEAD_SIZE
            abl_amt = mask[mask_dict[f"{layer}.{head_idx}"]]

            for bi in range(inp.shape[0]):
                intervention = abl_amt * inp[bi, to_last_token_pos[bi], head_start:head_end].clone() + (1 - abl_amt) * from_activations[layer][bi, from_last_token_pos[bi], head_start:head_end]
                inp[bi, to_last_token_pos[bi], head_start:head_end] = intervention

        from_activations[layer] = from_activations[layer].to("cpu")

        weights = model.state_dict()[f"{layer}.weight"]
        mod_output = einsum(inp, weights, "batch seq_len hidden_size, d_model hidden_size -> batch seq_len d_model")

        del weights
        torch.cuda.empty_cache()
        return mod_output

    else:
        assert False, "shouldn't be here"

In [14]:
mask = torch.ones(len(modules_w_heads), requires_grad=True, device=DEVICE, dtype=torch.float)
optimizer = torch.optim.Adam([mask], lr=1e-1)

for epoch in range(100):
    for di, desid_train in enumerate(desiderata_train):
        for bi, inputs in enumerate(desid_train):
            mask.data.clamp_(0, 1)
            optimizer.zero_grad()

            with TraceDict(
                model, 
                modules, 
                edit_output=partial(
                    edit_output, 
                    mask=mask,
                    from_activations=from_activations_train[di][bi], 
                    to_last_token_pos=inputs["base_input_last_pos"],
                    from_last_token_pos=inputs["source_input_last_pos"]
                )
            ) as _:
                output = model(inputs["base_input_ids"].to(DEVICE))

            target_logits = 0
            for idx in range(inputs["base_input_ids"].size(0)):
                target = inputs["labels"][idx]
                target_logits += output.logits[idx, inputs["base_input_last_pos"][idx], target]
            target_logits /= inputs["base_input_ids"].size(0)

            # maximize the target logits => minimize the negative target logits
            # minimize the number of heads => maximize #ones in the mask
            loss = 6*-target_logits + torch.sum(1 - mask)
            print(f"epoch: {epoch}, bi: {bi}, Loss: {loss.item()}, Target logits: {target_logits.item()}")
            rounded = torch.round(mask.data)
            print(f"#Zero heads: {(rounded == 0).nonzero().shape[0]}")

            loss.backward()
            optimizer.step()
        
        del output
        torch.cuda.empty_cache()

epoch: 0, bi: 0, Loss: -33.346099853515625, Target logits: 5.55768346786499
#Zero heads: 0
epoch: 0, bi: 1, Loss: -33.410945892333984, Target logits: 5.935157775878906
#Zero heads: 0
epoch: 0, bi: 2, Loss: -33.81093978881836, Target logits: 6.2650275230407715
#Zero heads: 0
epoch: 0, bi: 3, Loss: -33.954158782958984, Target logits: 6.501003265380859
#Zero heads: 0
epoch: 0, bi: 4, Loss: -43.0528564453125, Target logits: 8.271024703979492
#Zero heads: 0
epoch: 0, bi: 5, Loss: -32.17881393432617, Target logits: 6.680758953094482
#Zero heads: 1
epoch: 0, bi: 6, Loss: -37.13197326660156, Target logits: 7.72890567779541
#Zero heads: 12
epoch: 0, bi: 7, Loss: -37.0977783203125, Target logits: 7.933606147766113
#Zero heads: 13
epoch: 0, bi: 8, Loss: -38.54935836791992, Target logits: 8.368754386901855
#Zero heads: 15
epoch: 0, bi: 9, Loss: -41.75432586669922, Target logits: 9.07913875579834
#Zero heads: 15
epoch: 1, bi: 0, Loss: -42.205509185791016, Target logits: 9.314103126525879
#Zero head

In [18]:
masked_heads = []
inverse_mask_dict = {v: k for k, v in mask_dict.items()}

for mask_idx in (mask == 0).nonzero()[:, 0]:
    layer = inverse_mask_dict[mask_idx.item()]
    layer_idx = int(layer.split('.')[2])
    head_idx = int(layer.split('.')[-1])
    masked_heads.append([layer_idx, head_idx])

In [19]:
masked_heads

[[21, 3], [21, 25], [23, 20], [23, 24], [24, 5], [24, 7], [28, 17]]

: 

In [16]:
del output
torch.cuda.empty_cache()

In [41]:
rounded = torch.round(mask)
(rounded == 0).nonzero().shape[0]

567

## Evaluating Learned Mask

In [62]:
# rounded = [torch.round(mask) for mask in masks.values()]
# (rounded[0] == 0).nonzero().shape

In [10]:
modules = [f"model.layers.{i}.self_attn.o_proj" for i in range(model.config.num_hidden_layers)]

In [44]:
source_activations_valid = {}

for di, desid_train in enumerate(desiderata_train):
    source_activations_valid[di] = {}

    for bi, inputs in enumerate(desid_train):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(DEVICE)

        source_activations_valid[di][bi] = {}
        with torch.no_grad():
            with TraceDict(model, modules, retain_input=True) as trace:
                _ = model(inputs["base_input_ids"])

                for module in modules:
                    if "self_attn" in module:
                        source_activations_valid[di][bi][module] = trace[module].input.detach().cpu()
                    else:
                        source_activations_valid[di][bi][module] = trace[module].output.detach().cpu()

        del trace
        torch.cuda.empty_cache()
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to("cpu")

In [46]:
with torch.no_grad():
    for round_mask in [torch.ones(model.config.num_hidden_layers * model.config.num_attention_heads)]:
        count, total = 0, 0
        for di, desid_valid in enumerate(desiderata_train):
            accuracy = []
            for bi, inputs in enumerate(desid_train):
                with TraceDict(
                    model,
                    modules,
                    edit_output=partial(
                        dummy_edit,
                        mask=round_mask,
                        from_activations=source_activations_valid[di][bi],
                        to_last_token_pos=inputs["source_input_last_pos"],
                        from_last_token_pos=inputs["base_input_last_pos"],
                    ),
                ) as _:
                    outputs = model(inputs['source_input_ids'].to(DEVICE))

                for i in range(inputs['source_input_ids'].size(0)):
                    logits = outputs.logits[i, inputs['source_input_last_pos'][i]]
                    pred = torch.argmax(logits, dim=-1)
                    label = inputs['labels'][i]

                    if pred == label:
                        count += 1

                total += inputs['source_input_ids'].size(0)
            
        print(f'Accuracy: {count / total}')

Accuracy: 0.0


: 

In [116]:
prompt = "Box 0 contains jacket, Box 1 contains nothing, Box 2 contains lantern. Box 2 contains"
tokens = tokenizer(prompt, return_tensors='pt').input_ids.to(DEVICE)
output = model(tokens)
pred = torch.argmax(output.logits[0, -1], dim=-1)
print(tokenizer.decode(pred))

 nothing


In [38]:
tokenizer.decode(inputs['source_input_ids'][0])

'  Box 0 contains cash, Box 1 contains contrabass, Box 2 contains nametag. Box 0 contains</s></s></s>'