In [1]:
# Unofficial Inplementation of In-Context-Vector with transformer_lens Library
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px
import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from functools import partial
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from typing import List, Tuple
import os

In [1]:
import gpustat
gpustat.print_gpustat()

SH-IDC1-10-140-1-42       Mon Feb 26 15:14:45 2024  525.60.13
[0] NVIDIA A100-SXM4-80GB | 59°C,  99 % | 74975 / 81920 MB | limo(74972M)
[1] NVIDIA A100-SXM4-80GB | 44°C,  79 % | 20027 / 81920 MB | limo(20024M)
[2] NVIDIA A100-SXM4-80GB | 51°C,  99 % | 74975 / 81920 MB | limo(74972M)
[3] NVIDIA A100-SXM4-80GB | 57°C,  45 % | 35831 / 81920 MB | limo(35828M)
[4] NVIDIA A100-SXM4-80GB | 39°C,   0 % |     0 / 81920 MB |
[5] NVIDIA A100-SXM4-80GB | 41°C,  46 % | 39435 / 81920 MB | limo(39432M)
[6] NVIDIA A100-SXM4-80GB | 41°C,  47 % | 39533 / 81920 MB | limo(39530M)
[7] NVIDIA A100-SXM4-80GB | 51°C,  49 % | 35995 / 81920 MB | limo(35992M)


In [4]:
model_path = os.path.join(os.environ['my_models_dir'], 'llama-7b')
llama_7b_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = HookedTransformer.from_pretrained_no_processing("llama-7b-hf", hf_model=llama_7b_model, tokenizer=tokenizer, dtype='float16')
# model = HookedTransformer.from_pretrained_no_processing("gpt2-xl", dtype='float16')
dataset = load_dataset("s-nlp/paradetox")
display(dataset)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


Loaded pretrained model llama-7b-hf into HookedTransformer


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


DatasetDict({
    train: Dataset({
        features: ['en_toxic_comment', 'en_neutral_comment'],
        num_rows: 19744
    })
})

In [7]:
act_name = 'resid_pre'
def svd_flip(u, v):
    # columns of u, rows of v
    max_abs_cols = torch.argmax(torch.abs(u), 0)
    i = torch.arange(u.shape[1]).to(u.device)
    signs = torch.sign(u[max_abs_cols, i])
    u *= signs
    v *= signs.view(-1, 1)
    return u, v

class PCA(nn.Module):
    def __init__(self, n_components):
        super().__init__()
        self.n_components = n_components

    @torch.no_grad()
    def fit(self, X):
        n, d = X.size()
        if self.n_components is not None:
            d = min(self.n_components, d)
        self.register_buffer("mean_", X.mean(0, keepdim=True))
        Z = X - self.mean_ # center
        U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
        Vt = Vh
        U, Vt = svd_flip(U, Vt)
        self.register_buffer("components_", Vt[:d])
        return self

    def forward(self, X):
        return self.transform(X)

    def transform(self, X):
        assert hasattr(self, "components_"), "PCA must be fit before use."
        return torch.matmul(X - self.mean_, self.components_.t())

    def fit_transform(self, X):
        self.fit(X)
        return self.transform(X)

    def inverse_transform(self, Y):
        assert hasattr(self, "components_"), "PCA must be fit before use."
        return torch.matmul(Y, self.components_) + self.mean_

def get_in_context_vectors(model:HookedTransformer, positive_sentences, negative_sentences):
    pos_tokens = model.to_tokens(positive_sentences)
    neg_tokens = model.to_tokens(negative_sentences)
    names_filter = lambda x : x.startswith("blocks.") and x.endswith(act_name)
    pos_logits, pos_cache = model.run_with_cache(pos_tokens, names_filter=names_filter)
    neg_logits, neg_cache = model.run_with_cache(neg_tokens, names_filter=names_filter)
    pos_vectors = einops.rearrange([pos_cache[utils.get_act_name(act_name, l)][:,-1,:] for l in range(model.cfg.n_layers)], 'l b d -> b (l d)')
    neg_vectors = einops.rearrange([neg_cache[utils.get_act_name(act_name, l)][:,-1,:] for l in range(model.cfg.n_layers)], 'l b d -> b (l d)')
    fit_data = pos_vectors - neg_vectors
    pca = PCA(n_components=1).to(fit_data.device).fit(fit_data.float())
    direction = (pca.components_.sum(dim=0,keepdim=True) + pca.mean_).mean(0)
    icv = direction.view(model.cfg.n_layers, -1)
    return icv

def display_icv(model:HookedTransformer, icv, input, lamb=0.1):
    model.reset_hooks()
    original_output = model.generate(input, max_new_tokens=20, temperature=0, top_k=1, top_p=1, do_sample=False)
    def residual_stream_edit_hook(
        resid_pre: Float[torch.Tensor, "batch pos d_model"],
        hook: HookPoint,
        layer: int
    ) -> Float[torch.Tensor, "batch pos d_model"]:
        original_norm = torch.norm(resid_pre, dim=-1, keepdim=True)
        resid_pre += einops.repeat(icv[layer], 'd_model -> batch pos d_model', batch=resid_pre.shape[0], pos=resid_pre.shape[1]) * lamb
        new_norm = torch.norm(resid_pre, dim=-1, keepdim=True)
        resid_pre = resid_pre / new_norm * original_norm
        return resid_pre
    for l in range(model.cfg.n_layers):
        model.blocks[l].hook_resid_pre.add_hook(partial(residual_stream_edit_hook, layer=l))
    modified_output = model.generate(input, max_new_tokens=20, temperature=0, top_k=1, top_p=1, do_sample=False)
    print(f"Original: \n\n{original_output}\nModified: \n\n{modified_output}")

In [8]:
num_shots = 5
positive_sentences = dataset['train']['en_neutral_comment'][:num_shots]
negative_sentences = dataset['train']['en_toxic_comment'][:num_shots]
icv = get_in_context_vectors(model, positive_sentences, negative_sentences)

In [9]:
input_template = "Instruction: Please paraphrase the following sentence.\nSentence:{}\nParaphrase:"
test_sentence = dataset['train']['en_toxic_comment'][-1]
input=input_template.format(test_sentence)
print('input: ', input)
display_icv(model, icv, input)

input:  Instruction: Please paraphrase the following sentence.
Sentence:if anyone deserved it , it was this shit bag .
Paraphrase:


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Original: 

Instruction: Please paraphrase the following sentence.
Sentence:if anyone deserved it , it was this shit bag .
Paraphrase: If anyone deserved it, it was this shit bag.
I'm not sure if
Modified: 

Instruction: Please paraphrase the following sentence.
Sentence:if anyone deserved it , it was this shit bag .
Paraphrase:textttextttextttextttextttextttextttextttextttextttextttextttextttextttextttextttextttextttextttextt
