In [1]:
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    MBartTokenizer,
    default_data_collator,
    AutoModelWithLMHead,
    set_seed
)

model = AutoModelWithLMHead.from_pretrained("t5-base")
tokenizer = AutoTokenizer.from_pretrained("t5-base")

if tokenizer.encode("<Temp_S> <Temp_E>") != [32000, 32001,1]:
    # For non-t5 tokenizer
    tokenizer.add_special_tokens(
        {"additional_special_tokens": ["<Temp_S>", "<Temp_E>", "<Relation_S>", "<Relation_E>", \
            "<ORG>", "<VEH>", "<WEA>", "<LOC>","<FAC>","<PER>","<GPE>"]})




# Test Relation type tree

In [2]:
def get_label_name_tree(label_name_list, tokenizer, end_symbol='<end>'):
    # Change recurring into non-recurring labels, 
    sub_token_tree = dict()

    # this is label_name token ids
    label_tree = dict()
    for typename in label_name_list:
        after_tokenized = tokenizer.encode(typename, add_special_tokens=False)
        label_tree[typename] = after_tokenized

    for _, sub_label_seq in label_tree.items():
        # sub_label_seq is the tokenize_ids of typename
        parent = sub_token_tree
        for value in sub_label_seq:
            if value not in parent:
                parent[value] = dict()
            parent = parent[value]
        parent[end_symbol] = None

    return sub_token_tree

In [3]:
label_name_list = ["PER-SOC", "ORG-AFF", "GEN-AFF", "ART", "PART-WHOLE", "PHYS"]
a = get_label_name_tree(label_name_list, tokenizer)

In [1]:
from extraction.predict_parser.predict_parser import Metric

In [2]:
def eval_pred_with_decoding(gold_list, pred_list, text_list=None, raw_list=None):

    relation_metric = Metric()

    relation_metric.count_instance(gold_list, pred_list,verbose= False)

    role_result = relation_metric.compute_f1(prefix='relation-')

    result = dict()
    result.update(role_result)
    return result


In [4]:
gold_list = [1,1,1,2,2,2,3,3]
pred_list = [1,2,1,1,2,2,3,3]
result = eval_pred_with_decoding(gold_list, pred_list)

In [5]:
print(result)

{'relation-tp': 8.0, 'relation-gold': 8.0, 'relation-pred': 8.0, 'relation-P': 100.0, 'relation-R': 100.0, 'relation-F1': 100.0}


In [14]:
a = tokenizer.decode( [0,2090,31999, 32098])
b = {"new":1, "start":2}
print(b)

{'new': 1, 'start': 2}


In [15]:
print(a)

<pad> director Internațional<extra_id_1>


In [2]:
list_ = ["<Temp_S>", "<Temp_E>", "<Relation_S>", "<Relation_E>", "<ORG>", "<VEH>", "<WEA>", "<LOC>","<FAC>","<PER>","<GPE>"]
list_code = tokenizer.encode(" ".join(list_))

In [3]:
print(list_code)

[32100, 32101, 32102, 32103, 32104, 32105, 32106, 32107, 32108, 32109, 32110, 1]
