In [1]:
import os
os.chdir('../MQuAKE/')
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

import json
import random
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel

#### Set up OpenAI API

In [4]:
import openai

openai.api_key = os.getenv("OPENAI_API_KEY")

def call_gpt(cur_prompt, stop):
    ans = openai.Completion.create(
                model="text-davinci-003",
                max_tokens=256,
                stop=stop,
                prompt=cur_prompt,
                temperature=0)
    returned = ans['choices'][0]['text']
    return returned

ModuleNotFoundError: No module named 'openai'

In [2]:
from transformers import AutoModelForCausalLM
gptj_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", device_map='auto')
gptj_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")



In [3]:
#gptj_model.cuda()
# for nm, par in gptj_model.named_parameters():
#     print(nm, par.shape, par.device)

In [4]:
from transformers import StoppingCriteria

class StopByWords(StoppingCriteria):
    # Stopping-words criteria
    def __init__(self, stop_words, tokenizer):
        StoppingCriteria.__init__(self)
        # NOTE stop words应在stirng层面而非token ids层面判断
        self.tokenizer = tokenizer
        self.stop_words = stop_words
        # self.stop_word_ids = []
        # for sw in set(stop_words):
        #     sw_ids = tokenizer.encode(sw)
        #     self.stop_word_ids.append(sw_ids)

    def __call__(self, input_ids, scores=None):
        # NOTE 提高decoding性能, 限制suffix在最后10个token内
        suffix_str = self.tokenizer.decode(input_ids[0][-10:]).strip()
        for sw in self.stop_words:
            if suffix_str.endswith(sw):
                return True
        return False
        # for sw_ids in self.stop_word_ids:
        #     if input_ids[0][-len(sw_ids):].tolist() == sw_ids:
        #         return True

class gptj_interface:
    # Wrapper for GPT-J generation, including customization of the generation config & stopping-words
    def __init__(self, gptj_tokenizer, gptj_model, stop_words, gen_config={}):
        self.tokenizer = gptj_tokenizer
        self.model = gptj_model
        self.cfg = gen_config
        self.stopper = StopByWords(stop_words, gptj_tokenizer)

    def call_gptj_local(self, cur_prompt):
        # <1, prompt_len>
        inputs_dict = self.tokenizer(cur_prompt, return_tensors='pt').to(self.model.device)
        output_dict = self.model.generate(**inputs_dict,
                                        return_dict_in_generate=True,
                                        #output_attentions=True,
                                        repetition_penalty = self.cfg.get('repetition_penalty', 0.),
                                        temperature = self.cfg.get('temperature', 1.),
                                        top_k = self.cfg.get('top_k', 1),
                                        top_p = self.cfg.get('top_p', 1.),
                                        max_new_tokens = self.cfg.get('max_new_tokens', 32),
                                        do_sample = self.cfg.get('do_sample', True),
                                        pad_token_id = 50256,
                                        stopping_criteria = [self.stopper,],
                                        num_return_sequences = 5,)
        topk_output_ids = output_dict.sequences
        #print(output_dict.keys(), topk_output_ids.shape)
        output_sents = []
        for output_ids in topk_output_ids:
            output_sent = self.tokenizer.decode(output_ids)
            output_sents.append(output_sent)
        return output_sents

In [58]:
clue_config = {
    'repetition_penalty' : 1.05,
    'temperature' : 0.3,
    'top_k' : 5,
    'top_p' : 0.85,
    'max_new_tokens' : 32,
    'do_sample' : True,
}

tmp_stoppers = ['.']

gptj_api = gptj_interface(gptj_tokenizer, gptj_model, tmp_stoppers, clue_config)
gptj_api

<__main__.gptj_interface at 0x7fe3c91364c0>

In [21]:
output_sents = gptj_api.call_gptj_local('Ellie Kemper is a citizen of')
for sent in output_sents:
    print(sent, '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n')

Ellie Kemper is a citizen of the world.

The actress, who plays Carol on The Office, has lived in London, Paris, New York and Los Angeles, and she’s currently living in Los Angeles with her husband, actor Paul F. Tompkins. But when it comes to her career, Kemper has never been afraid 

Ellie Kemper is a citizen of the world. She’s lived in London, Paris and New York City. She’s traveled to more than 40 countries on five continents. She’s been to the top of Mount Kilimanjaro, the bottom of the Grand Canyon, and the middle of the Sahara Desert.

She 

Ellie Kemper is a citizen of the world.

The actress, best known for her role as Annie Edison on The Office, has lived in London, Los Angeles and New York City, and she’s currently based in Austin, Texas. But she’s never been to Paris.

Until now.

Kemper 

Ellie Kemper is a citizen of the world.

The actress, who plays Carol on The Office, has been to more than 60 countries and has lived in New York, Los Angeles, London, Paris, and Berlin.

In [60]:
sent

'Ellie Kemper is a citizen of the world.'

In [36]:
gptj_tokenizer.tokenize('旻旻旻')

['æĹ', '»', 'æĹ', '»', 'æĹ', '»']

#### Functions for retrieval models (Contriever)

In [5]:
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings

def get_sent_embeddings(sents, contriever, tok, BSZ=32):    
    all_embs = []
    for i in tqdm(range(0, len(sents), BSZ)):
        sent_batch = sents[i:i+BSZ]
        inputs = tok(sent_batch, padding=True, truncation=True, return_tensors='pt').to("cuda")
        with torch.no_grad():
            outputs = contriever(**inputs)
            embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
        all_embs.append(embeddings.cpu())
    all_embs = torch.vstack(all_embs)
    return all_embs

def retrieve_facts(query, fact_embs, contriever, tok, k=1):
    inputs = tok([query], padding=True, truncation=True, return_tensors='pt').to("cuda")
    with torch.no_grad():
        outputs = contriever(**inputs)
        query_emb = mean_pooling(outputs[0], inputs['attention_mask']).cpu()
    sim = (query_emb @ fact_embs.T)[0]
    knn = sim.topk(k, largest=True)
    return knn.indices

#### Load dataset

In [6]:
with open('datasets/MQuAKE-CF-3k.json', 'r') as f:
    dataset = json.load(f)

#### Build a memory index which contains all the edits

In [7]:
new_facts = set()
for d in dataset:
    for r in d["requested_rewrite"]:
        new_facts.add(f'{r["prompt"].format(r["subject"])} {r["target_new"]["str"]}')
new_facts = list(new_facts)

In [11]:
contriever = AutoModel.from_pretrained("facebook/contriever-msmarco").cuda()
# contriever = AutoModel.from_pretrained("facebook/contriever-msmarco", device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("facebook/contriever-msmarco")

In [12]:
embs = get_sent_embeddings(new_facts, contriever, tokenizer)

100%|█████████████████████████████████████████████████████████████| 88/88 [00:02<00:00, 36.48it/s]


In [14]:
# Run test for retrieval index
fact_ids = retrieve_facts("Who is the president of the US?", embs, contriever, tokenizer)
print(new_facts[fact_ids[0]])

The name of the current head of state in United States of America is Norodom Sihamoni


#### Run MeLLo

In [15]:
# read prompts
with open('prompts/MeLLo-prompt.txt', 'r') as f:
    task_prompt = f.read()
mquake_stop = ["Retrieved fact:", "Question:"]

In [16]:
clue_config = {
    'repetition_penalty' : 1.05,
    'temperature' : 0.3,
    'top_k' : 5,
    'top_p' : 0.85,
    'max_new_tokens' : 64,
    'do_sample' : True,
}

gptj_api = gptj_interface(gptj_tokenizer, gptj_model, mquake_stop, clue_config)
gptj_api

<__main__.gptj_interface at 0x7f1592c94790>

In [27]:
# Run MeLLo on the first T (T=10) examples
T = 10

#cor = 0
#tot = 0
cor, tot = 339, 2829

fout = open('trial.log', 'a')

for d in tqdm(dataset[tot:]):
    tot += 1
    for q in d["questions"]:
        found_ans = False
        prompt = task_prompt + "\n\nQustion: " + q
        print('======================================\n[Question]', q, file=fout)
        for i in range(4):
            # prompt the model to generate a subquestion and a tentative answer
            #gen = call_gpt(prompt, mquake_stop)
            llm_output = gptj_api.call_gptj_local(prompt)
            # 直接选择top-1
            gen = llm_output[0]
            print('\n--------~~~~~~~~--------\n', gen[len(task_prompt)+2 : ], end='\n------------------------\n', file=fout, flush=True)
            
            # if final answer is there, get the answer and exit
            # NOTE GPTJ不会结束生成, 因此将下一个Question的生成也作为finalize触发条件
            last_sent, prev_sent = gen.strip().split('\n')[-1], gen.strip().split('\n')[-3]
            if last_sent.startswith('Final answer: '):
                ans = last_sent[len("Final answer: "):]
                found_ans = True
            if last_sent.startswith('Question:'):
                assert(prev_sent.startswith('Final answer: '))
                ans = prev_sent[len("Final answer: "):]
                found_ans = True
            if found_ans:
                print('[Found Answer]', ans, file=fout)
                break
            
            # otherwise, extract the generated subquestion
            if len(gen.strip().split('\n')) < 2:
                print('[Generation Error] Only one line', file=fout)
                break # failed case

            # NOTE StoppingCriteria会保留stop words, 此处更新逻辑以跳过最后Retrieved fact行
            subquestion = gen.strip().split('\n')[-3]
            if not subquestion.startswith('Subquestion: '):
                print('[Subquestion Prefix Error]', subquestion, file=fout)
                break # failed case
            subquestion = subquestion[len("Subquestion: "):]
            
            # retrieve an edited fact using the generated subquestion
            fact_ids = retrieve_facts(subquestion, embs, contriever, tokenizer)
            fact_sent = new_facts[fact_ids[0]]
            
            # put the retrieved fact at the end of the prompt, the model self-checks if it contradicts
            #prompt = prompt + gen + 'Retrieved fact: ' + fact_sent + '.'
            # NOTE transformers的generate结果会保留input, 此处fix prompt更新逻辑
            # 此外, 也移除额外生成的retrieved fact
            prompt = gen.strip()[:-len('\nRetrieved fact:')]
            prompt += '\nRetrieved fact: ' + fact_sent + '.'
                    
        if not found_ans:
            continue
        # if the answer is correct
        if ans == d["new_answer"] or ans in d["new_answer_alias"]:
            cor += 1
            break
    
    print('Running acc = {} / {} = {}'.format(cor, tot, cor/tot), file=fout)

print(f'Multi-hop acc = {cor / tot} ({cor} / {tot})', file=fout)
fout.close()

 64%|████████████████████████████████▋                  | 919/1436 [72:57:12<40:47:27, 284.04s/it]

In [None]:
# E2E unlearning
e2e_prompt = d['questions'][0]

# Ad-hoc unlearning
edit_statements = []
edit_facts = d['requested_rewrite']

In [31]:
cor, tot

(339, 2830)

In [1]:
import sys
import os
#os.chdir('../MQuAKE/')
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

import ray
ray.init(num_cpus=24)

from vllm import LLM, SamplingParams
import torch

2024-04-14 21:57:55,573	INFO worker.py:1752 -- Started a local Ray instance.


In [2]:
llm = LLM(
    model = "EleutherAI/gpt-j-6B",
    tokenizer = "EleutherAI/gpt-j-6B",
    # dtype = 'float32',
    # Use all GPUs
    tensor_parallel_size = torch.cuda.device_count(),
    # tensor_parallel_size = 1,
    gpu_memory_utilization = 0.8,
)

2024-04-14 21:57:59,901	INFO worker.py:1585 -- Calling ray.init() again after it has already been called.


INFO 04-14 21:58:00 llm_engine.py:74] Initializing an LLM engine (v0.4.0.post1) with config: model='EleutherAI/gpt-j-6B', tokenizer='EleutherAI/gpt-j-6B', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=2, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 04-14 21:58:07 selector.py:16] Using FlashAttention backend.
[36m(RayWorkerVllm pid=231378)[0m INFO 04-14 21:58:08 selector.py:16] Using FlashAttention backend.
INFO 04-14 21:58:08 pynccl_utils.py:45] vLLM is using nccl==2.18.1
[36m(RayWorkerVllm pid=231378)[0m INFO 04-14 21:58:08 pynccl_utils.py:45] vLLM is using nccl==2.18.1
INFO 04-14 21:58:12 weight_utils.py:177] Using model weights format ['*.bin']
[36m(RayWorkerVllm pid=231378)[0m INFO 04-14 21:58:12 weight_utils.py:177] Using model weights format ['*.bin']
INFO 04-

In [143]:
sampling_config = SamplingParams(
    # repetition_penalty = gen_config.get('repetition_penalty', 0.),
    temperature = 0.,
    best_of = 3,
    # n = gen_config.get('top_k', 1),
    # top_k = gen_config.get('top_k', 1),
    # top_p = gen_config.get('top_p', 1.),
    max_tokens = 64,
    min_tokens = 24,
    stop = ["Retrieved fact:", "\n\n"],
    include_stop_str_in_output = True,
    use_beam_search = True,
)

In [3]:
sampling_config = SamplingParams(
    repetition_penalty = 1.05,
    temperature = 0.3,
    top_k = 5,
    top_p = 0.85,
    max_tokens = 64,
    use_beam_search = False,
    stop = ["Retrieved fact:", "\n\n"],
    include_stop_str_in_output = True,
    best_of = 3,
)

[36m(RayWorkerVllm pid=231378)[0m INFO 04-14 21:58:44 model_runner.py:867] Graph capturing finished in 6 secs.


In [12]:
lo = llm.generate('Question: Who\'s the US president?', sampling_config)

Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.19s/it]


In [13]:
lo[0].outputs

[CompletionOutput(index=1, text=' A. The US President\nB- The President of the USA C-The US President (Presidential)\nC+The president is a political officer of USA, which holds a position as an officer in the executive department. D+He has a position in the cabinet, as the highest executive official, who is', token_ids=[317, 13, 383, 1294, 1992, 198, 33, 12, 383, 1992, 286, 262, 4916, 327, 12, 464, 1294, 1992, 357, 10364, 498, 8, 198, 34, 10, 464, 1893, 318, 257, 1964, 3818, 286, 4916, 11, 543, 6622, 257, 2292, 355, 281, 3818, 287, 262, 4640, 5011, 13, 360, 10, 1544, 468, 257, 2292, 287, 262, 13447, 11, 355, 262, 4511, 4640, 1743, 11, 508, 318], cumulative_logprob=-103.1183865070343, logprobs=None, finish_reason=length, stop_reason=None),
 CompletionOutput(index=0, text='\nAnswer 1: Barack ObamanObama, a/l. (B) (Barac, Barack), né Barack\nObama Sr., born 1961, a.l./n.l. of a.l.. His wife, a./o\nSuey-Chi Obama, né Dunham', token_ids=[198, 33706, 352, 25, 8732, 1835, 10546, 15948, 11, 25