In [17]:
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['HF_HOME'] = '/root/autodl-tmp/huggingface_cache'
# from intervented_model.llama import Intervented_LlamaForCausalLM
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import argparse
import intervented_model.llama as llama
from datasets import load_dataset
from torch.utils.data import DataLoader, Subset
import re
from transformers import LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
import json
from tqdm import tqdm

class ValueFunction(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ValueFunction, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        #self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        #x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DataCollatorReward:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, data):
        batch = {}
        data_batch = []
        for sample in data:
            data_batch.append({"input_ids": sample['input_ids'], "attention_mask": sample["attention_mask"]})
        batch_data = self.tokenizer.pad(data_batch, padding=True, return_tensors="pt")
        batch['input_ids'] = batch_data['input_ids']
        batch['attention_mask'] = batch_data['attention_mask']
        return batch  

In [18]:
import torch
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
import torch.nn as nn
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from typing import List, Optional, Tuple, Union
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from torch.nn import CrossEntropyLoss
import torch.utils.checkpoint
from transformers.cache_utils import Cache, DynamicCache, StaticCache
# from transformers.src.transformers.models.llama.modeling_llama import LlamaPreTrainedModel



logger = logging.get_logger(__name__)



class Intervented_LlamaModel(LlamaModel):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)
    
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        value_model: Optional[nn.Module] = None,
        lr: Optional[float] = None,
        epochs: Optional[int] = None
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        past_seen_tokens = 0
        if use_cache:  # kept for BC (cache positions)
            if not isinstance(past_key_values, StaticCache):
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
                past_seen_tokens = past_key_values.get_seq_length()

        if cache_position is None:
            if isinstance(past_key_values, StaticCache):
                raise ValueError("cache_position is a required argument when using StaticCache.")
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)
        print("DEBUG past_key_values type:", type(past_key_values))
        # causal_mask = self._update_causal_mask(
        #     attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
        # )
        print('hi i define causal mask')
        causal_mask = self._update_causal_mask(
            attention_mask,
            inputs_embeds,
            cache_position,
            past_key_values,
        )
        # embed positions
        hidden_states = inputs_embeds

        # decoder layers
        print('hi here decoder layers')
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None
        cos, sin = self.rotary_emb(hidden_states, position_ids)
        position_embeddings = (cos, sin)
        print('start decoder layer')
        for decoder_layer in self.layers:
            print('decoder layer, total ', len(self.layers))
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings
                )

            hidden_states = layer_outputs[0]
            
            if use_cache:
                if isinstance(layer_outputs[-1], (tuple, list, Cache)):
                    next_decoder_cache = layer_outputs[-1]
                else:
                    next_decoder_cache = None
            else:
                next_decoder_cache = None
            print("DEBUG layer_outputs length:", len(layer_outputs))

            # if use_cache:
            #     next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)
        
        #### use the value model to optimize the hidden states

        hidden_states_opt = hidden_states.detach().clone() # also need to set the hidden states to be requires_grad=True
        hidden_states_opt.requires_grad = True

        optimizer = torch.optim.SGD([hidden_states_opt], lr=lr)  # Using Adam optimizer to update x

        with torch.enable_grad():
            for iteration in range(epochs):  # Run for a number of iterations
                print('hi iter', iteration)
                optimizer.zero_grad()   # Clear previous gradients
                output = value_model(hidden_states_opt)         # Compute the function value at current x
                loss = -output.sum()          # Negate the output to convert maximization problem to minimization
                loss.backward()         # Compute gradients with respect to x
                optimizer.step()        # Update x
            
        hidden_states = hidden_states_opt.detach().clone()  # Update the hidden states with the optimized x

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = None
        if use_cache:
            next_cache = (
                next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
            )
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

class Intervented_LlamaForCausalLM(LlamaForCausalLM):
    # def __init__(self, config, value_model: nn.Module):
    def __init__(self, config):
        super().__init__(config)
        self.model = Intervented_LlamaModel(config)
        self.value_model = None

    def set_value_model(self, value_model):
        self.value_model = value_model.to(torch.float16)
    
    def set_lr_and_epochs(self, lr, epochs):
        self.lr = lr
        self.epochs = epochs
    
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            value_model=self.value_model,
            epochs = self.epochs,
            lr = self.lr
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


In [19]:
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='vicuna_7B', choices=["vicuna_7B", "falcon_7B", "llama3_8B"])
parser.add_argument('--dataset_name', type=str, default='shp', choices=["hh_rlhf", "shp"])
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--use_intervention', default=True)
parser.add_argument('--value_lr', default=0.0001)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--bellman', default=False)
args = parser.parse_args([])

if args.use_intervention:

    value_model = ValueFunction(input_dim=4096, hidden_dim=4096, output_dim=1)
    ##load weights
    value_model.load_state_dict(torch.load(f'value_model_checkpoint/value_model_{args.model_name}_{args.dataset_name}.pth', 
                                           map_location=torch.device('cpu')))


MODEL_NAMES = { 
    'vicuna_7B': 'lmsys/vicuna-7b-v1.5', 
    'falcon_7B': 'tiiuae/falcon-7b-instruct',
    'llama3_8B': 'meta-llama/Meta-Llama-3-8B'
}

DATASET_NAMES = { 
    'hh_rlhf': 'Anthropic/hh-rlhf', 
    'shp': 'stanfordnlp/SHP'
}
MODEL = MODEL_NAMES[args.model_name]
DATASET = DATASET_NAMES[args.dataset_name]
tokenizer = LlamaTokenizer.from_pretrained(MODEL, padding_side='left')

device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"

In [20]:
if args.use_intervention:
    model = Intervented_LlamaForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map=device, cache_dir='/root/autodl-tmp/huggingface_cache')
    model.set_value_model(value_model)
    model.set_lr_and_epochs(args.lr, args.epochs)
else:
    model = LlamaForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map=device, cache_dir='/root/autodl-tmp/huggingface_cache')

model.generation_config.temperature = None
model.generation_config.top_k = None
model.to(device)

dataset = load_dataset(DATASET)

if args.dataset_name == 'hh_rlhf':
    dataset = dataset.remove_columns("rejected")
    for split in dataset.keys():
        dataset[split] = dataset[split].rename_column('chosen', 'prompt')
    test_dataset = dataset['test']
elif args.dataset_name == 'shp':
    test_file_path = 'dataset/test_dataset_shp.json'
    test_dataset = load_dataset('json', data_files=test_file_path)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

In [21]:
if args.model_name == 'vicuna_7B':
    begin_word = 'Human: '
elif args.model_name == 'llama3_8B':
    begin_word = 'User: '
elif args.model_name == 'falcon_7B':
    begin_word = 'User: '
prompting = '''
    A question from a curious user and an answer from an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the userâ€™s questions.\n'''

def preprocessing(example):
    if args.dataset_name == 'hh_rlhf':
        replaced_text = example['prompt'].replace("Human:", begin_word)
        parts = replaced_text.rsplit("Assistant:", 1)  # Split the string at the last occurrence of "Assistant:"
        result = parts[0] + "Assistant:"  # Append "Assistant:" back to the first part if needed
    elif args.dataset_name == 'shp':
        text = example['history']
        result = begin_word + text + "\nAssistant:"    


    return {'prompt': result}


# dataset = dataset.map(preprocessing)
test_dataset = test_dataset.map(preprocessing)
dataloader = DataLoader(test_dataset['train'], batch_size=1) # only use the test set for now
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id    

In [None]:
generated_responses = []
for batch_prompts in tqdm(dataloader):
    # print(batch_prompts)
    encoded_inputs = tokenizer(batch_prompts['prompt'], return_tensors="pt", padding=True) # tokenize the prompt
    input_ids = encoded_inputs['input_ids'].to(device)
    attention_mask = encoded_inputs['attention_mask'].to(device)
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=128, num_return_sequences=1, return_dict_in_generate=True)
    outputs_text = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
    for prompt, output in zip(batch_prompts['prompt'], outputs_text):
        generated_responses.append({'prompt': prompt, 'result': output.removeprefix(prompt), 'response': output})
    break

  0%|          | 0/1000 [00:00<?, ?it/s]

DEBUG past_key_values type: <class 'transformers.cache_utils.DynamicCache'>
hi i define causal mask
hi here decoder layers
start decoder layer
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG layer_outputs length: 1
decoder layer, total  32
DEBUG l