In [1]:
import sys
sys.path.append('../')
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import warnings
warnings.filterwarnings('ignore')
from data.pipe import BartNERPipe,Bart_RE_NER_Pipe
from model.bart import BartSeq2SeqModel
import fitlog
import random
import numpy as np

import torch
from fastNLP import Trainer
from model.metrics import Seq2SeqSpanMetric,Seq2SeqREMetric
from model.losses import Seq2SeqLoss
from torch import optim
from fastNLP import BucketSampler, GradientClipCallback, cache_results

from model.callbacks import WarmupCallback
from fastNLP.core.sampler import SortedSampler
from model.generater import SequenceGeneratorModel
from fastNLP.core.sampler import  ConstTokenNumSampler
from model.callbacks import FitlogCallback

fitlog.debug()
fitlog.set_log_dir('logs')

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', default='re_ner_ace05', type=str)

def set_seed(seed=1996):
    print("[SET SEED]: ",seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

args= parser.parse_args([])
dataset_name = args.dataset_name
args.length_penalty = 1
args.save_model = 1

# word: 生成word的start; bpe: 生成所有的bpe; span: 每一段按照start end生成; span_bpe: 每一段都是start的所有bpe，end的所有bpe
args.target_type = 'word'
args.bart_name = '/disk1/wxl/Desktop/DeepKE/example/ner/huggingface/bart-large'
args.schedule = 'linear'
args.decoder_type = 'avg_feature'
args.n_epochs = 30
args.num_beams = 1
args.batch_size = 16
args.use_encoder_mlp = 1
args.lr = 1e-5
args.warmup_ratio = 0.01
eval_start_epoch = 1

# the following hyper-parameters are for target_type=word
if dataset_name == 'conll2003':  # three runs get 93.18/93.18/93.36 F1
    max_len, max_len_a = 10, 0.6
elif dataset_name == 'en-ontonotes':  # three runs get 90.46/90.4/90/52 F1
    max_len, max_len_a = 10, 0.8
elif dataset_name == 'CADEC':
    max_len, max_len_a = 10, 1.6
    args.num_beams = 4
    args.lr = 2e-5
    args.n_epochs = 30
    eval_start_epoch=10
elif dataset_name == 'Share_2013':
    max_len, max_len_a = 10, 0.6
    args.use_encoder_mlp = 0
    args.num_beams = 4
    args.lr = 2e-5
    eval_start_epoch = 5
elif dataset_name == 'Share_2014':
    max_len, max_len_a = 10, 0.6
    args.num_beams = 4
    eval_start_epoch = 5
    args.n_epochs = 30
elif dataset_name == 'genia':  # three runs: 79.29/79.13/78.75
    max_len, max_len_a = 10, 0.5
    args.target_type = 'span'
    args.lr = 2e-5
    args.warmup_ratio = 0.01
elif dataset_name == 'en_ace04':  # four runs: 86.84/86.33/87/87.17
    max_len, max_len_a = 50, 1.1
    args.n_epochs = 55
    args.batch_size = 48
    args.lr = 4e-5
    seed = 4373
elif 're' in dataset_name:
    max_len, max_len_a = 10, 1.6
    args.num_beams = 4
    args.lr = 2e-5
    args.batch_size = 20
    args.n_epochs = 100
    seed = 1688
    eval_start_epoch=1
    rel_type_start = 10
elif dataset_name == 're_ace05':
    max_len, max_len_a = 10, 1.6
    args.num_beams = 4
    args.lr = 2e-5
    args.batch_size = 80
    args.n_epochs = 100
    seed = 1571
    eval_start_epoch=1
    rel_type_start = 9
elif dataset_name == 'en_ace05':  # three runs: 85.39/84.54/84.75
    max_len, max_len_a = 50, 0.7
    args.lr = 3e-5
    args.batch_size = 12
    args.num_beams = 4
    args.warmup_ratio = 0.1

set_seed(seed)
# with open("/disk1/wxl/Desktop/DeepKE/example/baseline/BARTNER/loss_log/D.json","r") as f:
#     b=f.readlines()
# with open("/disk1/wxl/Desktop/DeepKE/example/baseline/BARTNER/loss_log/E.json","r") as f:
#     c=f.readlines()
# for x,y in zip(b,c):
#     assert x==y,print(x,y)
# exit()

save_model = args.save_model
del args.save_model
lr = args.lr
n_epochs = args.n_epochs
batch_size = args.batch_size
num_beams = args.num_beams

length_penalty = args.length_penalty
if isinstance(args.decoder_type, str) and args.decoder_type.lower() == 'none':
    args.decoder_type = None
decoder_type = args.decoder_type
target_type = args.target_type
bart_name = args.bart_name
schedule = args.schedule
use_encoder_mlp = args.use_encoder_mlp

fitlog.add_hyper(args)

#######hyper
#######hyper

demo = False
if demo:
    cache_fn = f"caches/data_{bart_name}_{dataset_name}_{target_type}_demo.pt"
else:
    cache_fn = f"caches/data_{bart_name}_{dataset_name}_{target_type}.pt"

@cache_results(cache_fn, _refresh=False)
def get_data():
    if 're' in dataset_name:
        pipe = Bart_RE_NER_Pipe(tokenizer=bart_name, dataset_name=dataset_name, target_type=target_type, no_ent_type=False)

    if dataset_name == 'conll2003':
        paths = {'test': "./data/conll2003/test.txt",
                 'train': "./data/conll2003/train.txt",
                 'dev': "./data/conll2003/dev.txt"}
        data_bundle = pipe.process_from_file(paths, demo=demo)
    elif dataset_name == 'en-ontonotes':
        paths = './data/en-ontonotes/english'
        data_bundle = pipe.process_from_file(paths)
    else:
        print(f'./data/{dataset_name}')
        data_bundle = pipe.process_from_file(f'./data/{dataset_name}', demo=demo)
    return pipe, data_bundle, pipe.tokenizer, pipe.mapping2id
print(f'max_len_a:{max_len_a}, max_len:{max_len}')

pybuilddir.txt
pybuilddir.txt
/usr/share/zoneinfo/UTC
/usr/lib/ssl/certs/ca-certificates.crt
[SET SEED]:  1688
max_len_a:1.6, max_len:10


In [2]:
pipe, data_bundle, tokenizer, mapping2id = get_data()
ds = data_bundle.get_dataset("test")


Read cache from caches/data_/disk1/wxl/Desktop/DeepKE/example/ner/huggingface/bart-large_re_ner_ace05_word.pt.


In [2]:
for i in data_bundle.get_dataset("train"):
    tgt_seq_len = i["tgt_seq_len"]
    src_seq_len = i["src_seq_len"]
    assert tgt_seq_len < 10 + 1.3*src_seq_len, f"{tgt_seq_len},{src_seq_len},{i}"
for i in data_bundle.get_dataset("dev"):
    tgt_seq_len = i["tgt_seq_len"]
    src_seq_len = i["src_seq_len"]
    assert tgt_seq_len < 10 + 1.3*src_seq_len
for i in data_bundle.get_dataset("test"):
    tgt_seq_len = i["tgt_seq_len"]
    src_seq_len = i["src_seq_len"]
    assert tgt_seq_len < 10 + 1.2*src_seq_len, f"{tgt_seq_len},{src_seq_len},{10 + 1.3*src_seq_len},{i}"


NameError: name 'data_bundle' is not defined

In [10]:
from fastNLP import DataSet,DataSetIter
from fastNLP import SequentialSampler
sampler = SequentialSampler()
ds =  data_bundle.get_dataset("test")
batch = DataSetIter(batch_size=80, dataset=ds, sampler=sampler)
batch2 = DataSetIter(batch_size=2, dataset=ds, sampler=sampler)
# model2.eval()

In [11]:
ds[6]["re_tgt_tokens"]

[3,
 20,
 6,
 19,
 5,
 10,
 21,
 22,
 6,
 28,
 29,
 30,
 31,
 32,
 5,
 10,
 28,
 29,
 30,
 31,
 32,
 5,
 35,
 36,
 5,
 15,
 1]

In [9]:
from model.utils import get_span_from_pred, get_RE_from_pred,get_ent_tgt_tokens
from itertools import chain
max_type_id = len(pipe.mapping2targetid) + 2
label_ids = list(mapping2id.values())
if 're' in dataset_name:
    metric = Seq2SeqREMetric(1, num_labels=len(label_ids), rel_type_start=rel_type_start, target_type=target_type)
else:
    metric = Seq2SeqSpanMetric(1, num_labels=len(label_ids), target_type=target_type)
model2 = torch.load("/disk1/wxl/Desktop/DeepKE/example/baseline/BARTNER/save_models/re_ace05_1688_1734576024.9018843/best_SequenceGeneratorModel_f_2024-12-19-10-40-29-474678").to('cuda')
model2.eval()
max_type_id = len(pipe.mapping2targetid) + 2
pred_num = 0
all_ent_num = 0
false_num = 0
true_num = 0
false_pred_num = 0
false_list = []
true_list = [] 
false_pred_list = []
for id,i in enumerate(batch):
    #print(id)
    src_tokens=i[0]["src_tokens"]
    src_seq_len=i[0]["src_seq_len"]
    first=i[0]["first"]
    tgt_tokens = i[1]["tgt_tokens"].to('cuda')
    pred = model2.predict(src_tokens.to('cuda'), src_seq_len.to('cuda'), first.to('cuda'))['pred']
    pred_spans = [get_RE_from_pred(pred[id], max_type_id, 9) for id in range(pred.shape[0])]
    
    pred_spans = [set([str(y) for y in pred_spans[x]]) for x in range(pred.shape[0])]


    #target_spans1 = [set([str(list(i[1]["target_span"][idxx][idx*3]) + list(i[1]["target_span"][idxx][idx*3+1]) + list(i[1]["target_span"][idxx][idx*3+2]) ) for idx in range(len(i[1]["target_span"][idxx]))]) for idxx in range(pred.shape[0])]
    
    target_spans = []
    for idx_1 in range(pred.shape[0]):
        one_batch_tgt = i[1]["target_span"][idx_1]
        str_list = []
        for idx_2 in (range(len(one_batch_tgt) // 3)):
            rel_tgt = list(one_batch_tgt[idx_2*3]) + list(one_batch_tgt[idx_2*3+1]) +list(one_batch_tgt[idx_2*3+2])
            rel_tgt.sort()
            rel_tgt_str = str(rel_tgt)
            str_list.append(rel_tgt_str)
        target_spans.append(set(str_list))
    false_spans = [target_spans[id] - pred_spans[id] for id in range(pred.shape[0])] # 没预测出来的span
    # print(false_spans)
    true_spans = [target_spans[id] & pred_spans[id] for id in range(pred.shape[0])] # 预测出来的span
    # x = [list(ss) for ss in pred_spans]
    # y = [list(ss) for ss in target_spans]
    # for xx in x:
    #     xx.sort()
    # for yy in y:
    #     yy.sort()
    # print("pred_spans: ",x, "\ntarget_spans",y,"\n")
    # print("pred_spans: ",pred_spans, "\ntarget_spans",target_spans,"\n")
    # if id > 10:
    #     break
    false_pred_spans = [pred_spans[id] - target_spans[id] for id in range(pred.shape[0])]
    # print("false_pred_spans: ",false_pred_spans, "\nfalse_spans",false_spans,"\n")
    all_ent_num += sum(len(i[1]["target_span"][id]) for id in range(pred.shape[0]))
    false_num += sum([len(fs) for fs in false_spans])
    true_num += sum([len(ts) for ts in true_spans])
    false_pred_num += sum(len(fps) for fps in false_pred_spans)
    false_list += false_spans
    true_list += true_spans
    false_pred_list += false_pred_spans
    res = metric.evaluate(i[1]["target_span"], pred, tgt_tokens)
    # print("pred: ",pred, "\ntarget: ",tgt_tokens,"\n")
print(metric.get_metric())
print(true_num, " ", false_num, " ", false_pred_num)
print(all_ent_num)



metric类型: [RE]  rel_type_start: 9!

正确预测个数： 694  错误预测个数： 364  未被预测的正确实体个数： 457
{'f': 62.83, 'rec': 60.3, 'pre': 65.60000000000001, 'em': 0.78}
0   1151   1058
3453


In [10]:
tokenizer.convert_tokens_to_ids(['relation','entity','input'])
#print(tokenizer.convert_ids_to_tokens([2]))
# # print("=================")

# label_ids = list(mapping2id.values())
# metric = Seq2SeqSpanMetric(1, num_labels=len(label_ids), target_type=target_type)
# model2.eval()
# max_type_id = len(pipe.mapping2targetid) + 2
# pred_num = 0
# all_ent_num = 0
# false_num = 0
# true_num = 0
# false_pred_num = 0
# false_list = []
# true_list = [] 
# false_pred_list = []
# for id,i in enumerate(batch2):
#     src_tokens=i[0]["src_tokens"]
#     src_seq_len=i[0]["src_seq_len"]
#     first=i[0]["first"]
#     tgt_tokens = i[1]["tgt_tokens"].to('cuda')
#     pred = model2.predict(src_tokens.to('cuda'), src_seq_len.to('cuda'), first.to('cuda'))['pred']
#     print(pred)
#     pred_spans = [get_span_from_pred(pred[id], max_type_id) for id in range(pred.shape[0])]
#     pred_spans = [set([str(y) for y in pred_spans[x]]) for x in range(pred.shape[0])]

#     target_spans = [set([str(list(span)) for span in i[1]["target_span"][id]]) for id in range(pred.shape[0])]
    
#     false_spans = [target_spans[id] - pred_spans[id] for id in range(pred.shape[0])] # 没预测出来的span
#     #print(false_spans)
#     true_spans = [target_spans[id] & pred_spans[id] for id in range(pred.shape[0])] # 没预测出来的span
#     x = [list(ss) for ss in pred_spans]
#     y = [list(ss) for ss in target_spans]
#     for xx in x:
#         xx.sort()
#     for yy in y:
#         yy.sort()
#     print("pred_spans: ",pred_spans, "\ntarget_spans",target_spans,"\n")
#     if id > 1:
#         break
#     false_pred_spans = [pred_spans[id] - target_spans[id] for id in range(pred.shape[0])]
#     pred_num += sum([len(x[i]) for i in range(pred.shape[0])])
#     all_ent_num += sum(len(i[1]["target_span"][id]) for id in range(pred.shape[0]))
#     false_num += sum([len(fs) for fs in false_spans])
#     true_num += sum([len(ts) for ts in true_spans])
#     false_pred_num += sum(len(fps) for fps in false_pred_spans)
#     false_list += false_spans
#     true_list += true_spans
#     false_pred_list += false_pred_spans
#     res = metric.evaluate(i[1]["target_span"], pred, tgt_tokens)
#     #print("pred: ",pred, "\ntarget: ",tgt_tokens,"\n")
# print(metric.get_metric())
# print(true_num, " ", false_num, " ", false_pred_num)
# print(all_ent_num)

[47114, 46317, 46797]