In [1]:
from typing import *
import random
import json

In [2]:
from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader

In [3]:
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader

In [4]:
import sys
sys.path.append('..')

In [5]:
from model.par_with_attr import ParWithAttr
from model.base_par import DepParser
from utils import uas_las, to_cuda

# Config

In [6]:
class CFG:
    # domains = ['BC', 'FIN', 'LEG', 'PB', 'PC', 'ZX']
    domains = ['BC']
    # plm = 'hfl/chinese-electra-180g-large-discriminator'
    plm = 'hfl/chinese-electra-180g-base-discriminator'
    data_file = '../data_testset/1to800_1108.json'
    random_seed = 42
    num_epochs = 15
    batch_size = 128
    lr = 2e-5
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 2
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 3
    max_length = 160
    hidden_size = 400
    num_labels = 35
    gamma = 0.7
    alpha = 0.7
    print_every = 400
    eval_every = 800
    cuda = True
    fp16 = True

# Count

In [7]:
domains = CFG.domains

tags, rels = [], []
for domain in domains:
    # file = f'suda/train/{domain}-Train-full.conll'
    file = '../aug/codt/codt_train.conll'
    fp = open(file, encoding='utf-8')
    for line in fp.readlines():
        # another sentence
        if line == '\n':
            continue
        line_sp = line.split('\t')
        rels.append(line_sp[7])
        tags.append(line_sp[3])
    
    fp.close()
    break

In [8]:
rel_dct = {
    'root': '根节点',
    'sasubj-obj': '同主同宾',
    'sasubj': '同主语',
    'dfsubj': '不同主语',
    'subj': '主语',
    'subj-in': '内部主语',
    'obj': '宾语',
    'pred': '谓语',
    'att': '定语',
    'adv': '状语',
    'cmp': '补语',
    'coo': '并列',
    'pobj': '介宾',
    'iobj': '间宾',
    'de': '的',
    'adjct': '附加',
    'app': '称呼',
    'exp': '解释',
    'punc': '标点',
    'frag': '片段',
    'repet': '重复',
    # rst
    'attr': '归属',
    'bckg': '背景',
    'cause': '因果',
    'comp': '比较',
    'cond': '状况',
    'cont': '对比',
    'elbr': '阐述',
    'enbm': '目的',
    'eval': '评价',
    'expl': '解释-例证',
    'joint': '联合',
    'manner': '方式',
    'rstm': '重申',
    'temp': '时序',
    # 'tp-chg': '主题变更',
    # 'prob-sol': '问题-解决',
    # 'qst-ans': '疑问-回答',
    # 'stm-rsp': '陈述-回应',
    # 'req-proc': '需求-处理',
}

In [9]:
rel2id = {}
for i, (key, value) in enumerate(rel_dct.items()):
    rel2id[key] = i
print(rel2id)
print(len(rel2id))

{'root': 0, 'sasubj-obj': 1, 'sasubj': 2, 'dfsubj': 3, 'subj': 4, 'subj-in': 5, 'obj': 6, 'pred': 7, 'att': 8, 'adv': 9, 'cmp': 10, 'coo': 11, 'pobj': 12, 'iobj': 13, 'de': 14, 'adjct': 15, 'app': 16, 'exp': 17, 'punc': 18, 'frag': 19, 'repet': 20, 'attr': 21, 'bckg': 22, 'cause': 23, 'comp': 24, 'cond': 25, 'cont': 26, 'elbr': 27, 'enbm': 28, 'eval': 29, 'expl': 30, 'joint': 31, 'manner': 32, 'rstm': 33, 'temp': 34}
35


# Seed & Device

In [10]:
def seed_everything(seed=CFG.random_seed):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

seed_everything()

In [11]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')

Using device: cuda


# Data

In [12]:
class Dependency():
    def __init__(self, idx, word, head, rel):
        self.id = idx
        self.word = word
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.idx), self.word, "_", "_", "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.word}, {self.head}, {self.rel})"

In [13]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
CFG.tokenizer = tokenizer

In [14]:
attr_df = pd.read_csv('../attribute.csv', sep=',')
extra = pd.read_csv('../extra_attribute.csv', sep=',')

attr_df = pd.concat([attr_df, extra], axis=0).reset_index()
# attr_df = pd.concat([extra], axis=0).reset_index()
attr_df.head()

Unnamed: 0,index,rel,rel_name,attribute,example
0,0,root,根节点,一个句子的根节点，是句子的核心事件的中心词（一般是谓词），如果没有核心，则第一个主要谓语作为...,“你好，我想咨询一下”中，“咨询”是这个句子的核心，而“你 好”只是一句问候语，因此从“咨询...
1,1,sasubj-obj,同主语同宾语,两个谓语（一般为两个动词并列）共享主语和宾语，允许主语完全省略。,“有什么问题我能帮你处理或解决呢？”中，“处理”和“解决”的主语都是“我”，宾语是“问题”，...
2,2,sasubj,同主语,两个谓语共享主语，但是都是不及物动词或具有不同的宾语，允许主语完全省略。,“我帮你问一下”，“帮”和“问”的主语都是我，但显然宾语不同，因此由第一个谓词“帮”发射“同...
3,3,dfsubj,不同主语,两个谓语的主语不同。,“你把订单号发一下，我查询一下”，“发”和“查询”就是由不同主语所发出的动作，因此，由“发”...
4,4,subj,主语,主语是谓语动作的发出者，或是谓语动词的承受对象，或谓语是对主语的状态或其他情况的描述。主语通...,“我买了洗衣液”，“我”是“买”这个动作的主体对象，因此由“买”发射一条“主语”弧到“我”。


In [15]:
attrs = [
    attr_df['rel_name'][i] + 
    '：' +
    attr_df['attribute'][i] +
    ' 例子：' + 
    attr_df['example'][i]
    for i in range(len(attr_df))
]


# attrs[27:28] = [
#     attrs[i] +
#     attr_df['attribute'][2] +
#     attr_df['attribute'][3]
#     for i in range(27, 28)
# ]

# print(attrs)

attr_tokenized = tokenizer(attrs, return_offsets_mapping=False, padding=True, return_tensors='pt')

num_rels, attr_len = attr_tokenized['input_ids'].shape
print(attr_tokenized['input_ids'].shape)

torch.Size([40, 168])


In [16]:
def load_annoted(data_file):
    with open(CFG.data_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    sample_lst:List[List[Dependency]] = []
    
    for d in data:
        rel_dct = {}
        for tripple in d['relationship']:
            head, rel, tail = tripple
            head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
            tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
            if head_uttr_idx != tail_uttr_idx:
                continue
            
            if not rel_dct.get(head_uttr_idx, None):
                rel_dct[head_uttr_idx] = {tail_word_idx: [head_word_idx, rel]}
            else:
                rel_dct[head_uttr_idx][tail_word_idx] = [head_word_idx, rel]
            
        for item in d['dialog']:
            turn = item['turn']
            utterance = item['utterance']
            # dep_lst:List[Dependency] = [Dependency(0, '[root]', -1, '_')]
            dep_lst:List[Dependency] = []
            
            for word_idx, word in enumerate(utterance.split(' ')):
                head_word_idx, rel = rel_dct[turn].get(word_idx + 1, [word_idx, 'adjct'])  # some word annoted missed, padded with last word and 'adjct'
                dep_lst.append(Dependency(word_idx + 1, word, head_word_idx, rel))  # start from 1
            
            sample_lst.append(dep_lst)
        
    return sample_lst

In [17]:
class DialogDataset(Dataset):
    def __init__(self, cfg):
        self.cfg = cfg
        self.inputs, self.offsets, self.heads, self.rels, self.masks = self.read_data()
        
    def read_data(self):
        inputs, offsets = [], []
        tags, heads, rels, masks = [], [], [], []
        
        for deps in load_annoted(self.cfg.data_file):
            seq_len = len(deps)

            word_lst = [] 
            head_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)  # same as root index is 0, constrainting by mask 
            rel_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            mask_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            for i, dep in enumerate(deps):
                if i == seq_len or i + 1== self.cfg.max_length:
                    break

                word_lst.append(dep.word)

                if dep.head == -1 or dep.head + 1 >= self.cfg.max_length:
                # if dep.head == -1 or dep.head + 1 >= self.cfg.max_length:
                    head_tokens[i+1] = 0
                    mask_tokens[i+1] = 0
                else:
                    head_tokens[i+1] = int(dep.head)
                    mask_tokens[i+1] = 1

                rel_tokens[i+1] = rel2id.get(dep.rel, rel2id['adjct'])

            tokenized = tokenizer.encode_plus(word_lst, 
                                              padding='max_length', 
                                              truncation=True,
                                              max_length=self.cfg.max_length, 
                                              return_offsets_mapping=True, 
                                              return_tensors='pt',
                                              is_split_into_words=True)
            inputs.append({"input_ids": tokenized['input_ids'][0],
                           "token_type_ids": tokenized['token_type_ids'][0],
                           "attention_mask": tokenized['attention_mask'][0]
                          })

            sentence_word_idx = []
            for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                if start == 0 and end != 0:
                    sentence_word_idx.append(idx)

            if len(sentence_word_idx) < self.cfg.max_length - 1:
                sentence_word_idx.extend([0]* (self.cfg.max_length - 1 - len(sentence_word_idx)))
            offsets.append(torch.as_tensor(sentence_word_idx))

            heads.append(head_tokens)
            rels.append(rel_tokens)
            masks.append(mask_tokens)
                    
        return inputs, offsets, heads, rels, masks

    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.heads[idx], self.rels[idx], self.masks[idx]
    
    def __len__(self):
        return len(self.rels)

In [18]:
val_dataset = DialogDataset(CFG)

# Model

In [19]:
# model = ParWithAttr(CFG, attr_tokenized)
model = DepParser(CFG)
# model.load_state_dict(torch.load('results/electra_bc_epoch15.pt'))
print(model.load_state_dict(torch.load('../results/base_par_codt_3642.pt')))
model = model.cuda()

Some weights of the model checkpoint at hfl/chinese-electra-180g-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>


In [20]:
va_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size)

arc_logits, rel_logits, similarities = torch.Tensor(), torch.Tensor(), torch.Tensor()
heads_whole, rels_whole, masks_whole = torch.Tensor(), torch.Tensor(), torch.Tensor()
for batch in va_dataloader:
    inputs, offsets, heads, rels, masks = batch

    inputs_cuda = {}
    for key, value in inputs.items():
        inputs_cuda[key] = value.cuda()
    inputs = inputs_cuda

    offsets, heads, rels, masks = to_cuda(data=(offsets, heads, rels, masks))
    
    with torch.no_grad():
        model.eval()
        # arc_logit, rel_logit, similarity = model.predict(inputs, offsets, masks)
        arc_logit, rel_logit = model(inputs, offsets, heads, rels, masks, evaluate=True)
        
    arc_logit[:, torch.arange(arc_logit.size()[1]), torch.arange(arc_logit.size()[2])] = -1e4
    
    arc_logits = torch.cat([arc_logits, arc_logit.cpu()])
    rel_logits = torch.cat([rel_logits, rel_logit.cpu()])
    # similarities = torch.cat([similarities, similarity.cpu()])
    
    heads_whole = torch.cat([heads_whole, heads.cpu()])
    rels_whole = torch.cat([rels_whole, rels.cpu()])
    masks_whole = torch.cat([masks_whole, masks.cpu()])

In [46]:
CFG.num_labels = 21
model = DepParser(CFG)
# model.load_state_dict(torch.load('results/electra_bc_epoch15.pt'))
print(model.load_state_dict(torch.load('../results/base_par.pt')))
model = model.cuda()

Some weights of the model checkpoint at hfl/chinese-electra-180g-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>


In [47]:
va_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size)

arc_logits_another, rel_logits_another = torch.Tensor(), torch.Tensor()
for batch in va_dataloader:
    inputs, offsets, heads, rels, masks = batch

    inputs_cuda = {}
    for key, value in inputs.items():
        inputs_cuda[key] = value.cuda()
    inputs = inputs_cuda

    offsets, heads, rels, masks = to_cuda(data=(offsets, heads, rels, masks))
    
    with torch.no_grad():
        model.eval()
        # arc_logit, rel_logit, similarity = model.predict(inputs, offsets, masks)
        arc_logit, rel_logit = model(inputs, offsets, heads, rels, masks, evaluate=True)
        
    arc_logit[:, torch.arange(arc_logit.size()[1]), torch.arange(arc_logit.size()[2])] = -1e4
    
    arc_logits_another = torch.cat([arc_logits_another, arc_logit.cpu()])
    rel_logits_another = torch.cat([rel_logits_another, rel_logit.cpu()])
    # similarities = torch.cat([similarities, similarity.cpu()])
    
    # heads_whole = torch.cat([heads_whole, heads.cpu()])
    # rels_whole = torch.cat([rels_whole, rels.cpu()])
    # masks_whole = torch.cat([masks_whole, masks.cpu()])

: 

In [23]:
model = None

In [24]:
rel_logits_another.shape

torch.Size([20086, 160, 21])

In [25]:
# arc_logits /=2
# arc_logits += arc_logits_another/2

# rel_logits /=2
# rel_logits[:, :, :21]  += rel_logits_another/2

In [26]:
uas_las(arc_logits=arc_logits_another,
        rel_logits=rel_logits_another,
        arc_gt=heads_whole,
        rel_gt=rels_whole,
        mask=masks_whole)

{'UAS': 0.8501551560687515, 'LAS': 0.788198717533519}

In [27]:
arc_logits_tmp = arc_logits.clone()
rel_logits_tmp = rel_logits.clone()

In [45]:
arc_logits = arc_logits_tmp.clone()
rel_logits = rel_logits_tmp.clone()

In [29]:
# rel_logits= rel_logits.softmax(-1)
# rel_logits_another = torch.cat([rel_logits_another, torch.full((rel_logits_another.size(0), CFG.max_length, 14), fill_value=-1e3)], dim=-1).softmax(-1)

In [30]:
# rel_logits= rel_logits
# rel_logits_another = torch.cat([rel_logits_another, torch.full((rel_logits_another.size(0), CFG.max_length, 14), fill_value=-1e3)], dim=-1)

In [31]:
arc_logits = arc_logits.softmax(-1)
arc_logits_another = arc_logits_another.softmax(-1)

rel_logits = rel_logits.softmax(-1)
rel_logits_another = rel_logits_another.softmax(-1)

In [32]:
# arc_logits *= 0.8
# arc_logits += 0.2 * arc_logits_another
# rel_logits[:, :, :21] *= 0.8
# rel_logits[:, :, :21] += 0.2 * rel_logits_another[:, :, :21]

In [33]:
# arc_logits *= 0.4
# arc_logits += torch.where((rel_logits_another.argmax(-1) < 21).unsqueeze(-1).expand(-1, -1, arc_logits.size(-1)), 0.6 * arc_logits_another, 0.6 * arc_logits)

# rel_logits[:, :, :21] *= 0.4
# rel_logits[:, :, :21] += torch.where((rel_logits_another.argmax(-1) < 21).unsqueeze(-1).expand(-1, -1, rel_logits_another.size(-1)), 0.6 * rel_logits_another, 0.6 * rel_logits[:, :, :21])

In [34]:
# arc_logits *= 0.7
# arc_logits += torch.where(arc_logits > arc_logits_another, 0.3 * arc_logits, 0.3 * arc_logits_another)

# rel_logits[:, :, :21] *= 0.7
# rel_logits[:, :, :21] += torch.where(rel_logits[:, :, :21] > rel_logits_another, 0.3 * rel_logits[:, :, :21], 0.3 * rel_logits_another)


arc_logits = torch.where(arc_logits > arc_logits_another, arc_logits, arc_logits_another)
rel_logits[:, :, :21] = torch.where(rel_logits[:, :, :21] > rel_logits_another, rel_logits[:, :, :21], rel_logits_another)

In [35]:
uas_las(arc_logits=arc_logits,
        rel_logits=rel_logits,
        arc_gt=heads_whole,
        rel_gt=rels_whole,
        mask=masks_whole)

{'UAS': 0.8505791119302357, 'LAS': 0.7979320375200938}

In [36]:
# uas_las(arc_logits=arc_logits,
#         rel_logits=similarities,
#         arc_gt=heads_whole,
#         rel_gt=rels_whole,
#         mask=masks_whole)

In [37]:
# rel_preds = similarities.argmax(-1)
# # rel_preds.masked_fill_(rel_preds == 2, 27)
# # rel_preds.masked_fill_(rel_preds == 3, 27)

# arc_logits_correct = (arc_logits.argmax(-1) == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
# rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

# print(rel_logits_correct.sum() / (rels_whole >= 21).long().sum())
# print(arc_logits_correct.sum() / (rels_whole >= 21).long().sum())

In [38]:
rel_preds = rel_logits.argmax(-1)
head_pred = arc_logits.argmax(-1)

arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

print(rel_logits_correct.sum() / (rels_whole >= 21).long().sum())
print(arc_logits_correct.sum() / (rels_whole >= 21).long().sum())
print('---------------------------------------------------')
arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole < 21).long()
rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct 

print(rel_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())
print(arc_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())

print('===================================================')

rel_preds.masked_fill_(rel_preds == 2, 27)
rel_preds.masked_fill_(rel_preds == 3, 27)
# rel_preds.masked_fill_(rel_preds == 6, 27)
# rel_preds.masked_fill_(rel_preds == 12, 27)
# rel_preds.masked_fill_(rel_preds != 27, 27)

arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

print(rel_logits_correct.sum() / (rels_whole >= 21).long().sum())
print(arc_logits_correct.sum() / (rels_whole >= 21).long().sum())
print('---------------------------------------------------')

arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole < 21).long()
rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct 

print(rel_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())
print(arc_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())

tensor(0.1702)
tensor(0.5043)
---------------------------------------------------
tensor(0.8373)
tensor(0.8723)
tensor(0.3417)
tensor(0.5043)
---------------------------------------------------
tensor(0.8273)
tensor(0.8723)


In [39]:
rel_preds = rel_logits.argmax(-1)
head_preds = arc_logits.argmax(-1)

# rel_preds.masked_fill_(rel_preds == 2, 27)
# rel_preds.masked_fill_(rel_preds == 3, 27)

for i in range(21, 35):
    arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole == i).long()
    rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

    print(rel_logits_correct.sum() / (rels_whole == i).long().sum())
    # print(arc_logits_correct.sum() / (rels_whole == i).long().sum())

tensor(0.)
tensor(0.)
tensor(0.0314)
tensor(0.0067)
tensor(0.0095)
tensor(0.1069)
tensor(0.2627)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.0344)
tensor(0.)
tensor(0.0047)
tensor(0.0156)


In [40]:
from constant import rel2id, punct_lst, weak_signals, weak_labels

priority = {
    rel2id['cont']: 1,
    rel2id['temp']: 2, rel2id['cause']: 2, rel2id['bckg']: 2, rel2id['comp']: 2,
    rel2id['joint']: 3, rel2id['attr']: 3,
    None: 5,  # 4 is default
}

reverse_by_words = [
    '你 好', '您 好',
]

id2rel = list(rel2id.keys())

origin4change = [rel2id[x] for x in ['root', 'dfsubj', 'sasubj']]

signal_dct = {}
for i, signals in enumerate(weak_signals):
    for s in signals:
        signal_dct[s] = weak_labels[i]
print(signal_dct)

{'说': 21, '表示': 21, '看': 21, '显示': 21, '知道': 21, '认为': 21, '希望': 21, '指出': 21, '如果': 25, '假如': 25, '的话': 25, '若': 25, '如': 25, '因为': 23, '所以': 23, '导致': 23, '因此': 23, '造成': 23, '由于': 23, '因而': 23, '但是': 26, '可是': 26, '但': 26, '竟': 26, '却': 26, '不过': 26, '居然': 26, '而是': 26, '以及': 31, '也': 31, '并': 31, '并且': 31, '又': 31, '或者': 31, '对于': 22, '自从': 22, '之前': 22, '上次': 22, '明天': 34, '晚上': 34, '到时候': 34, '再': 34, '然后': 34, '接下来': 34, '最后': 34, '随后': 34, '为了': 28, '想要': 28, '使': 28, '为 的': 28, '通过': 32, '必须': 32, '点击': 32, '对 吗': 33, '是 吗': 33, '对 吧': 33, '是 吧': 33, '对 ?': 33, '更': 24, '比': 24, '只': 24, '解释': 30, '比如': 30, '例如': 30, '是 这样': 30, '理想': 29, '真 棒': 29, '太 棒': 29, '真差': 29, '太 差': 29, '不 行': 29, '扯皮': 29}


In [41]:
rel_preds = rel_logits.argmax(-1)
head_preds = arc_logits.argmax(-1)

max_len = CFG.max_length

signals_new_whole = torch.Tensor()
heads_new_whole, rels_new_whole = torch.Tensor(), torch.Tensor()
for sample_idx, deps in tqdm(enumerate(load_annoted(CFG.data_file))):
    seq_len = len(deps)
    if seq_len == 0:
        continue
    
    signals = torch.zeros(max_len).int()
    heads, rels = torch.full(size=(max_len,), fill_value=-2).int(), torch.zeros(max_len).int()
    split, splits, signal, word_lst  = 1, [], None, ['root']
    for i, dep in enumerate(deps[:-1]):
        if i + 2 >= max_len:
            break
        
        word = dep.word
        word_lst.append(word)
        if word in signal_dct.keys() and priority.get(signal_dct[word], 4) < priority.get(signal, 4):
            signal = signal_dct[word]
        if f'{word} {deps[i+1].word}' in signal_dct.keys() and priority.get(signal_dct[f'{word} {deps[i+1].word}'], 4) < priority.get(signal, 4):
            signal = signal_dct[f'{word} {deps[i+1].word}']
        
        if word in punct_lst and deps[i+1].word not in punct_lst:
            if signal is not None and i + 2 - split > 2:  # set 2 to the min length of edu
                signals[split:i+2] = signal
                signal = None
            splits.append(split)
            split = i + 2

        # if masks_whole[sample_idx, i+1] == 0:
        #     heads[i+1] = -1
        #     rels[i+1] = -1
        # else:
        #     heads[i+1] = head_preds[sample_idx, i+1]
        #     rels[i+1] = rel_preds[sample_idx, i+1]

    # add the last data
    if i + 1 < max_len:
        word_lst.append(word)

        # if masks_whole[sample_idx, i+1] == 0:
        #     heads[i+1] = -1
        #     rels[i+1] = -1
        # else:
        #     heads[i+1] = head_preds[sample_idx, i+1]
        #     rels[i+1] = rel_preds[sample_idx, i+1]

    heads = head_preds[sample_idx]
    heads.masked_fill_(mask=~masks_whole[sample_idx].bool(), value=-2)

    rels = rel_preds[sample_idx]
    rels.masked_fill_(mask=~masks_whole[sample_idx].bool(), value=-2)

    if split > 1:
        splits.append(split)
        if signal is not None:
            signals[split:i+2] = signal
    
    splits.append(len(deps))
    # when num of 'edu' >= 2, try change rel and head
    # if len(splits) > 2:
    cnt = -1
    for idx, head in enumerate(heads[1:]):
        if head == -2:
            break
        if head == -1:
            continue

        if len(splits) > 2 and idx + 1 >= splits[cnt+1]:
            cnt += 1

        if ((len(splits) > 2 and (head < splits[cnt] or head >= splits[cnt+1])) or idx - head > 0) and rels[idx + 1] in origin4change:  # cross 'edu'
            if signals[idx+1] != 0:
                rels[idx+1] = signals[idx+1]
                
                if head == 0 and rels[idx + 1] in [rel2id['cond']]:  # reverse
                    tmp_heads = heads.clone()
                    tmp_heads[:splits[cnt+1]] = 0
                    head_idx = [idx + 1]
                    tail_idx = (tmp_heads == idx + 1).nonzero()  # find tail index
                    if len(tail_idx) == 0:  # ring or fail
                        # if 'cont', revsere root; else unchange
                        tail_idx = (heads == idx + 1).nonzero() if signals[idx+1] == rel2id['cont'] else [idx + 1]
                        head_idx = (heads == idx + 1).nonzero() if head_idx == tail_idx else head_idx
                    if len(head_idx) != 0:
                        heads[tail_idx[0]] = 0
                        heads[head_idx[0]] = tail_idx[0]
            elif head != 0:  # default
                rels[idx + 1] = rel2id['elbr']

            # special cases
            if len(splits) > 2 and word_lst[idx+1] == '好' and word_lst[idx] in ['你', '您']:  # reverse
                tmp_heads = heads.clone()
                tmp_heads[:splits[cnt+1]] = 0
                tail_idx = (tmp_heads == idx + 1).nonzero()  # find tail index
                if len(tail_idx) != 0:  
                    heads[tail_idx[0]] = 0
                    heads[idx + 1] = tail_idx[0]
                    rels[idx + 1] = rel2id['elbr']

            # 'attr' label should be reversed again; below can match most of cases
            if splits[cnt] == 1 and rels[idx + 1] in [rel2id['attr']]:
                tmp_heads = heads.clone()
                tmp_heads[:splits[cnt+1]] = 0
                tail_idx = ((tmp_heads != 0) * (tmp_heads >= splits[cnt]) * (tmp_heads < splits[cnt + 1])).nonzero().flatten()
                if len(tail_idx) != 0 and rels[tail_idx[0]] in origin4change:
                    heads[tail_idx[0]] = idx + 1
                    rels[tail_idx[0]] = rels[idx + 1]
                else:
                    dep_idx = heads[idx + 1].item()
                    if dep_idx != 0 and (dep_idx < splits[cnt] or dep_idx > splits[cnt+1]):
                        heads[idx + 1] = 0
                        heads[dep_idx] = idx + 1
                        rels[dep_idx] = rels[idx + 1]

    rels.masked_fill_(heads == 0, 0)  # root
    heads[0] = 0
    heads[1:].masked_fill_(heads[1:] == -2, 0)

    heads_new_whole = torch.cat([heads_new_whole, heads.unsqueeze(0)])
    rels_new_whole = torch.cat([rels_new_whole, rels.unsqueeze(0)])
    signals_new_whole = torch.cat([signals_new_whole, signals.unsqueeze(0)])

20086it [00:45, 436.89it/s]


In [42]:
arc_logits_correct = (heads_new_whole == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
rel_logits_correct = (rels_new_whole == rels_whole).long() * arc_logits_correct

print(rel_logits_correct.sum() / (rels_whole >= 21).long().sum())
print(arc_logits_correct.sum() / (rels_whole >= 21).long().sum())
print('---------------------------------------------------')

arc_logits_correct = (heads_new_whole == heads_whole).long() * masks_whole * (rels_whole < 21).long()
rel_logits_correct = (rels_new_whole == rels_whole).long() * arc_logits_correct 

print(rel_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())
print(arc_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())

tensor(0.3489)
tensor(0.5243)
---------------------------------------------------
tensor(0.8312)
tensor(0.8735)


In [43]:
torch.save(heads_new_whole, '../preds/head_preds.pt')
torch.save(rels_new_whole, '../preds/rel_preds.pt')

In [44]:
rel_preds = similarities.argmax(-1)
values, topk = similarities.topk(k=2, dim=-1)

new_preds = []
for pred, value in tqdm(zip(topk, values)):
    tmp = []
    for p, v in zip(pred, value):
        if p[0] < 21:
            if p[1] >= 21:
                tmp.append(p[1])
                continue
            # if p[2] >= 21:
            #     tmp.append(p[2])
            #     continue
        tmp.append(p[0])
    new_preds.append(tmp)

new_preds = torch.tensor(new_preds)
arc_logits_correct = (arc_logits.argmax(-1) == heads_whole).long() * (rels_whole >= 21).long()
rel_logits_correct = (new_preds == rels_whole).long() * arc_logits_correct 

print(rel_logits_correct.sum() / (rels_whole >= 21).long().sum())
print(arc_logits_correct.sum() / (rels_whole >= 21).long().sum())

IndexError: argmax(): Expected reduction dim -1 to have non-zero size.

In [None]:
rel_preds = rel_logits.argmax(-1)
rel_preds.masked_fill_(rel_preds == 2, 27)
rel_preds.masked_fill_(rel_preds == 3, 27)

arc_logits_correct = (arc_logits.argmax(-1) == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

print(rel_logits_correct.sum() / (rels_whole >= 21).long().sum())
print(arc_logits_correct.sum() / (rels_whole >= 21).long().sum())

In [None]:
rel_preds = rel_logits.argmax(-1)

arc_logits_correct = (arc_logits.argmax(-1) == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

print(rel_logits_correct.sum() / ((rels_whole == 2).long() + (rels_whole == 3).long()).sum())
print(arc_logits_correct.sum() / ((rels_whole == 2).long() + (rels_whole == 3).long()).sum())