In [7]:
import os
import json
import random
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import numpy as np
device = "cuda:1"

In [8]:
device = "cuda:1"
model_dir = '/home/tianxueyun/Llama-2-13b-hf'
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForCausalLM.from_pretrained(model_dir,torch_dtype=torch.float16).to(device)

model.eval()

Loading checkpoint shards: 100%|██████████| 3/3 [00:22<00:00,  7.63s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 5120)
    (layers): ModuleList(
      (0-39): 40 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (k_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (v_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_

contriever

In [11]:
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings

def get_sent_embeddings(sents, contriever, tok, BSZ=32):    
    all_embs = []
    for i in tqdm(range(0, len(sents), BSZ)):
        sent_batch = sents[i:i+BSZ]
        inputs = tok(sent_batch, padding=True, truncation=True, return_tensors='pt').to(device)
        with torch.no_grad():
            outputs = contriever(**inputs)
            embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
        all_embs.append(embeddings)
    all_embs = torch.vstack(all_embs)
    return all_embs

def retrieve_facts(query, fact_embs, contriever, tok, k=1):
    inputs = tok([query], padding=True, truncation=True, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = contriever(**inputs)
        query_emb = mean_pooling(outputs[0], inputs['attention_mask'])
    sim = (query_emb @ fact_embs.T)[0]
    knn = sim.topk(k, largest=True)
    return knn.indices

获取retrieved fact的embedding

In [12]:
with open('/home/tianxueyun/MQuAKE/datasets/MQuAKE-CF-3k.json', 'r') as f:
    dataset = json.load(f)
new_facts = set()
for d in dataset:
    for r in d["requested_rewrite"]:
        new_facts.add(f'{r["prompt"].format(r["subject"])} {r["target_new"]["str"]}')
new_facts = list(new_facts)

contriever = AutoModel.from_pretrained("/home/tianxueyun/MQuAKE/contriever").to(device)
tokenizer_con = AutoTokenizer.from_pretrained("/home/tianxueyun/MQuAKE/contriever")
embs = get_sent_embeddings(new_facts, contriever, tokenizer_con)

100%|██████████| 88/88 [00:15<00:00,  5.60it/s]


构建用于判断contradict的数据集

In [59]:
cf_dataset = []
test = set()
contradict_count = 0
not_contradict_count = 0
for d in tqdm(dataset):
    for qs in d['requested_rewrite']:
        record = []
        question = qs['question']
        len_old = len(test)
        test.add(question)
        if (len(test)>len_old) :
            gold_retrieve_fact = f'{qs["prompt"].format(qs["subject"])} {qs["target_new"]["str"]}'
            fact_ids = retrieve_facts(question, embs, contriever, tokenizer_con)
            retrieved_fact = new_facts[fact_ids[0]]
            ans_correct = qs["target_new"]["str"]
            gold_generate = f'{qs["prompt"].format(qs["subject"])} {qs["target_true"]["str"]}'
            #print('gold_retrieve_fact:',gold_retrieve_fact)
            #print('gold_generate:',gold_generate)
            
            #找ans_alias
            for hop in d['new_single_hops']:
                if question == hop['question']:
                    ans_alias = hop['answer_alias']
            #print(ans_alias)

            record.append({'question':question,'gold_retrieve_fact':gold_retrieve_fact,
                        'retrieved_fact':retrieved_fact,'gold_generate':gold_generate,'answer_correct':ans_correct,
                        'ans_alias':ans_alias,'contradict':'contradict'})
            cf_dataset.append(record)
            contradict_count+=1
        else:
            continue

for d in tqdm(dataset):
    for qs in d['single_hops']:
        record = []
        question = qs['question']
        contradict_len = len(test)
        test.add(question)
        if (len(test)>contradict_len):
            gold_retrieve_fact = qs['cloze']+" "+qs['answer']
            #gold_retrieve_fact = ""
            fact_ids = retrieve_facts(question, embs, contriever, tokenizer_con)
            retrieved_fact = new_facts[fact_ids[0]]
            ans_correct = qs['answer']
            ans_alias = qs['answer_alias']
            #print(ans_alias)
            gold_generate = qs['cloze']+" "+qs['answer']
            record.append({'question':question,'gold_retrieve_fact':gold_retrieve_fact,
                        'retrieved_fact':retrieved_fact,'gold_generate':gold_generate,'answer_correct':ans_correct,
                        'ans_alias':ans_alias,'contradict':'not contradict'})
            cf_dataset.append(record)
            not_contradict_count+=1
        else:
            continue
print(contradict_count)
print(not_contradict_count)

  0%|          | 0/3000 [00:00<?, ?it/s]

100%|██████████| 3000/3000 [01:06<00:00, 45.40it/s] 
100%|██████████| 3000/3000 [00:30<00:00, 97.87it/s] 

2785
1277





In [61]:
cf_dataset[-10:]

[[{'question': 'What position does Erik Spoelstra play?',
   'gold_retrieve_fact': 'Erik Spoelstra plays the position of point guard',
   'retrieved_fact': 'Rik Smits plays the position of punter',
   'gold_generate': 'Erik Spoelstra plays the position of point guard',
   'answer_correct': 'point guard',
   'ans_alias': [],
   'contradict': 'not contradict'}],
 [{'question': 'Which company is Windows Phone 8 produced by?',
   'gold_retrieve_fact': 'The company that produced Windows Phone 8 is Microsoft',
   'retrieved_fact': 'The company that produced Windows Phone 8.1 is Toyota',
   'gold_generate': 'The company that produced Windows Phone 8 is Microsoft',
   'answer_correct': 'Microsoft',
   'ans_alias': ['MS',
    'Micro-Soft',
    'Microsoft Corp.',
    'Microsoft Corporation',
    'MSFT',
    'MICROSOFT TECHNOLOGY LICENSING, LLC     (Redmond, WA)'],
   'contradict': 'not contradict'}],
 [{'question': 'Who performed Illmatic?',
   'gold_retrieve_fact': 'Illmatic was performed by Na

查看有多少real fact和retrieved fact不一样

In [53]:
count = 0
for i in cf_dataset:
    if i[0]['gold_retrieve_fact']==i[0]['retrieved_fact']:
        count+=1
print(count)
print(len(cf_dataset))
    

2718
4062


保存cf_dataset到json

In [54]:
json_data = json.dumps(cf_dataset)
with open('contradict.json', 'w') as file:
    file.write(json_data)

In [127]:
with open('contradict.json', 'r') as file:
    j = json.load(file)
len(j)

4062

测试判断contradict/ans/contradict+ans

In [38]:
with open('/home/tianxueyun/MQuAKE/prompts/contradict-prompt.txt', 'r') as f:
    contradict_prompt = f.read()

with open('/home/tianxueyun/MQuAKE/prompts/contradict-ans-prompt.txt', 'r') as f:
    contradict_ans_prompt = f.read()

with open('/home/tianxueyun/MQuAKE/prompts/ans-prompt.txt', 'r') as f:
    ans_prompt = f.read()

In [39]:
def get_ans(prompt):
    start = len(prompt)
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    output = model.generate(input_ids, max_length = input_ids.size()[1]+80,num_return_sequences=1)
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    #print(generated_text)
    rest = generated_text[start:]
    fa_index = rest.find('\n\nQuestion:') #找final_ans
    rf_index = rest.find('Retrieved fact:')
    if (fa_index > rf_index and rf_index!=-1 ) or fa_index == -1:
        index = rf_index
    else:
        index = fa_index

    generate_q_a = rest[:index]
    #print(generate_q_a)
    return generate_q_a

In [57]:
#读取contradict数据集
with open('contradict.json', 'r') as file:
    j = json.load(file)
cor_c = 0 
cor_n = 0 
tot = 0
err = []
for data in tqdm(j):
    tot+=1
    #print(data)
    question = data[0]['question']
    gold_retrieve_fact = data[0]['gold_retrieve_fact']
    retrieved_fact = data[0]['retrieved_fact']
    gold_generate = data[0]['gold_generate']
    answer_correct = data[0]['answer_correct']
    contradict = data[0]['contradict']
    ans_alias = data[0]['ans_alias']
    prompt = contradict_ans_prompt +'\n\nQuestion: '+ question + '\nGenerated answer: ' + gold_generate +'.\n'+ 'Retrieved fact: ' + gold_retrieve_fact + '.'
    #print(prompt)
    #gen = get_ans(prompt)
    #last_sent = gen.strip().split('\n')[-1]
    #print('last sent:',last_sent)

    #if last_sent.startswith('Generated answer:'):
    #    prompt = prompt + gen + 'Retrieved fact: ' + gold_retrieve_fact + '.' 
    #else:
    #    print(last_sent)
    #    continue
    
    gen = get_ans(prompt)
    last_sent = gen.strip().split('\n')[-1]
    if last_sent.startswith('Retrieved fact'):
        pos = last_sent.find('. The answer is: ')
        length = len('. The answer is: ')
        if len(last_sent[15:pos-21])==10: #判断为contradict
            if contradict == 'contradict' :
                ans = last_sent[pos+length:-1]
                if ans == answer_correct or ans in ans_alias:
                    cor_c+=1
            else:
                err.append(data)
                #print(data)
                print('c:',last_sent)
        else: #判断为not contradict
            if contradict == 'not contradict':
                ans = last_sent[pos+length:-1]
                if ans == answer_correct or ans in ans_alias:
                    cor_n+=1
                else:
                    print(ans)
                    print(last_sent)
                #print(ans)
            else:
                ans = last_sent[pos+length:-1]
                err.append(data)
    else:
        print('not_in:',gen)
print('total:',(cor_c+cor_n)/tot)
print('cor_c:',cor_c)
print('cor_n:',cor_n)

json_data = json.dumps(err)
with open('err_contradict_ans.json', 'w') as file:
    file.write(json_data)

 50%|█████     | 5/10 [00:37<00:37,  7.51s/it]

Judo
Retrieved fact does not contradict to generated answer. The answer is: Judo.


100%|██████████| 10/10 [01:12<00:00,  7.30s/it]

total: 0.9
cor_c: 0
cor_n: 9





In [46]:
j[-10:]

[[{'question': 'What position does Erik Spoelstra play?',
   'gold_retrieve_fact': 'Erik Spoelstra plays the position of point guard',
   'retrieved_fact': 'Rik Smits plays the position of punter',
   'gold_generate': 'Erik Spoelstra plays the position of point guard',
   'answer_correct': 'point guard',
   'contradict': 'not contradict'}],
 [{'question': 'Which company is Windows Phone 8 produced by?',
   'gold_retrieve_fact': 'The company that produced Windows Phone 8 is Microsoft',
   'retrieved_fact': 'The company that produced Windows Phone 8.1 is Toyota',
   'gold_generate': 'The company that produced Windows Phone 8 is Microsoft',
   'answer_correct': 'Microsoft',
   'contradict': 'not contradict'}],
 [{'question': 'Who performed Illmatic?',
   'gold_retrieve_fact': 'Illmatic was performed by Nas',
   'retrieved_fact': 'Chris Claremont is famous for Arthashastra',
   'gold_generate': 'Illmatic was performed by Nas',
   'answer_correct': 'Nas',
   'contradict': 'not contradict'}]

In [10]:
a = 'Retrieved fact does not contradicts to generated answer. The answer is: Crotia.'
len('Retrieved fact')
len('to generated answer.')
#last_sent[15:-22] #len =10
len(a[15:-22]) #len=19

length = len('. The answer is: ')
pos = a.find('. The answer is: ')
contradict = a[15:pos-21]
ans = a[pos+length:-1]
print(contradict)
print(ans)

does not contradict
Crotia
