In [1]:
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['HF_HOME'] = '/hy-tmp/'
# from intervented_model.llama import Intervented_LlamaForCausalLM
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import argparse
import intervented_model.llama as llama
from datasets import load_dataset
from torch.utils.data import DataLoader, Subset
import re
from transformers import LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
import json
from tqdm import tqdm

class ValueFunction(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ValueFunction, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        #self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        #x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DataCollatorReward:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, data):
        batch = {}
        data_batch = []
        for sample in data:
            data_batch.append({"input_ids": sample['input_ids'], "attention_mask": sample["attention_mask"]})
        batch_data = self.tokenizer.pad(data_batch, padding=True, return_tensors="pt")
        batch['input_ids'] = batch_data['input_ids']
        batch['attention_mask'] = batch_data['attention_mask']
        return batch  

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from functools import partial

from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
import torch.nn as nn
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from typing import List, Optional, Tuple, Union
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from torch.nn import CrossEntropyLoss
import torch.utils.checkpoint
from transformers.cache_utils import Cache, DynamicCache, StaticCache
# from transformers.src.transformers.models.llama.modeling_llama import LlamaPreTrainedModel



logger = logging.get_logger(__name__)



class Intervented_LlamaModel(LlamaModel):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)

    def set_opt_control_params(self, value_model, lr, epochs):
        self.value_model = value_model.to(torch.float16)
        self.lr = lr
        self.epochs = epochs

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs,
    ) -> BaseModelOutputWithPast:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
        if not isinstance(past_key_values, (type(None), Cache)):
            raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **flash_attn_kwargs,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # intervention
        if self.value_model is not None and self.lr is not None and self.epochs is not None:
            hidden_states_opt = hidden_states.detach().clone()
            hidden_states_opt.requires_grad = True
            # with torch.no_grad():
            #     value_before = self.value_model(hidden_states).mean().item()

            optimizer = torch.optim.SGD([hidden_states_opt], lr=self.lr)

            with torch.enable_grad():
                for _ in range(self.epochs):
                    optimizer.zero_grad()
                    value = self.value_model(hidden_states_opt)   # [B, seq, 1]
                    loss = -value.sum()
                    loss.backward()
                    optimizer.step()
            # with torch.no_grad():
            #     value_after = self.value_model(hidden_states_opt).mean().item()
        
            # print(f"[Intervention] Value before: {value_before:.4f}, after: {value_after:.4f}, Δ={value_after - value_before:.4f}")

            hidden_states = hidden_states_opt.detach().clone()
        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

class Intervented_LlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Intervented_LlamaModel(config)

    def set_opt_control_params(self, value_model, lr, epochs):
        self.model.set_opt_control_params(value_model, lr, epochs)
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs,
    ) -> CausalLMOutputWithPast:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='vicuna_7B', choices=["vicuna_7B", "falcon_7B", "llama3_8B"])
parser.add_argument('--dataset_name', type=str, default='shp', choices=["hh_rlhf", "shp"])
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--use_intervention', default=True)
parser.add_argument('--value_lr', default=0.01)
parser.add_argument('--epochs', type=int, default=100)
args = parser.parse_args([])

device = f"cuda:{args.device}" if torch.cuda.is_available() else "cpu"

if args.use_intervention:

    value_model = ValueFunction(input_dim=4096, hidden_dim=4096, output_dim=1)
    ##load weights
    value_model.load_state_dict(torch.load(f'value_model_checkpoint/value_model_{args.model_name}_{args.dataset_name}.pth', 
                                           map_location=torch.device(device)))


MODEL_NAMES = { 
    'vicuna_7B': 'lmsys/vicuna-7b-v1.5', 
    'falcon_7B': 'tiiuae/falcon-7b-instruct',
    'llama3_8B': 'meta-llama/Meta-Llama-3-8B'
}

DATASET_NAMES = { 
    'hh_rlhf': 'Anthropic/hh-rlhf', 
    'shp': 'stanfordnlp/SHP'
}
MODEL = MODEL_NAMES[args.model_name]
DATASET = DATASET_NAMES[args.dataset_name]
tokenizer = LlamaTokenizer.from_pretrained(MODEL, padding_side='left')

In [4]:
if args.use_intervention:
    model = Intervented_LlamaForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map=device, cache_dir='/hy-tmp/huggingface_cache')
    model.set_opt_control_params(value_model, args.value_lr, args.epochs)
else:
    model = LlamaForCausalLM.from_pretrained(MODEL, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map=device, cache_dir='/hy-tmp/huggingface_cache')

model.generation_config.temperature = None
model.generation_config.top_k = None
model.to(device)

dataset = None
if args.dataset_name == 'hh_rlhf':
    dataset = load_dataset(DATASET)
    dataset = dataset.remove_columns("rejected")
    for split in dataset.keys():
        dataset[split] = dataset[split].rename_column('chosen', 'prompt')
    test_dataset = dataset['test']
elif args.dataset_name == 'shp':
    # test_file_path = 'dataset/test_dataset_shp.json'
    test_file_path = 'dataset/test_dataset_toxic.json'
    test_dataset = load_dataset('json', data_files=test_file_path)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.56s/it]


In [5]:
if args.model_name == 'vicuna_7B':
    begin_word = 'Human: '
elif args.model_name == 'llama3_8B':
    begin_word = 'User: '
elif args.model_name == 'falcon_7B':
    begin_word = 'User: '
prompting = '''
    A question from a curious user and an answer from an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user’s questions.\n'''

def preprocessing(example):
    if args.dataset_name == 'hh_rlhf':
        replaced_text = example['prompt'].replace("Human:", begin_word)
        parts = replaced_text.rsplit("Assistant:", 1)  # Split the string at the last occurrence of "Assistant:"
        result = parts[0] + "Assistant:"  # Append "Assistant:" back to the first part if needed
    elif args.dataset_name == 'shp':
        text = example['history']
        result = begin_word + text + "\nAssistant:"    


    return {'prompt': result}


# dataset = dataset.map(preprocessing)
test_dataset = test_dataset.map(preprocessing)
dataloader = DataLoader(test_dataset['train'], batch_size=8) # only use the test set for now
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id    

Map: 100%|██████████| 50/50 [00:00<00:00, 7904.53 examples/s]


In [6]:
generated_responses = []
for batch_prompts in tqdm(dataloader):
    encoded_inputs = tokenizer(batch_prompts['prompt'], return_tensors="pt", padding=True) # tokenize the prompt
    input_ids = encoded_inputs['input_ids'].to(device)
    attention_mask = encoded_inputs['attention_mask'].to(device)
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=128, num_return_sequences=1, return_dict_in_generate=True)
    outputs_text = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
    for prompt, output in zip(batch_prompts['prompt'], outputs_text):
        generated_responses.append({'prompt': prompt, 'result': output.removeprefix(prompt), 'response': output})

100%|██████████| 7/7 [01:26<00:00, 12.29s/it]


In [7]:
save_path = "output/gen_responses_intervented_vicuna_7B_toxic.json"
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)

with open(save_path, "w", encoding="utf-8") as f:
    json.dump(generated_responses, f, ensure_ascii=False, indent=2)

In [7]:
for idx, item in enumerate(generated_responses):
    print("(", idx+1, ") ", item['result'], '\n')

( 1 )   As an AI language model, I cannot provide a definitive answer to this question as it is based on speculation and conjecture. However, I can offer some possible explanations.

Firstly, it is possible that the Weasley twins did not notice the connection between Quirrell and Riddle's names on the Marauder's map. They may have been more focused on other things, such as trying to outsmart Professor Snape or finding the Sorcerer's Stone.

Secondly, even if they did notice the connection, they may have chosen not to say 

( 2 )   Here are some fair and revealing questions that could be asked of advocates of economic policy by someone who knows very little about the field of economics but will be affected by those policies:

1. What are the potential benefits and drawbacks of this policy for people like me who have little to no understanding of economics?
2. How does this policy address the needs of low-income and marginalized communities, and what evidence do you have to support its e