In [4]:
import os
from collections import Counter
from typing import *
import random
import json

In [5]:
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import KFold

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.cuda.amp import autocast, GradScaler

from transformers import AutoConfig, AutoModel, AutoTokenizer, AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup

In [25]:
import sys
sys.path.append('..')

from utils import to_cuda, arc_rel_loss, uas_las
from constant import punct_lst

## Seed

In [7]:
def seed_everything(seed=42):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

seed_everything()

## CFG

In [8]:
class CFG:
    data_file = 'data/seg_with_train.json'
    plm = 'hfl/chinese-electra-180g-large-discriminator'
    num_folds = 5
    trn_folds = [0]  # only one fold, as splitting train/val randomly
    random_seed = 42
    num_epochs = 3
    batch_size = 128
    max_length = 160
    num_labels = 2
    lr = 2e-5
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 1
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 5
    hidden_size = 400
    print_every = 500  
    eval_every = 1000
    cuda = True
    fp16 = True
    debug = False

In [11]:
data_file = '../data_testset/1to800_1108.json'

In [12]:
with open(data_file, 'r', encoding='utf-8') as f:
    data = json.load(f)

In [13]:
class Dependency():
    def __init__(self, idx, word, head, rel):
        self.id = idx
        self.word = word
        self.tag = '_'
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.idx), self.word, "_", self.tag, "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.word}, {self.tag}, {self.head}, {self.rel})"

In [14]:
rst_dct = {
    'attr': '归属',
    'bckg': '背景',
    'cause': '因果',
    'comp': '比较',
    'cond': '状况',
    'cont': '对比',
    'elbr': '阐述',
    'enbm': '目的',
    'eval': '评价',
    'expl': '解释-例证',
    'joint': '联合',
    'manner': '方式',
    'rstm': '重申',
    'temp': '时序',
    'tp-chg': '主题变更',
    'prob-sol': '问题-解决',
    'qst-ans': '疑问-回答',
    'stm-rsp': '陈述-回应',
    'req-proc': '需求-处理',
}

rst_lst = [x for x in rst_dct.keys()]
print(rst_lst)

['attr', 'bckg', 'cause', 'comp', 'cond', 'cont', 'elbr', 'enbm', 'eval', 'expl', 'joint', 'manner', 'rstm', 'temp', 'tp-chg', 'prob-sol', 'qst-ans', 'stm-rsp', 'req-proc']


In [15]:
# debug

sample_lst:List[List[Dependency]] = []
split_ids_lst = []
a, b = 278, 1

for i, d in enumerate(data):
    # if i != a - 1:
    #     continue
        
    rel_dct = {}
    for tripple in d['relationship']:
        head, rel, tail = tripple
        head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
        tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
        if head_uttr_idx != tail_uttr_idx:
            continue

        if not rel_dct.get(head_uttr_idx, None):
            rel_dct[head_uttr_idx] = {tail_word_idx: [head_word_idx, rel]}
        else:
            rel_dct[head_uttr_idx][tail_word_idx] = [head_word_idx, rel]
    
    
    for item in d['dialog']:
        turn = item['turn']
        # if turn != b - 1:
        #     continue
        utterance = item['utterance']
        # print(utterance)
        dep_lst:List[Dependency] = []
        
        edus = [[-1, -1]]  # 0 for root
        inner_rels, inter_rels = [], []
        for word_idx, word in enumerate(utterance.split(' ')):
            head_word_idx, rel = rel_dct[turn].get(word_idx + 1, [word_idx, 'adjct'])  # some word annoted missed, padded with last word and 'adjct'
            
            if rel in rst_lst:
                inter_rels.append(rel)
                edus.append([word_idx + 1, word_idx + 1])
                continue
            
            expand, include = False, False
            for edu in edus:
                if edu[0] <= head_word_idx <= edu[1]:
                    edu[0] = min(edu[0], word_idx + 1)
                    edu[1] = max(edu[1], word_idx + 1)
                    expand = True
                    break
                elif edu[0] <= word_idx + 1 <= edu[1] and head_word_idx != 0:
                    edu[0] = min(edu[0], head_word_idx)
                    edu[1] = max(edu[1], head_word_idx)
                    include = True
                    break
                    
            if not expand and not include:
                if head_word_idx == 0:  # ignore root
                    edus.append([word_idx + 1, word_idx + 1])
                else:
                    # max with 1 to ignore root index
                    edus.append([max(min(word_idx + 1, head_word_idx), 1), max(word_idx + 1, head_word_idx)])
            inner_rels.append(rel)
            
            dep_lst.append(Dependency(word_idx + 1, word, head_word_idx, rel))  # start from 1
        
        # print(edus)
        
        finals = [edus[1]]
        if len(edus) >= 2:
            for edu in edus[2:]:
                if finals[-1][0] <= edu[0] <= edu[1] <= finals[-1][1]:
                    continue
                finals.append(edu)
                
        split_ids = [x[1] for x in finals]
        split_ids_lst.append(split_ids)
#         print(split_ids)      
        
#         print(finals)
#         print(dep_lst)
        
        sample_lst.append(dep_lst)

In [16]:
sample_lst[0]

[(有, _, 0, root),
 (什么, _, 3, att),
 (问题, _, 1, obj),
 (我, _, 6, subj),
 (可以, _, 6, adv),
 (您, _, 6, obj),
 (处理, _, 6, sasubj),
 (或, _, 10, adjct),
 (解决, _, 8, sasubj-obj),
 (呢, _, 6, adjct),
 (?, _, 11, punc)]

In [17]:
split_ids_lst[0]

[3, 12]

## Data

In [18]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
print(len(tokenizer))
 
num_added_toks = tokenizer.add_tokens(['[root]', '[qst]', '[ans]'], special_tokens=True)
tokenizer.root_token = '[root]'
tokenizer.root_token_ids = tokenizer('[root]')['input_ids'][1]
print(f"add token: {tokenizer.root_token} {tokenizer.root_token_ids}")

tokenizer.qst_token = '[qst]'
tokenizer.qst_token_ids = tokenizer('[qst]')['input_ids'][1]
print(f"add token: {tokenizer.qst_token} {tokenizer.qst_token_ids}")

tokenizer.ans_token = '[ans]'
tokenizer.ans_token_ids = tokenizer('[ans]')['input_ids'][1]
print(f"add token: {tokenizer.ans_token} {tokenizer.ans_token_ids}")
print(len(tokenizer))

CFG.tokenizer = tokenizer

21128
add token: [root] 21128
add token: [qst] 21129
add token: [ans] 21130
21131


In [19]:
def data_gen(data):
    special_tokens = {
        "Q": "[qst]",
        "A": "[ans]",
    }    

    for i, d in enumerate(data):
        rel_dct = {}
        for tripple in d['relationship']:
            head, rel, tail = tripple
            head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
            tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
            if head_uttr_idx != tail_uttr_idx:
                continue

            if not rel_dct.get(head_uttr_idx, None):
                rel_dct[head_uttr_idx] = {tail_word_idx: [head_word_idx, rel]}
            else:
                rel_dct[head_uttr_idx][tail_word_idx] = [head_word_idx, rel]


        for item in d['dialog']:
            turn = item['turn']
            speaker = item['speaker']
            utterance = item['utterance']
            dep_lst:List[Dependency] = []

            edus = [[-1, -1]]  # 0 for root
            inner_rels, inter_rels = [], []
            for word_idx, word in enumerate(utterance.split(' ')):
                head_word_idx, rel = rel_dct[turn].get(word_idx + 1, [word_idx, 'adjct'])  # some word annoted missed, padded with last word and 'adjct'

                if rel in rst_lst:
                    inter_rels.append(rel)
                    edus.append([word_idx + 1, word_idx + 1])
                    continue

                expand, include = False, False
                for edu in edus:
                    if edu[0] <= head_word_idx <= edu[1]:
                        edu[0] = min(edu[0], word_idx + 1)
                        edu[1] = max(edu[1], word_idx + 1)
                        expand = True
                        break
                    elif edu[0] <= word_idx + 1 <= edu[1] and head_word_idx != 0:
                        edu[0] = min(edu[0], head_word_idx)
                        edu[1] = max(edu[1], head_word_idx)
                        include = True
                        break

                if not expand and not include:
                    if head_word_idx == 0:  # ignore root
                        edus.append([word_idx + 1, word_idx + 1])
                    else:
                        # max with 1 to ignore root index
                        edus.append([max(min(word_idx + 1, head_word_idx), 1), max(word_idx + 1, head_word_idx)])
                inner_rels.append(rel)

                dep_lst.append(Dependency(word_idx + 1, word, head_word_idx, rel))  # start from 1

            finals = [edus[1]]
            if len(edus) >= 2:
                for edu in edus[2:]:
                    if finals[-1][0] <= edu[0] <= edu[1] <= finals[-1][1]:
                        continue
                    finals.append(edu)

            split_ids = [x[1] for x in finals]

            yield [special_tokens[speaker]] + utterance.split(' '), split_ids

In [20]:
class EduDataset(Dataset):
    def __init__(self, cfg, data):
        self.cfg = cfg
        self.data = data
        self.inputs, self.offsets, self.tags = self.load_data()
    
    def load_data(self):
        inputs, offsets, tags = [], [], []
            
        for word_lst, split_ids in tqdm(data_gen(self.data)):
            tokenized = self.cfg.tokenizer.encode_plus(word_lst, 
                                                       padding='max_length', 
                                                       truncation=True,
                                                       max_length=self.cfg.max_length + 2,  # reserved for cls and sep
                                                       return_offsets_mapping=True, 
                                                       return_tensors='pt',
                                                       is_split_into_words=True)
            
            inputs.append({"input_ids": tokenized['input_ids'][0],
                           "token_type_ids": tokenized['token_type_ids'][0],
                           "attention_mask": tokenized['attention_mask'][0]
                          })
            
            sentence_word_idx = []
            for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                if start == 0 and end != 0:
                    sentence_word_idx.append(idx)
            if len(sentence_word_idx) < self.cfg.max_length - 1:
                sentence_word_idx.extend([0]* (self.cfg.max_length - len(sentence_word_idx)))
            offsets.append(torch.as_tensor(sentence_word_idx))
            
            # ignore cls for convenience
            tag = torch.full(size=(CFG.max_length, ), fill_value=-1, dtype=torch.long)
            tag[0:split_ids[-1]+1] = 0
            tag[split_ids[:-1]] = 1
            
            tags.append(tag)  
            
        return inputs, offsets, tags
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.tags[idx]
    
    def __len__(self):
        return len(self.inputs)

In [21]:
dataset = EduDataset(CFG, data)

eval_iter = DataLoader(dataset, shuffle=False, batch_size=CFG.batch_size)

20086it [00:38, 525.78it/s]


In [28]:
def splitby_punct():
    eval_iter = DataLoader(dataset, shuffle=False, batch_size=1)

    punct_lst = ["，", ".", "。", "!", "?", "~", "...", "......"]
    
    preds_whole, tags_whole = torch.Tensor(), torch.Tensor()
    for batch, (word_lst, _) in zip(eval_iter, data_gen(data)):
        preds = torch.zeros_like(batch[-1]).int()
        for i, word in enumerate(word_lst[:-1]):
            if word in punct_lst and word_lst[i+1] not in punct_lst:
                preds[0][i] = 1

        preds_whole = torch.cat([preds_whole, preds], dim=0)
        tags_whole = torch.cat([tags_whole, batch[-1]], dim=0)

    tags_whole = tags_whole.masked_fill(tags_whole==-1, 0)
    
    # inner split points f1
    predictions = preds_whole.view(-1)
    labels = tags_whole.view(-1)

    pred_detected = (predictions == 1)
    label_detected = (labels == 1)
    co = pred_detected * (pred_detected == label_detected)
    co_sum = sum(co)

    if sum(pred_detected) == 0:
        precision = 0
    else:
        precision = co_sum / sum(pred_detected)
    recall = co_sum / sum(label_detected)
    f1 = (2 * precision * recall) / (precision + recall)

    print(f1)
    
    # single f1
    # tags_whole = tags_whole.masked_fill(tags_whole==-1, 0)
    labels_single = (tags_whole.sum(1) == 0)
    preds_single = (preds_whole.sum(1) == 0)

    co = preds_single * (preds_single == labels_single)
    co_sum = sum(co)

    if sum(preds_single) == 0:
        precision = 0
    else:
        precision = co_sum / sum(preds_single)
    recall = co_sum / sum(labels_single)
    f1 = (2 * precision * recall) / (precision + recall)

    print(f1)

In [29]:
# splitby_punct()
# assert False

tensor(0.6882)
tensor(0.9010)


AssertionError: 

In [None]:
class SegModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        
        self.cfg = cfg
        
        self.encoder = AutoModel.from_pretrained(cfg.plm)
        self.encoder.resize_token_embeddings(len(cfg.tokenizer))
        
        self.mlp = nn.Linear(self.encoder.config.hidden_size, cfg.num_labels)
        self.dropout = nn.Dropout(cfg.dropout)
        
    def feat(self, inputs, offsets):
        length = torch.sum(inputs["attention_mask"], dim=-1) - 2
        
        feats, *_ = self.encoder(**inputs, return_dict=False)   # batch_size, seq_len (tokenized), plm_hidden_size
           
        # remove [CLS] [SEP]
        cls_feat = feats[:, :1]
        char_feat = torch.narrow(feats, 1, 1, feats.size(1) - 2)
        return cls_feat, char_feat, length
        
    def forward(self, inputs, offsets, tags=None):
        cls_feat, char_feat, char_len = self.feat(inputs, offsets)
        
        word_idx = offsets.unsqueeze(-1).expand(-1, -1, char_feat.shape[-1])  # expand to the size of char feat
        word_feat = torch.gather(char_feat, dim=1, index=word_idx)  # embeddings of first char in each word
        
        feats = self.dropout(word_feat)
        logits = self.mlp(feats)
        
        if tags is not None:
            loss = nn.CrossEntropyLoss(reduction='mean', ignore_index=-1)(logits.view(-1, logits.size(-1)), tags.long().view(-1))
            return torch.softmax(logits, dim=-1), loss

        return torch.softmax(logits, dim=-1)

In [None]:
# def metrics_fn(predictions, labels):
#     predictions = predictions.view(-1, predictions.size(-1))
#     labels = labels.view(-1)
    
#     pred_detected = (predictions[:, 1] > 0.5)
#     label_detected = (labels == 1)
#     co = pred_detected * (pred_detected == label_detected)
    
#     if sum(pred_detected) == 0:
#         precision = 0
#     else:
#         precision = sum(co) / sum(pred_detected)
#     recall = sum(co) / sum(label_detected)
#     f1 = (2 * precision * recall) / (precision + recall)

#     return {"f1": f1.item()}

In [None]:
def metrics_fn(predictions, labels):
    left = predictions.detach()
    right = predictions.detach()
    
    predictions = predictions.view(-1, predictions.size(-1))
    labels = labels.view(-1)
    
    pred_detected = (predictions[:, 1] > 0.5)
    label_detected = (labels == 1)
    co = pred_detected * (pred_detected == label_detected)
    co_sum = sum(co)
    
    tolerance = 3
    for i in range(tolerance):
        left = torch.cat([left, torch.zeros(left.size(0), 1, left.size(2)).int()], dim=1)[:, 1:, :]
        right = torch.cat([torch.zeros(left.size(0), 1, left.size(2)).int(), right], dim=1)[:, :-1, :]
        
        left_flat = left.reshape(-1, left.size(-1))
        right_flat = right.reshape(-1, right.size(-1))
        
        left_detected = (left_flat[:, 1] > 0.5)
        right_detected = (right_flat[:, 1] > 0.5)
        
        co_sum += sum(left_detected * (left_detected == label_detected))
        co_sum += sum(right_detected * (right_detected == label_detected))
        
    if sum(pred_detected) == 0:
        precision = 0
    else:
        precision = co_sum / sum(pred_detected)
    recall = co_sum / sum(label_detected)
    f1 = (2 * precision * recall) / (precision + recall)

    return {"f1": f1.item()}

In [None]:
def single_f1(predictions, labels):
    print(labels)
    labels = labels.masked_fill(labels==-1, 0)
    print(labels)
    labels_single = (labels.sum(1) == 0)
    preds_single = ((predictions[:, :, 1] > 0.5).sum(1) == 0)
    print(labels_single)
    print((predictions[:, :, 1] > 0.5).sum(1))
    
    co = preds_single * (preds_single == labels_single)
    co_sum = sum(co)
    
    if sum(preds_single) == 0:
        precision = 0
    else:
        precision = co_sum / sum(preds_single)
    recall = co_sum / sum(labels_single)
    f1 = (2 * precision * recall) / (precision + recall)
    
    return {"f1": f1.item()}

In [None]:
# single_f1(predictions=logits_whole, labels=tags_whole)

In [None]:
# metrics_fn(predictions=logits_whole, labels=tags_whole)

In [None]:
model = SegModel(CFG)
model.load_state_dict(torch.load(f'/root/autodl-tmp/diag_dep/edu_seg/0/model.bin'))
model.to('cuda' if CFG.cuda else 'cpu')
model.eval()

avg_loss, step = 0.0, 0
logits_whole, tags_whole = torch.Tensor(), torch.Tensor()
for step, batch in tqdm(enumerate(eval_iter)):
    inputs, offsets, tags = batch

    if CFG.cuda and torch.cuda.is_available():
        inputs_cuda = {}
        for key,value in inputs.items():
            inputs_cuda[key] = value.cuda()
        inputs = inputs_cuda
        offsets, tags = to_cuda(data=(offsets, tags))

    with torch.no_grad():
        logits, loss = model(inputs, offsets, tags)

    logits_whole = torch.cat([logits_whole, logits.cpu()], dim=0)
    tags_whole = torch.cat([tags_whole, tags.cpu()], dim=0)

    avg_loss += loss * len(tags)  # times the batch size of data

metrics = metrics_fn(predictions=logits_whole, labels=tags_whole)

avg_loss /= len(eval_iter.dataset)
print("--Evaluation:")
print("-Loss: {}  F1: {} \n".format(avg_loss, metrics['f1']))

Some weights of the model checkpoint at hfl/chinese-electra-180g-large-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.weight']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
98it [00:56,  1.75it/s]


--Evaluation:
-Loss: 0.22451341152191162  F1: 0.14755763113498688 

