In [1]:
from typing import *
import random
import json

In [2]:
from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader

In [3]:
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader

In [4]:
import sys
sys.path.append('..')

In [5]:
from model.par_with_attr import ParWithAttr
from model.base_par import DepParser
from utils import uas_las, to_cuda

# Config

In [6]:
class CFG:
    plm = 'hfl/chinese-electra-180g-base-discriminator'
    data_file = '../data_testset/1to800_1108.json'
    random_seed = 42
    num_epochs = 15
    batch_size = 128
    lr = 2e-5
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 2
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 3
    max_length = 160
    hidden_size = 400
    num_labels = 21
    gamma = 0.7
    alpha = 0.7
    print_every = 400
    eval_every = 800
    cuda = True
    fp16 = True

# Count

In [7]:
rel_dct = {
    'root': '根节点',
    'sasubj-obj': '同主同宾',
    'sasubj': '同主语',
    'dfsubj': '不同主语',
    'subj': '主语',
    'subj-in': '内部主语',
    'obj': '宾语',
    'pred': '谓语',
    'att': '定语',
    'adv': '状语',
    'cmp': '补语',
    'coo': '并列',
    'pobj': '介宾',
    'iobj': '间宾',
    'de': '的',
    'adjct': '附加',
    'app': '称呼',
    'exp': '解释',
    'punc': '标点',
    'frag': '片段',
    'repet': '重复',
    # rst
    'attr': '归属',
    'bckg': '背景',
    'cause': '因果',
    'comp': '比较',
    'cond': '状况',
    'cont': '对比',
    'elbr': '阐述',
    'enbm': '目的',
    'eval': '评价',
    'expl': '解释-例证',
    'joint': '联合',
    'manner': '方式',
    'rstm': '重申',
    'temp': '时序',
    # 'tp-chg': '主题变更',
    # 'prob-sol': '问题-解决',
    # 'qst-ans': '疑问-回答',
    # 'stm-rsp': '陈述-回应',
    # 'req-proc': '需求-处理',
}

In [8]:
rel2id = {}
for i, (key, value) in enumerate(rel_dct.items()):
    rel2id[key] = i
print(rel2id)
print(len(rel2id))

{'root': 0, 'sasubj-obj': 1, 'sasubj': 2, 'dfsubj': 3, 'subj': 4, 'subj-in': 5, 'obj': 6, 'pred': 7, 'att': 8, 'adv': 9, 'cmp': 10, 'coo': 11, 'pobj': 12, 'iobj': 13, 'de': 14, 'adjct': 15, 'app': 16, 'exp': 17, 'punc': 18, 'frag': 19, 'repet': 20, 'attr': 21, 'bckg': 22, 'cause': 23, 'comp': 24, 'cond': 25, 'cont': 26, 'elbr': 27, 'enbm': 28, 'eval': 29, 'expl': 30, 'joint': 31, 'manner': 32, 'rstm': 33, 'temp': 34}
35


# Seed & Device

In [9]:
def seed_everything(seed=CFG.random_seed):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

seed_everything()

In [10]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')

Using device: cuda


# Data

In [11]:
class Dependency():
    def __init__(self, idx, word, head, rel):
        self.id = idx
        self.word = word
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.idx), self.word, "_", "_", "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.word}, {self.head}, {self.rel})"

In [12]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
CFG.tokenizer = tokenizer

In [13]:
def load_annoted(data_file):
    with open(CFG.data_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    sample_lst:List[List[Dependency]] = []
    
    for d in data:
        rel_dct = {}
        for tripple in d['relationship']:
            head, rel, tail = tripple
            head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
            tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
            if head_uttr_idx != tail_uttr_idx:
                continue
            
            if not rel_dct.get(head_uttr_idx, None):
                rel_dct[head_uttr_idx] = {tail_word_idx: [head_word_idx, rel]}
            else:
                rel_dct[head_uttr_idx][tail_word_idx] = [head_word_idx, rel]
            
        for item in d['dialog']:
            turn = item['turn']
            utterance = item['utterance']
            # dep_lst:List[Dependency] = [Dependency(0, '[root]', -1, '_')]
            dep_lst:List[Dependency] = []
            
            for word_idx, word in enumerate(utterance.split(' ')):
                head_word_idx, rel = rel_dct[turn].get(word_idx + 1, [word_idx, 'adjct'])  # some word annoted missed, padded with last word and 'adjct'
                dep_lst.append(Dependency(word_idx + 1, word, head_word_idx, rel))  # start from 1
            
            sample_lst.append(dep_lst)
        
    return sample_lst

In [14]:
class DialogDataset(Dataset):
    def __init__(self, cfg):
        self.cfg = cfg
        self.inputs, self.offsets, self.heads, self.rels, self.masks = self.read_data()
        
    def read_data(self):
        inputs, offsets = [], []
        tags, heads, rels, masks = [], [], [], []
        
        for deps in load_annoted(self.cfg.data_file):
            seq_len = len(deps)

            word_lst = [] 
            head_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)  # same as root index is 0, constrainting by mask 
            rel_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            mask_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            for i, dep in enumerate(deps):
                if i == seq_len or i + 1== self.cfg.max_length:
                    break

                word_lst.append(dep.word)

                if dep.head != -1 and dep.rel in rel2id.keys() and dep.head + 1 < self.cfg.max_length:
                    head_tokens[i+1] = dep.head
                    mask_tokens[i+1] = 1
                    rel_tokens[i+1] = rel2id[dep.rel]

                rel_tokens[i+1] = rel2id.get(dep.rel, rel2id['adjct'])

            tokenized = tokenizer.encode_plus(word_lst, 
                                              padding='max_length', 
                                              truncation=True,
                                              max_length=self.cfg.max_length, 
                                              return_offsets_mapping=True, 
                                              return_tensors='pt',
                                              is_split_into_words=True)
            inputs.append({"input_ids": tokenized['input_ids'][0],
                           "token_type_ids": tokenized['token_type_ids'][0],
                           "attention_mask": tokenized['attention_mask'][0]
                          })

            sentence_word_idx = []
            for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                if start == 0 and end != 0:
                    sentence_word_idx.append(idx)

            if len(sentence_word_idx) < self.cfg.max_length - 1:
                sentence_word_idx.extend([0]* (self.cfg.max_length - 1 - len(sentence_word_idx)))
            offsets.append(torch.as_tensor(sentence_word_idx))

            heads.append(head_tokens)
            rels.append(rel_tokens)
            masks.append(mask_tokens)
                    
        return inputs, offsets, heads, rels, masks

    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.heads[idx], self.rels[idx], self.masks[idx]
    
    def __len__(self):
        return len(self.rels)

In [15]:
val_dataset = DialogDataset(CFG)

# Model

In [16]:
# model = ParWithAttr(CFG, attr_tokenized)
model = DepParser(CFG)
print(model.load_state_dict(torch.load('../results/base_par.pt')))
model = model.cuda()

Some weights of the model checkpoint at hfl/chinese-electra-180g-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>


In [17]:
va_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size)

arc_logits, rel_logits, similarities = torch.Tensor(), torch.Tensor(), torch.Tensor()
heads_whole, rels_whole, masks_whole = torch.Tensor(), torch.Tensor(), torch.Tensor()
for batch in va_dataloader:
    inputs, offsets, heads, rels, masks = batch

    inputs_cuda = {}
    for key, value in inputs.items():
        inputs_cuda[key] = value.cuda()
    inputs = inputs_cuda

    offsets, heads, rels, masks = to_cuda(data=(offsets, heads, rels, masks))
    
    with torch.no_grad():
        model.eval()
        # arc_logit, rel_logit, similarity = model.predict(inputs, offsets, masks)
        arc_logit, rel_logit = model(inputs, offsets, heads, rels, masks, evaluate=True)
        
    arc_logit[:, torch.arange(arc_logit.size()[1]), torch.arange(arc_logit.size()[2])] = -1e4
    
    arc_logits = torch.cat([arc_logits, arc_logit.cpu()])
    rel_logits = torch.cat([rel_logits, rel_logit.cpu()])
    # similarities = torch.cat([similarities, similarity.cpu()])
    
    heads_whole = torch.cat([heads_whole, heads.cpu()])
    rels_whole = torch.cat([rels_whole, rels.cpu()])
    masks_whole = torch.cat([masks_whole, masks.cpu()])

In [18]:
uas_las(arc_logits=arc_logits,
        rel_logits=rel_logits,
        arc_gt=heads_whole,
        rel_gt=rels_whole,
        mask=masks_whole)

{'UAS': 0.8475163106201569, 'LAS': 0.7871081329344984}

In [19]:
# uas_las(arc_logits=arc_logits,
#         rel_logits=similarities,
#         arc_gt=heads_whole,
#         rel_gt=rels_whole,
#         mask=masks_whole)

In [20]:
# rel_preds = similarities.argmax(-1)
# # rel_preds.masked_fill_(rel_preds == 2, 27)
# # rel_preds.masked_fill_(rel_preds == 3, 27)

# arc_logits_correct = (arc_logits.argmax(-1) == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
# rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

# print(rel_logits_correct.sum() / (rels_whole >= 21).long().sum())
# print(arc_logits_correct.sum() / (rels_whole >= 21).long().sum())

In [41]:
rel_preds = rel_logits.argmax(-1)
head_pred = arc_logits.argmax(-1)

arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

print(rel_logits_correct.sum() / (rels_whole >= 21).long().sum())
print(arc_logits_correct.sum() / (rels_whole >= 21).long().sum())
print('---------------------------------------------------')
arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole < 21).long()
rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct 

print(rel_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())
print(arc_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())

print('===================================================')

target = 25

rel_preds.masked_fill_(rel_preds == 2, target)
# rel_preds.masked_fill_(rel_preds == 3, target)
# rel_preds.masked_fill_(rel_preds == 6, 27)
# rel_preds.masked_fill_(rel_preds == 12, 27)
# rel_preds.masked_fill_(rel_preds != 27, 27)

arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole == target).long()
rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

print(rel_logits_correct.sum() / (rels_whole == target).long().sum())
print(arc_logits_correct.sum() / (rels_whole == target).long().sum())
print('---------------------------------------------------')

arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole == target).long()
rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct 

print(rel_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())
print(arc_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())

tensor(0.)
tensor(0.4783)
---------------------------------------------------
tensor(0.8365)
tensor(0.8707)
tensor(0.)
tensor(0.0541)
---------------------------------------------------
tensor(0.)
tensor(0.0003)


In [54]:
# matches = {}

# for target in range(21, 36):
#     rel_preds = rel_logits.argmax(-1)
#     head_pred = arc_logits.argmax(-1)

#     # rel_preds.masked_fill_(rel_preds == 0, target)
#     # rel_preds.masked_fill_(rel_preds == 2, target)
#     rel_preds.masked_fill_(rel_preds == 3, target)
#     # rel_preds.masked_fill_(rel_preds == 6, target)
#     # rel_preds.masked_fill_(rel_preds == 12, 27)
#     # rel_preds.masked_fill_(rel_preds != 27, 27)

#     arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole == target).long()
#     rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

#     matches[target] = rel_logits_correct.sum() / (rels_whole == target).long().sum()
# print(matches)

{21: tensor(0.), 22: tensor(0.1422), 23: tensor(0.3972), 24: tensor(0.3289), 25: tensor(0.0081), 26: tensor(0.3168), 27: tensor(0.2656), 28: tensor(0.1160), 29: tensor(0.2400), 30: tensor(0.5000), 31: tensor(0.1485), 32: tensor(0.0291), 33: tensor(0.0467), 34: tensor(0.0814), 35: tensor(nan)}


In [22]:
rel_preds = rel_logits.argmax(-1)
head_preds = arc_logits.argmax(-1)

# rel_preds.masked_fill_(rel_preds == 2, 27)
# rel_preds.masked_fill_(rel_preds == 3, 27)

for i in range(21, 35):
    arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole == i).long()
    rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct

    print(rel_logits_correct.sum() / (rels_whole == i).long().sum())

tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)


In [23]:
def load_codt_signal(data_file: str, return_two=False):
    sentence:List[Dependency] = []

    with open(data_file, 'r', encoding='utf-8') as f:
        # data example: 1	上海	_	NR	NR	_	2	nn	_	_
        for line in f.readlines():
            toks = line.split()
            if len(toks) == 0 and len(sentence) != 0:
                yield sentence
                sentence = []
            elif len(toks) == 10:                
                if return_two:
                    sentence.append([int(toks[2]), int(toks[3])])
                else:
                    sentence.append(int(toks[2]))

In [24]:
from constant import rel2id, punct_lst, weak_signals, weak_labels

id2rel = list(rel2id.keys())

origin4change = [rel2id[x] for x in ['root', 'dfsubj', 'sasubj']]
# origin4change.extend([i for i in range(21, 35)])

signal_dct = {}
for i, signals in enumerate(weak_signals):
    for s in signals:
        signal_dct[s] = weak_labels[i]
print(signal_dct)

{'说': 21, '表示': 21, '看到': 21, '显示': 21, '知道': 21, '认为': 21, '希望': 21, '指出': 21, '如果': 25, '假如': 25, '的话': 25, '若': 25, '如': 25, '因为': 23, '所以': 23, '导致': 23, '因此': 23, '造成': 23, '由于': 23, '因而': 23, '但是': 26, '可是': 26, '但': 26, '竟': 26, '却': 26, '不过': 26, '居然': 26, '而是': 26, '以及': 31, '也': 31, '并': 31, '并且': 31, '又': 31, '或者': 31, '对于': 22, '自从': 22, '上次': 22, '明天': 34, '晚上': 34, '到时候': 34, '再': 34, '然后': 34, '接下来': 34, '最后': 34, '随后': 34, '为了': 28, '使': 28, '为 的': 28, '为 了': 28, '通过': 32, '必须': 32, '点击': 32, '对 吗': 33, '是 吗': 33, '对 吧': 33, '是 吧': 33, '对 ?': 33, '更': 24, '比': 24, '只': 24, '解释': 30, '比如': 30, '例如': 30, '是 这样': 30, '理想': 29, '真 棒': 29, '太 棒': 29, '真差': 29, '太 差': 29, '不 行': 29, '扯皮': 29, '这么 麻烦': 29}


In [25]:
rel_preds = rel_logits.argmax(-1)
head_preds = arc_logits.argmax(-1)

max_len = CFG.max_length

signals_new_whole = torch.Tensor()
heads_new_whole, rels_new_whole = torch.Tensor(), torch.Tensor()
for sample_idx, (deps, pred_signals) in tqdm(enumerate(zip(load_annoted(CFG.data_file), load_codt_signal('../mlm_based/diag_test.conll')))):
    seq_len = len(deps)
    if seq_len == 0:
        continue
    
    signals = torch.full(size=(max_len,), fill_value=rel2id['elbr']).int()
    heads, rels = torch.full(size=(max_len,), fill_value=-2).int(), torch.zeros(max_len).int()
    split, splits, signal, word_lst  = 1, [1], rel2id['elbr'], ['root']
    for i, dep in enumerate(deps[:-1]):
        if i + 2 >= max_len:
            break
        
        word = dep.word
        word_lst.append(word)

        # if word in signal_dct.keys():
        #     signal = signal_dct[word]
        # if f'{word} {deps[i+1].word}' in signal_dct.keys():
        #     signal = signal_dct[f'{word} {deps[i+1].word}']

        try:
            signal = pred_signals[i]
        except IndexError:
            signal = pred_signals[len(pred_signals) - 1]

        if word in punct_lst and deps[i+1].word not in punct_lst:
            if i + 2 - split > 2:  # set 2 to the min length of edu
                signals[split:i+2] = signal
                # signal = None
            split = i + 2
            splits.append(split)

    splits.append(len(deps))
            
    # add the last data
    if i + 1 < max_len:
        signal = pred_signals[-1]
        word_lst.append(word)

    heads = head_preds[sample_idx]
    heads.masked_fill_(mask=~masks_whole[sample_idx].bool(), value=-2)

    rels = rel_preds[sample_idx]
    rels.masked_fill_(mask=~masks_whole[sample_idx].bool(), value=-2)

    cnt, attr, = -1, False
    for idx, head in enumerate(heads[1:]):
        if head == -2:
            break
        if head == -1:
            continue

        if len(splits) > 2 and idx + 1 >= splits[cnt+1]:
            cnt += 1

        if ((len(splits) > 2 and (head < splits[cnt] or head >= splits[cnt+1])) or idx - head > 0) and rels[idx + 1] in origin4change:  # cross 'edu'

            rels[idx+1] = signals[idx+1]

            if rels[idx + 1] in [rel2id['cond']]:  # reverse
                tmp_heads = heads.clone()
                tmp_heads[:splits[cnt+1]] = 0
                head_idx = [idx + 1]
                tail_idx = (tmp_heads == idx + 1).nonzero()  # find tail index
                if len(tail_idx) == 0:  # ring or fail
                    # unchange
                    tail_idx = [idx + 1]
                    head_idx = (heads == idx + 1).nonzero() if head_idx == tail_idx else head_idx
                elif len(head_idx) != 0:
                    heads[tail_idx[0]] = 0
                    heads[head_idx[0]] = tail_idx[0]

            # special cases
            if word_lst[idx+1] == '好' and word_lst[idx] in ['你', '您']:  # reverse
                tmp_heads = heads.clone()
                tmp_heads[:splits[cnt+1]] = 0
                tail_idx = (tmp_heads == idx + 1).nonzero()  # find tail index
                if len(tail_idx) != 0:  
                    heads[tail_idx[0]] = 0
                    heads[idx + 1] = tail_idx[0]
                    rels[idx + 1] = rel2id['elbr']

        if not attr and rels[idx + 1] in [rel2id['obj']] and signals[idx+1] == rel2id['attr']:
            rels[idx+1] = signals[idx+1]
            attr = True

    rels.masked_fill_(heads == 0, 0)  # root
    heads[0] = 0
    heads[1:].masked_fill_(heads[1:] == -2, 0)

    heads_new_whole = torch.cat([heads_new_whole, heads.unsqueeze(0)])
    rels_new_whole = torch.cat([rels_new_whole, rels.unsqueeze(0)])
    signals_new_whole = torch.cat([signals_new_whole, signals.unsqueeze(0)])

20086it [00:44, 454.77it/s]


In [26]:
arc_logits_correct = (heads_new_whole == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
rel_logits_correct = (rels_new_whole == rels_whole).long() * arc_logits_correct

print(rel_logits_correct.sum() / (rels_whole >= 21).long().sum())
print(arc_logits_correct.sum() / (rels_whole >= 21).long().sum())
print('---------------------------------------------------')

arc_logits_correct = (heads_new_whole == heads_whole).long() * masks_whole * (rels_whole < 21).long()
rel_logits_correct = (rels_new_whole == rels_whole).long() * arc_logits_correct 

print(rel_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())
print(arc_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum())

tensor(0.3777)
tensor(0.5209)
---------------------------------------------------
tensor(0.8302)
tensor(0.8725)


In [27]:
# arc_logits_correct = (heads_new_whole == heads_whole).long() * masks_whole * (rels_whole == 0).long()
# rel_logits_correct = (rels_new_whole == rels_whole).long() * arc_logits_correct 

# print(rel_logits_correct.sum() / (masks_whole * (rels_whole == 0).long()).sum())
# print(arc_logits_correct.sum() / (masks_whole * (rels_whole == 0).long()).sum())

# Inter

In [28]:
# rel_preds = rel_logits.argmax(-1)

cnt = 0
root_ids = []
# for rel_pred, mask in zip(rel_preds, masks_whole):
for rel_pred, mask in zip(rels_new_whole, masks_whole):
    try:
        root_idx = (((rel_pred == 0) * mask) != 0).nonzero()[0].item()
    except IndexError: # no root
        root_idx = 2
    root_ids.append(root_idx)

print(root_ids[:10])

[1, 2, 5, 3, 2, 2, 2, 1, 1, 1]


In [29]:
def load_inter(data_file):
    with open(data_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    signal_iter = load_codt_signal('../mlm_based/diag_test.conll', return_two=True)

    sample_lst:List[List[Dependency]] = []
    # for d, pred_signals in tqdm(zip(data, load_codt_signal('../prompt_based/diag_test.conll', idx=3))):
    for d in data:
        rel_dct = {}
        for tripple in d['relationship']:
            head, rel, tail = tripple
            head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
            tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
            
            if rel == 'root' and head_uttr_idx != 0: # ignore root
                continue
                 
            if not rel_dct.get(tail_uttr_idx, None):
                rel_dct[tail_uttr_idx] = {tail_word_idx: [head, rel]}
            else:
                rel_dct[tail_uttr_idx][tail_word_idx] = [head, rel]
                
        sent_lens_accum = [1]
        for i, item in enumerate(d['dialog']):
            utterance = item['utterance']
            sent_lens_accum.append(sent_lens_accum[i] + len(utterance.split(' ')) + 1)
        sent_lens_accum[0] = 0
        
        dep_lst:List[Dependency] = []
        role_lst:List[str] = []
        weak_signal = []
        for item in d['dialog']:
            turn = item['turn']
            utterance = item['utterance']

            pred_signals = next(signal_iter)

            role = '[ans]' if item['speaker'] == 'A' else '[qst]'
            dep_lst.append(Dependency(sent_lens_accum[turn], role, -1, '_'))  
            
            tmp_signal = []
            for word_idx, word in enumerate(utterance.split(' ')):
                tail2head = rel_dct.get(turn, {1: [f'{turn}-{word_idx}', 'adjct']})
                head, rel = tail2head.get(word_idx + 1, [f'{turn}-{word_idx}', 'adjct'])  # some word annoted missed, padded with last word and 'adjct'
                head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
                
                # only parse cross-utterance
                if turn != head_uttr_idx:
                    dep_lst.append(Dependency(sent_lens_accum[turn] + word_idx + 1, word, sent_lens_accum[head_uttr_idx] + head_word_idx, rel))  # add with accumulated length
                else:
                    dep_lst.append(Dependency(sent_lens_accum[turn] + word_idx + 1, word, -1, '_')) 

                try:
                    signal1, signal2 = pred_signals[i]
                except IndexError:
                    signal1, signal2 = pred_signals[len(pred_signals) - 1]
                
                tmp_signal = [signal1, signal2]
                
                # if word in weak_signal_dct.keys():
                #     tmp_signal.append(weak_signal_dct[word])

            if len(tmp_signal) != 0:
                # weak_signal.append(tmp_signal[-1])  # choose the last
                weak_signal.append(tmp_signal)  # choose the last
            else:
                weak_signal.append(-1)
            role_lst.append(item['speaker'])        
        sample_lst.append([dep_lst, role_lst, weak_signal])
        
    return sample_lst

In [30]:
class InterDataset(Dataset):
    def __init__(self, cfg):
        self.cfg = cfg
        self.inputs, self.offsets, self.heads, self.rels, self.masks, self.speakers, self.signs = self.read_data()
        
    def read_data(self):
        inputs, offsets = [], []
        tags, heads, rels, masks, speakers, signs = [], [], [], [], [], []
                
        for idx, (deps, roles, sign) in enumerate(load_inter(self.cfg.data_file)):
            seq_len = len(deps)
            signs.append(sign)

            word_lst = [] 
            head_tokens = np.zeros(1024, dtype=np.int64)  # same as root index is 0, constrainting by mask 
            rel_tokens = np.zeros(1024, dtype=np.int64)
            mask_tokens = np.zeros(1024, dtype=np.int64)
            for i, dep in enumerate(deps):
                if i == seq_len or i + 1== 1024:
                    break

                word_lst.append(dep.word)
                
                if int(dep.head) == -1 or int(dep.head) + 1 >= 1024:
                    head_tokens[i+1] = 0
                    mask_tokens[i+1] = 0
                else:
                    head_tokens[i+1] = int(dep.head)
                    mask_tokens[i+1] = 1
#                     head_tokens[i] = dep.head if dep.head != '_' else 0
                rel_tokens[i+1] = rel2id.get(dep.rel, 0)

            tokenized = tokenizer.encode_plus(word_lst, 
                                              padding='max_length', 
                                              truncation=True,
                                              max_length=1024, 
                                              return_offsets_mapping=True, 
                                              return_tensors='pt',
                                              is_split_into_words=True)
            inputs.append({"input_ids": tokenized['input_ids'][0],
                          "token_type_ids": tokenized['token_type_ids'][0],
                           "attention_mask": tokenized['attention_mask'][0]
                          })

#                 sentence_word_idx = np.zeros(self.cfg.max_length, dtype=np.int64)
            sentence_word_idx = []
            for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                if start == 0 and end != 0:
                    sentence_word_idx.append(idx)
#                         sentence_word_idx[idx] = idx
            if len(sentence_word_idx) < 1024 - 1:
                sentence_word_idx.extend([0]* (1024 - 1 - len(sentence_word_idx)))
            offsets.append(torch.as_tensor(sentence_word_idx))
#                 offsets.append(sentence_word_idx)

            heads.append(head_tokens)
            rels.append(rel_tokens)
            masks.append(mask_tokens)
            speakers.append(roles)
                    
        return inputs, offsets, heads, rels, masks, speakers, signs

    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.heads[idx], self.rels[idx], self.masks[idx], self.speakers[idx], self.signs[idx]
    
    def __len__(self):
        return len(self.rels)

In [55]:
inter_dataset = InterDataset(CFG)
inter_dataloader = DataLoader(inter_dataset, batch_size=1)

In [70]:
cnt = 0

inter_heads_whole, inter_rels_whole, inter_masks_whole = torch.Tensor(), torch.Tensor(), torch.Tensor()
inter_heads_preds, inter_rels_preds = torch.Tensor(), torch.Tensor()
for batch in inter_dataloader:
    inputs, offsets, heads, rels, masks, speakers, signs = batch
    inter_head_preds = torch.zeros_like(heads, dtype=int)
    inter_rel_preds = torch.zeros_like(rels, dtype=int)

    inter_heads_whole = torch.cat([inter_heads_whole, heads])
    inter_rels_whole = torch.cat([inter_rels_whole, rels])
    inter_masks_whole = torch.cat([inter_masks_whole, masks])

    accum = 1
    # print(masks.sum())
    for i, speakr in enumerate(speakers[1:]):
        seq_len = masks_whole[cnt].sum().item() + 1
        # if signs[i] != -1:
        #     rel = signs[i]
        # elif speakr == speakers[i]:
        #     rel = 27
        # else:
        #     rel = 38
        if speakr == speakers[i]:
            rel = signs[i][0]
        else:
            rel = signs[i][1]

        # rel = 39
        
        head_idx = int(root_ids[cnt] + accum) if i > 0 else root_ids[cnt]
        tail_idx = int(root_ids[cnt+1] + accum + seq_len)
        
        # print(head_idx, tail_idx)
        inter_head_preds[0][tail_idx] = head_idx
        inter_rel_preds[0][tail_idx] = rel

        # print(inter_head_preds)
        # print(inter_rel_preds)
        cnt += 1
        accum += seq_len

    cnt += 1

    inter_heads_preds = torch.cat([inter_heads_preds, inter_head_preds])
    inter_rels_preds = torch.cat([inter_rels_preds, inter_rel_preds])

In [59]:
arc_logits_correct = (inter_heads_preds == inter_heads_whole).long() * inter_masks_whole
rel_logits_correct = (inter_rels_preds == inter_rels_whole).long() * arc_logits_correct

print(rel_logits_correct.sum() / inter_masks_whole.long().sum())
print(arc_logits_correct.sum() / inter_masks_whole.long().sum())

tensor(0.2616)
tensor(0.7007)


In [71]:
# arc_logits_correct = (inter_heads_preds == inter_heads_whole).long() * inter_masks_whole * (inter_rels_whole == 39).long()
# rel_logits_correct = (inter_rels_preds == inter_rels_whole).long() * arc_logits_correct

# print(rel_logits_correct.sum() / (inter_rels_whole == 39).long().sum())

tensor(0.6078)


In [34]:
(rels_whole >= 21).long().sum()

tensor(10028)

In [35]:
(masks_whole * (rels_whole < 21).long()).sum()

tensor(159800.)

In [36]:
inter_masks_whole.long().sum()

tensor(19175)

In [37]:
arc_logits_correct_inner = (heads_new_whole == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
arc_logits_correct_inter = (inter_heads_preds == inter_heads_whole).long() * inter_masks_whole

(arc_logits_correct_inner.sum() + arc_logits_correct_inter.sum()) / ((rels_whole >= 21).long().sum() + inter_masks_whole.long().sum())

tensor(0.6390)

In [38]:
rel_logits_correct_inner = (rels_new_whole == rels_whole).long() * arc_logits_correct_inner
rel_logits_correct_inter = (inter_rels_preds == inter_rels_whole).long() * arc_logits_correct_inter

(rel_logits_correct_inner.sum() + rel_logits_correct_inter.sum()) / ((rels_whole >= 21).long().sum() + inter_masks_whole.long().sum())

tensor(0.4914)