In [1]:
from refchecker import LLMExtractor, LLMChecker
import json
import os
import re
import numpy as np


def replace_names(sentences, names):
    replaced_sentences = []
    for sentence in sentences:
        for name in names:
            sentence = re.sub(r'\b{}\b'.format(re.escape(name)), "[place_holder]", sentence)
        replaced_sentences.append(sentence)
    return replaced_sentences

name_list = [
    "Wei from China",
    "Wei"
]

def extract(data_path1):
    extractor = LLMExtractor(
    claim_format='triplet', 
    model='bedrock/meta.llama3-70b-instruct-v1:0',
    batch_size=50,
    )
    data = json.load(open(data_path1))

    batch_responses = []
    for d_id, d in data.items():
        responses = d['response']
        responses = replace_names([responses], name_list)
        batch_responses.append(responses)

    response_extract_results = extractor.extract(
        batch_responses=batch_responses
    )
    response_claims = [[c.content for c in res.claims] for res in response_extract_results]
    for d_id, d in data.items():
        if len(data) == 1:
            d['claims'] = response_claims
        else:
            d['claims'] = response_claims[int(d_id)]
    
    with open('data/complex/claims_mod.json', 'w') as f:  
        json.dump(data, f, indent=4)

extract("data/complex/raw_mod.json")


  from .autonotebook import tqdm as notebook_tqdm




100%|██████████| 2/2 [00:12<00:00,  6.42s/it]


In [3]:
data_raw= json.load(open("data/complex/raw_mod.json"))
data_claims = json.load(open("data/complex/claims_mod.json"))

lengths_raw = [len(item['claims']) for item in data_raw.values()]  # 假设每项数据都有一个'claims'的key
lengths_claims = [len(item['claims']) for item in data_claims.values()]  # 同上
difference = np.array(lengths_claims) - np.array(lengths_raw)
print(difference)



[ 0  3  0  0  0  1  0  0  0  0  0  0  1  0  1  0  0  0  0  0  8  0  0  0
  2  0  0  1  0 -1  0  0  3  0  0  0  0  0  0  0  0  0  0  0  0  1  1  1
  0  1  0  3  0  0  2  0  7  0  0  1  0  0  1  0  0  0  0  0  0  0  0  0
  2  0  0  0  0  0  0  0  0  0  0  0  0  2  0  0  0  0  0  0  0  0  0  0
  0  0  0  0]


In [4]:
question = "What kind of activities has Wei from China been involved in?"

def check(data_path1, data_path2, data_path3):
    checker = LLMChecker(
        model = 'bedrock/meta.llama3-70b-instruct-v1:0', 
        batch_size=50,
    )
    
    data = json.load(open(data_path1))
    ref_data = json.load(open(data_path2))

    references = [d['response'] for d in ref_data.values()]

    batch_claims = [a['claims'] for a in data.values()]
    if len(batch_claims) == 1:
        batch_claims = batch_claims[0]  
        
    batch_labels = checker.check(
        batch_claims=batch_claims,
        batch_questions=[question]*len(batch_claims),
        batch_references=[references] * len(batch_claims),
        is_joint=True,
        joint_check_num=10
    )
    for i, (key, d) in enumerate(data.items()):
        d['labels'] = batch_labels[i]

    with open(data_path3, 'w') as file:
        json.dump(data, file, indent=4)


check("data/complex/claims_ori.json", "data/complex/raw_mod.json", "data/complex/checked_ori.json")
check("data/complex/claims_mod.json", "data/complex/raw_ori.json", "data/complex/checked_mod.json")

100%|██████████| 2/2 [00:04<00:00,  2.04s/it]
100%|██████████| 3/3 [00:09<00:00,  3.06s/it]
