In [1]:
import os
from collections import Counter
from typing import *
import random
import json

In [2]:
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import KFold

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.cuda.amp import autocast, GradScaler

from transformers import AutoConfig, AutoModel, AutoTokenizer, AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup

In [3]:
from model import DepParser
from utils import arc_rel_loss, uas_las, to_cuda

## Config

In [4]:
class CFG:
    data_file = '/root/diag_dep/data_testset/1to500_1013.json'
    plm = 'hfl/chinese-electra-180g-large-discriminator'
    num_folds = 5
    trn_folds = [0, 1, 2, 3, 4]
    random_seed = 42
    num_epochs = 10
    batch_size = 128
    lr = 2e-5
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 1
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 5
    max_length = 160
    num_labels = 40
    hidden_size = 400
    print_every = 1e9
    eval_every = 50
    cuda = True
    fp16 = True

## Seed and Device

In [5]:
def seed_everything(seed=CFG.random_seed):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

seed_everything()

## Data

In [6]:
rel_dct = {
    'root': '根节点',
    'sasubj-obj': '同主同宾',
    'sasubj': '同主语',
    'dfsubj': '不同主语',
    'subj': '主语',
    'subj-in': '内部主语',
    'obj': '宾语',
    'pred': '谓语',
    'att': '定语',
    'adv': '状语',
    'cmp': '补语',
    'coo': '并列',
    'pobj': '介宾',
    'iobj': '间宾',
    'de': '的',
    'adjct': '附加',
    'app': '称呼',
    'exp': '解释',
    'punc': '标点',
    'frag': '片段',
    'repet': '重复',
    # rst
    'attr': '归属',
    'bckg': '背景',
    'cause': '因果',
    'comp': '比较',
    'cond': '状况',
    'cont': '对比',
    'elbr': '阐述',
    'enbm': '目的',
    'eval': '评价',
    'expl': '解释-例证',
    'joint': '联合',
    'manner': '方式',
    'rstm': '重申',
    'temp': '时序',
    'tp-chg': '主题变更',
    'prob-sol': '问题-解决',
    'qst-ans': '疑问-回答',
    'stm-rsp': '陈述-回应',
    'req-proc': '需求-处理',
}

In [7]:
rel2id = {key:idx for idx, key in enumerate(rel_dct.keys())}
print(rel2id)

id2rel = [key for key in rel_dct.keys()]
print(id2rel)

{'root': 0, 'sasubj-obj': 1, 'sasubj': 2, 'dfsubj': 3, 'subj': 4, 'subj-in': 5, 'obj': 6, 'pred': 7, 'att': 8, 'adv': 9, 'cmp': 10, 'coo': 11, 'pobj': 12, 'iobj': 13, 'de': 14, 'adjct': 15, 'app': 16, 'exp': 17, 'punc': 18, 'frag': 19, 'repet': 20, 'attr': 21, 'bckg': 22, 'cause': 23, 'comp': 24, 'cond': 25, 'cont': 26, 'elbr': 27, 'enbm': 28, 'eval': 29, 'expl': 30, 'joint': 31, 'manner': 32, 'rstm': 33, 'temp': 34, 'tp-chg': 35, 'prob-sol': 36, 'qst-ans': 37, 'stm-rsp': 38, 'req-proc': 39}
['root', 'sasubj-obj', 'sasubj', 'dfsubj', 'subj', 'subj-in', 'obj', 'pred', 'att', 'adv', 'cmp', 'coo', 'pobj', 'iobj', 'de', 'adjct', 'app', 'exp', 'punc', 'frag', 'repet', 'attr', 'bckg', 'cause', 'comp', 'cond', 'cont', 'elbr', 'enbm', 'eval', 'expl', 'joint', 'manner', 'rstm', 'temp', 'tp-chg', 'prob-sol', 'qst-ans', 'stm-rsp', 'req-proc']


In [8]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
print(len(tokenizer))
 
num_added_toks = tokenizer.add_tokens(['[root]', '[qst]', '[aws]'], special_tokens=True)
tokenizer.root_token = '[root]'
tokenizer.root_token_ids = tokenizer('[root]')['input_ids'][1]
print(f"add token: {tokenizer.root_token} {tokenizer.root_token_ids}")

tokenizer.qst_token = '[qst]'
tokenizer.qst_token_ids = tokenizer('[qst]')['input_ids'][1]
print(f"add token: {tokenizer.qst_token} {tokenizer.qst_token_ids}")

tokenizer.ans_token = '[ans]'
tokenizer.ans_token_ids = tokenizer('[ans]')['input_ids'][1]
print(f"add token: {tokenizer.ans_token} {tokenizer.ans_token_ids}")
print(len(tokenizer))

CFG.tokenizer = tokenizer

21128
add token: [root] 21128
add token: [qst] 21129
add token: [ans] 138
21131


In [9]:
class Dependency():
    def __init__(self, idx, word, head, rel):
        self.id = idx
        self.word = word
        self.tag = '_'
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.idx), self.word, "_", self.tag, "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.id}, {self.word}, {self.tag}, {self.head}, {self.rel})"

In [10]:
def load_data(data_file, get_data_split=False):
    with open(CFG.data_file, 'r', encoding='utf-8') as f:
        data = json.load(f)[500:]  # have annotated 500 data
        
    sample_lst:List[List[Dependency]] = []
    split_ids:List[int] = []
    
    for d in data:
        rel_dct = {}
        for tripple in d['relationship']:
            head, rel, tail = tripple
            head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
            tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
            if head_uttr_idx != tail_uttr_idx:
                continue
            
            if not rel_dct.get(head_uttr_idx, None):
                rel_dct[head_uttr_idx] = {tail_word_idx: [head_word_idx, rel]}
            else:
                rel_dct[head_uttr_idx][tail_word_idx] = [head_word_idx, rel]
            
        for idx, item in enumerate(d['dialog']):
            turn = item['turn']
            utterance = item['utterance']
            # dep_lst:List[Dependency] = [Dependency(0, '[root]', -1, '_')]
            dep_lst:List[Dependency] = []
            
            for word_idx, word in enumerate(utterance.split(' ')):
                head_word_idx, rel = rel_dct[turn].get(word_idx + 1, [word_idx, 'adjct'])  # some word annoted missed, padded with last word and 'adjct'
                dep_lst.append(Dependency(word_idx + 1, word, head_word_idx, rel))  # start from 1
            
            sample_lst.append(dep_lst)
        split_ids.append(idx + 1)
     
    if get_data_split:
        return sample_lst, split_ids
    return sample_lst

In [11]:
class DialogDataset(Dataset):
    def __init__(self, cfg):
        self.cfg = cfg
        self.inputs, self.offsets, self.heads, self.rels, self.masks = self.read_data()
        
    def read_data(self):
        inputs, offsets = [], []
        tags, heads, rels, masks = [], [], [], []
        
        for deps in load_data(self.cfg.data_file):
            # another sentence
            seq_len = len(deps)

            word_lst = [] 
#                 head_tokens = np.ones(self.cfg.max_length, dtype=np.int64)*(-1)  # root index is 0, thus using -1 for padding 
            head_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)  # same as root index is 0, constrainting by mask 
            rel_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            mask_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            for i, dep in enumerate(deps):
                if i == seq_len or i + 1== self.cfg.max_length:
                    break

                word_lst.append(dep.word)

                if dep.head == -1 or dep.head + 1 >= self.cfg.max_length:
                    head_tokens[i+1] = 0
                    mask_tokens[i+1] = 0
                else:
                    head_tokens[i+1] = int(dep.head)
                    mask_tokens[i+1] = 1
#                     head_tokens[i] = dep.head if dep.head != '_' else 0
                rel_tokens[i+1] = rel2id.get(dep.rel, 0)

            tokenized = tokenizer.encode_plus(word_lst, 
                                              padding='max_length', 
                                              truncation=True,
                                              max_length=self.cfg.max_length, 
                                              return_offsets_mapping=True, 
                                              return_tensors='pt',
                                              is_split_into_words=True)
            inputs.append({"input_ids": tokenized['input_ids'][0],
                          "token_type_ids": tokenized['token_type_ids'][0],
                           "attention_mask": tokenized['attention_mask'][0]
                          })

#                 sentence_word_idx = np.zeros(self.cfg.max_length, dtype=np.int64)
            sentence_word_idx = []
            for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                if start == 0 and end != 0:
                    sentence_word_idx.append(idx)
#                         sentence_word_idx[idx] = idx
            if len(sentence_word_idx) < self.cfg.max_length - 1:
                sentence_word_idx.extend([0]* (self.cfg.max_length - 1 - len(sentence_word_idx)))
            offsets.append(torch.as_tensor(sentence_word_idx))
#                 offsets.append(sentence_word_idx)

            heads.append(head_tokens)
            rels.append(rel_tokens)
            masks.append(mask_tokens)
                    
        return inputs, offsets, heads, rels, masks

    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.heads[idx], self.rels[idx], self.masks[idx]
    
    def __len__(self):
        return len(self.rels)

## prediction

In [12]:
arc_logits_avg, rel_logits_avg = None, None

for fold in CFG.trn_folds:
    print(f'FOLD {fold}')
    print('--------------------------------')

    if CFG.cuda and torch.cuda.is_available:
        torch.cuda.empty_cache()

    # data loading
    test_dataset = DialogDataset(CFG)
    print(len(test_dataset))
    test_iter = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False)
                                 
    model = DepParser(CFG)
    model.load_state_dict(torch.load(f'/root/autodl-tmp/diag_dep/1to500-inner/{fold}/model.bin'))
    model.to('cuda' if CFG.cuda else 'cpu')
    model.eval()
    
    arc_logits_onefold, rel_logits_onefold, masks_onefold = torch.Tensor(), torch.Tensor(), torch.Tensor()
    for batch in test_iter:
        inputs, offsets, heads, rels, masks = batch

        if CFG.cuda and torch.cuda.is_available():
            inputs_cuda = {}
            for key,value in inputs.items():
                inputs_cuda[key] = value.cuda()
            inputs = inputs_cuda
        offsets, heads, rels, masks = to_cuda(data=(offsets, heads, rels, masks))
        masks = (masks == 0)
        
        with torch.no_grad():
            arc_logits, rel_logits = model(inputs, heads, offsets, evaluate=True)
            
        batch_size, seq_len, seq_len = arc_logits.shape 
        
        # diag
        arc_logits[:, torch.arange(seq_len), torch.arange(seq_len)] = -1e4
        
#         # one and only one root
#         root_masks = torch.full_like(input=arc_logits[:, :, 0], fill_value=True, dtype=torch.bool)
#         root_argmax = arc_logits[:, :, 0].argmax(dim=1)
#         root_masks[torch.arange(batch_size), root_argmax] = False
#         arc_logits[:, :, 0].masked_fill_(mask=root_masks, value=-1e4)
#         arc_logits[:, :, 0].masked_fill_(mask=~root_masks, value=1e4)
        
#         root_masks = torch.full_like(input=rel_logits[:, :, 0], fill_value=True, dtype=torch.bool)
#         root_argmax = rel_logits[:, :, 0].argmax(dim=1)
#         root_masks[torch.arange(batch_size), root_argmax] = False
#         rel_logits[:, :, 0].masked_fill_(mask=root_masks, value=-1e4)
#         rel_logits[:, :, 0].masked_fill_(mask=~root_masks, value=1e4)
        
        masks, arc_logits, rel_logits = masks.cpu(), arc_logits.cpu(), rel_logits.cpu()
        masks_onefold = torch.cat([masks_onefold, masks], dim=0)
        arc_logits_onefold = torch.cat([arc_logits_onefold, arc_logits], dim=0)
        rel_logits_onefold = torch.cat([rel_logits_onefold, rel_logits], dim=0)
        
    head_preds = arc_logits_onefold.argmax(-1)
    head_preds.masked_fill_(masks_onefold.bool(), value=-1)
    # head_preds = head_preds[1:].tolist()

    rel_preds = rel_logits_onefold.argmax(-1)
    rel_preds.masked_fill_(masks_onefold.bool(), value=-1)
    # rel_preds = rel_preds[1:].tolist()

    print(head_preds)
    print(rel_preds)
    
    if arc_logits_avg is None:
        arc_logits_avg = arc_logits_onefold / len(CFG.trn_folds)
        rel_logits_avg = rel_logits_onefold / len(CFG.trn_folds)
    else:
        arc_logits_avg += arc_logits_onefold / len(CFG.trn_folds)
        rel_logits_avg += rel_logits_onefold / len(CFG.trn_folds)

head_preds = arc_logits_avg.argmax(-1)
head_preds.masked_fill_(masks_onefold.bool(), value=-1)
# head_preds = head_preds[1:].tolist()

rel_preds = rel_logits_avg.argmax(-1)
rel_preds.masked_fill_(masks_onefold.bool(), value=-1)
# rel_preds = rel_preds[1:].tolist()

print("-----Argmax the Mean of all fold's Logits-----")
print(head_preds)
print(rel_preds)

FOLD 0
--------------------------------
7601


Some weights of the model checkpoint at hfl/chinese-electra-180g-large-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        ...,
        [-1,  0,  3,  ..., -1, -1, -1],
        [-1,  2,  3,  ..., -1, -1, -1],
        [-1,  3,  1,  ..., -1, -1, -1]])
tensor([[-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0,  6,  ..., -1, -1, -1],
        ...,
        [-1,  0,  9,  ..., -1, -1, -1],
        [-1,  8,  9,  ..., -1, -1, -1],
        [-1, 16, 18,  ..., -1, -1, -1]])
FOLD 1
--------------------------------
7601


Some weights of the model checkpoint at hfl/chinese-electra-180g-large-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        ...,
        [-1,  0,  3,  ..., -1, -1, -1],
        [-1,  2,  3,  ..., -1, -1, -1],
        [-1,  3,  1,  ..., -1, -1, -1]])
tensor([[-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0,  6,  ..., -1, -1, -1],
        ...,
        [-1,  0,  9,  ..., -1, -1, -1],
        [-1,  8,  9,  ..., -1, -1, -1],
        [-1, 16, 18,  ..., -1, -1, -1]])
FOLD 2
--------------------------------
7601


Some weights of the model checkpoint at hfl/chinese-electra-180g-large-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        ...,
        [-1,  0,  3,  ..., -1, -1, -1],
        [-1,  2,  3,  ..., -1, -1, -1],
        [-1,  3,  1,  ..., -1, -1, -1]])
tensor([[-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0,  6,  ..., -1, -1, -1],
        ...,
        [-1,  0,  9,  ..., -1, -1, -1],
        [-1,  8,  9,  ..., -1, -1, -1],
        [-1, 16, 18,  ..., -1, -1, -1]])
FOLD 3
--------------------------------
7601


Some weights of the model checkpoint at hfl/chinese-electra-180g-large-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        ...,
        [-1,  0,  3,  ..., -1, -1, -1],
        [-1,  2,  3,  ..., -1, -1, -1],
        [-1,  3,  1,  ..., -1, -1, -1]])
tensor([[-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0,  6,  ..., -1, -1, -1],
        ...,
        [-1,  0,  9,  ..., -1, -1, -1],
        [-1,  8,  9,  ..., -1, -1, -1],
        [-1, 16, 18,  ..., -1, -1, -1]])
FOLD 4
--------------------------------
7601


Some weights of the model checkpoint at hfl/chinese-electra-180g-large-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        ...,
        [-1,  0,  3,  ..., -1, -1, -1],
        [-1,  2,  3,  ..., -1, -1, -1],
        [-1,  3,  1,  ..., -1, -1, -1]])
tensor([[-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0,  6,  ..., -1, -1, -1],
        ...,
        [-1,  0,  9,  ..., -1, -1, -1],
        [-1,  8,  9,  ..., -1, -1, -1],
        [-1, 16, 18,  ..., -1, -1, -1]])
-----Argmax the Mean of all fold's Logits-----
tensor([[-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        [-1,  0,  1,  ..., -1, -1, -1],
        ...,
        [-1,  0,  3,  ..., -1, -1, -1],
        [-1,  2,  3,  ..., -1, -1, -1],
        [-1,  3,  1,  ..., -1, -1, -1]])
tensor([[-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0, 15,  ..., -1, -1, -1],
        [-1,  0,  6,  ..., -1, -1, -1],
        ...,
        [-1,  0,  9,  ..., -1, -1, -1],
        [-1,  8,  

In [13]:
data, split_ids = load_data(CFG.data_file, get_data_split=True)
max_len = 0
for d in data:
    max_len = max(max_len, len(d))
print(max_len)

152


In [14]:
with open(CFG.data_file, 'r', encoding='utf-8') as f:
    output = json.load(f)  

data, split_ids = load_data(CFG.data_file, get_data_split=True)

uttr_cnt, dialog_cnt = 0, 0
split_idx = split_ids.pop(0)
tripples = []

for i, d in enumerate(data):
    # print(d)
    head_pred, rel_pred = head_preds[i], rel_preds[i]
    # print(head_pred)
    # print(rel_pred)

    for dep in d:
        tail = f'{uttr_cnt}-{dep.id}'
        head_word_idx = head_pred[dep.id].item()
        # try:
        #     head = f'{uttr_cnt}-{head_pred[dep.id].item()}'
        #     rel = id2rel[rel_pred[dep.id].item()]
        # except IndexError:
        #     head = f'{uttr_cnt}-{head_pred[dep.id - 1].item()}'
        #     rel = id2rel[rel_pred[dep.id - 1].item()]
        head = f'{uttr_cnt}-{head_word_idx}'
        rel = id2rel[rel_pred[dep.id].item()]
        if head_word_idx <= len(d):
            tripples.append([head, rel, tail])
    
    uttr_cnt += 1
    if uttr_cnt == split_idx:
        # 500 is the index of dialog which is not annnoted
        output[dialog_cnt + 500]['relationship'] = tripples
        # print(tripples)
        dialog_cnt += 1
        tripples = []
        uttr_cnt = 0
        if len(split_ids) != 0:
            split_idx = split_ids.pop(0)

In [15]:
fw = open('pre_annot_new/annoted_by_1to500.json', 'w', encoding='utf-8')
save_str = json.dumps(output, ensure_ascii=False, indent=4, separators=(',', ': '))
fw.write(save_str)
fw.close()

In [None]:
import os
os.system("shutdown")