In [1]:
from typing import *
import random
import json
from itertools import chain

from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader

2022-12-13 14:23:34.217697: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [2]:
import sys
sys.path.append('..')

from model.par_with_attr import ParWithAttr
from model.base_par import DepParser
from utils import arc_rel_loss, uas_las, to_cuda

# Config

In [3]:
class CFG:
    train_file = '../data/train_proc.json'
    dev_file = '../data/val_proc.json'
    # data_file = '../aug/diag_codt/diag_train.conll'
    plm = 'hfl/chinese-electra-180g-base-discriminator'
    random_seed = 42
    num_epochs = 10
    batch_size = 256
    plm_lr = 2e-5
    head_lr = 1e-4
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 2
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 3
    max_length = 128
    hidden_size = 400
    num_labels = 35
    # print_every = 400
    # eval_every = 800
    print_every = 2000
    eval_every = 4000
    cuda = True
    fp16 = True

# Count

In [4]:
rel_dct = {
    'root': '根节点',
    'sasubj-obj': '同主同宾',
    'sasubj': '同主语',
    'dfsubj': '不同主语',
    'subj': '主语',
    'subj-in': '内部主语',
    'obj': '宾语',
    'pred': '谓语',
    'att': '定语',
    'adv': '状语',
    'cmp': '补语',
    'coo': '并列',
    'pobj': '介宾',
    'iobj': '间宾',
    'de': '的',
    'adjct': '附加',
    'app': '称呼',
    'exp': '解释',
    'punc': '标点',
    'frag': '片段',
    'repet': '重复',
    # rst
    'attr': '归属',
    'bckg': '背景',
    'cause': '因果',
    'comp': '比较',
    'cond': '状况',
    'cont': '对比',
    'elbr': '阐述',
    'enbm': '目的',
    'eval': '评价',
    'expl': '解释-例证',
    'joint': '联合',
    'manner': '方式',
    'rstm': '重申',
    'temp': '时序',
    # 'tp-chg': '主题变更',
    # 'prob-sol': '问题-解决',
    # 'qst-ans': '疑问-回答',
    # 'stm-rsp': '陈述-回应',
    # 'req-proc': '需求-处理',
}

In [5]:
rel2id = {}
for i, (key, value) in enumerate(rel_dct.items()):
    rel2id[key] = i
print(rel2id)
print(len(rel2id))

{'root': 0, 'sasubj-obj': 1, 'sasubj': 2, 'dfsubj': 3, 'subj': 4, 'subj-in': 5, 'obj': 6, 'pred': 7, 'att': 8, 'adv': 9, 'cmp': 10, 'coo': 11, 'pobj': 12, 'iobj': 13, 'de': 14, 'adjct': 15, 'app': 16, 'exp': 17, 'punc': 18, 'frag': 19, 'repet': 20, 'attr': 21, 'bckg': 22, 'cause': 23, 'comp': 24, 'cond': 25, 'cont': 26, 'elbr': 27, 'enbm': 28, 'eval': 29, 'expl': 30, 'joint': 31, 'manner': 32, 'rstm': 33, 'temp': 34}
35


In [6]:
id2rel = list(rel_dct.keys())
print(id2rel)

['root', 'sasubj-obj', 'sasubj', 'dfsubj', 'subj', 'subj-in', 'obj', 'pred', 'att', 'adv', 'cmp', 'coo', 'pobj', 'iobj', 'de', 'adjct', 'app', 'exp', 'punc', 'frag', 'repet', 'attr', 'bckg', 'cause', 'comp', 'cond', 'cont', 'elbr', 'enbm', 'eval', 'expl', 'joint', 'manner', 'rstm', 'temp']


# Seed & Device

In [7]:
def seed_everything(seed=CFG.random_seed):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

seed_everything()

In [8]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')

Using device: cuda


# Data

In [9]:
class Dependency():
    def __init__(self, idx, word, head, rel):
        self.id = idx
        self.word = word
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.id), self.word, "_", "_", "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.word}, {self.head}, {self.rel})"

In [10]:
def load_codt(data_file: str):
    # id, form, tag, head, rel
#     sentence:List[Dependency] = [Dependency('0', '<root>', '_', '0', '_')]
    sentence:List[Dependency] = []
    
    with open(data_file, 'r', encoding='utf-8') as f:
        # data example: 1	上海	_	NR	NR	_	2	nn	_	_
        for line in f.readlines():
            toks = line.split()
            if len(toks) == 0:
                yield sentence
#                 sentence = [Dependency('0', '<root>', '_', '0', '_')]
                sentence = []
            elif len(toks) == 10:
                dep = Dependency(toks[0], toks[1], toks[3], toks[6], toks[7])
                sentence.append(dep)

def load_codt_with_aug(data_file: str):
    # id, form, tag, head, rel
#     sentence:List[Dependency] = [Dependency('0', '<root>', '_', '0', '_')]
    sentence:List[Dependency] = []

    f1 = open(data_file, 'r', encoding='utf-8')
    f2 = open('../aug/codt/codt_train_fixed.conll')
    for line in chain(f1.readlines(), f2.readlines()):
        toks = line.split()
        if len(toks) == 0:
            yield sentence
#                 sentence = [Dependency('0', '<root>', '_', '0', '_')]
            sentence = []
        elif len(toks) == 10:
            dep = Dependency(toks[0], toks[1], toks[3], toks[6], toks[7])
            sentence.append(dep)

    f1.close()
    f2.close()

In [11]:
def load_annoted(data_file):
    with open(data_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    sample_lst:List[List[Dependency]] = []
    
    for d in data:
        rel_dct = {}
        for tripple in d['relationship']:
            head, rel, tail = tripple
            head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
            tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
            if head_uttr_idx != tail_uttr_idx:
                continue
            
            if not rel_dct.get(head_uttr_idx, None):
                rel_dct[head_uttr_idx] = {tail_word_idx: [head_word_idx, rel]}
            else:
                rel_dct[head_uttr_idx][tail_word_idx] = [head_word_idx, rel]
            
        for item in d['dialog']:
            turn = item['turn']
            utterance = item['utterance']
            # dep_lst:List[Dependency] = [Dependency(0, '[root]', -1, '_')]
            dep_lst:List[Dependency] = []
            
            for word_idx, word in enumerate(utterance.split(' ')):
                # some word annoted missed, padded with last word and 'adjct'
                head_word_idx, rel = rel_dct.get(turn, {word_idx + 1: [word_idx, 'adjct']}).get(word_idx + 1, [word_idx, 'adjct'])  
                dep_lst.append(Dependency(word_idx + 1, word, head_word_idx, rel))  # start from 1
            
            sample_lst.append(dep_lst)
        
    return sample_lst

In [12]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
CFG.tokenizer = tokenizer

In [13]:
data = [d for d in load_annoted(CFG.train_file)]
print(len(data))

236590


In [14]:
class CODTDataset(Dataset):
    def __init__(self, cfg, data, train):
        self.train = train
        self.cfg = cfg
        self.inputs, self.offsets, self.heads, self.rels, self.masks = self.read_data(data)
        
    def read_data(self, data):
        inputs, offsets = [], []
        tags, heads, rels, masks = [], [], [], []
        
        file = self.cfg.train_file if self.train else self.cfg.dev_file
        
        for deps in tqdm(data):
            # another sentence
            seq_len = len(deps)
            
            word_lst = [] 
            rel_attr = {'input_ids':torch.Tensor(), 'token_type_ids':torch.Tensor(), 'attention_mask':torch.Tensor()}
            head_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)  # same as root index is 0, constrainting by mask 
            rel_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            mask_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            for i, dep in enumerate(deps):
                if i == seq_len or i + 1== self.cfg.max_length:
                    break
                    
                word_lst.append(dep.word)
                    
                if dep.head in ['_', '-1'] or int(dep.head) + 1 >= self.cfg.max_length:
                    head_tokens[i+1] = 0
                    mask_tokens[i+1] = 0
                else:
                    head_tokens[i+1] = int(dep.head)
                    mask_tokens[i+1] = 1

                if self.train:
                    rel_tokens[i+1] = rel2id[dep.rel]
                else:
                    rel_tokens[i+1] = rel2id.get(dep.rel, rel2id['adjct'])

            tokenized = tokenizer.encode_plus(word_lst, 
                                                padding='max_length', 
                                                truncation=True,
                                                max_length=self.cfg.max_length, 
                                                return_offsets_mapping=True, 
                                                return_tensors='pt',
                                                is_split_into_words=True)
            inputs.append({"input_ids": tokenized['input_ids'][0],
                            "token_type_ids": tokenized['token_type_ids'][0],
                            "attention_mask": tokenized['attention_mask'][0]
                            })
            
#                 sentence_word_idx = np.zeros(self.cfg.max_length, dtype=np.int64)
            sentence_word_idx = []
            for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                if start == 0 and end != 0:
                    sentence_word_idx.append(idx)
#                         sentence_word_idx[idx] = idx
            if len(sentence_word_idx) < self.cfg.max_length - 1:
                sentence_word_idx.extend([0]* (self.cfg.max_length - 1 - len(sentence_word_idx)))
            offsets.append(torch.as_tensor(sentence_word_idx))
            
            heads.append(head_tokens)
            rels.append(rel_tokens)
            masks.append(mask_tokens)
                    
        return inputs, offsets, heads, rels, masks

    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.heads[idx], self.rels[idx], self.masks[idx]
    
    def __len__(self):
        return len(self.rels)

In [15]:
dataset = CODTDataset(CFG, data, train=True)

100%|██████████| 236590/236590 [06:08<00:00, 642.88it/s]


# Model

In [16]:
model = DepParser(CFG)
print(model.load_state_dict(torch.load('../results/few_shot/with_codt/model.bin')))

model = model.cuda()

Some weights of the model checkpoint at hfl/chinese-electra-180g-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>


# Prediction

In [17]:
dataloader = DataLoader(dataset, batch_size=CFG.batch_size)

arc_logits, rel_logits = torch.Tensor(), torch.Tensor()
heads_whole, rels_whole, masks_whole = torch.Tensor(), torch.Tensor(), torch.Tensor()
sampled_ids, accum = torch.Tensor(), 0
for batch in tqdm(dataloader):
    inputs, offsets, heads, rels, masks = batch
    
    inputs_cuda = {}
    for key, value in inputs.items():
        inputs_cuda[key] = value.cuda()
    inputs = inputs_cuda

    offsets, heads, rels, masks = to_cuda(data=(offsets, heads, rels, masks))
    
    with torch.no_grad():
        model.eval()
        arc_logit, rel_logit = model(inputs, offsets, masks=masks, heads=None, rels=None, evaluate=True)
        # arc_logit, rel_logit, similarity = model.predict(inputs, offsets, masks)
    
    arc_logit = arc_logit.softmax(-1)
    rel_logit = rel_logit.softmax(-1)
    
    arc_logit_masked = arc_logit * masks.unsqueeze(-1).expand(-1, -1, arc_logit.size()[-1])
    rel_logit_masked = rel_logit * masks.unsqueeze(-1).expand(-1, -1, rel_logit.size()[-1])
    
    threshold = .98

    arc_sampled = (arc_logit_masked.max(-1)[0].sum(-1) / masks.sum(-1)) > threshold
    rel_sampled = (rel_logit_masked.max(-1)[0].sum(-1) / masks.sum(-1)) > threshold
    sampled = arc_sampled * rel_sampled
    sampled_idx = sampled.nonzero().squeeze(1)
    
    arc_logit = arc_logit[sampled_idx, :, :]
    rel_logit = rel_logit[sampled_idx, :, :]
    masks = masks[sampled_idx, :]
    
    arc_logit[:, torch.arange(arc_logit.size()[1]), torch.arange(arc_logit.size()[2])] = -1e4
    
    sampled_ids = torch.cat([sampled_ids, sampled_idx.cpu().int() + accum])
    accum += CFG.batch_size
        
    arc_logits = torch.cat([arc_logits, arc_logit.cpu()])
    rel_logits = torch.cat([rel_logits, rel_logit.cpu()])
    masks_whole = torch.cat([masks_whole, masks.cpu()])

 63%|██████▎   | 585/925 [08:49<08:09,  1.44s/it]

: 

: 

In [None]:
arc_logits.shape

torch.Size([75521, 128, 128])

In [None]:
sampled_ids = sampled_ids.int()

In [None]:
arc_preds = arc_logits.argmax(-1)
rel_preds = rel_logits.argmax(-1)

# Output

In [None]:
fw = open('../diag_train_sampled.conll', 'w+', encoding='utf-8')

output = []
cnt = 0
for i, d in enumerate(data):
    if i in sampled_ids:
        for idx, dep in enumerate(d):
            head = arc_preds[cnt][idx+1].item()
            rel = id2rel[rel_preds[cnt][idx+1].item()]
            
            save_str = f'{dep.id}\t{dep.word}\t_\t_\t_\t_\t{head}\t{rel}\t_\t_\n'
            fw.write(save_str)
        
        fw.write('\n')
        cnt += 1
        # break
        
fw.close()