In [None]:
import torch
from torch import nn
from transformers import BertModel
import numpy as np
import random
from transformers import BertTokenizer
import math
from collections import Counter

import argparse
import os
import json
from torch.utils.tensorboard import SummaryWriter
import zipfile
from transformers import AdamW, get_linear_schedule_with_warmup
from convlab2.nlu.jointBERT.dataloader import Dataloader
from convlab2.nlu.jointBERT.jointBERT import JointBERT


In [None]:
class Dataloader:
    def __init__(self, intent_vocab, tag_vocab, pretrained_weights):
        self.tag_vocab = tag_vocab
        self.intent_dim = len(intent_vocab)
        self.tag_dim = len(tag_vocab)
        self.id2intent = dict([(i, x) for i, x in enumerate(intent_vocab)])
        self.intent2id = dict([(x, i) for i, x in enumerate(intent_vocab)])
        self.id2tag = dict([(i, x) for i, x in enumerate(tag_vocab)])
        self.tag2id = dict([(x, i) for i, x in enumerate(tag_vocab)])
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
        self.data = {}
        self.intent_weight = [1] * len(self.intent2id)

    def load_data(self, data, data_key, cut_sen_len, use_bert_tokenizer=True):
        """
        sample representation: [list of words, list of tags, list of intents, original dialog act]
        :param data_key: train/val/test
        :param data:
        :return:
        """
        self.data[data_key] = data
        max_sen_len, max_context_len = 0, 0
        sen_len = []
        context_len = []
        for d in self.data[data_key]:
            max_sen_len = max(max_sen_len, len(d[0]))#计算最大句子长度
            sen_len.append(len(d[0]))
            if cut_sen_len > 0:
                d[0] = d[0][:cut_sen_len]
                d[1] = d[1][:cut_sen_len]
                d[3] = [' '.join(s.split()[:cut_sen_len]) for s in d[3]]

            d[3] = self.tokenizer.encode('[CLS] ' + ' [SEP] '.join(d[3]))
            max_context_len = max(max_context_len, len(d[3]))
            context_len.append(len(d[3]))

            if use_bert_tokenizer:
                word_seq, tag_seq, new2ori = self.bert_tokenize(d[0], d[1])
            else:
                word_seq = d[0]
                tag_seq = d[1]
                new2ori = None
            d.append(new2ori)
            d.append(word_seq)
            d.append(self.seq_tag2id(tag_seq))



    def bert_tokenize(self, word_seq, tag_seq):
        split_tokens = []
        new_tag_seq = []
        new2ori = {}
        basic_tokens = self.tokenizer.basic_tokenizer.tokenize(' '.join(word_seq))
        accum = ''
        i, j = 0, 0
        for i, token in enumerate(basic_tokens):
            if (accum + token).lower() == word_seq[j].lower():
                accum = ''
            else:
                accum += token
            for sub_token in self.tokenizer.wordpiece_tokenizer.tokenize(basic_tokens[i]):
                new2ori[len(new_tag_seq)] = j
                split_tokens.append(sub_token)
                new_tag_seq.append(tag_seq[j])
            if accum == '':
                j += 1
        return split_tokens, new_tag_seq, new2ori

    def seq_tag2id(self, tags):
        return [self.tag2id[x] for x in tags if x in self.tag2id]

    def seq_id2tag(self, ids):
        return [self.id2tag[x] for x in ids]

    def seq_intent2id(self, intents):
        return [self.intent2id[x] for x in intents if x in self.intent2id]

    def seq_id2intent(self, ids):
        return [self.id2intent[x] for x in ids]

    def pad_batch(self, batch_data):
        batch_size = len(batch_data)
        max_seq_len  = max([len(x[0]) for x in batch_data]) + 2
        word_mask_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long)
        word_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long)
        tag_mask_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long)
        tag_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long)
        context_max_seq_len = max([len(x[3]) for x in batch_data])
        context_mask_tensor = torch.zeros((batch_size, context_max_seq_len), dtype=torch.long)
        context_seq_tensor = torch.zeros((batch_size, context_max_seq_len), dtype=torch.long)
        for i in range(batch_size):
            words = batch_data[i][-2]
            tags = batch_data[i][-1]
            words = ['[CLS]'] + words + ['[SEP]']
            indexed_tokens = self.tokenizer.convert_tokens_to_ids(words)
            sen_len = len(words)
            word_seq_tensor[i, :sen_len] = torch.LongTensor([indexed_tokens])
            tag_seq_tensor[i, 1:sen_len-1] = torch.LongTensor(tags)
            word_mask_tensor[i, :sen_len] = torch.LongTensor([1] * sen_len)
            tag_mask_tensor[i, 1:sen_len-1] = torch.LongTensor([1] * (sen_len-2))
            context_len = len(batch_data[i][3])

            context_seq_tensor[i, :context_len] = torch.LongTensor([batch_data[i][3]])
            context_mask_tensor[i, :context_len] = torch.LongTensor([1] * context_len)

        return word_seq_tensor, tag_seq_tensor,  word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor

    def get_train_batch(self, batch_size):
        batch_data = random.choices(self.data['train'], k=batch_size)
        return self.pad_batch(batch_data)

    def yield_batches(self, batch_size, data_key):
        batch_num = math.ceil(len(self.data[data_key]) / batch_size)
        for i in range(batch_num):
            batch_data = self.data[data_key][i * batch_size:(i + 1) * batch_size]
            yield self.pad_batch(batch_data), batch_data, len(batch_data)


In [None]:
class JointBERT(nn.Module):
    def __init__(self, model_config, device, slot_dim):
        super(JointBERT, self).__init__()
        self.slot_num_labels = slot_dim
        self.device = device
        #print(model_config['pretrained_weights'])
        self.bert = BertModel.from_pretrained(model_config['pretrained_weights'])
        self.dropout = nn.Dropout(model_config['dropout'])
        self.context = model_config['context']
        self.finetune = model_config['finetune']
        self.context_grad = model_config['context_grad']
        self.hidden_units = model_config['hidden_units']
        if self.hidden_units > 0:
            if self.context:
                self.slot_classifier = nn.Linear(self.hidden_units, self.slot_num_labels)
                self.slot_hidden = nn.Linear(2 * self.bert.config.hidden_size, self.hidden_units)
            else:
                self.slot_classifier = nn.Linear(self.hidden_units, self.slot_num_labels)
                self.slot_hidden = nn.Linear(self.bert.config.hidden_size, self.hidden_units)
            nn.init.xavier_uniform_(self.slot_hidden.weight)
        else:
            if self.context:
                self.slot_classifier = nn.Linear(2 * self.bert.config.hidden_size, self.slot_num_labels)
            else:
                self.slot_classifier = nn.Linear(self.bert.config.hidden_size, self.slot_num_labels)
        nn.init.xavier_uniform_(self.slot_classifier.weight)
        self.slot_loss_fct = torch.nn.CrossEntropyLoss()

    def forward(self, word_seq_tensor, word_mask_tensor, tag_seq_tensor=None, tag_mask_tensor=None,
             context_seq_tensor=None, context_mask_tensor=None):
        if not self.finetune:
            self.bert.eval()
            with torch.no_grad():
                outputs = self.bert(input_ids=word_seq_tensor,
                                    attention_mask=word_mask_tensor)
        else:
            outputs = self.bert(input_ids=word_seq_tensor,
                                attention_mask=word_mask_tensor)

        sequence_output = outputs[0]
        pooled_output = outputs[1]

        if self.context and (context_seq_tensor is not None):
            if not self.finetune or not self.context_grad:
                with torch.no_grad():
                    context_output = self.bert(input_ids=context_seq_tensor, attention_mask=context_mask_tensor)[1]
            else:
                context_output = self.bert(input_ids=context_seq_tensor, attention_mask=context_mask_tensor)[1]
            sequence_output = torch.cat(
                [context_output.unsqueeze(1).repeat(1, sequence_output.size(1), 1),
                 sequence_output], dim=-1)
            pooled_output = torch.cat([context_output, pooled_output], dim=-1)

        if self.hidden_units > 0:
            sequence_output = nn.functional.relu(self.slot_hidden(self.dropout(sequence_output)))
            # pooled_output = nn.functional.relu(self.intent_hidden(self.dropout(pooled_output)))

        sequence_output = self.dropout(sequence_output)
        slot_logits = self.slot_classifier(sequence_output)
        outputs = (slot_logits,)

        pooled_output = self.dropout(pooled_output)
        outputs= outputs
        if tag_seq_tensor is not None:
            active_tag_loss = tag_mask_tensor.view(-1) == 1
            active_tag_logits = slot_logits.view(-1, self.slot_num_labels)[active_tag_loss]
            active_tag_labels = tag_seq_tensor.view(-1)[active_tag_loss]
            slot_loss = self.slot_loss_fct(active_tag_logits, active_tag_labels)
            outputs = outputs + (slot_loss,)
        return outputs


In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
config_path = './crosswoz_all_context.json'
config = json.load(open(config_path))
data_dir = config['data_dir']
output_dir = config['output_dir']
log_dir = config['log_dir']
DEVICE = config['DEVICE']

set_seed(config['seed'])

if 'crosswoz' in data_dir:
    print('-' * 20 + 'dataset:crosswoz' + '-' * 20)
    from convlab2.nlu.jointBERT.crosswoz.postprocess import is_slot_da, calculateF1, recover_intent

intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json'),encoding="utf-8"))
tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json'),encoding="utf-8"))
dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab,
                        pretrained_weights=config['model']['pretrained_weights'])
print('tag num:', len(tag_vocab))
for data_key in ['train', 'val', 'test']:
    dataloader.load_data(json.load(open(os.path.join(data_dir, '{}_data.json'.format(data_key)),encoding="utf-8")), data_key,
                         cut_sen_len=config['cut_sen_len'], use_bert_tokenizer=config['use_bert_tokenizer'])
    print('{} set size: {}'.format(data_key, len(dataloader.data[data_key])))

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

writer = SummaryWriter(log_dir)


model = JointBERT(config['model'], DEVICE, dataloader.tag_dim )
model.to(DEVICE)

if config['model']['finetune']:
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if
                    not any(nd in n for nd in no_decay) and p.requires_grad],
         'weight_decay': config['model']['weight_decay']},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
         'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=config['model']['learning_rate'],
                      eps=config['model']['adam_epsilon'])
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config['model']['warmup_steps'],
                                                num_training_steps=config['model']['max_step'])
else:
    for n, p in model.named_parameters():
        if 'bert' in n:
            p.requires_grad = False
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                 lr=config['model']['learning_rate'])

for name, param in model.named_parameters():
    print(name, param.shape, param.device, param.requires_grad)

max_step = config['model']['max_step']
check_step = config['model']['check_step']
batch_size = config['model']['batch_size']
model.zero_grad()
train_slot_loss=0
best_val_f1 = 0.

writer.add_text('config', json.dumps(config))

for step in range(1, max_step + 1):
    model.train()
    batched_data = dataloader.get_train_batch(batch_size)

    batched_data = tuple(t.to(DEVICE) for t in batched_data)
    word_seq_tensor, tag_seq_tensor, word_mask_tensor, tag_mask_tensor,context_seq_tensor, context_mask_tensor = batched_data
    if not config['model']['context']:
        context_seq_tensor, context_mask_tensor = None, None
    _, slot_loss = model.forward(word_seq_tensor,
                                                 word_mask_tensor,
                                                 tag_seq_tensor,
                                                 tag_mask_tensor,
                                                context_seq_tensor,
                                                context_mask_tensor )

    train_slot_loss += slot_loss.item()
    loss = slot_loss
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    if config['model']['finetune']:
        scheduler.step()  # Update learning rate schedule
    model.zero_grad()
    if step % check_step == 0:
        train_slot_loss = train_slot_loss / check_step
        print('[%d|%d] step' % (step, max_step))
        print('\t slot loss:', train_slot_loss)
        predict_golden = { 'slot': [], 'overall': []}
        val_slot_loss= 0
        model.eval()
        for pad_batch, ori_batch, real_batch_size in dataloader.yield_batches(batch_size, data_key='val'):
            pad_batch = tuple(t.to(DEVICE) for t in pad_batch)
            word_seq_tensor, tag_seq_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch
            if not config['model']['context']:
                context_seq_tensor, context_mask_tensor = None, None

            with torch.no_grad():
                slot_logits,  slot_loss = model.forward(word_seq_tensor,  word_mask_tensor,
                                                                            tag_seq_tensor,
                                                                            tag_mask_tensor,
                                                                            context_seq_tensor,
                                                                            context_mask_tensor)
            val_slot_loss += slot_loss.item() * real_batch_size
            for j in range(real_batch_size):
                predicts = recover_intent(dataloader, slot_logits[j], tag_mask_tensor[j],
                                          ori_batch[j][0], ori_batch[j][1])
                labels = ori_batch[j][2]
                print('labels',labels)

                predict_golden['overall'].append({
                    'predict': predicts,
                    'golden': labels
                })
                predict_golden['slot'].append({
                    'predict': [x for x in predicts if is_slot_da(x)],
                    'golden': [x for x in labels if is_slot_da(x)]
                })
        for j in range(10):
            writer.add_text('val_sample_{}'.format(j),
                            json.dumps(predict_golden['overall'][j], indent=2, ensure_ascii=False),
                            global_step=step)

        total = len(dataloader.data['val'])
        val_slot_loss /= total
        print('%d samples val' % total)
        print('\t slot loss:', val_slot_loss)

        writer.add_scalar('slot_loss/train', train_slot_loss, global_step=step)
        writer.add_scalar('slot_loss/val', val_slot_loss, global_step=step)

        for x in ['slot', 'overall']:
            precision, recall, F1 = calculateF1(predict_golden[x])
            print('-' * 20 + x + '-' * 20)
            print('\t Precision: %.2f' % (100 * precision))
            print('\t Recall: %.2f' % (100 * recall))
            print('\t F1: %.2f' % (100 * F1))

            writer.add_scalar('val_{}/precision'.format(x), precision, global_step=step)
            writer.add_scalar('val_{}/recall'.format(x), recall, global_step=step)
            writer.add_scalar('val_{}/F1'.format(x), F1, global_step=step)

        if F1 > best_val_f1:
            best_val_f1 = F1
            torch.save(model.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
            print('best val F1 %.4f' % best_val_f1)
            print('save on', output_dir)

        train_slot_loss= 0

writer.add_text('val overall F1', '%.2f' % (100 * best_val_f1))
writer.close()

model_path = os.path.join(output_dir, 'pytorch_model.bin')
zip_path = config['zipped_model_path']
print('zip model to', zip_path)

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
    zf.write(model_path)
