In [1]:
import os
import logging
from tqdm import tqdm, trange
import torch.nn as nn

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import BertConfig, AdamW, get_linear_schedule_with_warmup

from utils import MODEL_CLASSES, compute_metrics, get_intent_labels, get_slot_labels, compute_metrics_multi_intent,compute_metrics_multi_intent_Pro,compute_metrics_final

from model.modeling_final import JointBERTMultiIntent
from seqeval.metrics.sequence_labeling import get_entities

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Trainer_multi(object):
    def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None):
        self.args = args
        self.train_dataset = train_dataset
        self.dev_dataset = dev_dataset
        self.test_dataset = test_dataset
        
        self.slot_preds_list = None
        self.intent_token_preds_list = None
        
        
        # set of intents
        self.intent_label_lst = get_intent_labels(args)
        # set of slots
        self.slot_label_lst = get_slot_labels(args)

        self.num_intent_labels = len(self.intent_label_lst)
        self.num_slot_labels = len(self.slot_label_lst)

        # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
        self.pad_token_label_id = args.ignore_index

        # self.config_class, self.model_class, _ = MODEL_CLASSES[args.model_type]
        #
        # self.config = self.config_class.from_pretrained(args.model_name_or_path, finetuning_task=args.task)

        # self.model = self.model_class.from_pretrained(args.model_name_or_path,
        #                                               config=self.config,
        #                                               )

        self.model = JointBERTMultiIntent()
        # print(self.model)
        
        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        self.model.to(self.device)

    def get_intent_token_loss(self,intent_token_logits,intent_token_ids,attention_mask):
        intent_token_loss = 0.0
        intent_token_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
        if attention_mask is not None:
            active_intent_loss = attention_mask.view(-1) == 1
            active_intent_logits = intent_token_logits.view(-1, self.num_intent_labels)[active_intent_loss]
            active_intent_tokens = intent_token_ids.view(-1)[active_intent_loss]
            intent_token_loss = intent_token_loss_fct(active_intent_logits, active_intent_tokens)
        else:
            intent_token_loss = intent_token_loss_fct(intent_token_logits.view(-1, self.num_intent_labels), intent_token_ids.view(-1))

        return self.args.slot_loss_coef * intent_token_loss

    def get_slot_loss(self,slot_logits,slot_labels_ids,attention_mask):
        slot_loss = 0.0
        slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
        # Only keep active parts of the loss
        if attention_mask is not None:
            try:
                active_loss = attention_mask.view(-1) == 1
                attention_mask_cpu = attention_mask.data.cpu().numpy()
                active_loss_cpu = active_loss.data.cpu().numpy()
                active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
                active_labels = slot_labels_ids.view(-1)[active_loss]
                slot_loss = slot_loss_fct(active_logits, active_labels)
            except:
                print('attention_mask: ', attention_mask_cpu)
                print('active_loss: ', active_loss_cpu)
                logger.info('attention_mask: ', attention_mask_cpu)
                logger.info('active_loss: ', active_loss_cpu)
        else:
            slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))

        return self.args.slot_loss_coef * slot_loss

    def get_referee_token_loss(self,referee_token_logits,referee_labels_ids,attention_mask,pro_sample_mask):
        referee_token_loss = 0.0
        class_weights = torch.FloatTensor([1,10,300]).to('cuda')
        referee_token_loss_fct = nn.CrossEntropyLoss(weight = class_weights,ignore_index=self.args.ignore_index) #self.referee_token_loss_fct

        if attention_mask is not None:
            try:
                active_loss = attention_mask[pro_sample_mask].view(-1) == 1
                attention_mask_cpu = attention_mask.data.cpu().numpy()
                active_loss_cpu = active_loss.data.cpu().numpy()
                active_logits = referee_token_logits.view(-1, 3)[active_loss]
                active_labels = referee_labels_ids[pro_sample_mask].view(-1)[active_loss]
                referee_token_loss = referee_token_loss_fct(active_logits,
                                                            active_labels)
            except:
                logger.info('attention_mask: ', attention_mask_cpu)
                logger.info('active_loss: ', active_loss_cpu)
        else:
            referee_token_loss = referee_token_loss_fct(referee_token_logits.view(-1, 3), referee_labels_ids[pro_sample_mask].view(-1))

        return self.args.pro_loss_coef * referee_token_loss



    def train(self):
        
        #logger.info(vars(self.args))
        train_sampler = RandomSampler(self.train_dataset)
        train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size=self.args.train_batch_size)

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            self.args.num_train_epochs = self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
        else:
            t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.args.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(self.train_dataset))
        logger.info("  Num Epochs = %d", self.args.num_train_epochs)
        logger.info("  Total train batch size = %d", self.args.train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)
        logger.info("  Logging steps = %d", self.args.logging_steps)
        logger.info("  Save steps = %d", self.args.save_steps)

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()

        train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch")

        step_per_epoch = len(train_dataloader) // 2

        # record the evaluation loss
        eval_acc = 0.0
        MAX_RECORD = self.args.patience
        num_eval = -1
        eval_result_record = (num_eval, eval_acc)
        flag = False
        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=FLAG)
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)  # GPU or CPU

                attention_mask = batch[1]
                intent_label_ids = batch[3]
                slot_labels_ids =  batch[4]
                intent_token_ids =  batch[5]
                B_tag_mask =  batch[6]
                BI_tag_mask =  batch[7]
                tag_intent_label =  batch[8]
                referee_labels_ids =  batch[9] #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                pro_labels_ids = batch[10]

                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'pro_labels_ids' : batch[10]}
                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]


                outputs = self.model(**inputs)


                logits = outputs[0]

                if self.args.pro and self.args.intent_seq:
                    slot_logits, intent_token_logits, referee_token_logits,all_referee_token_logits = logits

                slot_loss = self.get_slot_loss(slot_logits,slot_labels_ids,attention_mask)
                intent_token_loss =  self.get_intent_token_loss(intent_token_logits,intent_token_ids,attention_mask)

                pro_token_mask = pro_labels_ids > 0
                pro_sample_mask = torch.max(pro_token_mask.long(),dim = 1)[0] > 0
                referee_token_loss = self.get_referee_token_loss(referee_token_logits,referee_labels_ids,attention_mask,pro_sample_mask)
                loss = slot_loss + intent_token_loss + referee_token_loss





                if self.args.gradient_accumulation_steps > 1:
                    loss = loss / self.args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                if (step + 1) % self.args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    self.model.zero_grad()
                    global_step += 1

                    # if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
                    if self.args.logging_steps > 0 and global_step % step_per_epoch == 0:
                        logger.info("***** Training Step %d *****", step)
                        logger.info("  total_loss = %f", loss)
                        logger.info("  slot_loss = %f", slot_loss)
                        logger.info("  intent_token_loss = %f", intent_token_loss)
                        logger.info("  referee_token_loss = %f", referee_token_loss) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

                        dev_result = self.evaluate("dev")
                        # test_result = self.evaluate("test")
                        num_eval += 1
                        if self.args.patience != 0:
                            if dev_result['sementic_frame_acc'] + dev_result['intent_acc'] + dev_result['slot_f1']   > eval_result_record[1]:
                                self.save_model()
                                eval_result_record = (num_eval, dev_result['sementic_frame_acc'] + dev_result['intent_acc'] + dev_result['slot_f1'] )
                            else:
                                cur_num_eval = eval_result_record[0]
                                if num_eval - cur_num_eval >= MAX_RECORD:
                                    # it has been ok
                                    logger.info(' EARLY STOP Evaluate: at {}, best eval {} intent_slot_acc: {} '.format(num_eval, cur_num_eval, eval_result_record[1]))
                                    flag = True
                                    break
                        else:
                            self.save_model()

                            

                if 0 < self.args.max_steps < global_step:
                    epoch_iterator.close()
                    break
            print('referee_token_loss: ',referee_token_loss.item(),' slot_loss: ',slot_loss.item(), ' intent_token_loss: ',intent_token_loss.item())
            if flag:
                train_iterator.close()
                break

            if 0 < self.args.max_steps < global_step:
                train_iterator.close()
                break

        return global_step, tr_loss / global_step

    def evaluate(self, mode):
        if mode == 'test':
            dataset = self.test_dataset
        elif mode == 'dev':
            dataset = self.dev_dataset
        else:
            raise Exception("Only dev and test dataset available")

        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size)

        # Eval!
        logger.info("***** Running evaluation on %s dataset *****", mode)
        logger.info("  Num examples = %d", len(dataset))
        logger.info("  Batch size = %d", self.args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        intent_preds = None
        slot_preds = None
        intent_token_preds = None
        out_intent_label_ids = None
        out_slot_labels_ids = None
        out_intent_token_ids = None
        
        tag_intent_preds = None
        out_tag_intent_ids = None
        referee_preds = None #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        out_referee_labels_ids = None #!!!!!!!!!!!!!!!!!!!!!!!!!
        all_referee_preds = None #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        all_out_referee_labels_ids = None #!!!!!!!!!!!!!!!!!!!!!!!!!

        self.model.eval()

        for batch in tqdm(eval_dataloader, desc="Evaluating", disable=FLAG):
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                attention_mask = batch[1]
                intent_label_ids = batch[3]
                slot_labels_ids =  batch[4]
                intent_token_ids =  batch[5]
                B_tag_mask =  batch[6]
                BI_tag_mask =  batch[7]
                tag_intent_label =  batch[8]
                referee_labels_ids =  batch[9] #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                pro_labels_ids = batch[10]

                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'pro_labels_ids' : batch[10]}


                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]
                outputs = self.model(**inputs)


            logits = outputs[0]
            if self.args.pro and self.args.intent_seq:
                slot_logits, intent_token_logits, referee_token_logits,all_referee_token_logits = logits

            slot_loss = self.get_slot_loss(slot_logits,slot_labels_ids,attention_mask)
            intent_token_loss =  self.get_intent_token_loss(intent_token_logits,intent_token_ids,attention_mask)

            pro_token_mask = pro_labels_ids > 0
            pro_sample_mask = torch.max(pro_token_mask.long(),dim = 1)[0] > 0
            referee_token_loss = self.get_referee_token_loss(referee_token_logits,referee_labels_ids,attention_mask,pro_sample_mask)
            loss = slot_loss + intent_token_loss + referee_token_loss


            # ============================= Slot prediction ==============================
            if slot_preds is None:
                if self.args.use_crf:
                    # decode() in `torchcrf` returns list with best index directly
                    slot_preds = np.array(self.model.crf.decode(slot_logits))
                else:
                    slot_preds = slot_logits.detach().cpu().numpy()

                out_slot_labels_ids = slot_labels_ids.detach().cpu().numpy()
            else:
                if self.args.use_crf:
                    slot_preds = np.append(slot_preds, np.array(self.model.crf.decode(slot_logits)), axis=0)
                else:
                    slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)

                out_slot_labels_ids = np.append(out_slot_labels_ids, slot_labels_ids.detach().cpu().numpy(), axis=0)

            # ============================= Pronoun referee prediction ==============================
            if self.args.pro:
                if all_referee_preds is None:
                    all_referee_preds = all_referee_token_logits.detach().cpu().numpy()
                    referee_preds = referee_token_logits.detach().cpu().numpy()

                    pro_sample_mask_np = (torch.max(inputs["pro_labels_ids"],dim = 1)[0] > 0).detach().cpu().numpy()
                    all_out_referee_labels_ids = referee_labels_ids.detach().cpu().numpy()
                    out_referee_labels_ids = all_out_referee_labels_ids[pro_sample_mask_np]#np.array([ele for i,ele in enumerate(all_out_referee_labels_ids) if pro_sample_mask_np[i] != False])

                else:
                    all_referee_preds = np.append(all_referee_preds,all_referee_token_logits.detach().cpu().numpy(), axis = 0)
                    referee_preds = np.append(referee_preds, referee_token_logits.detach().cpu().numpy(), axis = 0)

                    pro_sample_mask_np = (torch.max(inputs["pro_labels_ids"],dim = 1)[0] > 0).detach().cpu().numpy()
                    new_all_out_referee_labels_ids = referee_labels_ids.detach().cpu().numpy()
                    all_out_referee_labels_ids = np.append(all_out_referee_labels_ids,new_all_out_referee_labels_ids,axis = 0)
                    new_out_referee_labels_ids = new_all_out_referee_labels_ids[pro_sample_mask_np]#np.array([ele for i,ele in enumerate(new_all_out_referee_labels_ids) if pro_sample_mask_np[i] != False])
                    out_referee_labels_ids = np.append(out_referee_labels_ids, new_out_referee_labels_ids, axis = 0)



            # print('referee_preds shape: ',referee_preds.shape)
            # print('all_referee_preds shape: ', all_referee_preds.shape)
            # print('out_referee_labels_ids shape: ',out_referee_labels_ids.shape)
            # print('all_out_referee_labels_ids shape: ', all_out_referee_labels_ids.shape)
            # ============================== Intent Token Seq =============================
            if self.args.intent_seq:
                if intent_token_preds is None:
                    if self.args.use_crf:
                        intent_token_preds = np.array(self.model.crf.decode(intent_token_logits))
                    else:
                        intent_token_preds = intent_token_logits.detach().cpu().numpy()

                    out_intent_token_ids = intent_token_ids.detach().cpu().numpy()
                else:
                    if self.args.use_crf:
                        intent_token_preds = np.append(intent_token_preds, np.array(self.model.crf.decode(intent_token_logits)), axis=0)
                    else:
                        intent_token_preds = np.append(intent_token_preds, intent_token_logits.detach().cpu().numpy(), axis=0)

                    out_intent_token_ids = np.append(out_intent_token_ids, intent_token_ids.detach().cpu().numpy(), axis=0)


            eval_loss += loss.item()
            nb_eval_steps += 1
            eval_loss = eval_loss / nb_eval_steps
            results = {
                "loss": eval_loss
            }
        nb_eval_steps += 1


        if not self.args.use_crf:
            slot_preds = np.argmax(slot_preds, axis=2)
        slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
        out_slot_label_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
        slot_preds_list = [[] for _ in range(out_slot_labels_ids.shape[0])]

        
        # generate mask
        for i in range(out_slot_labels_ids.shape[0]):
            # record the padding position
            pos_offset = [0 for _ in range(out_slot_labels_ids.shape[1])]
            pos_cnt = 0
            padding_recording = [0 for _ in range(out_slot_labels_ids.shape[1])]
            
            for j in range(out_slot_labels_ids.shape[1]):
                if out_slot_labels_ids[i, j] != self.pad_token_label_id:
                    out_slot_label_list[i].append(slot_label_map[out_slot_labels_ids[i][j]])
                    slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
                    pos_offset[pos_cnt+1] = pos_offset[pos_cnt]
                    pos_cnt += 1
                else:
                    pos_offset[pos_cnt] = pos_offset[pos_cnt] + 1
                    padding_recording[j] = 1
                    



        # ============================= Pronoun Referee Prediction ============================ !!!!!!!!!!!!!!!!!!!!!!!!!!!!
        # print('referee_preds shape: ',referee_preds.shape)
        # print('all_referee_preds shape: ', all_referee_preds.shape)
        # print('out_referee_labels_ids shape: ',out_referee_labels_ids.shape)
        print('all_out_referee_labels_ids shape: ', all_out_referee_labels_ids.shape)
        print('out_intent_token_ids shape: ',out_intent_token_ids.shape)

        referee_token_map = {0:'PAD', 1:'O' ,2: 'B-referee'} # All referee are just one word in EGPSR


        if self.args.pro:
            referee_preds = np.argmax(referee_preds, axis=2)
            all_referee_preds = np.argmax(all_referee_preds, axis=2)


            referee_preds_list = [[] for _ in range(out_referee_labels_ids.shape[0])]
            out_referee_label_list = [[] for _ in range(out_referee_labels_ids.shape[0])]
            all_referee_preds_list = [[] for _ in range(all_out_referee_labels_ids.shape[0])]
            all_out_referee_label_list = [[] for _ in range(all_out_referee_labels_ids.shape[0])]



            for i in range(out_referee_labels_ids.shape[0]):
                for j in range(out_referee_labels_ids.shape[1]):
                    if out_referee_labels_ids[i, j] != self.pad_token_label_id:
                        out_referee_label_list[i].append(referee_token_map[out_referee_labels_ids[i][j]])
                        referee_preds_list[i].append(referee_token_map[referee_preds[i][j]])

            for i in range(all_out_referee_labels_ids.shape[0]):
                for j in range(all_out_referee_labels_ids.shape[1]):
                    if all_out_referee_labels_ids[i, j] != self.pad_token_label_id:
                        all_out_referee_label_list[i].append(referee_token_map[all_out_referee_labels_ids[i][j]])
                        all_referee_preds_list[i].append(referee_token_map[all_referee_preds[i][j]])



        intent_token_map = {i: label for i, label in enumerate(self.intent_label_lst)}
        out_intent_token_list = None
        intent_token_preds_list = None
        # ============================= Intent Seq Prediction ============================
        if self.args.intent_seq:
            if not self.args.use_crf:
                intent_token_preds = np.argmax(intent_token_preds, axis=2)
            out_intent_token_list = [[] for _ in range(out_intent_token_ids.shape[0])]
            intent_token_preds_list = [[] for _ in range(out_intent_token_ids.shape[0])]

            for i in range(out_intent_token_ids.shape[0]):
                for j in range(out_intent_token_ids.shape[1]):
                    if out_intent_token_ids[i, j] != self.pad_token_label_id:
                        out_intent_token_list[i].append(intent_token_map[out_intent_token_ids[i][j]])
                        intent_token_preds_list[i].append(intent_token_map[intent_token_preds[i][j]])




        total_result = compute_metrics_final(
                                       slot_preds_list,
                                       out_slot_label_list,
                                       intent_token_preds_list,
                                       out_intent_token_list,
                                       referee_preds_list,
                                       out_referee_label_list
                                      )
        results.update(total_result)

        # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Pronoun Acc !!!!!!!!!!!!!!!!!!!!!!!!

        correct = 0
        for ref_pred_seq,ref_label_seq in zip(referee_preds_list,out_referee_label_list):
            if ref_pred_seq == ref_label_seq:
                correct += 1
            # else:
            #     print('ref_pred_seq: ',ref_pred_seq,'\n')
            #     print('ref_label_seq: ',ref_label_seq,'\n')
        ref_acc = correct/len(referee_preds_list)
        print('Pronoun Accurac: ',ref_acc, ', correct: ',correct, ', total: ',len(referee_preds_list))
        results.update({'Pronoun Accuracy':ref_acc})

        #-----------------------------------------------------

        # tp = 0
        # fn = 0
        # fp = 0
        # tn = 0
        # for sample_idx,has_pro in enumerate(pro_sample_mask_np):
        #     if has_pro:
        #         if referee_token_preds_list[sample_idx] == out_referee_token_list[sample_idx]:
        #             tp += 1
        #         elif all(referee_token_preds_list[sample_idx] == 'O'): #referee not detected
        #             fn += 1
        #         else:
        #             fp += 1 #detected the wrong referee
        #
        # ref_acc = tp/sum(pro_sample_mask_np)
        # results.update({'Pronoun Accuracy':ref_acc})

        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

        logger.info("***** Eval results *****")
        for key in sorted(results.keys()):
            logger.info("  %s_%s = %s", mode, key, str(results[key]))
        
        #self.store_pred(slot_preds_list,intent_token_preds_list)
        self.slot_preds_list = slot_preds_list
        self.intent_token_preds_list = intent_token_preds_list
        return results

    def save_model(self):
        # Save model checkpoint (Overwrite)
        if not os.path.exists(self.args.model_dir):
            os.makedirs(self.args.model_dir)
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
        #model_to_save.save_pretrained(os.path.join(self.args.model_dir, 'pytorch_model.pt'))
        torch.save(model_to_save, os.path.join(self.args.model_dir, 'pytorch_model.pt'))
        # Save training arguments together with the trained model
        #torch.save(self.args, os.path.join(self.args.model_dir, 'training_args.pt'))
        logger.info("Saving model checkpoint to %s", self.args.model_dir)

    def load_model(self):
        # Check whether model exists
        if not os.path.exists(self.args.model_dir):
            raise Exception("Model doesn't exists! Train first!")

        try:
            self.model = self.model_class.from_pretrained(self.args.model_dir,
                                                          args=self.args,
                                                          intent_label_lst=self.intent_label_lst,
                                                          slot_label_lst=self.slot_label_lst)
            self.model.to(self.device)
            logger.info("***** Model Loaded *****")
        except:
            raise Exception("Some model files might be missing...")

In [3]:
FLAG = False
logger = logging.getLogger(__name__)

In [4]:

import random
import time
import argparse
from datetime import datetime
import logging

#from trainer import Trainer, Trainer_multi, Trainer_woISeq
from utils import init_logger, load_tokenizer, read_prediction_text, set_seed, MODEL_CLASSES, MODEL_PATH_MAP
from data_loader import load_and_cache_examples

def init_logger():
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)

def main(args):
#     init_logger(args)
#     init_logger()

    set_seed(args)
    tokenizer = load_tokenizer(args)
    train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
    dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
    test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
    
    if args.multi_intent == 1:
        trainer = Trainer_multi(args, train_dataset, dev_dataset, test_dataset)
    else:
        trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
    if args.do_train:
        trainer.train()
    if args.do_eval:
        trainer.load_model()
        trainer.evaluate("test")
    return train_dataset



if __name__ == '__main__':
    time_wait = random.uniform(0, 10)
    time.sleep(time_wait)
    parser = argparse.ArgumentParser()
#     parser.add_argument("--task", default='mixsnips', required=True, type=str, help="The name of the task to train")
    parser.add_argument("--task", default='gpsr_pro_instance', type=str, help="The name of the task to train")

#     parser.add_argument("--model_dir", default='./gpsr_model', required=True, type=str, help="Path to save, load model")
    parser.add_argument("--model_dir", default='./gpsr_final_model', type=str, help="Path to save, load model")

    parser.add_argument("--data_dir", default="./data", type=str, help="The input data dir")
    parser.add_argument("--intent_label_file", default="intent_label.txt", type=str, help="Intent Label file")
    parser.add_argument("--slot_label_file", default="slot_label.txt", type=str, help="Slot Label file")
    parser.add_argument("--model_type", default="multibert", type=str, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
#     parser.add_argument("--intent_seq", type=int, default=0, help="whether we use intent seq setting")
    parser.add_argument("--intent_seq", type=int, default=1, help="whether we use intent seq setting")


    parser.add_argument("--pro", type=int, default=1, help="support pronoun disambiguition")#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


    parser.add_argument("--multi_intent", type=int, default=1, help="whether we use multi intent setting")
    parser.add_argument("--tag_intent", type=int, default=0, help="whether we can use tag to predict intent")
    
    parser.add_argument("--BI_tag", type=int, default=0, help='use BI sum or just B')
    parser.add_argument("--cls_token_cat", type=int, default=1, help='whether we cat the cls to the slot output of bert')
    parser.add_argument("--intent_attn", type=int, default=1, help='whether we use attention mechanism on the CLS intent output')
    parser.add_argument("--num_mask", type=int, default=7, help="assumptive number of slot in one sentence")
                                           #max slot num = 7

    parser.add_argument('--seed', type=int, default=25, help="random seed for initialization")
    parser.add_argument("--train_batch_size", default=64, type=int, help="Batch size for training.")
#     parser.add_argument("--train_batch_size", default=64, type=int, help="Batch size for training.")

    parser.add_argument("--eval_batch_size", default=128, type=int, help="Batch size for evaluation.")
    parser.add_argument("--max_seq_len", default=32, type=int, help="The maximum total input sequence length after tokenization.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
#     parser.add_argument("--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform.")
    parser.add_argument("--num_train_epochs", default=4.0, type=float, help="Total number of training epochs to perform.")
                                            #####
    
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1, type=float, help="Max gradient norm.")
    parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument("--dropout_rate", default=0.1, type=float, help="Dropout for fully-connected layers")
    parser.add_argument('--logging_steps', type=int, default=500, help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=300, help="Save checkpoint every X updates steps.")
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the test set.")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument("--ignore_index", default=0, type=int,
                        help='Specifies a target value that is ignored and does not contribute to the input gradient')
    parser.add_argument('--slot_loss_coef', type=float, default=2.0, help='Coefficient for the slot loss.')


    parser.add_argument('--pro_loss_coef', type=float, default=2.0, help='Coefficient for the pronoun loss.')



    parser.add_argument('--tag_intent_coef', type=float, default=1.0, help='Coefficient for the tag intent loss')

    # CRF option
    parser.add_argument("--use_crf", action="store_true", help="Whether to use CRF")
    parser.add_argument("--slot_pad_label", default="PAD", type=str, help="Pad token for slot label pad (to be ignore when calculate loss)")
    parser.add_argument("--patience", default=0, type=int, help="The initial learning rate for Adam.")
    
    parser.add_argument('-f')#########################
    
    args = parser.parse_args()
    
    now = datetime.now()
    args.model_dir = args.model_dir + '_' + now.strftime('%m-%d-%H:%M:%S')
    args.model_name_or_path = MODEL_PATH_MAP[args.model_type]

tokenizer = load_tokenizer(args)

In [5]:
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
trainer = Trainer_multi(args, train_dataset, dev_dataset, test_dataset)
trainer.train()
# test_dataset[1]

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]
Iteration:   0%|          | 0/514 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/514 [00:00<05:46,  1.48it/s][A
Iteration:   0%|          | 2/514 [00:00<03:14,  2.63it/s][A
Iteration:   1%|          | 3/514 [00:01<02:26,  3.50it/s][A
Iteration:   1%|          | 4/514 [00:01<02:03,  4.14it/s][A
Iteration:   1%|          | 5/514 [00:01<01:50,  4.61it/s][A
Iteration:   1%|          | 6/514 [00:01<01:44,  4.88it/s][A
Iteration:   1%|▏         | 7/514 [00:01<01:39,  5.12it/s][A
Iteration:   2%|▏         | 8/514 [00:01<01:35,  5.29it/s][A
Iteration:   2%|▏         | 9/514 [00:02<01:32,  5.44it/s][A
Iteration:   2%|▏         | 10/514 [00:02<01:31,  5.53it/s][A
Iteration:   2%|▏         | 11/514 [00:02<01:30,  5.57it/s][A
Iteration:   2%|▏         | 12/514 [00:02<01:29,  5.59it/s][A
Iteration:   3%|▎         | 13/514 [00:02<01:29,  5.61it/s][A
Iteration:   3%|▎         | 14/514 [00:02<01:28,  5.65it/s][A
Iteration:   3%|▎         | 

all_out_referee_labels_ids shape:  (2192, 32)
out_intent_token_ids shape:  (2192, 32)
preds:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.8680351906158358 , correct:  592 , total:  682



Iteration:  50%|█████     | 257/514 [00:48<04:01,  1.06it/s][A
Iteration:  50%|█████     | 258/514 [00:49<03:02,  1.40it/s][A
Iteration:  50%|█████     | 259/514 [00:49<02:20,  1.81it/s][A
Iteration:  51%|█████     | 260/514 [00:49<01:51,  2.27it/s][A
Iteration:  51%|█████     | 261/514 [00:49<01:31,  2.77it/s][A
Iteration:  51%|█████     | 262/514 [00:49<01:17,  3.26it/s][A
Iteration:  51%|█████     | 263/514 [00:49<01:07,  3.71it/s][A
Iteration:  51%|█████▏    | 264/514 [00:50<01:00,  4.12it/s][A
Iteration:  52%|█████▏    | 265/514 [00:50<00:55,  4.46it/s][A
Iteration:  52%|█████▏    | 266/514 [00:50<00:52,  4.76it/s][A
Iteration:  52%|█████▏    | 267/514 [00:50<00:49,  4.97it/s][A
Iteration:  52%|█████▏    | 268/514 [00:50<00:48,  5.12it/s][A
Iteration:  52%|█████▏    | 269/514 [00:51<00:46,  5.23it/s][A
Iteration:  53%|█████▎    | 270/514 [00:51<00:45,  5.31it/s][A
Iteration:  53%|█████▎    | 271/514 [00:51<00:45,  5.39it/s][A
Iteration:  53%|█████▎    | 272/514 [00

all_out_referee_labels_ids shape:  (2192, 32)
out_intent_token_ids shape:  (2192, 32)
preds:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9941348973607038 , correct:  678 , total:  682



Iteration: 100%|██████████| 514/514 [01:38<00:00,  5.24it/s][A
Epoch:  33%|███▎      | 1/3 [01:38<03:16, 98.02s/it]

referee_token_loss:  0.0013531188014894724  slot_loss:  0.007066477555781603  intent_token_loss:  0.006665100809186697



Iteration:   0%|          | 0/514 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/514 [00:00<01:30,  5.69it/s][A
Iteration:   0%|          | 2/514 [00:00<01:31,  5.60it/s][A
Iteration:   1%|          | 3/514 [00:00<01:31,  5.58it/s][A
Iteration:   1%|          | 4/514 [00:00<01:31,  5.57it/s][A
Iteration:   1%|          | 5/514 [00:00<01:32,  5.53it/s][A
Iteration:   1%|          | 6/514 [00:01<01:32,  5.50it/s][A
Iteration:   1%|▏         | 7/514 [00:01<01:31,  5.52it/s][A
Iteration:   2%|▏         | 8/514 [00:01<01:31,  5.54it/s][A
Iteration:   2%|▏         | 9/514 [00:01<01:30,  5.56it/s][A
Iteration:   2%|▏         | 10/514 [00:01<01:30,  5.56it/s][A
Iteration:   2%|▏         | 11/514 [00:01<01:30,  5.54it/s][A
Iteration:   2%|▏         | 12/514 [00:02<01:30,  5.54it/s][A
Iteration:   3%|▎         | 13/514 [00:02<01:30,  5.54it/s][A
Iteration:   3%|▎         | 14/514 [00:02<01:30,  5.55it/s][A
Iteration:   3%|▎         | 15/514 [00:02<01:29,  5.55it/s][A
Iteration

all_out_referee_labels_ids shape:  (2192, 32)
out_intent_token_ids shape:  (2192, 32)
preds:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9956011730205279 , correct:  679 , total:  682



Iteration:  50%|█████     | 257/514 [00:49<04:20,  1.01s/it][A
Iteration:  50%|█████     | 258/514 [00:49<03:15,  1.31it/s][A
Iteration:  50%|█████     | 259/514 [00:49<02:29,  1.70it/s][A
Iteration:  51%|█████     | 260/514 [00:49<01:58,  2.15it/s][A
Iteration:  51%|█████     | 261/514 [00:50<01:36,  2.63it/s][A
Iteration:  51%|█████     | 262/514 [00:50<01:20,  3.12it/s][A
Iteration:  51%|█████     | 263/514 [00:50<01:10,  3.58it/s][A
Iteration:  51%|█████▏    | 264/514 [00:50<01:02,  4.00it/s][A
Iteration:  52%|█████▏    | 265/514 [00:50<00:57,  4.36it/s][A
Iteration:  52%|█████▏    | 266/514 [00:50<00:53,  4.66it/s][A
Iteration:  52%|█████▏    | 267/514 [00:51<00:50,  4.90it/s][A
Iteration:  52%|█████▏    | 268/514 [00:51<00:48,  5.06it/s][A
Iteration:  52%|█████▏    | 269/514 [00:51<00:47,  5.18it/s][A
Iteration:  53%|█████▎    | 270/514 [00:51<00:46,  5.27it/s][A
Iteration:  53%|█████▎    | 271/514 [00:51<00:45,  5.33it/s][A
Iteration:  53%|█████▎    | 272/514 [00

all_out_referee_labels_ids shape:  (2192, 32)
out_intent_token_ids shape:  (2192, 32)
preds:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9956011730205279 , correct:  679 , total:  682



Iteration: 100%|██████████| 514/514 [01:38<00:00,  5.20it/s][A
Epoch:  67%|██████▋   | 2/3 [03:16<01:38, 98.51s/it]

referee_token_loss:  0.00012732377217616886  slot_loss:  0.0022298325784504414  intent_token_loss:  0.0005015155766159296



Iteration:   0%|          | 0/514 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/514 [00:00<01:31,  5.63it/s][A
Iteration:   0%|          | 2/514 [00:00<01:31,  5.59it/s][A
Iteration:   1%|          | 3/514 [00:00<01:31,  5.56it/s][A
Iteration:   1%|          | 4/514 [00:00<01:31,  5.55it/s][A
Iteration:   1%|          | 5/514 [00:00<01:32,  5.51it/s][A
Iteration:   1%|          | 6/514 [00:01<01:32,  5.49it/s][A
Iteration:   1%|▏         | 7/514 [00:01<01:32,  5.50it/s][A
Iteration:   2%|▏         | 8/514 [00:01<01:31,  5.51it/s][A
Iteration:   2%|▏         | 9/514 [00:01<01:31,  5.50it/s][A
Iteration:   2%|▏         | 10/514 [00:01<01:31,  5.49it/s][A
Iteration:   2%|▏         | 11/514 [00:01<01:31,  5.47it/s][A
Iteration:   2%|▏         | 12/514 [00:02<01:31,  5.49it/s][A
Iteration:   3%|▎         | 13/514 [00:02<01:31,  5.50it/s][A
Iteration:   3%|▎         | 14/514 [00:02<01:30,  5.51it/s][A
Iteration:   3%|▎         | 15/514 [00:02<01:30,  5.51it/s][A
Iteration

all_out_referee_labels_ids shape:  (2192, 32)
out_intent_token_ids shape:  (2192, 32)
preds:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9956011730205279 , correct:  679 , total:  682



Iteration:  50%|█████     | 257/514 [00:49<04:26,  1.04s/it][A
Iteration:  50%|█████     | 258/514 [00:49<03:19,  1.28it/s][A
Iteration:  50%|█████     | 259/514 [00:50<02:33,  1.67it/s][A
Iteration:  51%|█████     | 260/514 [00:50<02:00,  2.11it/s][A
Iteration:  51%|█████     | 261/514 [00:50<01:37,  2.59it/s][A
Iteration:  51%|█████     | 262/514 [00:50<01:22,  3.07it/s][A
Iteration:  51%|█████     | 263/514 [00:50<01:11,  3.53it/s][A
Iteration:  51%|█████▏    | 264/514 [00:50<01:03,  3.95it/s][A
Iteration:  52%|█████▏    | 265/514 [00:51<00:57,  4.32it/s][A
Iteration:  52%|█████▏    | 266/514 [00:51<00:53,  4.63it/s][A
Iteration:  52%|█████▏    | 267/514 [00:51<00:50,  4.86it/s][A
Iteration:  52%|█████▏    | 268/514 [00:51<00:49,  5.01it/s][A
Iteration:  52%|█████▏    | 269/514 [00:51<00:47,  5.13it/s][A
Iteration:  53%|█████▎    | 270/514 [00:52<00:46,  5.24it/s][A
Iteration:  53%|█████▎    | 271/514 [00:52<00:45,  5.33it/s][A
Iteration:  53%|█████▎    | 272/514 [00

all_out_referee_labels_ids shape:  (2192, 32)
out_intent_token_ids shape:  (2192, 32)
preds:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9956011730205279 , correct:  679 , total:  682



Iteration: 100%|██████████| 514/514 [01:39<00:00,  5.17it/s][A
Epoch: 100%|██████████| 3/3 [04:56<00:00, 98.74s/it]

referee_token_loss:  0.00024469883646816015  slot_loss:  0.00027137796860188246  intent_token_loss:  0.00034721242263913155





(1542, 0.20752348990028455)

In [6]:
trainer.evaluate("test")

Evaluating: 100%|██████████| 69/69 [00:07<00:00,  8.85it/s]


all_out_referee_labels_ids shape:  (8771, 32)
out_intent_token_ids shape:  (8771, 32)
preds:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9988901220865705 , correct:  2700 , total:  2703


{'loss': 4.233489614536386e-06,
 'slot_precision': 0.9999703466476885,
 'slot_recall': 0.9999703466476885,
 'slot_f1': 0.9999703466476885,
 'Pro_precision': 0.9988901220865705,
 'Pro_recall': 0.9988901220865705,
 'Pro_f1': 0.9988901220865705,
 'intent_token_precision': 1.0,
 'intent_token_recall': 1.0,
 'intent_token_f1': 1.0,
 'Pronoun Accuracy': 0.9988901220865705}

In [7]:
trainer.save()

AttributeError: 'Trainer_multi' object has no attribute 'save'

In [None]:
# convert the local check point to onnx
import subprocess
subprocess.run(f"python -m transformers.onnx --model=local-pt-checkpoint --feature=sequence-classification onnx/".split())

In [None]:
# !python -m transformers.onnx --model=gpsr_final_model_02-14-15\:18\:17/ onnx/

In [None]:
from transformers.convert_graph_to_onnx import convert

# Handles all the above steps for you
convert(framework="pt", model="gpsr_final_model_02-14-18:20:41", output="onnx/bert-base-cased.onnx", opset=11)
