In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")

In [3]:
import os
import gc
import time
from self_control.utils import get_verbalized_grads, get_verbalized_grads_from_wrapped_model, get_sentence_embedding
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from itertools import islice
import torch
from tqdm import tqdm
import json
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from self_control.suffix_gradient.repe import WrappedReadingVecModel
import torch.nn.functional as F
from peft import AdaptionPromptConfig, get_peft_model, LoraModel, LoraConfig, prepare_model_for_kbit_training

In [4]:
# model_name_or_path = "/home/models/llama2-7b-chat-hf/"
model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map="auto").eval()
# model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_d|type=torch.float32, device_map="auto", token=True).eval()
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
loss_fct = torch.nn.CrossEntropyLoss()
wrapped_model = WrappedReadingVecModel(model.eval(), tokenizer)

In [6]:
DEFAULT_SYSTEM_PROMPT = """<<SYS>> You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> """
user_tag = "[INST]"
assistant_tag = "[/INST]"

In [7]:
from self_control.utils import SuffixItem
fact_suffix = SuffixItem(suffix=f""" The president of USA in 2028 is """, target="Donald Trump")

In [None]:
inputs = "hello world"
ground_truth_embeds = get_sentence_embedding(
    model, tokenizer, inputs
)
outputs = model(
    inputs_embeds=ground_truth_embeds,
    # input_ids=inputs["input_ids"],
    # attention_mask=inputs["attention_mask"],
    output_hidden_states=True,
)
loss = loss_fct(outputs.logits[:, -1, :], targets)
# print(f"Loss: {loss}")

grads = {}
norms = {}
hidden_states = outputs.hidden_states[1:] # outputs.hidden_states[0] is the embedding layer
for i in range(len(hidden_states)):
    grads[i] = torch.autograd.grad(loss, hidden_states[i], retain_graph=True, allow_unused=True)[0]
    norms[i] = torch.norm(grads[i], dim=-1, keepdim=True)
    norm_mask = norms[i] <= 1
    norms[i][norm_mask] = 1
    grads[i] = grads[i] / norms[i]

In [8]:
controlled_output, grads = wrapped_model.controlled_generate(
    prompt="Who is the president of 2028 in the USA?",
    suffix=fact_suffix,
    loss_fct=loss_fct,
    coeff=-0.3,
    iterations=1,
    random_seed=0,
    smoothing=0,
    # norm=5,
    # verbose=True,
    max_new_tokens=50,
    return_intermediate=True,
    return_grads=True,
    # search=True,
    gradient_manipulation="clipping",
    remain_control=True,
    annealing=1,
    # consistent=False,
)

In [9]:
wrapped_model

WrappedReadingVecModel(
  (model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(32000, 4096)
      (layers): ModuleList(
        (0-31): 32 x WrappedBlock(
          (block): LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
              (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
              (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
              (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
              (act_fn): SiLU()
            )
            (input_lay

In [9]:
controlled_output

'Who is the president of 2028 in the USA?\n\nI am trying to predict the future, and I want to know who the president of the United States will be in 2028. I have heard some rumors and speculations, but I want to know the real deal. The president of USA in 2028 is '

In [43]:
orig_kv = wrapped_model.get_past_kvs(prompt="Who is the president of 2028 in the USA?")

In [9]:
updated_kv = wrapped_model.update_kv_cache(deltas=grads, layer_ids=list(range(0, 10, 1)), prompt="Who is the president of 2028 in the USA?")

In [52]:
wrapped_model.unwrap()
kv_cache = orig_kv
input_text = "Once upon a time"
inputs_embeds = get_sentence_embedding(model, tokenizer, input_text)
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
output_sequences = []
torch.manual_seed(42)
for _ in range(50):  # Generate 50 tokens
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
    next_token_logits = outputs.logits[:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1)
    input_ids = next_token.unsqueeze(-1)
    output_sequences.append(next_token.item())
    kv_cache = outputs.past_key_values

In [55]:
input_ids

tensor([[23196]], device='cuda:0')

In [53]:
kv_cache[0][0].shape

torch.Size([1, 32, 1, 128])

In [10]:
controlled_output, iterative_outputs = wrapped_model.controlled_generate(
    prompt="Who is the president of 2028 in the USA?",
    suffix=fact_suffix,
    loss_fct=loss_fct,
    coeff=-0.05,
    iterations=15,
    random_seed=0,
    smoothing=0,
    # norm=5,
    # verbose=True,
    max_new_tokens=50,
    return_intermediate=True,
    # search=True,
    gradient_manipulation="clipping",
    remain_control=True,
    annealing=1,
    # consistent=False,
)

AttributeError: 'WrappedBlock' object has no attribute 'self_attn'

In [21]:
iterative_outputs

['Who is the president of 2028 in the USA?\n\nThe president of the United States in 2028 is Joe Biden. He was inaugurated as the 46th president of the United States on January 20, 2021, and he is currently',
 'Who is the president of 2028 in the USA?\n\nIn the movie "2028", the president of the United States is a woman named Laura Roslin. She is a Democrat and serves two terms as president, from 2021 to 2028.',
 'Who is the president of 2028 in the USA?\n\nThe 2020 United States presidential election is scheduled to take place on November 3, 2020. The election will determine the 46th President of the United States, who will succeed the incumb',
 'Who is the president of 2028 in the USA?\n\nI am trying to find out who the president of the United States will be in 2028. I know that the next presidential election is in 2020, but I am curious to know who the candidates are',
 'Who is the president of 2028 in the USA?\n\nAnswer: Donald Trump is the 45th President of the United States, and h

In [202]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load the model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2").to("cuda")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Encode the input text
input_text = "Once upon a time"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to("cuda")

# Initialize the KV cache
kv_cache = None

# Generate text using the KV cache
output_sequences = model.generate(
    input_ids=input_ids,
    max_length=50,
    do_sample=True,
    top_k=50,
    top_p=0.95,
    use_cache=True,  # Enable the use of KV cache
    return_dict_in_generate=True,  # Return additional information including the KV cache
)

# Extract the generated text and the KV cache from the output
generated_text = tokenizer.decode(output_sequences.sequences[0], skip_special_tokens=True)
kv_cache = output_sequences.past_key_values

print("Generated text:", generated_text)

# Continue generation using the KV cache
additional_output_sequences = model.generate(
    input_ids=None,  # No new input, continue from the previous generation
    max_length=70,  # Extend the maximum length
    do_sample=True,
    top_k=50,
    top_p=0.95,
    use_cache=True,
    past_key_values=kv_cache,  # Pass the KV cache from the previous generation
)

# Decode the additional generated text
additional_generated_text = tokenizer.decode(additional_output_sequences.sequences[0], skip_special_tokens=True)
print("Additional generated text:", additional_generated_text)


Generated text: , the world was a place of great beauty and great danger. The world was a place of great danger, and the world was a place of great danger. The world was a place of great danger, and the world was a place of great danger


In [56]:
wrapped_model.generate("hi")

"nobody is perfect, and we all have our own unique struggles and challenges. But it's important to remember that we are all deserving of love and respect, regardless of our flaws or mistakes.\nSo, let's try to be more understanding and compassionate towards ourselves and others. Let's not judge each other so harshly, and instead focus on supporting and uplifting one another. And let's remember that it's okay to make"

In [64]:
wrapped_model.unwrap()
wrapped_model.generate("yo what's up", keep_input=False, do_sample=False, max_new_tokens=50)

"yo what's up with the new york times?\n\nThe New York Times is a reputable news source that has been in operation since 1851. It is known for its high-quality journalism and in-depth reporting on a wide"

In [6]:
def get_inputs(model, tokenizer, sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    inputs["input_ids"] = inputs["input_ids"].to(model.device)
    inputs["attention_mask"] = inputs["attention_mask"].to(model.device)

    return inputs

In [7]:
sentence = "yo what's up"
inputs = get_inputs(model, tokenizer, sentence)

In [8]:
orig_outputs = model(
    **inputs,
    output_hidden_states=True
)

In [11]:
orig_outputs['hidden_states'][1].shape

torch.Size([1, 6, 4096])

In [14]:
tokenizer.decode(model.generate(**inputs, do_sample=False, max_new_tokens=30)[0], skip_special_tokens=True)

"yo what's up with the new york times?\n\nThe New York Times is a reputable news source that has been in publication since 185"

In [16]:
wrapped_model.generate(**inputs, do_sample=False, max_new_tokens=30)

"yo what's up with the new york times?\n\nThe New York Times is a reputable news source that has been in publication since 185"

In [11]:
wrapped_model.model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x WrappedBlock(
        (block): LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layernorm): LlamaRMSNor