In [26]:
from typing import *
import random

In [27]:
from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader

In [28]:
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader

In [29]:
from model.par_with_attr import ParWithAttr
from utils import uas_las, to_cuda

# Config

In [30]:
class CFG:
    # domains = ['BC', 'FIN', 'LEG', 'PB', 'PC', 'ZX']
    domains = ['BC']
    # plm = 'hfl/chinese-electra-180g-large-discriminator'
    plm = 'hfl/chinese-electra-180g-base-discriminator'
    random_seed = 42
    num_epochs = 15
    # batch_size = 64  # using 22g memory when using electra-large
    batch_size = 32
    lr = 2e-5
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 2
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 3
    max_length = 128
    hidden_size = 400
    num_labels = 21
    gamma = 0.7
    alpha = 0.7
    print_every = 400
    eval_every = 800
    cuda = True
    fp16 = True

# Count

In [31]:
domains = CFG.domains

tags, rels = [], []
for domain in domains:
    file = f'suda/train/{domain}-Train-full.conll'
    fp = open(file, encoding='utf-8')
    for line in fp.readlines():
        # another sentence
        if line == '\n':
            continue
        line_sp = line.split('\t')
        rels.append(line_sp[7])
        tags.append(line_sp[3])
    
    fp.close()
    break

In [32]:
rel_dct = {
    'root': '根节点',
    'sasubj-obj': '同主同宾',
    'sasubj': '同主语',
    'dfsubj': '不同主语',
    'subj': '主语',
    'subj-in': '内部主语',
    'obj': '宾语',
    'pred': '谓语',
    'att': '定语',
    'adv': '状语',
    'cmp': '补语',
    'coo': '并列',
    'pobj': '介宾',
    'iobj': '间宾',
    'de': '的',
    'adjct': '附加',
    'app': '称呼',
    'exp': '解释',
    'punc': '标点',
    'frag': '片段',
    'repet': '重复',
}

In [33]:
rel2id = {}
for i, (key, value) in enumerate(rel_dct.items()):
    rel2id[key] = i
print(rel2id)
print(len(rel2id))

{'root': 0, 'sasubj-obj': 1, 'sasubj': 2, 'dfsubj': 3, 'subj': 4, 'subj-in': 5, 'obj': 6, 'pred': 7, 'att': 8, 'adv': 9, 'cmp': 10, 'coo': 11, 'pobj': 12, 'iobj': 13, 'de': 14, 'adjct': 15, 'app': 16, 'exp': 17, 'punc': 18, 'frag': 19, 'repet': 20}
21


# Seed & Device

In [34]:
def seed_everything(seed=CFG.random_seed):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

seed_everything()

In [35]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')

Using device: cuda


# Data

In [36]:
class Dependency():
    def __init__(self, idx, word, tag, head, rel):
        self.id = idx
        self.word = word
        self.tag = tag
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.idx), self.word, "_", self.tag, "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.word}, {self.tag}, {self.head}, {self.rel})"

In [37]:
def load_codt(data_file: str):
    # id, form, tag, head, rel
#     sentence:List[Dependency] = [Dependency('0', '<root>', '_', '0', '_')]
    sentence:List[Dependency] = []

    with open(data_file, 'r', encoding='utf-8') as f:
        # data example: 1	上海	_	NR	NR	_	2	nn	_	_
        for line in f.readlines():
            toks = line.split()
            if len(toks) == 0:
                yield sentence
#                 sentence = [Dependency('0', '<root>', '_', '0', '_')]
                sentence = []
            elif len(toks) == 10:
                dep = Dependency(toks[0], toks[1], toks[3], toks[6], toks[7])
                sentence.append(dep)

In [38]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
CFG.tokenizer = tokenizer

In [39]:
attr_df = pd.read_csv('attribute.csv', sep=',')
attr_df.head()

Unnamed: 0,rel,rel_name,attribute,example
0,root,根节点,一个句子的根节点，是句子的核心事件的中心词（一般是谓词），如果没有核心，则第一个主要谓语作为...,“你好，我想咨询一下”中，“咨询”是这个句子的核心，而“你 好”只是一句问候语，因此从“咨询...
1,sasubj-obj,同主语同宾语,两个谓语（一般为两个动词并列）共享主语和宾语，允许主语完全省略。,“有什么问题我能帮你处理或解决呢？”中，“处理”和“解决”的主语都是“我”，宾语是“问题”，...
2,sasubj,同主语,两个谓语共享主语，但是都是不及物动词或具有不同的宾语，允许主语完全省略。,“我帮你问一下”，“帮”和“问”的主语都是我，但显然宾语不同，因此由第一个谓词“帮”发射“同...
3,dfsubj,不同主语,两个谓语的主语不同。,“你把订单号发一下，我查询一下”，“发”和“查询”就是由不同主语所发出的动作，因此，由“发”...
4,subj,主语,主语是谓语动作的发出者，或是谓语动词的承受对象，或谓语是对主语的状态或其他情况的描述。主语通...,“我买了洗衣液”，“我”是“买”这个动作的主体对象，因此由“买”发射一条“主语”弧到“我”。


In [40]:
attrs = [
    attr_df['rel_name'][i] + 
    '：' +
    attr_df['attribute'][i] 
    for i in range(len(attr_df))
]

attr_tokenized = tokenizer(attrs, return_offsets_mapping=False, padding=True, return_tensors='pt')

num_rels, attr_len = attr_tokenized['input_ids'].shape
print(attr_tokenized['input_ids'].shape)

torch.Size([21, 109])


In [41]:
# special_tokens_dict = {'additional_special_tokens': ['<root>']}
 
# num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
# print('We have added', num_added_toks, 'tokens')

In [42]:
class CODTDataset(Dataset):
    def __init__(self, cfg, train):
        self.train = train
        self.cfg = cfg
        self.inputs, self.offsets, self.heads, self.rels, self.masks, self.rel_attrs = self.read_data()
        
    def read_data(self):
        inputs, offsets = [], []
        tags, heads, rels, masks = [], [], [], []
        rel_attrs = []
        
        for domain in self.cfg.domains:
            if self.train:
                file = f'suda/train/{domain}-Train-full.conll'  
            else:
                file = f'suda/dev/{domain}-Dev.conll'
            
            for deps in tqdm(load_codt(file)):
                # another sentence
                seq_len = len(deps)
                
                word_lst = [] 
                rel_attr = {'input_ids':torch.Tensor(), 'token_type_ids':torch.Tensor(), 'attention_mask':torch.Tensor()}
                head_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)  # same as root index is 0, constrainting by mask 
                rel_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
                mask_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
                for i, dep in enumerate(deps):
                    if i == seq_len or i + 1== self.cfg.max_length:
                        break
                        
                    word_lst.append(dep.word)
                      
                    if dep.head in ['_', '-1'] or int(dep.head) + 1 >= self.cfg.max_length:
                        head_tokens[i+1] = 0
                        mask_tokens[i+1] = 0
                    else:
                        head_tokens[i+1] = int(dep.head)
                        mask_tokens[i+1] = 1

                    rel_tokens[i+1] = rel2id.get(dep.rel, rel2id['adjct'])
                
                    rel_attr = {
                        'input_ids': torch.cat([rel_attr['input_ids'], attr_tokenized['input_ids'][rel2id.get(dep.rel, rel2id['adjct'])].unsqueeze(0)], dim=0),
                        'token_type_ids': torch.cat([rel_attr['token_type_ids'], attr_tokenized['token_type_ids'][rel2id.get(dep.rel, rel2id['adjct'])].unsqueeze(0)], dim=0),
                        'attention_mask': torch.cat([rel_attr['attention_mask'], attr_tokenized['attention_mask'][rel2id.get(dep.rel, rel2id['adjct'])].unsqueeze(0)], dim=0)
                    }
                    
                
                # padding
                rel_attr = {
                    'input_ids': torch.cat([rel_attr['input_ids'], torch.full((self.cfg.max_length - len(deps), attr_len), 0, dtype=torch.long)], dim=0).long(),
                    'token_type_ids': torch.cat([rel_attr['token_type_ids'], torch.full((self.cfg.max_length - len(deps), attr_len), 0, dtype=torch.int)], dim=0).int(),
                    'attention_mask': torch.cat([rel_attr['attention_mask'], torch.full((self.cfg.max_length - len(deps), attr_len), 0, dtype=torch.int)], dim=0).int()
                }
                rel_attrs.append(rel_attr)
                
                tokenized = tokenizer.encode_plus(word_lst, 
                                                  padding='max_length', 
                                                  truncation=True,
                                                  max_length=self.cfg.max_length, 
                                                  return_offsets_mapping=True, 
                                                  return_tensors='pt',
                                                  is_split_into_words=True)
                inputs.append({"input_ids": tokenized['input_ids'][0],
                              "token_type_ids": tokenized['token_type_ids'][0],
                               "attention_mask": tokenized['attention_mask'][0]
                              })
                
#                 sentence_word_idx = np.zeros(self.cfg.max_length, dtype=np.int64)
                sentence_word_idx = []
                for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                    if start == 0 and end != 0:
                        sentence_word_idx.append(idx)
#                         sentence_word_idx[idx] = idx
                if len(sentence_word_idx) < self.cfg.max_length - 1:
                    sentence_word_idx.extend([0]* (self.cfg.max_length - 1 - len(sentence_word_idx)))
                offsets.append(torch.as_tensor(sentence_word_idx))
                
                heads.append(head_tokens)
                rels.append(rel_tokens)
                masks.append(mask_tokens)
                    
        return inputs, offsets, heads, rels, masks, rel_attrs

    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.heads[idx], self.rels[idx], self.masks[idx], self.rel_attrs[idx]
    
    def __len__(self):
        return len(self.rels)

In [43]:
# train_dataset = CODTDataset(CFG, train=True)
val_dataset = CODTDataset(CFG, train=False)

974it [00:01, 500.35it/s]


# Model

In [44]:
model = ParWithAttr(CFG, attr_tokenized)

# model.load_state_dict(torch.load('results/electra_bc_epoch15.pt'))
print(model.load_state_dict(torch.load('results/base_par.pt')))
model = model.cuda()

Some weights of the model checkpoint at hfl/chinese-electra-180g-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at hfl/chinese-electra-180g-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight'

<All keys matched successfully>


In [48]:
va_dataloader = DataLoader(val_dataset, batch_size=128)

arc_logits, rel_logits = torch.Tensor(), torch.Tensor()
heads_whole, rels_whole, masks_whole = torch.Tensor(), torch.Tensor(), torch.Tensor()
for batch in va_dataloader:
    inputs, offsets, heads, rels, masks, rel_attrs = batch
    
    # rel_attrs_concat = {}
    # for key, value in rel_attrs.items():
    #     rel_attrs_concat[key] = value.view(-1, value.size()[-1])
    # rel_attrs = rel_attrs_concat
    
    inputs_cuda = {}
    for key, value in inputs.items():
        inputs_cuda[key] = value.cuda()
    inputs = inputs_cuda

    offsets, heads, rels, masks = to_cuda(data=(offsets, heads, rels, masks))
    # rels = rels.cuda()
    # offsets = offsets.cuda()
    
    with torch.no_grad():
        model.eval()
        arc_logit, rel_logit, similarity = model.predict(inputs, offsets, masks)
        # arc_logit, rel_logit, loss = model(inputs, offsets, heads, rels, masks, evaluate=True)
        
    arc_logits = torch.cat([arc_logits, arc_logit.cpu()])
    rel_logits = torch.cat([rel_logits, similarity.cpu()])
    
    heads_whole = torch.cat([heads_whole, heads.cpu()])
    rels_whole = torch.cat([rels_whole, rels.cpu()])
    masks_whole = torch.cat([masks_whole, masks.cpu()])

In [46]:
rel_logits.shape

torch.Size([974, 128, 21])

In [47]:
uas_las(arc_logits=arc_logits,
        rel_logits=rel_logits,
        arc_gt=heads_whole,  
        rel_gt=rels_whole,
        mask=masks_whole)

{'UAS': 0.8881118881118881, 'LAS': 0.8590803136257682}

In [49]:
uas_las(arc_logits=arc_logits,
        rel_logits=rel_logits,
        arc_gt=heads_whole,  
        rel_gt=rels_whole,
        mask=masks_whole)

{'UAS': 0.8881118881118881, 'LAS': 0.8370417461326553}