In [1]:
import os
import logging
from tqdm import tqdm, trange
import torch.nn as nn

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import BertConfig, AdamW, get_linear_schedule_with_warmup

from utils import MODEL_CLASSES, compute_metrics, get_intent_labels, get_slot_labels, compute_metrics_multi_intent,compute_metrics_multi_intent_Pro,compute_metrics_final

from seqeval.metrics.sequence_labeling import get_entities

In [2]:
from model.modeling_final import JointBERTMultiIntent

class Trainer_multi(object):
    def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None):
        self.args = args
        self.train_dataset = train_dataset
        self.dev_dataset = dev_dataset
        self.test_dataset = test_dataset
        
        self.slot_preds_list = None
        self.intent_token_preds_list = None
        
        
        # set of intents
        self.intent_label_lst = get_intent_labels(args)
        # set of slots
        self.slot_label_lst = get_slot_labels(args)

        self.num_intent_labels = len(self.intent_label_lst)
        self.num_slot_labels = len(self.slot_label_lst)

        # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
        self.pad_token_label_id = args.ignore_index

        self.config_class, self.model_class, _ = MODEL_CLASSES[args.model_type]
        
        self.config = self.config_class.from_pretrained(args.model_name_or_path, finetuning_task=args.task)
        # self.config = BertConfig.from_pretrained('bert-base-uncased', finetuning_task='gpsr_pro_instance')

        self.model = self.model_class.from_pretrained(args.model_name_or_path,
                                                      config=self.config,
                                                      intent_label_lst=self.intent_label_lst,
                                                      slot_label_lst=self.slot_label_lst,
                                                      )

        ###############################################################
        for name, param in self.model.named_parameters():
            if name.startswith("encoder.layer.11") or name.startswith("encoder.layer.10"): # unfroze last 2 layers
                param.requires_grad = True
            # else:
            #     param.requires_grad = False
            print('name: ',name,'param: ',param.requires_grad)
        ###############################################################


        # self.model = JointBERTMultiIntent()
        
        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        self.model.to(self.device)

    def get_intent_token_loss(self,intent_token_logits,intent_token_ids,attention_mask):
        intent_token_loss = 0.0
        intent_token_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
        if attention_mask is not None:
            active_intent_loss = attention_mask.view(-1) == 1
            active_intent_logits = intent_token_logits.view(-1, self.num_intent_labels)[active_intent_loss]
            active_intent_tokens = intent_token_ids.view(-1)[active_intent_loss]
            intent_token_loss = intent_token_loss_fct(active_intent_logits, active_intent_tokens)
        else:
            intent_token_loss = intent_token_loss_fct(intent_token_logits.view(-1, self.num_intent_labels), intent_token_ids.view(-1))

        return self.args.slot_loss_coef * intent_token_loss

    def get_slot_loss(self,slot_logits,slot_labels_ids,attention_mask):
        slot_loss = 0.0
        slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
        # Only keep active parts of the loss
        if attention_mask is not None:
            try:
                active_loss = attention_mask.view(-1) == 1
                attention_mask_cpu = attention_mask.data.cpu().numpy()
                active_loss_cpu = active_loss.data.cpu().numpy()
                active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
                active_labels = slot_labels_ids.view(-1)[active_loss]
                slot_loss = slot_loss_fct(active_logits, active_labels)
            except:
                print('attention_mask: ', attention_mask_cpu)
                print('active_loss: ', active_loss_cpu)
                logger.info('attention_mask: ', attention_mask_cpu)
                logger.info('active_loss: ', active_loss_cpu)
        else:
            slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))

        return self.args.slot_loss_coef * slot_loss

    def get_referee_token_loss(self,referee_token_logits,referee_labels_ids,attention_mask,pro_sample_mask):
        referee_token_loss = 0.0
        class_weights = torch.FloatTensor([1,10,200]).to(self.device)
        referee_token_loss_fct = nn.CrossEntropyLoss(weight = class_weights,ignore_index=self.args.ignore_index) #self.referee_token_loss_fct

        if attention_mask is not None:
            try:
                active_loss = attention_mask[pro_sample_mask].view(-1) == 1
                attention_mask_cpu = attention_mask.data.cpu().numpy()
                active_loss_cpu = active_loss.data.cpu().numpy()
                active_logits = referee_token_logits.view(-1, 3)[active_loss]
                active_labels = referee_labels_ids[pro_sample_mask].view(-1)[active_loss]
                referee_token_loss = referee_token_loss_fct(active_logits,
                                                            active_labels)
            except:
                logger.info('attention_mask: ', attention_mask_cpu)
                logger.info('active_loss: ', active_loss_cpu)
        else:
            referee_token_loss = referee_token_loss_fct(referee_token_logits.view(-1, 3), referee_labels_ids[pro_sample_mask].view(-1))

        return self.args.pro_loss_coef * referee_token_loss


    def train(self):
        
        #logger.info(vars(self.args))
        train_sampler = RandomSampler(self.train_dataset)
        train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size=self.args.train_batch_size)

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            self.args.num_train_epochs = self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
        else:
            t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.args.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(self.train_dataset))
        logger.info("  Num Epochs = %d", self.args.num_train_epochs)
        logger.info("  Total train batch size = %d", self.args.train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)
        logger.info("  Logging steps = %d", self.args.logging_steps)
        logger.info("  Save steps = %d", self.args.save_steps)

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()

        train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch")

        step_per_epoch = len(train_dataloader) // 2

        # record the evaluation loss
        eval_acc = 0.0
        MAX_RECORD = self.args.patience
        num_eval = -1
        eval_result_record = (num_eval, eval_acc)
        flag = False
        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=FLAG)
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)  # GPU or CPU

                attention_mask = batch[1]
                intent_label_ids = batch[3]
                slot_labels_ids =  batch[4]
                intent_token_ids =  batch[5]
                B_tag_mask =  batch[6]
                BI_tag_mask =  batch[7]
                tag_intent_label =  batch[8]
                referee_labels_ids =  batch[9] #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                pro_labels_ids = batch[10]

                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'pro_labels_ids' : batch[10]}
                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]


                outputs = self.model(**inputs)

                if self.args.pro and self.args.intent_seq:
                    slot_logits, intent_token_logits, referee_token_logits,all_referee_token_logits = outputs

                slot_loss = self.get_slot_loss(slot_logits,slot_labels_ids,attention_mask)
                intent_token_loss =  self.get_intent_token_loss(intent_token_logits,intent_token_ids,attention_mask)

                pro_token_mask = pro_labels_ids > 0
                pro_sample_mask = torch.max(pro_token_mask.long(),dim = 1)[0] > 0
                referee_token_loss = self.get_referee_token_loss(referee_token_logits,referee_labels_ids,attention_mask,pro_sample_mask)
                loss = slot_loss + intent_token_loss + referee_token_loss


                if self.args.gradient_accumulation_steps > 1:
                    loss = loss / self.args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                if (step + 1) % self.args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    self.model.zero_grad()
                    global_step += 1

                    # if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
                    if self.args.logging_steps > 0 and global_step % step_per_epoch == 0:
                        logger.info("***** Training Step %d *****", step)
                        logger.info("  total_loss = %f", loss)
                        logger.info("  slot_loss = %f", slot_loss)
                        logger.info("  intent_token_loss = %f", intent_token_loss)
                        logger.info("  referee_token_loss = %f", referee_token_loss) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

                        dev_result = self.evaluate("dev")
                        test_result = self.evaluate("test")
                        num_eval += 1
                        if self.args.patience != 0:
                            if dev_result['sementic_frame_acc'] + dev_result['intent_acc'] + dev_result['slot_f1']   > eval_result_record[1]:
                                self.save_model()
                                eval_result_record = (num_eval, dev_result['sementic_frame_acc'] + dev_result['intent_acc'] + dev_result['slot_f1'] )
                            else:
                                cur_num_eval = eval_result_record[0]
                                if num_eval - cur_num_eval >= MAX_RECORD:
                                    # it has been ok
                                    logger.info(' EARLY STOP Evaluate: at {}, best eval {} intent_slot_acc: {} '.format(num_eval, cur_num_eval, eval_result_record[1]))
                                    flag = True
                                    break
                        else:
                            self.save_model()

                            

                if 0 < self.args.max_steps < global_step:
                    epoch_iterator.close()
                    break

            if flag:
                train_iterator.close()
                break

            if 0 < self.args.max_steps < global_step:
                train_iterator.close()
                break

        return global_step, tr_loss / global_step

    def evaluate(self, mode):
        if mode == 'test':
            dataset = self.test_dataset
        elif mode == 'dev':
            dataset = self.dev_dataset
        else:
            raise Exception("Only dev and test dataset available")

        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size)

        # Eval!
        logger.info("***** Running evaluation on %s dataset *****", mode)
        logger.info("  Num examples = %d", len(dataset))
        logger.info("  Batch size = %d", self.args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        intent_preds = None
        slot_preds = None
        intent_token_preds = None
        out_intent_label_ids = None
        out_slot_labels_ids = None
        out_intent_token_ids = None
        
        tag_intent_preds = None
        out_tag_intent_ids = None
        referee_preds = None #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        out_referee_labels_ids = None #!!!!!!!!!!!!!!!!!!!!!!!!!
        all_referee_preds = None #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        all_out_referee_labels_ids = None #!!!!!!!!!!!!!!!!!!!!!!!!!

        self.model.eval()

        for batch in tqdm(eval_dataloader, desc="Evaluating", disable=FLAG):
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                attention_mask = batch[1]
                intent_label_ids = batch[3]
                slot_labels_ids =  batch[4]
                intent_token_ids =  batch[5]
                B_tag_mask =  batch[6]
                BI_tag_mask =  batch[7]
                tag_intent_label =  batch[8]
                referee_labels_ids =  batch[9] #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                pro_labels_ids = batch[10]

                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'pro_labels_ids' : batch[10]}


                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]
                outputs = self.model(**inputs)

            # logits = outputs[0]
            if self.args.pro and self.args.intent_seq:
                slot_logits, intent_token_logits, referee_token_logits,all_referee_token_logits = outputs

            slot_loss = self.get_slot_loss(slot_logits,slot_labels_ids,attention_mask)
            intent_token_loss =  self.get_intent_token_loss(intent_token_logits,intent_token_ids,attention_mask)

            pro_token_mask = pro_labels_ids > 0
            pro_sample_mask = torch.max(pro_token_mask.long(),dim = 1)[0] > 0
            referee_token_loss = self.get_referee_token_loss(referee_token_logits,referee_labels_ids,attention_mask,pro_sample_mask)
            loss = slot_loss + intent_token_loss + referee_token_loss

            # ============================= Slot prediction ==============================
            if slot_preds is None:
                if self.args.use_crf:
                    # decode() in `torchcrf` returns list with best index directly
                    slot_preds = np.array(self.model.crf.decode(slot_logits))
                else:
                    slot_preds = slot_logits.detach().cpu().numpy()

                out_slot_labels_ids = slot_labels_ids.detach().cpu().numpy()
            else:
                if self.args.use_crf:
                    slot_preds = np.append(slot_preds, np.array(self.model.crf.decode(slot_logits)), axis=0)
                else:
                    slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)

                out_slot_labels_ids = np.append(out_slot_labels_ids, slot_labels_ids.detach().cpu().numpy(), axis=0)

            # ============================= Pronoun referee prediction ==============================
            if self.args.pro:
                if all_referee_preds is None:

                    all_referee_preds = all_referee_token_logits.detach().cpu().numpy()
                    referee_preds = referee_token_logits.detach().cpu().numpy()

                    pro_sample_mask_np = (torch.max(inputs["pro_labels_ids"],dim = 1)[0] > 0).detach().cpu().numpy()
                    all_out_referee_labels_ids = referee_labels_ids.detach().cpu().numpy()
                    out_referee_labels_ids = all_out_referee_labels_ids[pro_sample_mask_np]


                else:

                    all_referee_preds = np.append(all_referee_preds,all_referee_token_logits.detach().cpu().numpy(), axis = 0)
                    referee_preds = np.append(referee_preds, referee_token_logits.detach().cpu().numpy(), axis = 0)

                    pro_sample_mask_np = (torch.max(inputs["pro_labels_ids"],dim = 1)[0] > 0).detach().cpu().numpy()
                    new_all_out_referee_labels_ids = referee_labels_ids.detach().cpu().numpy()
                    all_out_referee_labels_ids = np.append(all_out_referee_labels_ids,new_all_out_referee_labels_ids,axis = 0)
                    new_out_referee_labels_ids = new_all_out_referee_labels_ids[pro_sample_mask_np]#np.array([ele for i,ele in enumerate(new_all_out_referee_labels_ids) if pro_sample_mask_np[i] != False])
                    out_referee_labels_ids = np.append(out_referee_labels_ids, new_out_referee_labels_ids, axis = 0)

            if self.args.intent_seq:
                if intent_token_preds is None:
                    if self.args.use_crf:
                        intent_token_preds = np.array(self.model.crf.decode(intent_token_logits))
                    else:
                        intent_token_preds = intent_token_logits.detach().cpu().numpy()

                    out_intent_token_ids = intent_token_ids.detach().cpu().numpy()
                else:
                    if self.args.use_crf:
                        intent_token_preds = np.append(intent_token_preds, np.array(self.model.crf.decode(intent_token_logits)), axis=0)
                    else:
                        intent_token_preds = np.append(intent_token_preds, intent_token_logits.detach().cpu().numpy(), axis=0)

                    out_intent_token_ids = np.append(out_intent_token_ids, intent_token_ids.detach().cpu().numpy(), axis=0)

            eval_loss += loss.item()
            nb_eval_steps += 1
            eval_loss = eval_loss / nb_eval_steps
            results = {
                "loss": eval_loss
            }
        nb_eval_steps += 1

        # Slot result
        # (batch_size, seq_len)
        if not self.args.use_crf:
            slot_preds = np.argmax(slot_preds, axis=2)
        slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
        out_slot_label_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
        slot_preds_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
        
        B_tag_mask_pred = []
        BI_tag_mask_pred = []
        
        # generate mask
        for i in range(out_slot_labels_ids.shape[0]):
            # record the padding position
            pos_offset = [0 for _ in range(out_slot_labels_ids.shape[1])]
            pos_cnt = 0
            padding_recording = [0 for _ in range(out_slot_labels_ids.shape[1])]
            
            for j in range(out_slot_labels_ids.shape[1]):
                if out_slot_labels_ids[i, j] != self.pad_token_label_id:
                    out_slot_label_list[i].append(slot_label_map[out_slot_labels_ids[i][j]])
                    slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
                    pos_offset[pos_cnt+1] = pos_offset[pos_cnt]
                    pos_cnt += 1
                else:
                    pos_offset[pos_cnt] = pos_offset[pos_cnt] + 1
                    padding_recording[j] = 1

        # ============================= Pronoun Referee Prediction ============================ !!!!!!!!!!!!!!!!!!!!!!!!!!!!
        # print('referee_preds shape: ',referee_preds.shape)
        # print('all_referee_preds shape: ', all_referee_preds.shape)
        # print('out_referee_labels_ids shape: ',out_referee_labels_ids.shape)
        print('all_out_referee_labels_ids shape: ', all_out_referee_labels_ids.shape)
        # print('out_intent_token_ids shape: ',out_intent_token_ids.shape)

        referee_token_map = {0:'PAD', 1:'O' ,2: 'B-referee'} # All referee are just one word in EGPSR


        if self.args.pro:
            referee_preds = np.argmax(referee_preds, axis=2)
            all_referee_preds = np.argmax(all_referee_preds, axis=2)


            referee_preds_list = [[] for _ in range(out_referee_labels_ids.shape[0])]
            out_referee_label_list = [[] for _ in range(out_referee_labels_ids.shape[0])]
            all_referee_preds_list = [[] for _ in range(all_out_referee_labels_ids.shape[0])]
            all_out_referee_label_list = [[] for _ in range(all_out_referee_labels_ids.shape[0])]



            for i in range(out_referee_labels_ids.shape[0]):
                for j in range(out_referee_labels_ids.shape[1]):
                    if out_referee_labels_ids[i, j] != self.pad_token_label_id:
                        out_referee_label_list[i].append(referee_token_map[out_referee_labels_ids[i][j]])
                        referee_preds_list[i].append(referee_token_map[referee_preds[i][j]])

            for i in range(all_out_referee_labels_ids.shape[0]):
                for j in range(all_out_referee_labels_ids.shape[1]):
                    if all_out_referee_labels_ids[i, j] != self.pad_token_label_id:
                        all_out_referee_label_list[i].append(referee_token_map[all_out_referee_labels_ids[i][j]])
                        all_referee_preds_list[i].append(referee_token_map[all_referee_preds[i][j]])



        intent_token_map = {i: label for i, label in enumerate(self.intent_label_lst)}
        out_intent_token_list = None
        intent_token_preds_list = None
        # ============================= Intent Seq Prediction ============================
        if self.args.intent_seq:
            if not self.args.use_crf:
                intent_token_preds = np.argmax(intent_token_preds, axis=2)
            out_intent_token_list = [[] for _ in range(out_intent_token_ids.shape[0])]
            intent_token_preds_list = [[] for _ in range(out_intent_token_ids.shape[0])]

            for i in range(out_intent_token_ids.shape[0]):
                for j in range(out_intent_token_ids.shape[1]):
                    if out_intent_token_ids[i, j] != self.pad_token_label_id:
                        out_intent_token_list[i].append(intent_token_map[out_intent_token_ids[i][j]])
                        intent_token_preds_list[i].append(intent_token_map[intent_token_preds[i][j]])


        total_result = compute_metrics_final(
                                       slot_preds_list,
                                       out_slot_label_list,
                                       intent_token_preds_list,
                                       out_intent_token_list,
                                       referee_preds_list,
                                       out_referee_label_list
                                      )
        results.update(total_result)
        print(total_result)



        print(slot_label_map)
        print(intent_token_map)

        # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Pronoun Acc !!!!!!!!!!!!!!!!!!!!!!!!

        correct = 0
        for ref_pred_seq,ref_label_seq in zip(referee_preds_list,out_referee_label_list):
            if ref_pred_seq == ref_label_seq:
                correct += 1
        ref_acc = correct/len(referee_preds_list)
        print('Pronoun Accurac: ',ref_acc, ', correct: ',correct, ', total: ',len(referee_preds_list))
        results.update({'Pronoun Accuracy':ref_acc})


        correct = 0
        for slot_pred_seq,slot_label_seq in zip(slot_preds_list,out_slot_label_list):
            if slot_pred_seq == slot_label_seq:
                correct += 1
            else:
                print('pred: ',slot_pred_seq)
                print('true: ',slot_label_seq,'\n')
        slot_acc = correct/len(slot_preds_list)
        print('Slot Accurac: ',slot_acc, ', correct: ',correct, ', total: ',len(slot_preds_list))



        correct = 0
        for intent_pred_seq,intent_label_seq in zip(intent_token_preds_list,out_intent_token_list):
            if intent_pred_seq == intent_label_seq:
                correct += 1
            else:
                print('pred: ',intent_pred_seq)
                print('true: ',intent_label_seq,'\n')
        intent_acc = correct/len(intent_token_preds_list)
        print('Intent Accurac: ',intent_acc, ', correct: ',correct, ', total: ',len(intent_token_preds_list))

        logger.info("***** Eval results *****")
        for key in sorted(results.keys()):
            logger.info("  %s_%s = %s", mode, key, str(results[key]))
        
        #self.store_pred(slot_preds_list,intent_token_preds_list)
        self.slot_preds_list = slot_preds_list
        self.intent_token_preds_list = intent_token_preds_list
        return results

    # def save_model(self):
    #     # Save model checkpoint (Overwrite)
    #     if not os.path.exists(self.args.model_dir):
    #         os.makedirs(self.args.model_dir)
    #     model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
    #     #model_to_save.save_pretrained(os.path.join(self.args.model_dir, 'pytorch_model.pt'))
    #     torch.save(model_to_save, os.path.join(self.args.model_dir, 'pytorch_model.pt'))
    #     # Save training arguments together with the trained model
    #     # torch.save(self.args, os.path.join(self.args.model_dir, 'training_args.pt'))
    #     logger.info("Saving model checkpoint to %s", self.args.model_dir)

    def save_model(self):
        # Save model checkpoint (Overwrite)
        if not os.path.exists(self.args.model_dir):
            os.makedirs(self.args.model_dir)
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
        model_to_save.save_pretrained(self.args.model_dir)

        # Save training arguments together with the trained model
        torch.save(self.args, os.path.join(self.args.model_dir, 'training_args.bin'))
        logger.info("Saving model checkpoint to %s", self.args.model_dir)

    def load_model(self):
        # Check whether model exists
        if not os.path.exists(self.args.model_dir):
            raise Exception("Model doesn't exists! Train first!")

        try:
            self.model = self.model_class.from_pretrained(self.args.model_dir,
                                                          args=self.args,
                                                          intent_label_lst=self.intent_label_lst,
                                                          slot_label_lst=self.slot_label_lst)
            self.model.to(self.device)
            logger.info("***** Model Loaded *****")
        except:
            raise Exception("Some model files might be missing...")

In [3]:
FLAG = False
logger = logging.getLogger(__name__)

In [4]:
import random
import time
import argparse
from datetime import datetime
import logging
import sys
sys.argv = ['']

#from trainer import Trainer, Trainer_multi, Trainer_woISeq
from utils import init_logger, load_tokenizer, read_prediction_text, set_seed, MODEL_CLASSES, MODEL_PATH_MAP
from data_loader import load_and_cache_examples

def init_logger():
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)

def main(args):
#     init_logger(args)
#     init_logger()

    set_seed(args)
    tokenizer = load_tokenizer(args)
    train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
    dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
    test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
    
    if args.multi_intent == 1:
        trainer = Trainer_multi(args, train_dataset, dev_dataset, test_dataset)
    else:
        trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
    if args.do_train:
        trainer.train()
    if args.do_eval:
        trainer.load_model()
        trainer.evaluate("test")
    return train_dataset



if __name__ == '__main__':
    time_wait = random.uniform(0, 10)
    time.sleep(time_wait)
    parser = argparse.ArgumentParser()
#     parser.add_argument("--task", default='mixsnips', required=True, type=str, help="The name of the task to train")
    parser.add_argument("--task", default='gpsr_pro_instance_say', type=str, help="The name of the task to train")

#     parser.add_argument("--model_dir", default='./gpsr_model', required=True, type=str, help="Path to save, load model")
    parser.add_argument("--model_dir", default='./gpsr_final_model', type=str, help="Path to save, load model")

    parser.add_argument("--data_dir", default="./data", type=str, help="The input data dir")
    parser.add_argument("--intent_label_file", default="intent_label.txt", type=str, help="Intent Label file")
    parser.add_argument("--slot_label_file", default="slot_label.txt", type=str, help="Slot Label file")
    parser.add_argument("--model_type", default="mobilebert", type=str, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
#     parser.add_argument("--intent_seq", type=int, default=0, help="whether we use intent seq setting")
    parser.add_argument("--intent_seq", type=int, default=1, help="whether we use intent seq setting")

    parser.add_argument("--pro", type=int, default=1, help="support pronoun disambiguition")#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    parser.add_argument("--multi_intent", type=int, default=1, help="whether we use multi intent setting")
    parser.add_argument("--tag_intent", type=int, default=0, help="whether we can use tag to predict intent")
    
    parser.add_argument("--BI_tag", type=int, default=0, help='use BI sum or just B')
    parser.add_argument("--cls_token_cat", type=int, default=1, help='whether we cat the cls to the slot output of bert')
    parser.add_argument("--intent_attn", type=int, default=1, help='whether we use attention mechanism on the CLS intent output')
    parser.add_argument("--num_mask", type=int, default=7, help="assumptive number of slot in one sentence")
                                           #max slot num = 7
    
    parser.add_argument('--seed', type=int, default=25, help="random seed for initialization")
    parser.add_argument("--train_batch_size", type=int, default=64,  help="Batch size for training.")
#     parser.add_argument("--train_batch_size", default=64, type=int, help="Batch size for training.")

    parser.add_argument("--eval_batch_size", default=128, type=int, help="Batch size for evaluation.")
    parser.add_argument("--max_seq_len", default=32, type=int, help="The maximum total input sequence length after tokenization.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
#     parser.add_argument("--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.")
    
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1, type=float, help="Max gradient norm.")
    parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument("--dropout_rate", default=0.1, type=float, help="Dropout for fully-connected layers")
    parser.add_argument('--logging_steps', type=int, default=500, help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=300, help="Save checkpoint every X updates steps.")
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the test set.")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument("--ignore_index", default=0, type=int, help='Specifies a target value that is ignored and does not contribute to the input gradient')
    parser.add_argument('--slot_loss_coef', type=float, default=2.0, help='Coefficient for the slot loss.')


    parser.add_argument('--pro_loss_coef', type=float, default=10.0, help='Coefficient for the pronoun loss.')



    parser.add_argument('--tag_intent_coef', type=float, default=1.0, help='Coefficient for the tag intent loss')

    # CRF option
    parser.add_argument("--use_crf", action="store_true", help="Whether to use CRF")
    parser.add_argument("--slot_pad_label", default="PAD", type=str, help="Pad token for slot label pad (to be ignore when calculate loss)")
    parser.add_argument("--patience", default=0, type=int, help="The initial learning rate for Adam.")
    
    parser.add_argument('-f')#########################
    args = parser.parse_args()
    
    now = datetime.now()
    args.model_dir = args.model_dir + '_' + now.strftime('%m-%d-%H:%M:%S')
    args.model_name_or_path = MODEL_PATH_MAP[args.model_type]

tokenizer = load_tokenizer(args)
print("Launing with model name: {}".format(args.model_name_or_path))

Launing with model name: google/mobilebert-uncased


In [5]:
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
trainer = Trainer_multi(args, train_dataset, dev_dataset, test_dataset)
trainer.train()
# test_dataset[1]

Downloading:   0%|          | 0.00/147M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/mobilebert-uncased were not used when initializing JointMobileBERTMultiIntent: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing JointMobileBERTMultiIntent from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing JointMobileBERTMultiIntent from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of JointMobileBERTM

name:  mobilebert.embeddings.word_embeddings.weight param:  True
name:  mobilebert.embeddings.position_embeddings.weight param:  True
name:  mobilebert.embeddings.token_type_embeddings.weight param:  True
name:  mobilebert.embeddings.embedding_transformation.weight param:  True
name:  mobilebert.embeddings.embedding_transformation.bias param:  True
name:  mobilebert.embeddings.LayerNorm.bias param:  True
name:  mobilebert.embeddings.LayerNorm.weight param:  True
name:  mobilebert.encoder.layer.0.attention.self.query.weight param:  True
name:  mobilebert.encoder.layer.0.attention.self.query.bias param:  True
name:  mobilebert.encoder.layer.0.attention.self.key.weight param:  True
name:  mobilebert.encoder.layer.0.attention.self.key.bias param:  True
name:  mobilebert.encoder.layer.0.attention.self.value.weight param:  True
name:  mobilebert.encoder.layer.0.attention.self.value.bias param:  True
name:  mobilebert.encoder.layer.0.attention.output.dense.weight param:  True
name:  mobileber

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 24/24 [00:01<00:00, 17.29it/s]


all_out_referee_labels_ids shape:  (3016, 32)
{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 1.0, 'Pro_recall': 1.0, 'Pro_f1': 1.0, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  1.0 , correct:  648 , total:  648
Slot Accurac:  1.0 , correct:  3016 , total:  3016
Intent Accurac:  1.0 , correct:  3016 , total:  3016



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 95/95 [00:05<00:00, 16.57it/s]


all_out_referee_labels_ids shape:  (12068, 32)




{'slot_precision': 0.9999767144021423, 'slot_recall': 0.9999767144021423, 'slot_f1': 0.9999767144021423, 'Pro_precision': 0.9969430645777608, 'Pro_recall': 0.9996168582375479, 'Pro_f1': 0.9982781710350106, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  0.9969348659003832 , correct:  2602 , total:  2610
pred:  ['O', 'B-dest', 'B-what', 'I-what', 'O', 'O', 'O', 'O', 'I-what', 'I-what', 'I-what', 'I-what', 'I-what']
true:  ['O', 'B-dest', 'B-what', 'I-what', 'I-what', 'O', 'O', 'O', 'I-what', 'I-what', 'I-what', 'I-what', 'I-


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 24/24 [00:01<00:00, 17.22it/s]


all_out_referee_labels_ids shape:  (3016, 32)
{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 0.9984591679506933, 'Pro_recall': 1.0, 'Pro_f1': 0.9992289899768697, 'intent_token_precision': 0.9998354044934573, 'intent_token_recall': 0.999917695473251, 'intent_token_f1': 0.9998765482901938}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  0.9984567901234568 , correct:  647 , total:  648
Slot Accurac:  1.0 , correct:  3016 , total:  3016
pred:  ['O', 'O', 'B-go', 'I-go', 'I-go', 'I-go', 'O', 'B-find', 'I-find', 'I-find', 'I-take', 'O', 'O', 'B-take', 'I-take', 'I-take', 'I-take']



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 95/95 [00:05<00:00, 16.61it/s]


all_out_referee_labels_ids shape:  (12068, 32)


Iteration: 100%|██████████| 708/708 [02:13<00:00,  5.32it/s]
Epoch:  33%|███▎      | 1/3 [02:13<04:26, 133.20s/it]

{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 0.9992340099578706, 'Pro_recall': 0.9996168582375479, 'Pro_f1': 0.9994253974334418, 'intent_token_precision': 0.9998359714590339, 'intent_token_recall': 0.9999179790026247, 'intent_token_f1': 0.9998769735493132}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  0.9992337164750957 , correct:  2608 , total:  2610
Slot Accurac:  1.0 , correct:  12068 , total:  12068
pred:  ['B-go', 'I-go', 'I-go', 'I-go', 'O', 'B-find', 'I-find', 'I-take', 'O', 'O', 'B-take', 'I-take', 'I-take', 'I-take', 'I-take']
true:  ['B-go', 'I-go', 'I-go', 'I-g


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 24/24 [00:01<00:00, 17.31it/s]


all_out_referee_labels_ids shape:  (3016, 32)
{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 0.9984591679506933, 'Pro_recall': 1.0, 'Pro_f1': 0.9992289899768697, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  0.9984567901234568 , correct:  647 , total:  648
Slot Accurac:  1.0 , correct:  3016 , total:  3016
Intent Accurac:  1.0 , correct:  3016 , total:  3016



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 95/95 [00:05<00:00, 16.66it/s]


all_out_referee_labels_ids shape:  (12068, 32)




{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 0.9996168582375479, 'Pro_recall': 0.9996168582375479, 'Pro_f1': 0.9996168582375479, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  0.9996168582375479 , correct:  2609 , total:  2610
Slot Accurac:  1.0 , correct:  12068 , total:  12068
Intent Accurac:  1.0 , correct:  12068 , total:  12068



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 24/24 [00:01<00:00, 17.34it/s]


all_out_referee_labels_ids shape:  (3016, 32)
{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 0.9984591679506933, 'Pro_recall': 1.0, 'Pro_f1': 0.9992289899768697, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  0.9984567901234568 , correct:  647 , total:  648
Slot Accurac:  1.0 , correct:  3016 , total:  3016
Intent Accurac:  1.0 , correct:  3016 , total:  3016



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 95/95 [00:06<00:00, 14.38it/s]


all_out_referee_labels_ids shape:  (12068, 32)


Iteration: 100%|██████████| 708/708 [02:13<00:00,  5.31it/s]
Epoch:  67%|██████▋   | 2/3 [04:26<02:13, 133.32s/it]

{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 0.9996170049789352, 'Pro_recall': 1.0, 'Pro_f1': 0.9998084658111472, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  0.9996168582375479 , correct:  2609 , total:  2610
Slot Accurac:  1.0 , correct:  12068 , total:  12068
Intent Accurac:  1.0 , correct:  12068 , total:  12068



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 24/24 [00:01<00:00, 17.19it/s]


all_out_referee_labels_ids shape:  (3016, 32)
{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 1.0, 'Pro_recall': 1.0, 'Pro_f1': 1.0, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  1.0 , correct:  648 , total:  648
Slot Accurac:  1.0 , correct:  3016 , total:  3016
Intent Accurac:  1.0 , correct:  3016 , total:  3016



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 95/95 [00:05<00:00, 16.69it/s]


all_out_referee_labels_ids shape:  (12068, 32)




{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 0.9996170049789352, 'Pro_recall': 1.0, 'Pro_f1': 0.9998084658111472, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  0.9996168582375479 , correct:  2609 , total:  2610
Slot Accurac:  1.0 , correct:  12068 , total:  12068
Intent Accurac:  1.0 , correct:  12068 , total:  12068



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 24/24 [00:01<00:00, 17.29it/s]


all_out_referee_labels_ids shape:  (3016, 32)
{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 1.0, 'Pro_recall': 1.0, 'Pro_f1': 1.0, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  1.0 , correct:  648 , total:  648
Slot Accurac:  1.0 , correct:  3016 , total:  3016
Intent Accurac:  1.0 , correct:  3016 , total:  3016



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating: 100%|██████████| 95/95 [00:05<00:00, 16.59it/s]


all_out_referee_labels_ids shape:  (12068, 32)


Iteration: 100%|██████████| 708/708 [02:11<00:00,  5.38it/s]
Epoch: 100%|██████████| 3/3 [06:38<00:00, 132.70s/it]

{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 0.9996170049789352, 'Pro_recall': 1.0, 'Pro_f1': 0.9998084658111472, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  0.9996168582375479 , correct:  2609 , total:  2610
Slot Accurac:  1.0 , correct:  12068 , total:  12068
Intent Accurac:  1.0 , correct:  12068 , total:  12068





(2124, 0.24817539988782883)

In [6]:
trainer.evaluate("test")

Evaluating: 100%|██████████| 95/95 [00:05<00:00, 16.80it/s]


all_out_referee_labels_ids shape:  (12068, 32)
{'slot_precision': 1.0, 'slot_recall': 1.0, 'slot_f1': 1.0, 'Pro_precision': 0.9996170049789352, 'Pro_recall': 1.0, 'Pro_f1': 0.9998084658111472, 'intent_token_precision': 1.0, 'intent_token_recall': 1.0, 'intent_token_f1': 1.0}
{0: 'PAD', 1: 'O', 2: 'I-obj', 3: 'B-sour', 4: 'B-dest', 5: 'I-sour', 6: 'B-what', 7: 'B-obj', 8: 'I-dest', 9: 'I-per', 10: 'I-what', 11: 'B-per'}
{0: 'PAD', 1: 'O', 2: 'B-greet', 3: 'I-greet', 4: 'B-know', 5: 'I-know', 6: 'B-follow', 7: 'I-follow', 8: 'B-take', 9: 'I-take', 10: 'B-tell', 11: 'I-tell', 12: 'B-guide', 13: 'I-guide', 14: 'B-go', 15: 'I-go', 16: 'B-answer', 17: 'I-answer', 18: 'B-find', 19: 'I-find'}
Pronoun Accurac:  0.9996168582375479 , correct:  2609 , total:  2610
Slot Accurac:  1.0 , correct:  12068 , total:  12068
Intent Accurac:  1.0 , correct:  12068 , total:  12068


{'loss': 1.5992251583198933e-07,
 'slot_precision': 1.0,
 'slot_recall': 1.0,
 'slot_f1': 1.0,
 'Pro_precision': 0.9996170049789352,
 'Pro_recall': 1.0,
 'Pro_f1': 0.9998084658111472,
 'intent_token_precision': 1.0,
 'intent_token_recall': 1.0,
 'intent_token_f1': 1.0,
 'Pronoun Accuracy': 0.9996168582375479}

In [7]:
intent_lab = get_intent_labels(args)
slot_label_lst = get_slot_labels(args)
slot_label_lst
intent_lab

['PAD',
 'O',
 'B-greet',
 'I-greet',
 'B-know',
 'I-know',
 'B-follow',
 'I-follow',
 'B-take',
 'I-take',
 'B-tell',
 'I-tell',
 'B-guide',
 'I-guide',
 'B-go',
 'I-go',
 'B-answer',
 'I-answer',
 'B-find',
 'I-find']

In [8]:

model_to_save = trainer.model.module if hasattr(trainer.model, 'module') else trainer.model
trainer.model.save_pretrained('trained_model')
print('finished')



finished
