In [1]:
import os
from collections import Counter
from typing import *
import random
import json

In [2]:
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import KFold

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.cuda.amp import autocast, GradScaler

from transformers import AutoConfig, AutoModel, AutoTokenizer, AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup

In [3]:
from trainer import BasicTrainer
from model import DepParserTwostage
from utils import arc_rel_loss_split, uas_las_split

## Config

In [4]:
class CFG:
    data_file = '/root/diag_dep/data_new/1to200_0923.json'
    plm = 'hfl/chinese-electra-180g-large-discriminator'
    num_folds = 5
    trn_folds = [0, 1, 2, 3, 4]
    random_seed = 42
    num_epochs = 15
    batch_size = 1  # can not work if batch size != 1
    lr = 2e-5
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 1
    scheduler = 'linear'
    warmup_ratio = 0.12
    num_early_stop = 5
    max_length = 70
    max_turns = 56
    num_labels = 40
    hidden_size = 200
    print_every = 1e9
    eval_every = 100
    cuda = True
    fp16 = True

## Seed and Device

In [5]:
def seed_everything(seed=CFG.random_seed):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

seed_everything()

In [6]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')

Using device: cuda


## Data

In [7]:
rel_dct = {
    'root': '根节点',
    'sasubj-obj': '同主同宾',
    'sasubj': '同主语',
    'dfsubj': '不同主语',
    'subj': '主语',
    'subj-in': '内部主语',
    'obj': '宾语',
    'pred': '谓语',
    'att': '定语',
    'adv': '状语',
    'cmp': '补语',
    'coo': '并列',
    'pobj': '介宾',
    'iobj': '间宾',
    'de': '的',
    'adjct': '附加',
    'app': '称呼',
    'exp': '解释',
    'punc': '标点',
    'frag': '片段',
    'repet': '重复',
    # rst
    'attr': '归属',
    'bckg': '背景',
    'cause': '因果',
    'comp': '比较',
    'cond': '状况',
    'cont': '对比',
    'elbr': '阐述',
    'enbm': '目的',
    'eval': '评价',
    'expl': '解释-例证',
    'joint': '联合',
    'manner': '方式',
    'rstm': '重申',
    'temp': '时序',
    'tp-chg': '主题变更',
    'prob-sol': '问题-解决',
    'qst-ans': '疑问-回答',
    'stm-rsp': '陈述-回应',
    'req-proc': '需求-处理',
}

In [8]:
rst_lst = [
    'attr','bckg','cause','comp','cond','cont','elbr','enbm','eval','expl','joint',
    'manner','rstm','temp','tp-chg','prob-sol','qst-ans','stm-rsp','req-proc',
]

In [9]:
rel2id = {key:idx for idx, key in enumerate(rel_dct.keys())}
print(rel2id)

{'root': 0, 'sasubj-obj': 1, 'sasubj': 2, 'dfsubj': 3, 'subj': 4, 'subj-in': 5, 'obj': 6, 'pred': 7, 'att': 8, 'adv': 9, 'cmp': 10, 'coo': 11, 'pobj': 12, 'iobj': 13, 'de': 14, 'adjct': 15, 'app': 16, 'exp': 17, 'punc': 18, 'frag': 19, 'repet': 20, 'attr': 21, 'bckg': 22, 'cause': 23, 'comp': 24, 'cond': 25, 'cont': 26, 'elbr': 27, 'enbm': 28, 'eval': 29, 'expl': 30, 'joint': 31, 'manner': 32, 'rstm': 33, 'temp': 34, 'tp-chg': 35, 'prob-sol': 36, 'qst-ans': 37, 'stm-rsp': 38, 'req-proc': 39}


In [10]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
print(len(tokenizer))
 
num_added_toks = tokenizer.add_tokens(['[root]', '[qst]', '[aws]'], special_tokens=True)
tokenizer.root_token = '[root]'
tokenizer.root_token_ids = tokenizer('[root]')['input_ids'][1]
print(f"add token: {tokenizer.root_token} {tokenizer.root_token_ids}")

tokenizer.root_token = '[qst]'
tokenizer.root_token_ids = tokenizer('[qst]')['input_ids'][1]
print(f"add token: {tokenizer.root_token} {tokenizer.root_token_ids}")

tokenizer.root_token = '[ans]'
tokenizer.root_token_ids = tokenizer('[ans]')['input_ids'][1]
print(f"add token: {tokenizer.root_token} {tokenizer.root_token_ids}")
print(len(tokenizer))

CFG.tokenizer = tokenizer

21128
add token: [root] 21128
add token: [qst] 21129
add token: [ans] 138
21131


In [11]:
with open(CFG.data_file, 'r', encoding='utf-8') as f:
    data = json.load(f)[:200]  # have annotated 200 data

sent_lst = []
max_len, trun_cnt = 0, 0
for d in data:
    word_lst = []
    for item in d['dialog']:
        word_lst.extend(item['utterance'].split(' '))
    if len(word_lst) > max_len:
        max_len = len(word_lst)
    if len(word_lst) > 450:
        trun_cnt += 1
    sent_lst.append(len(word_lst))

print(trun_cnt)
print(max_len)
print(np.mean(sent_lst))

1
626
209.725


In [12]:
class Dependency():
    def __init__(self, idx, word, head, rel):
        self.id = idx
        self.word = word
        self.tag = '_'
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.idx), self.word, "_", self.tag, "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.id}, {self.word}, {self.tag}, {self.head}, {self.rel})"

In [13]:
def load_annoted(data_file):
    max_turns = 0
    with open(CFG.data_file, 'r', encoding='utf-8') as f:
        data = json.load(f)[:200]  # have annotated 200 data
        
    sample_lst:List[List[Dependency]] = []
    
    for d in data:
        rel_dct = {}
        for tripple in d['relationship']:
            head, rel, tail = tripple
            head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
            tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
            
            if rel == 'root' and head_uttr_idx != 0: # ignore root
                continue
                 
            if not rel_dct.get(tail_uttr_idx, None):
                rel_dct[tail_uttr_idx] = {tail_word_idx: [head, rel]}
            else:
                rel_dct[tail_uttr_idx][tail_word_idx] = [head, rel]
                
        sent_lens_accum = [0]
        for i, item in enumerate(d['dialog']):
            utterance = item['utterance']
            sent_lens_accum.append(sent_lens_accum[i] + len(utterance.split(' ')) + 1)
        
        one_sample:List[List[Dependency]] = []
        for item in d['dialog']:
            turn = item['turn']
            if turn > max_turns:
                max_turns = turn
            utterance = item['utterance']
            # dep_lst:List[Dependency] = [Dependency(0, '[root]', -1, '_')]
            
            role = '[ans]' if item['speaker'] == 'A' else '[qst]'
            dep_lst = [Dependency(0, role, -1, '_')]
            
            for word_idx, word in enumerate(utterance.split(' ')):
                tail2head = rel_dct.get(turn, {1: [f'{turn}-{word_idx}', 'adjct']})
                head, rel = tail2head.get(word_idx + 1, [f'{turn}-{word_idx}', 'adjct'])  # some word annoted missed, padded with last word and 'adjct'
                head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
                
                if head_uttr_idx == turn:
                    dep_lst.append(Dependency(word_idx + 1, word, head_word_idx, rel)) 
                else:
                    dep_lst.append(Dependency(word_idx + 1, word, sent_lens_accum[head_uttr_idx] + head_word_idx, rel))  # add with accumulated length
            one_sample.append(dep_lst)
         
        sample_lst.append(one_sample)
    
    print(max_turns)
    return sample_lst

In [14]:
# data = load_annoted(CFG.data_file)
# data[0]

In [15]:
class DialogDataset(Dataset):
    def __init__(self, cfg):
        self.cfg = cfg
        self.inputs, self.offsets, self.heads, self.rels, self.masks = self.read_data()
        
    def read_data(self):
        inputs, offsets = [], []
        tags, heads, rels, masks = [], [], [], []
        
        for deps_lst in load_annoted(self.cfg.data_file):
            head_tokens = np.zeros((self.cfg.max_turns, self.cfg.max_length), dtype=np.int64)  # same as root index is 0, constrainting by mask 
            rel_tokens = np.zeros((self.cfg.max_turns, self.cfg.max_length), dtype=np.int64)
            mask_tokens = np.zeros((self.cfg.max_turns, self.cfg.max_length), dtype=np.int64)
            
            inputs_one = []
            sentence_word_ids = []
            
            for idx, deps in enumerate(deps_lst):
                if idx + 1 == self.cfg.max_turns:
                    break
                seq_len = len(deps)

                word_lst = [] 
                for i, dep in enumerate(deps):
                    if i == seq_len or i + 1== self.cfg.max_length:
                        break

                    word_lst.append(dep.word)

                    if int(dep.head) == -1 or int(dep.head) >= self.cfg.max_length:
                        head_tokens[idx, i+1] = 0
                        mask_tokens[idx, i+1] = 0
                    else:
                        head_tokens[idx, i+1] = int(dep.head)
                        mask_tokens[idx, i+1] = 1

                    rel_tokens[idx, i+1] = rel2id.get(dep.rel, 0)

                tokenized = tokenizer.encode_plus(word_lst, 
                                                  padding='max_length', 
                                                  truncation=True,
                                                  max_length=self.cfg.max_length, 
                                                  return_offsets_mapping=True, 
                                                  return_tensors='pt',
                                                  is_split_into_words=True)
                inputs_one.append({"input_ids": tokenized['input_ids'][0],
                                   "token_type_ids": tokenized['token_type_ids'][0],
                                   "attention_mask": tokenized['attention_mask'][0]
                              })

                sentence_word_idx = []
                for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                    if start == 0 and end != 0:
                        sentence_word_idx.append(idx)

                if len(sentence_word_idx) < self.cfg.max_length - 1:
                    sentence_word_idx.extend([0]* (self.cfg.max_length - 1 - len(sentence_word_idx)))
                    
                sentence_word_ids.append(sentence_word_idx)
            
            inputs_merge = {"input_ids": torch.Tensor(), "token_type_ids": torch.Tensor(), "attention_mask": torch.Tensor()}
            for tokenized in inputs_one:
                for key, value in tokenized.items():
                    inputs_merge[key] = torch.cat([inputs_merge[key], value.unsqueeze(0)], dim=0)
            
            # padding
            inputs_merge = {key: torch.cat([value, torch.zeros(CFG.max_turns - value.size()[0], CFG.max_length)], dim=0).int() for key, value in inputs_merge.items()}
            sentence_word_ids = torch.as_tensor(sentence_word_ids)
            sentence_word_ids = torch.cat([sentence_word_ids, torch.zeros(CFG.max_turns - sentence_word_ids.size()[0], CFG.max_length - 1)], dim=0).long()
                    
            inputs.append(inputs_merge)        
            offsets.append(sentence_word_ids)
            heads.append(head_tokens)
            rels.append(rel_tokens)
            masks.append(mask_tokens)
                   
        return inputs, offsets, heads, rels, masks

    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.heads[idx], self.rels[idx], self.masks[idx]
    
    def __len__(self):
        return len(self.rels)

In [16]:
dataset = DialogDataset(CFG)

56


In [17]:
dataset[0][0]['input_ids'].shape

torch.Size([56, 70])

In [18]:
dataset[0][1].shape

torch.Size([56, 69])

In [19]:
dataset[0][2].shape

(56, 70)

In [20]:
dataset[0][3].shape

(56, 70)

In [21]:
dataset[0][4].shape

(56, 70)

In [22]:
data_iter = DataLoader(dataset, batch_size = 1)

In [23]:
for item in data_iter:
    print(item[0]['input_ids'].shape)
    break

torch.Size([1, 56, 70])


## Tranining

In [24]:
kfold = KFold(n_splits=CFG.num_folds, shuffle=True, random_state=CFG.random_seed)

In [25]:
for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset)):
    if fold not in CFG.trn_folds:
        continue
    print(f'FOLD {fold}')
    print(f'{len(train_ids)}/{len(val_ids)}')
    print('--------------------------------')

    if CFG.cuda and torch.cuda.is_available:
        torch.cuda.empty_cache()

    random.shuffle(train_ids)
    random.shuffle(val_ids)

    tr_dataset = Subset(dataset, train_ids)
    va_dataset = Subset(dataset, val_ids)
    
    tr_iter = DataLoader(tr_dataset, batch_size=CFG.batch_size)
    va_iter = DataLoader(va_dataset, batch_size=CFG.batch_size)
    
    model = DepParserTwostage(CFG)
    
    optim = AdamW(model.parameters(), 
                      lr=CFG.lr,
                      weight_decay=CFG.weight_decay
                      )

    training_step = int(CFG.num_epochs * (len(train_ids) / CFG.batch_size))
    warmup_step = int(CFG.warmup_ratio * training_step)  
    lr_scheduler = get_linear_schedule_with_warmup(optimizer=optim, 
                                                        num_warmup_steps=warmup_step, 
                                                        num_training_steps=training_step)

    trainer = BasicTrainer(optim=optim, 
                        lr_scheduler=lr_scheduler,
                        trainset_size=len(train_ids), 
                        loss_fn=arc_rel_loss_split, 
                        metrics_fn=uas_las_split, 
                        config=CFG)
    
    best_res, best_state_dict = trainer.train(model=model, train_iter=tr_iter, val_iter=va_iter)
    print(best_res)
    with open("/root/autodl-tmp/diag_dep/k-fold/res.txt", 'a+') as f:
        f.write(f'{fold}\t {str(best_res)}\n')
    
    torch.save(best_state_dict, f"/root/autodl-tmp/diag_dep/k-fold/{fold}/model.bin")

FOLD 0
160/40
--------------------------------


KeyboardInterrupt: 

In [None]:
os.system('shutdown')