In [None]:
!pip install accelerate transformers -q
!pip install torch -q

In [None]:
!pip install peft -q
!pip install logger -q

In [None]:
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from contrast_vector_control import *
from transformers.models.llama.modeling_llama import LlamaModel, LlamaForCausalLM
from transformers.cache_utils import StaticCache, DynamicCache, Cache
from transformers.models.mistral.modeling_mistral import MistralModel, MistralForCausalLM
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation import GreedySearchDecoderOnlyOutput
from transformers import AutoTokenizer, AutoConfig, pipeline, AutoModelForCausalLM
from transformers.generation.streamers import BaseStreamer
import torch
import torch.nn.functional as F
from torch import nn
from typing import List, Optional, Tuple, Union
from transformers.modeling_attn_mask_utils import (
    AttentionMaskConverter,
    _prepare_4d_attention_mask,
    _prepare_4d_causal_attention_mask,
    _prepare_4d_causal_attention_mask_for_sdpa,
)
from functools import partial
import warnings
import logger
import numpy as np
import torch.distributed as dist

In [None]:
model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
model = ContrastVecLlamaForCausalLM.from_pretrained(model_name_or_path, use_auth_token="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", torch_dtype=torch.bfloat16, device_map="sequential")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", use_auth_token="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")
tokenizer.pad_token_id = 0

In [None]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

template = "[INST] {instruction} [/INST]"
pos_tag = "Give a happy response. "
neg_tag = "Give a sad response."
layer_ids = np.arange(0, 32, 1).tolist()
contrast_tokens = -1
alpha = 0.5

question = "I played basketball today."

pos_question = pos_tag + question
neg_question = neg_tag + question

question = template.format(instruction=question)
pos_question = template.format(instruction=pos_question)
neg_question = template.format(instruction=neg_question)

input_tokenized = tokenizer([question, pos_question, neg_question], return_tensors="pt", padding="longest").to(model.device)

repe_kwargs = dict(pos_input_ids=input_tokenized["input_ids"][1].unsqueeze(dim=0),
                 pos_attention_mask=input_tokenized["attention_mask"][1].unsqueeze(dim=0),
                 neg_input_ids=input_tokenized["input_ids"][2].unsqueeze(dim=0),
                 neg_attention_mask=input_tokenized["attention_mask"][2].unsqueeze(dim=0),
                 contrast_tokens=-8,
                 compute_contrast=True,
                 alpha=alpha,
                 control_layer_ids=layer_ids)

with torch.no_grad():

    original_output = model.generate(input_tokenized["input_ids"][0].unsqueeze(dim=0),
                                     attention_mask=input_tokenized["attention_mask"][0].unsqueeze(dim=0),
                                     max_new_tokens=100,
                                     do_sample=False)


    controlled_output = model.generate(input_tokenized["input_ids"][0].unsqueeze(dim=0),
                                       attention_mask=input_tokenized["attention_mask"][0].unsqueeze(dim=0),
                                       max_new_tokens=100,
                                       use_cache=False,
                                       do_sample=False,
                                       **repe_kwargs)

In [None]:
original_output = tokenizer.decode(original_output[0], skip_special_tokens=True)
controlled_output = tokenizer.decode(controlled_output[0], skip_special_tokens=True)
print("========== Original Output ==========")
print(original_output)
print("==========Controlled Output ==========")
print(controlled_output)