In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import json
import shutil
import os
from datetime import datetime
from glob import glob
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
token = "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"

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block
        self.last_hidden_state = None
        self.add_activations = None
        self.output_init = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.last_hidden_state = output[0]
        self.output_before_adding = output
        if self.add_activations is not None:
            output = (output[0] + self.add_activations,) + output[1:]
        self.output_after_adding = output
        return output

    def add(self, activations):
        self.add_activations = activations

    def reset(self):
        self.last_hidden_state = None
        self.add_activations = None

    
class Llama27BHelper:
    def __init__(self, pretrained_model="meta-llama/Llama-2-7b-hf"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model, device_map="auto", use_auth_token=token, torch_dtype=torch.bfloat16)
        self.model = AutoModelForCausalLM.from_pretrained(pretrained_model, device_map="auto", use_auth_token=token, torch_dtype=torch.bfloat16)#.to(self.device)
        for i, layer in enumerate(self.model.model.layers):
            self.model.model.layers[i] = BlockOutputWrapper(layer)

    def generate_text(self, prompt, do_sample=False, temperature=1., max_length=100):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        generate_ids = self.model.generate(inputs.input_ids.to(self.device), do_sample=do_sample, temperature=temperature,max_length=max_length)
        return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    
    def get_logits(self, tokens):
        with torch.no_grad():
            logits = self.model(tokens.to(self.device)).logits
            return logits
    
    def get_last_activations(self, layer):
        return self.model.model.layers[layer].last_hidden_state

    def set_add_activations(self, layer, activations):
        self.model.model.layers[layer].add(activations)

    def reset_all(self):
        for layer in self.model.model.layers:
            layer.reset()

In [None]:
model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
tokenizer.pad_token = tokenizer.eos_token

model = Llama27BHelper(model_name)

Loading checkpoint shards: 100%|██████████| 2/2 [01:38<00:00, 49.31s/it]


In [4]:
tokenizer.tokenize("elephant"), tokenizer.tokenize("crocodile"), tokenizer.tokenize("rhinoceros")

(['▁ele', 'ph', 'ant'], ['▁cro', 'cod', 'ile'], ['▁r', 'hin', 'oc', 'eros'])

In [5]:
acts_size = 4096 if model_name == "meta-llama/Llama-2-7b-hf" else 5120

neg_inputs = ["horn", "the horn", "a horn"]
neg_acts = torch.zeros((1, 1, acts_size))
for seq in neg_inputs:
    model.reset_all()
    model.get_logits(tokenizer.encode(seq, return_tensors="pt"))
    # get the activations of the last token because that carries most of the relevant context
    neg_acts += model.get_last_activations(28)[0, -1, :].detach().cpu()
    
neg_acts = neg_acts / len(neg_inputs)
neg_acts = neg_acts.to(torch.half)

In [6]:
model.reset_all()
# t = tokenizer.encode("A cow", return_tensors="pt")
# print("Tokens", t)
# model.get_logits(t)
# inputs = ["Cow", "the cow", "cow", "A cow", "a cow"]
# inputs = ["Elephant", "the elephant", "elephant", "an elephant"]
inputs = ["crocodile", "the crocodile", "a crocodile", "a crocodile"]
# inputs = ["rhinoceros", "the rhinoceros", "a rhinoceros", "a rhinoceros"]

# inputs = ["cow"]
multipliers = [0, 0.5, 1, 5, 10, 15]
layer = 28

acts = torch.zeros((1, 1, acts_size))
for seq in inputs:
    model.reset_all()
    model.get_logits(tokenizer.encode(seq, return_tensors="pt"))
    # get the activations of the last token because that carries most of the relevant context
    acts += model.get_last_activations(layer)[0, -1, :].detach()

acts = acts / len(inputs)
acts = acts.to(torch.half)

for m in multipliers:
    print(f"\n-----{m}-----")
    model.reset_all()
#     model.set_add_activations(28, m*(acts-neg_acts).to("cuda:1"))
    model.set_add_activations(layer, m*(acts-neg_acts))
    
    for _ in range(5):
        out = model.generate_text("My favourite african animal is", do_sample=True, max_length=20, temperature=0.2)
        print(out[:30] + ":" + out[30:])


-----0-----


RuntimeError: expected scalar type Float but found BFloat16

In [46]:
(980 + 0.15*700 + 2500 / 12 * 2)

1501.6666666666667

In [49]:
(500 + 0.15*700 + 2500 / 15 * 2)

938.3333333333333

In [50]:
2450 / 14

175.0

In [41]:
2500 / 15 * 2

333.3333333333333