In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

## GPT-J 文件路径
tokenizer = AutoTokenizer.from_pretrained('Instruct_gpt_J')
generator = AutoModelForCausalLM.from_pretrained("Instruct_gpt_J", max_length = 512, temperature=1, torch_dtype=torch.float16).cuda()

In [None]:
sentence_dic = defaultdict(list)

In [None]:
import json
from collections import defaultdict
import re
from tqdm import tqdm
import os
import string
import argparse

def data_helper(file_dir):
    idx = 0
    sentence_list = []
    
    ## label的顺序是固定的 
    ### CrossTest Dataset
    if 'ai' in file_dir:
        labels = ["Field", "Task", "Conference", "Misc", "Product", "Programlang", "Organisation", "Algorithm", "Researcher", "Metrics", "University", "Country", "Person", "Location"]
    elif 'literature' in file_dir:
        labels = ['Book', 'Writer', 'Award', 'Misc', 'Organisation', 'Person', 'Literarygenre', 'Poem', 'Event', 'Country', 'Location', 'Magazine']
    elif 'music' in file_dir:
        labels = ['Award', 'Album', 'Band', 'Musicalartist', 'Musicgenre', 'Organisation', 'Song', 'Location', 'Event', 'Misc', 'Country', 'Person', 'Musicalinstrument']
    elif 'politics' in file_dir:
        labels = ['Politicalparty', 'Election', 'Organisation', 'Politician', 'Event', 'Person', 'Location', 'Misc', 'Country']
    elif 'science' in file_dir:
        labels = ['Organisation', 'Scientist', 'Misc', 'Award', 'Astronomicalobject', 'Academicjournal', 'University', 'Chemicalcompound', 'Person', 'Location', 'Protein', 'Event', 'Enzyme', 'Discipline', 'Country', 'Chemicalelement', 'Theory']
    ### ConLL2003 including Typos0 and OOV
    # elif 'conll03' in file_dir:
    ### OntoNotes5
    
    prompt = "Extract the entities in the text to the entity type " 

    # Open the file in read mode
    with tqdm(total = os.path.getsize(file_dir)) as pbar:
        with open(file_dir, 'r') as f:
            for line in f.readlines():
                line = line.strip('\n') 
                if 'context' in line:
                    idx += 1 
                    line = str(line)
                    sentence = line.split(": " )
                    for label in labels:
                        context = sentence[1][:-1]
                        query = prompt + label + ":\n" + context
                        # print(query)
                        inputs = tokenizer(query, return_tensors='pt')
                        outputs = generator.generate(inputs.input_ids.cuda(), pad_token_id=tokenizer.eos_token_id)
                        outputs_str = (tokenizer.decode(outputs[0]))
                        result = (outputs_str.split("\n" )[2])
                        sentence_dic[context].append(result)
                pbar.update(len(line)) 
                pass
                
            print("number_sentence:", idx)
            return sentence_dic, labels

def get_entity(line, context):
    line = line.split(": " )[-1]
    # res = line.split(", ")
    if '\ ' in line:
        res = re.split('\|, |"', line)
    else:
        res = line.split(", ")

    out = []

    for i in res:
        tmp = str.maketrans({key: None for key in string.punctuation})
        j = i.translate(tmp)
        if j != '' and j.lower() in context.lower():
            out.append(j)
            
        res_entity = ','.join(out)
    return res_entity

def get_result(ground_file, sentence_dic, labels):
    with open(ground_file, 'r') as f:
        all_datas = []
        for line in f.readlines():
            line = line.strip('\n')
            if 'context' in line:
                line = str(line)
                sentence = line.split(": " )
                key = (sentence[1][:-1])
                # print(key)
                # assert key in sentence_dic
                results = list(sentence_dic[key])
                pos = {}
                # print((results))
                # assert len(labels) == len(results)
                for i in range (len(results)):
                    entity_result = get_entity(results[i], key)
                    if i < len(labels):
                        key_entity = str(labels[i]).lower()
                        pos[key_entity] = entity_result
                one_samp = {
                    "context": key,
                    "entity": pos}
                all_datas.append(one_samp)
        return(all_datas)


# dataset dir 需要给出测试路径
real_dir = 'dataset/politics.test'
print("get dataset!")




In [None]:
print("Generating results...")

sentence_dic, labels = data_helper(real_dir)

In [None]:
### get results
all_datas = get_result(real_dir, sentence_dic, labels)

In [None]:
### save resutls
# pred_dir需要根据测试的数据集文件改一下路径，如 xxx_result.test
pred_dir = "xxx_result.test"
with open(pred_dir, "w") as f:
    json.dump(all_datas, f, sort_keys=False, ensure_ascii=False, indent=2)
    

In [None]:
def metrics(real_dir, pred_dir):
    all_test_data = json.load(open(pred_dir, encoding="utf-8"))
    all_ground_data = json.load(open(real_dir, encoding="utf-8"))

    tp, fp, fn = 0, 0, 0
    print(len(all_ground_data))
    print(len(all_test_data))

    for idx in range (len(all_ground_data)):
        test_data = all_test_data[idx]
        ground_data = all_ground_data[idx]
        pred_idxLab = test_data["entity"]
        real_idxLab = ground_data["entity"]

        assert len(pred_idxLab) == len(real_idxLab)
        for key,value in real_idxLab.items():
            assert key in pred_idxLab
            if value != None and pred_idxLab[key] != '' and pred_idxLab[key] != ' ':
                real_value = str(value).split(',｜, | ,')
                pred_value = str(pred_idxLab[key]).split(',')

                for i in range(len(real_value)):   # 遍历list中的每一个值
                    real_value[i] = real_value[i].strip(' ').lower()
                    real_value[i] = real_value[i].lower()
                for j in range(len(pred_value)):   # 遍历list中的每一个值
                    pred_value[j] = pred_value[j].strip(' ')
                    pred_value[j] = pred_value[j].lower()

                tp += len(list(set(real_value).intersection(set(pred_value))))
                fn += len(list(set(real_value).difference(set(pred_value))))
                fp += len(list(set(pred_value).difference(set(real_value))))
                
            if value == None and pred_idxLab[key] != '' and pred_idxLab[key] != ' ':
                pred_value = str(pred_idxLab[key]).split(',')
                for j in range(len(pred_value)):   # 遍历list中的每一个值
                    pred_value[j] = pred_value[j].strip(' ')
                    pred_value[j] = pred_value[j].lower()
                fp += len(pred_value)
                
            if value != None and (pred_idxLab[key] == '' or pred_idxLab[key] == ' '):
                pred_value = str(pred_idxLab[key]).split(',')
                for j in range(len(pred_value)):   # 遍历list中的每一个值
                    pred_value[j] = pred_value[j].strip(' ')
                    pred_value[j] = pred_value[j].lower()
                fn += len(pred_value)

    return tp, fn, fp


print("evaluating...")
### evaluate results
tp, fn, fp = metrics(real_dir, pred_dir)
precision = tp/(tp+fp+1e-5) 
recall = tp/(tp+fn+1e-5)
f1 = 2*precision*recall/(precision+recall+1e-5)
print('| [ENTITY] precision: {0:3.4f}, recall: {1:3.4f}, f1: {2:3.4f}'.format(precision, recall, f1) + '\r')  

print('finished! 🎉')