In [1]:
from collections import Counter
from typing import *
import random
import json

In [2]:
from tqdm import tqdm
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

In [3]:
from transformers import AutoConfig, AutoModel, AutoTokenizer, AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

2022-12-13 14:09:05.871482: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [4]:
import sys
sys.path.append('..')

from constant import rel2id

# Config

In [5]:
class CFG:
#     domains = ['BC', 'FIN', 'LEG', 'PB', 'PC', 'ZX']
    domains = ['BC']
    plm = 'hfl/chinese-electra-180g-base-discriminator'
    random_seed = 42
    num_epochs = 10
    batch_size = 64
    hidden_size = 400
    lr = 2e-5
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 2
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 10
    max_length = 128
    fp16 = True
    # num_labels = 21
    num_labels = 35
    cuda = True

# Data

In [6]:
class Dependency():
    def __init__(self, idx, word, tag, head, rel):
        self.id = idx
        self.word = word
        self.tag = tag
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.idx), self.word, "_", self.tag, "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.word}, {self.tag}, {self.head}, {self.rel})"

In [7]:
def load_codt(data_file: str):
    # id, form, tag, head, rel
#     sentence:List[Dependency] = [Dependency('0', '<root>', '_', '0', '_')]
    sentence:List[Dependency] = []

    with open(data_file, 'r', encoding='utf-8') as f:
        # data example: 1	上海	_	NR	NR	_	2	nn	_	_
        for line in f.readlines():
            toks = line.split()
            if len(toks) == 0:
                yield sentence
#                 sentence = [Dependency('0', '<root>', '_', '0', '_')]
                sentence = []
            elif len(toks) == 10:
                dep = Dependency(toks[0], toks[1], toks[3], toks[6], toks[7])
                sentence.append(dep)

In [8]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
CFG.tokenizer = tokenizer

# Load Model

In [9]:
from model.base_par import DepParser

parser = DepParser(CFG)

Some weights of the model checkpoint at hfl/chinese-electra-180g-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
# model_path = "../results/base_par.pt"
model_path = '../results/few_shot/with_codt/model_shot_50.bin'
model_stat_dict = torch.load(model_path)

In [11]:
parser.load_state_dict(model_stat_dict)

<All keys matched successfully>

# Prediction

In [12]:
domains = CFG.domains

tags, rels = [], []
for domain in domains:
    file = f'../suda/train/{domain}-Train-full.conll'
    fp = open(file, encoding='utf-8')
    for line in fp.readlines():
        # another sentence
        if line == '\n':
            continue
        line_sp = line.split('\t')
        rels.append(line_sp[7])
        tags.append(line_sp[3])
    
    fp.close()
    break

In [13]:
print(rel2id)
id2rel = list(rel2id.keys())

{'root': 0, 'sasubj-obj': 1, 'sasubj': 2, 'dfsubj': 3, 'subj': 4, 'subj-in': 5, 'obj': 6, 'pred': 7, 'att': 8, 'adv': 9, 'cmp': 10, 'coo': 11, 'pobj': 12, 'iobj': 13, 'de': 14, 'adjct': 15, 'app': 16, 'exp': 17, 'punc': 18, 'frag': 19, 'repet': 20, 'attr': 21, 'bckg': 22, 'cause': 23, 'comp': 24, 'cond': 25, 'cont': 26, 'elbr': 27, 'enbm': 28, 'eval': 29, 'expl': 30, 'joint': 31, 'manner': 32, 'rstm': 33, 'temp': 34, 'tp-chg': 35, 'prob-sol': 36, 'qst-ans': 37, 'stm-rsp': 38, 'req-proc': 39}


In [14]:
text = "你 好 , 请问 有 什么 我 可以 帮 您 的 吗 ?"
tokenized = tokenizer.encode_plus(text.split(' '), 
                      return_offsets_mapping=True, 
                      return_tensors='pt',
                      is_split_into_words=True)

In [15]:
def tokenize(tokenizer, text):
    word_lst = text.split(' ')
    
    tokenized = tokenizer.encode_plus(word_lst, 
                      return_offsets_mapping=True, 
                      return_tensors='pt',
                      is_split_into_words=True)
    return list(range(1, len(word_lst)+1)), tokenized 

In [16]:
def get_inputs(tokenized):
    inputs = {"input_ids": tokenized['input_ids'],
              "token_type_ids": tokenized['token_type_ids'],
              "attention_mask": tokenized['attention_mask']}    

    sentence_word_idx = []
    for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
        if start == 0 and end != 0:
            sentence_word_idx.append(idx)
    return inputs, torch.as_tensor(sentence_word_idx).unsqueeze(0)

In [17]:
def predict(parser, inputs, offsets):
    with torch.no_grad():
        # arc_logit, rel_logit = parser(inputs=inputs, offsets=offsets, heads=None, tags=None, evaluate=True)
        arc_logit, rel_logit = parser(inputs=inputs, offsets=offsets, heads=None, rels=None, masks=None, evaluate=True)

    arc_logit[:, torch.arange(arc_logit.size()[1]), torch.arange(arc_logit.size()[2])] = -1e4
    
    heads = arc_logit.argmax(-1)[0]
    heads = heads[1:].tolist()
    
    rels = rel_logit.argmax(-1)[0]
    rels = rels[1:].tolist()
    
    return heads, rels

In [18]:
def output_tripples(idcs, heads, rels, id2rel):
    tripples = []
    for i in range(len(idcs)):
        tripples.append((heads[i], id2rel[rels[i]], idcs[i]))
    return tripples

In [19]:
idcs, tokenized = tokenize(tokenizer, text)

In [20]:
inputs, offsets = get_inputs(tokenized)

In [21]:
heads, rels = predict(parser, inputs, offsets)

In [22]:
tripples = output_tripples(idcs, heads, rels, id2rel)

In [23]:
tripples

[(2, 'subj', 1),
 (0, 'root', 2),
 (2, 'punc', 3),
 (2, 'dfsubj', 4),
 (4, 'obj', 5),
 (11, 'att', 6),
 (9, 'subj', 7),
 (9, 'adv', 8),
 (11, 'de', 9),
 (9, 'obj', 10),
 (5, 'obj', 11),
 (5, 'adjct', 12),
 (12, 'punc', 13)]

In [24]:
parser = parser.cuda()

In [25]:
def pre_annot(in_file, out_file):
    fp = open(in_file, 'r', encoding='utf-8')
    fw = open(out_file, 'w', encoding='utf-8')
    data = json.load(fp)
    
    for i, d in enumerate(tqdm(data)):
        
        for turn, dialog in enumerate(d.get('dialog')):
            speaker = dialog.get('speaker')

            utterance = dialog.get('utterance')
            idcs, tokenized = tokenize(tokenizer, utterance)
            
            inputs, offsets = get_inputs(tokenized)
            inputs = {key:value.cuda() for key, value in inputs.items()}
            offsets = offsets.cuda()
            
            heads, rels = predict(parser, inputs, offsets)
            
            for i in range(len(heads)):
                head = f"{turn}-{heads[i]}"
                tail = f"{turn}-{idcs[i]}"
                rel = id2rel[rels[i]]
                d['relationship'].append([head, rel, tail])
        
    save_str = json.dumps(data, ensure_ascii=False, indent=4, separators=(',', ': '))
    fw.write(save_str)
    fp.close()
    fw.close()

In [26]:
# pre_annot(in_file='data_testset/test.json', out_file='pre_annot_new/test.json')

In [27]:
def predict_with_postproc(parser, word_lst, inputs, offsets):
    from constant import rel2id, punct_lst

    with torch.no_grad():
        # arc_logit, rel_logit = parser(inputs=inputs, offsets=offsets, heads=None, tags=None, evaluate=True)
        arc_logit, rel_logit = parser(inputs=inputs, offsets=offsets, heads=None, rels=None, masks=None,evaluate=True)

    word_lst.insert(0, '[root]')
    arc_logit[:, torch.arange(arc_logit.size()[1]), torch.arange(arc_logit.size()[2])] = -1e4

    weak_signals = [
    ['说', '表示', '看', '显示', '知道', '认为', '希望', '指出'],
    ['如果', '假如', '的话', '若', '如'],
    ['因为', '所以', '导致', '因此', '造成', '由于', '因而'],
    ['但是', '可是', '但', '竟', '却', '不过', '居然', '而是'],
    ['以及', '也', '并', '并且', '又', '或者'],
    ['对于', '自从', '之前', '上次'],
    ['明天', '晚上', '到时候', '再', '然后', '接下来', '最后', '随后'],
    ['为了', '想要', '使', '为 的'],
    ['通过', '必须', '点击'],
    ['对 吗', '是 吗', '对 吧', '是 吧', '对 ?'],
    ['更', '比', '只'],
    ['解释', '比如', '例如', '是 这样'],
    ['理想', '真 棒', '太 棒', '真差', '太 差', '不 行', '扯皮'],
    ]

    weak_labels = [
        rel2id['attr'],
        rel2id['cond'],
        rel2id['cause'],
        rel2id['cont'],
        rel2id['joint'],
        rel2id['bckg'],
        rel2id['temp'],
        rel2id['enbm'],
        rel2id['manner'],
        rel2id['rstm'],
        rel2id['comp'],
        rel2id['expl'],
        rel2id['eval'],
    ]

    priority = {
        rel2id['cont']: 1,
        rel2id['temp']: 2, rel2id['cause']: 2, rel2id['bckg']: 2, rel2id['comp']: 2,
        rel2id['joint']: 3, rel2id['attr']: 3,
        None: 5,  # 4 is default
    }

    reverse_by_words = [
        '你 好', '您 好',
    ]

    id2rel = list(rel2id.keys())

    origin4change = [rel2id[x] for x in ['root', 'dfsubj', 'sasubj']]

    signal_dct = {}
    for i, signals in enumerate(weak_signals):
        for s in signals:
            signal_dct[s] = weak_labels[i]

    from constant import rel2id, punct_lst

    rel_preds = rel_logit.argmax(-1)
    head_preds = arc_logit.argmax(-1)

    signals_new_whole = torch.Tensor()
    heads_new_whole, rels_new_whole = torch.Tensor(), torch.Tensor()
    seq_len = len(word_lst[1:])
    if seq_len == 0:
        return
    
    signals = torch.zeros(seq_len+1).int()
    heads, rels = torch.full(size=(seq_len+1,), fill_value=-2).int(), torch.zeros(seq_len+1).int()
    split, splits, signal  = 1, [], None
    for i, word in enumerate(word_lst[:-1]):

        if word in signal_dct.keys() and priority.get(signal_dct[word], 4) < priority.get(signal, 4):
            signal = signal_dct[word]
        if f'{word} {word_lst[i+1]}' in signal_dct.keys() and priority.get(signal_dct[f'{word} {word_lst[i+1]}'], 4) < priority.get(signal, 4):
            signal = signal_dct[f'{word} {word_lst[i+1]}']
        
        if word in punct_lst and word_lst[i+1] not in punct_lst:
            if signal is not None and i + 2 - split > 2:  # set 2 to the min length of edu
                signals[split:i+2] = signal
                signal = None
            splits.append(split)
            split = i + 2

    # add the last data
    if i + 1 < seq_len:
        word_lst.append(word)

    heads = head_preds[0]
    # heads.masked_fill_(mask=~masks_whole[sample_idx].bool(), value=-2)

    rels = rel_preds[0]
    # rels.masked_fill_(mask=~masks_whole[sample_idx].bool(), value=-2)

    if split > 1:
        splits.append(split)
        if signal is not None:
            signals[split:i+2] = signal
    
    splits.append(len(word_lst[1:]))
    # when num of 'edu' >= 2, try change rel and head
    # if len(splits) > 2:
    cnt = -1
    for idx, head in enumerate(heads[1:]):
        if head == -2:
            break
        if head == -1:
            continue

        if len(splits) > 2 and idx + 1 >= splits[cnt+1]:
            cnt += 1

        if ((len(splits) > 2 and (head < splits[cnt] or head >= splits[cnt+1])) or idx - head > 0) and rels[idx + 1] in origin4change:  # cross 'edu'
            if signals[idx+1] != 0:
                rels[idx+1] = signals[idx+1]
                
                if head == 0 and rels[idx + 1] in [rel2id['cond']]:  # reverse
                    tmp_heads = heads.clone()
                    tmp_heads[:splits[cnt+1]] = 0
                    head_idx = [idx + 1]
                    tail_idx = (tmp_heads == idx + 1).nonzero()  # find tail index
                    if len(tail_idx) == 0:  # ring or fail
                        # if 'cont', revsere root; else unchange
                        tail_idx = (heads == idx + 1).nonzero() if signals[idx+1] == rel2id['cont'] else [idx + 1]
                        head_idx = (heads == idx + 1).nonzero() if head_idx == tail_idx else head_idx
                    if len(head_idx) != 0:
                        heads[tail_idx[0]] = 0
                        heads[head_idx[0]] = tail_idx[0]
            elif head != 0:  # default
                rels[idx + 1] = rel2id['elbr']

            # special cases
            if len(splits) > 2 and word_lst[idx+1] == '好' and word_lst[idx] in ['你', '您']:  # reverse
                tmp_heads = heads.clone()
                tmp_heads[:splits[cnt+1]] = 0
                tail_idx = (tmp_heads == idx + 1).nonzero()  # find tail index
                if len(tail_idx) != 0:  
                    heads[tail_idx[0]] = 0
                    heads[idx + 1] = tail_idx[0]
                    rels[idx + 1] = rel2id['elbr']

            # 'attr' label should be reversed again; below can match most of cases
            if splits[cnt] == 1 and rels[idx + 1] in [rel2id['attr']]:
                tmp_heads = heads.clone()
                tmp_heads[:splits[cnt+1]] = 0
                tail_idx = ((tmp_heads != 0) * (tmp_heads >= splits[cnt]) * (tmp_heads < splits[cnt + 1])).nonzero().flatten()
                if len(tail_idx) != 0 and rels[tail_idx[0]] in origin4change:
                    heads[tail_idx[0]] = idx + 1
                    rels[tail_idx[0]] = rels[idx + 1]
                else:
                    dep_idx = heads[idx + 1].item()
                    if dep_idx != 0 and (dep_idx < splits[cnt] or dep_idx > splits[cnt+1]):
                        heads[idx + 1] = 0
                        heads[dep_idx] = idx + 1
                        rels[dep_idx] = rels[idx + 1]

    rels.masked_fill_(heads == 0, 0)  # root
    heads[0] = 0
    heads[1:].masked_fill_(heads[1:] == -2, 0)
    
    return heads, rels

In [28]:
def pre_annot_conll(in_file, out_file):
    fp = open(in_file, 'r', encoding='utf-8')
    fw = open(out_file, 'w', encoding='utf-8')
    data = json.load(fp)
    
    for d in tqdm(data):
        
        for turn, dialog in enumerate(d.get('dialog')):
            speaker = dialog.get('speaker')

            utterance = dialog.get('utterance')
            idcs, tokenized = tokenize(tokenizer, utterance)
            word_lst = utterance.split(' ')
            
            inputs, offsets = get_inputs(tokenized)
            inputs = {key:value.cuda() for key, value in inputs.items()}
            offsets = offsets.cuda()
            
            heads, rels = predict(parser, inputs, offsets)
            # print(heads.shape, rels.shape)
            for i in range(len(heads)):
                head = int(heads[i])
                tail = f"{turn}-{idcs[i]}"
                rel = id2rel[rels[i]]

                save_str = f'{i+1}\t{word_lst[i]}\t_\t_\t_\t_\t{head}\t{rel}\t_\t_\n'
                fw.write(save_str)

            fw.write('\n')
    fp.close()
    fw.close()

In [29]:
def pre_annot_conll_ensemble(in_file, out_file):
    fp = open(in_file, 'r', encoding='utf-8')
    fw = open(out_file, 'w', encoding='utf-8')
    data = json.load(fp)
    
    CFG.num_labels = 21
    parser_1 = DepParser(CFG)
    parser_1.load_state_dict(torch.load('../results/base_par.pt'))
    parser_1 = parser_1.cuda()

    for d in tqdm(data):
        
        for turn, dialog in enumerate(d.get('dialog')):
            speaker = dialog.get('speaker')

            utterance = dialog.get('utterance')
            idcs, tokenized = tokenize(tokenizer, utterance)
            word_lst = utterance.split(' ')
            
            inputs, offsets = get_inputs(tokenized)
            inputs = {key:value.cuda() for key, value in inputs.items()}
            offsets = offsets.cuda()
            
            with torch.no_grad():
                # arc_logit, rel_logit = parser(inputs=inputs, offsets=offsets, heads=None, tags=None, evaluate=True)
                arc_logit, rel_logit = parser(inputs=inputs, offsets=offsets, heads=None, rels=None, masks=None,evaluate=True)
            arc_logit[:, torch.arange(arc_logit.size()[1]), torch.arange(arc_logit.size()[2])] = -1e4
            arc_logit, rel_logit = arc_logit.softmax(-1), rel_logit.softmax(-1)

            with torch.no_grad():
                # arc_logit, rel_logit = parser(inputs=inputs, offsets=offsets, heads=None, tags=None, evaluate=True)
                arc_logit_another, rel_logit_another = parser_1(inputs=inputs, offsets=offsets, heads=None, rels=None, masks=None,evaluate=True)
            arc_logit_another[:, torch.arange(arc_logit_another.size()[1]), torch.arange(arc_logit_another.size()[2])] = -1e4
            arc_logit_another, rel_logit_another = arc_logit_another.softmax(-1), rel_logit_another.softmax(-1)

            # soft
            # arc_logit *= 0.8
            # arc_logit += 0.2 * arc_logit_another
            # rel_logit[:, :, :21] *= 0.8
            # rel_logit[:, :, :21] += 0.2 * rel_logit_another[:, :, :21]

            # hard
            arc_logit *= 0.8
            arc_logit += torch.where(arc_logit > arc_logit_another, 0.2 * arc_logit, 0.2 * arc_logit_another)
            rel_logit *= 0.8
            rel_logit[:, :, :21] += torch.where(rel_logit[:, :, :21] > rel_logit_another, 0.2 * rel_logit[:, :, :21], 0.2 * rel_logit_another)

            heads = arc_logit.argmax(-1)[0]
            rels = rel_logit.argmax(-1)[0]

            heads = heads[1:].tolist()
            rels = rels[1:].tolist()

            # print(heads.shape, rels.shape)
            for i in range(len(heads)):
                head = int(heads[i])
                tail = f"{turn}-{idcs[i]}"
                rel = id2rel[rels[i]]

                save_str = f'{i+1}\t{word_lst[i]}\t_\t_\t_\t_\t{head}\t{rel}\t_\t_\n'
                fw.write(save_str)

            fw.write('\n')
    fp.close()
    fw.close()

In [30]:
def pre_annot_via_tensor(in_file, out_file):
    fp = open(in_file, 'r', encoding='utf-8')
    fw = open(out_file, 'w', encoding='utf-8')
    data = json.load(fp)
    
    head_preds, rel_preds = torch.load('../preds/head_preds_ensemble_logit.pt'), torch.load('../preds/rel_preds_ensemble_logit.pt')
    for i, d in enumerate(tqdm(data)):
        
        for turn, dialog in enumerate(d.get('dialog')):
            speaker = dialog.get('speaker')

            utterance = dialog.get('utterance')
            idcs, tokenized = tokenize(tokenizer, utterance)
            word_lst = utterance.split(' ')
            
            inputs, offsets = get_inputs(tokenized)
            inputs = {key:value.cuda() for key, value in inputs.items()}
            offsets = offsets.cuda()
            
            # heads, rels = predict(parser, inputs, offsets)
            heads, rels = head_preds[i][1:].int(), rel_preds[i][1:].int()
            
            # print(heads.shape, rels.shape)
            for j in range(len(word_lst)):
                head = int(heads[j])
                # tail = f"{turn}-{idcs[i]}"
                rel = id2rel[rels[j]]

                save_str = f'{j+1}\t{word_lst[j]}\t_\t_\t_\t_\t{head}\t{rel}\t_\t_\n'
                fw.write(save_str)

            fw.write('\n')
    fp.close()
    fw.close()

In [31]:
pre_annot_conll(in_file='../data/train_proc.json', out_file='../aug/diag_train.conll')

 44%|████▎     | 3977/9101 [11:19<14:35,  5.85it/s]


KeyboardInterrupt: 