In [1]:
import os
import logging
from tqdm import tqdm, trange
import torch.nn as nn

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import BertConfig, AdamW, get_linear_schedule_with_warmup

from utils import MODEL_CLASSES, compute_metrics, get_intent_labels, get_slot_labels, compute_metrics_multi_intent,compute_metrics_multi_intent_Pro

from seqeval.metrics.sequence_labeling import get_entities

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def weighted_binary_cross_entropy(output, target, weights=None):
    if weights is not None:
        assert len(weights) == 2

        loss = weights[1] * (target * torch.log(output)) + \
               weights[0] * ((1 - target) * torch.log(1 - output))
    else:
        loss = target * torch.log(output) + (1 - target) * torch.log(1 - output)

    return torch.neg(torch.mean(loss))

class Trainer_multi(object):
    def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None):
        self.args = args
        self.train_dataset = train_dataset
        self.dev_dataset = dev_dataset
        self.test_dataset = test_dataset
        
        self.slot_preds_list = None
        self.intent_token_preds_list = None
        
        
        # set of intents
        self.intent_label_lst = get_intent_labels(args)
        # set of slots
        self.slot_label_lst = get_slot_labels(args)
        # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
        self.pad_token_label_id = args.ignore_index

        self.config_class, self.model_class, _ = MODEL_CLASSES[args.model_type]
        
        self.config = self.config_class.from_pretrained(args.model_name_or_path, finetuning_task=args.task)


        self.model = self.model_class.from_pretrained(args.model_name_or_path,
                                                      config=self.config,
                                                      args=args,
                                                      intent_label_lst=self.intent_label_lst,
                                                      slot_label_lst=self.slot_label_lst,
                                                      )
        
        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        self.model.to(self.device)
       

    def train(self):
        
        #logger.info(vars(self.args))
        train_sampler = RandomSampler(self.train_dataset)
        train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size=self.args.train_batch_size)

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            self.args.num_train_epochs = self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
        else:
            t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.args.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(self.train_dataset))
        logger.info("  Num Epochs = %d", self.args.num_train_epochs)
        logger.info("  Total train batch size = %d", self.args.train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)
        logger.info("  Logging steps = %d", self.args.logging_steps)
        logger.info("  Save steps = %d", self.args.save_steps)

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()

        train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch")

        step_per_epoch = len(train_dataloader) // 2

        # record the evaluation loss
        eval_acc = 0.0
        MAX_RECORD = self.args.patience
        num_eval = -1
        eval_result_record = (num_eval, eval_acc)
        flag = False
        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=FLAG)
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)  # GPU or CPU

                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'intent_label_ids': batch[3],
                          'slot_labels_ids': batch[4],
                          'intent_token_ids': batch[5],
                          'B_tag_mask': batch[6],
                          'BI_tag_mask': batch[7],
                          'tag_intent_label': batch[8],
                          'referee_labels_ids' : batch[9], #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                          'pro_labels_ids' : batch[10]}
                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]


                outputs = self.model(**inputs)
                losses = outputs[0]
                loss = losses[0]
                intent_loss = losses[1]
                slot_loss = losses[2]
                intent_token_loss = losses[3]
                tag_intent_loss = losses[4]
                referee_token_loss = losses[5] #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                # tag_intent_loss_softmax = losses[5]

                if step == 500:
                    print('referee_token_loss: ',referee_token_loss)

                if self.args.gradient_accumulation_steps > 1:
                    loss = loss / self.args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                if (step + 1) % self.args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    self.model.zero_grad()
                    global_step += 1

                    # if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
                    if self.args.logging_steps > 0 and global_step % step_per_epoch == 0:
                        logger.info("***** Training Step %d *****", step)
                        logger.info("  total_loss = %f", loss)
                        logger.info("  intent_loss = %f", intent_loss)
                        logger.info("  slot_loss = %f", slot_loss)
                        logger.info("  intent_token_loss = %f", intent_token_loss)
                        logger.info("  tag_intent_loss = %f", tag_intent_loss)
                        logger.info("  referee_token_loss = %f", referee_token_loss) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                        # logger.info("  tag_intent_loss_softmax = %f", tag_intent_loss_softmax)

                        dev_result = self.evaluate("dev")
                        test_result = self.evaluate("test")
                        num_eval += 1
                        if self.args.patience != 0:
                            if dev_result['sementic_frame_acc'] + dev_result['intent_acc'] + dev_result['slot_f1']   > eval_result_record[1]:
                                self.save_model()
                                eval_result_record = (num_eval, dev_result['sementic_frame_acc'] + dev_result['intent_acc'] + dev_result['slot_f1'] )
                            else:
                                cur_num_eval = eval_result_record[0]
                                if num_eval - cur_num_eval >= MAX_RECORD:
                                    # it has been ok
                                    logger.info(' EARLY STOP Evaluate: at {}, best eval {} intent_slot_acc: {} '.format(num_eval, cur_num_eval, eval_result_record[1]))
                                    flag = True
                                    break
                        else:
                            self.save_model()

                            

                        # we check whether there is an overfitting issue for mixsnips
                        

                    # if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                    #     self.save_model()
                    
                if 0 < self.args.max_steps < global_step:
                    epoch_iterator.close()
                    break

            if flag:
                train_iterator.close()
                break

            if 0 < self.args.max_steps < global_step:
                train_iterator.close()
                break

        return global_step, tr_loss / global_step

    def evaluate(self, mode):
        if mode == 'test':
            dataset = self.test_dataset
        elif mode == 'dev':
            dataset = self.dev_dataset
        else:
            raise Exception("Only dev and test dataset available")

        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size)

        # Eval!
        logger.info("***** Running evaluation on %s dataset *****", mode)
        logger.info("  Num examples = %d", len(dataset))
        logger.info("  Batch size = %d", self.args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        intent_preds = None
        slot_preds = None
        intent_token_preds = None
        out_intent_label_ids = None
        out_slot_labels_ids = None
        out_intent_token_ids = None
        
        tag_intent_preds = None
        out_tag_intent_ids = None
        referee_preds = None #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        out_referee_labels_ids = None #!!!!!!!!!!!!!!!!!!!!!!!!!
        full_referee_preds = None #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        full_out_referee_labels_ids = None #!!!!!!!!!!!!!!!!!!!!!!!!!

        self.model.eval()

        for batch in tqdm(eval_dataloader, desc="Evaluating", disable=FLAG):
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'intent_label_ids': batch[3],
                          'slot_labels_ids': batch[4],
                          'intent_token_ids': batch[5],
                          'B_tag_mask': batch[6],
                          'BI_tag_mask': batch[7],
                          'tag_intent_label': batch[8],
                          'referee_labels_ids' : batch[9],
                          'pro_labels_ids' : batch[10]}
                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]
                outputs = self.model(**inputs)
                if self.args.pro and self.args.intent_seq and self.args.tag_intent: #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                    tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits, tag_intent_logits, referee_token_logits,all_referee_token_logits) = outputs[:2] #!!!!!!!!!!!!!
                elif self.args.intent_seq and self.args.tag_intent:
                    tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits, tag_intent_logits) = outputs[:2]
                elif self.args.intent_seq:
                    tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits) = outputs[:2]
                elif self.args.tag_intent:
                    tmp_eval_loss, (intent_logits, slot_logits, tag_intent_logits) = outputs[:2]
                else:
                    tmp_eval_loss, (intent_logits, slot_logits) = outputs[:2]

#                 eval_loss += tmp_eval_loss.mean().item()
#             nb_eval_steps += 1

            # ============================ Intent prediction =============================
            if intent_preds is None:
                intent_preds = intent_logits.detach().cpu().numpy()
                out_intent_label_ids = inputs['intent_label_ids'].detach().cpu().numpy()
            else:
                intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
                out_intent_label_ids = np.append(
                    out_intent_label_ids, inputs['intent_label_ids'].detach().cpu().numpy(), axis=0)

            # ============================= Slot prediction ==============================
            if slot_preds is None:
                if self.args.use_crf:
                    # decode() in `torchcrf` returns list with best index directly
                    slot_preds = np.array(self.model.crf.decode(slot_logits))
                else:
                    slot_preds = slot_logits.detach().cpu().numpy()

                out_slot_labels_ids = inputs["slot_labels_ids"].detach().cpu().numpy()
            else:
                if self.args.use_crf:
                    slot_preds = np.append(slot_preds, np.array(self.model.crf.decode(slot_logits)), axis=0)
                else:
                    slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)

                out_slot_labels_ids = np.append(out_slot_labels_ids, inputs["slot_labels_ids"].detach().cpu().numpy(), axis=0)

            # ============================= Pronoun referee prediction ==============================
            if self.args.pro:
                if referee_preds is None:
                    all_referee_preds = all_referee_token_logits.detach().cpu().numpy()
                    referee_preds = referee_token_logits.detach().cpu().numpy()

                    pro_sample_mask_np = (torch.max(inputs["pro_labels_ids"],dim = 1)[0] > 0).detach().cpu().numpy()
                    all_out_referee_labels_ids = inputs["referee_labels_ids"].detach().cpu().numpy()
                    out_referee_labels_ids = np.array([ele for i,ele in enumerate(all_out_referee_labels_ids) if pro_sample_mask_np[i] != False])


                else:
                    all_referee_preds = np.append(all_referee_preds,all_referee_token_logits.detach().cpu().numpy(), axis = 0)
                    referee_preds = np.append(referee_preds, referee_token_logits.detach().cpu().numpy(), axis = 0)

                    pro_sample_mask_np = (torch.max(inputs["pro_labels_ids"],dim = 1)[0] > 0).detach().cpu().numpy()
                    new_all_out_referee_labels_ids = inputs["referee_labels_ids"].detach().cpu().numpy()
                    all_out_referee_labels_ids = np.append(all_out_referee_labels_ids,new_all_out_referee_labels_ids,axis = 0)
                    new_out_referee_labels_ids = np.array([ele for i,ele in enumerate(new_all_out_referee_labels_ids) if pro_sample_mask_np[i] != False])
                    out_referee_labels_ids = np.append(out_referee_labels_ids, new_out_referee_labels_ids, axis = 0)



            # print('referee_preds shape: ',referee_preds.shape)
            # print('all_referee_preds shape: ', all_referee_preds.shape)
            # print('out_referee_labels_ids shape: ',out_referee_labels_ids.shape)
            # print('all_out_referee_labels_ids shape: ', all_out_referee_labels_ids.shape)
            # ============================== Intent Token Seq =============================
            if self.args.intent_seq:
                if intent_token_preds is None:
                    if self.args.use_crf:
                        intent_token_preds = np.array(self.model.crf.decode(intent_token_logits))
                    else:
                        intent_token_preds = intent_token_logits.detach().cpu().numpy()

                    out_intent_token_ids = inputs["intent_token_ids"].detach().cpu().numpy()
                else:
                    if self.args.use_crf:
                        intent_token_preds = np.append(intent_token_preds, np.array(self.model.crf.decode(intent_token_logits)), axis=0)
                    else:
                        intent_token_preds = np.append(intent_token_preds, intent_token_logits.detach().cpu().numpy(), axis=0)

                    out_intent_token_ids = np.append(out_intent_token_ids, inputs["intent_token_ids"].detach().cpu().numpy(), axis=0)
        
            # slot_preds: (64 * n, 50, 74)    
        
#         eval_loss = eval_loss / nb_eval_steps
#         results = {
#             "loss": eval_loss
#         }

        # Intent result
        # (batch_size, )
        # intent_preds = np.argmax(intent_preds, axis=1)
        # (batch_size, num_intents)
        # we set the threshold to 0.5
        intent_preds = torch.as_tensor(intent_preds > 0.5, dtype=torch.int32)

        # Slot result
        # (batch_size, seq_len)
        if not self.args.use_crf:
            slot_preds = np.argmax(slot_preds, axis=2)
        slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
        out_slot_label_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
        slot_preds_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
        
        B_tag_mask_pred = []
        BI_tag_mask_pred = []
        
        # generate mask
        for i in range(out_slot_labels_ids.shape[0]):
            # record the padding position
            pos_offset = [0 for _ in range(out_slot_labels_ids.shape[1])]
            pos_cnt = 0
            padding_recording = [0 for _ in range(out_slot_labels_ids.shape[1])]
            
            for j in range(out_slot_labels_ids.shape[1]):
                if out_slot_labels_ids[i, j] != self.pad_token_label_id:
                    out_slot_label_list[i].append(slot_label_map[out_slot_labels_ids[i][j]])
                    slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
                    pos_offset[pos_cnt+1] = pos_offset[pos_cnt]
                    pos_cnt += 1
                else:
                    pos_offset[pos_cnt] = pos_offset[pos_cnt] + 1
                    padding_recording[j] = 1
                    

            entities = get_entities(slot_preds_list[i])
            entities = [tag for entity_idx, tag in enumerate(entities) if slot_preds_list[i][tag[1]].startswith('B')]
            
            if len(entities) > self.args.num_mask:
                entities = entities[:self.args.num_mask]
            
            entity_masks = []
            
            for entity_idx, entity in enumerate(entities):
                entity_mask = [0 for _ in range(out_slot_labels_ids.shape[1])]
                start_idx = entity[1] + pos_offset[entity[1]]
                end_idx = entity[2] + pos_offset[entity[2]] + 1
                if self.args.BI_tag:
                    entity_mask[start_idx:end_idx] = [1] * (end_idx - start_idx)
                    for padding_idx in range(start_idx, end_idx):
                        if padding_recording[padding_idx]:
                            entity_mask[padding_idx] = 0
                else:
                    entity_mask[start_idx] = 1
                    
                entity_masks.append(entity_mask)
            
            for extra_idx in range(self.args.num_mask - len(entity_masks)):
                entity_masks.append([
                    0 for _ in range(out_slot_labels_ids.shape[1])
                ])

            
            if self.args.BI_tag:
                BI_tag_mask_pred.append(entity_masks)
            else:
                B_tag_mask_pred.append(entity_masks)
                
        if self.args.BI_tag:
            BI_tag_mask_pred_tensor = torch.FloatTensor(BI_tag_mask_pred)
        else:
            B_tag_mask_pred_tensor = torch.FloatTensor(B_tag_mask_pred)
        
        BI_tag_mask_pred_input = None
        B_tag_mask_pred_input = None
        
        for eval_idx, batch in tqdm(enumerate(eval_dataloader), desc="Evaluating", disable=FLAG):
            if self.args.BI_tag:
                BI_tag_mask_pred_input = BI_tag_mask_pred_tensor[eval_idx*self.args.eval_batch_size:(eval_idx+1)*self.args.eval_batch_size]
            else:
                B_tag_mask_pred_input = B_tag_mask_pred_tensor[eval_idx*self.args.eval_batch_size:(eval_idx+1)*self.args.eval_batch_size]
            
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'intent_label_ids': batch[3],
                          'slot_labels_ids': batch[4],
                          'intent_token_ids': batch[5],
                          'B_tag_mask': B_tag_mask_pred_input,
                          'BI_tag_mask': BI_tag_mask_pred_input,
                          'tag_intent_label': batch[8],
                          'referee_labels_ids' : batch[9],
                          'pro_labels_ids' : batch[10]}
                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]
                outputs = self.model(**inputs)
                if self.args.pro and self.args.intent_seq and self.args.tag_intent: #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                    # print('len: ',len(outputs[:2][1]))
                    tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits, tag_intent_logits, referee_token_logits,all_referee_token_logits) = outputs[:2] #!!!!!!!!!!!!!
                elif self.args.intent_seq and self.args.tag_intent:
                    tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits, tag_intent_logits) = outputs[:2]
                elif self.args.intent_seq:
                    tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits) = outputs[:2]
                elif self.args.tag_intent:
                    tmp_eval_loss, (intent_logits, slot_logits, tag_intent_logits) = outputs[:2]
                else:
                    tmp_eval_loss, (intent_logits, slot_logits) = outputs[:2]
                # if mode == 'test':
                #     print(eval_idx, ' ', tmp_eval_loss)



                eval_loss += tmp_eval_loss[0].mean().item()
            nb_eval_steps += 1
            
            if self.args.tag_intent:
                size_1 = inputs['tag_intent_label'].size(0)
                size_2 = inputs['tag_intent_label'].size(1)
                
                if tag_intent_preds is None:
                    tag_intent_preds = tag_intent_logits.view(size_1, size_2, -1).detach().cpu().numpy()
                    out_tag_intent_ids = inputs['tag_intent_label'].detach().cpu().numpy()
                else:
                    tag_intent_preds = np.append(tag_intent_preds, tag_intent_logits.view(size_1, size_2, -1).detach().cpu().numpy(), axis=0)
#                     print('out_tag_intent_ids shape: ', out_tag_intent_ids.shape)
#                     print('tag_intent_label shape: ', inputs['tag_intent_label'].shape)
                    out_tag_intent_ids = np.append(
                        out_tag_intent_ids, inputs['tag_intent_label'].detach().cpu().numpy(), axis=0)
                
        
        eval_loss = eval_loss / nb_eval_steps
        results = {
            "loss": eval_loss
        }


        # ============================= Pronoun Referee Prediction ============================ !!!!!!!!!!!!!!!!!!!!!!!!!!!!
        # print('referee_preds shape: ',referee_preds.shape)
        # print('all_referee_preds shape: ', all_referee_preds.shape)
        # print('out_referee_labels_ids shape: ',out_referee_labels_ids.shape)
        # print('all_out_referee_labels_ids shape: ', all_out_referee_labels_ids.shape)

        referee_token_map = {0:'PAD', 1:'O' ,2: 'B-referee'} # All referee are just one word in EGPSR


        if self.args.pro:
            referee_preds = np.argmax(referee_preds, axis=2)
            all_referee_preds = np.argmax(all_referee_preds, axis=2)


            referee_preds_list = [[] for _ in range(out_referee_labels_ids.shape[0])]
            out_referee_label_list = [[] for _ in range(out_referee_labels_ids.shape[0])]
            all_referee_preds_list = [[] for _ in range(all_out_referee_labels_ids.shape[0])]
            all_out_referee_label_list = [[] for _ in range(all_out_referee_labels_ids.shape[0])]



            for i in range(out_referee_labels_ids.shape[0]):
                for j in range(out_referee_labels_ids.shape[1]):
                    if out_referee_labels_ids[i, j] != self.pad_token_label_id:
                        out_referee_label_list[i].append(referee_token_map[out_referee_labels_ids[i][j]])
                        referee_preds_list[i].append(referee_token_map[referee_preds[i][j]])

            for i in range(all_out_referee_labels_ids.shape[0]):
                for j in range(all_out_referee_labels_ids.shape[1]):
                    if all_out_referee_labels_ids[i, j] != self.pad_token_label_id:
                        all_out_referee_label_list[i].append(referee_token_map[all_out_referee_labels_ids[i][j]])
                        all_referee_preds_list[i].append(referee_token_map[all_referee_preds[i][j]])



        intent_token_map = {i: label for i, label in enumerate(self.intent_label_lst)}
        out_intent_token_list = None
        intent_token_preds_list = None
        # ============================= Intent Seq Prediction ============================
        if self.args.intent_seq:
            if not self.args.use_crf:
                intent_token_preds = np.argmax(intent_token_preds, axis=2)
            out_intent_token_list = [[] for _ in range(out_intent_token_ids.shape[0])]
            intent_token_preds_list = [[] for _ in range(out_intent_token_ids.shape[0])]

            for i in range(out_intent_token_ids.shape[0]):
                for j in range(out_intent_token_ids.shape[1]):
                    if out_intent_token_ids[i, j] != self.pad_token_label_id:
                        out_intent_token_list[i].append(intent_token_map[out_intent_token_ids[i][j]])
                        intent_token_preds_list[i].append(intent_token_map[intent_token_preds[i][j]])



        # ============================ Tag Intent Prediction ==============================
        if self.args.tag_intent:
            tag_intent_preds = np.argmax(tag_intent_preds, axis=2)
            out_tag_intent_list = [[] for _ in range(out_tag_intent_ids.shape[0])]
            tag_intent_preds_list = [[] for _ in range(out_tag_intent_ids.shape[0])]
            
            for i in range(out_tag_intent_ids.shape[0]):
                for j in range(out_tag_intent_ids.shape[1]):
                    if out_tag_intent_ids[i, j] != self.pad_token_label_id:
                        out_tag_intent_list[i].append(intent_token_map[out_tag_intent_ids[i][j]])
                        tag_intent_preds_list[i].append(intent_token_map[tag_intent_preds[i][j]])

        #######################################################################3
        # # eval_referee_token_preds_list = [['B-referee' if logit > BINARY_THRESHOLD else 'O' for logit in sample] for sample in only_referee_token_preds]
        # # eval_out_referee_token_list = [['B-referee' if label == 1 else 'O' for label in sample] for sample in only_out_referee_token_ids]
        #
        # eval_referee_token_preds_list = [[label for label in sample] for sample in referee_token_preds_list]
        # eval_out_referee_token_list = [[label for label in sample] for sample in out_referee_token_list]
        #
        # # print('eval_referee_token_preds_list shape: ',len(eval_referee_token_preds_list), len(eval_referee_token_preds_list[0]))
        # # print('slot_preds_list shape: ',len(slot_preds_list), len(slot_preds_list[0]))
        #
        #
        # print('eval_referee_token_preds_list[0]:', eval_referee_token_preds_list[0])
        # print('eval_out_referee_token_list[0]:', eval_out_referee_token_list[0],'\n')
        # #
        # # print('eval_referee_token_preds_list[1]:', eval_referee_token_preds_list[1])
        # # print('eval_out_referee_token_list[1]:', eval_out_referee_token_list[1])
        #############################################################3

# referee_preds_list = [[] for _ in range(out_referee_labels_ids.shape[0])]
#             out_referee_label_list = [[] for _ in range(out_referee_labels_ids.shape[0])]
#             all_referee_preds_list = [[] for _ in range(all_out_referee_labels_ids.shape[0])]
#             all_out_referee_label_list


        # print('all_referee_preds_list len: ',len(all_referee_preds_list))
        # print('all_out_referee_label_list len: ',len(all_out_referee_label_list))

        total_result = compute_metrics_multi_intent_Pro(intent_preds,
                                       out_intent_label_ids,
                                       slot_preds_list,
                                       out_slot_label_list,
                                       intent_token_preds_list,
                                       out_intent_token_list,
                                       tag_intent_preds_list,
                                       out_tag_intent_list,
                                       referee_preds_list,
                                       out_referee_label_list
                                      )
        results.update(total_result)

        # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Pronoun Acc !!!!!!!!!!!!!!!!!!!!!!!!

        correct = 0
        for ref_pred_seq,ref_label_seq in zip(referee_preds_list,out_referee_label_list):
            if ref_pred_seq == ref_label_seq:
                correct += 1
        ref_acc = correct/len(referee_preds_list)
        print('Pronoun Accurac: ',ref_acc, ', correct: ',correct, ', total: ',len(referee_preds_list))
        results.update({'Pronoun Accuracy':ref_acc})

        #-----------------------------------------------------

        # tp = 0
        # fn = 0
        # fp = 0
        # tn = 0
        # for sample_idx,has_pro in enumerate(pro_sample_mask_np):
        #     if has_pro:
        #         if referee_token_preds_list[sample_idx] == out_referee_token_list[sample_idx]:
        #             tp += 1
        #         elif all(referee_token_preds_list[sample_idx] == 'O'): #referee not detected
        #             fn += 1
        #         else:
        #             fp += 1 #detected the wrong referee
        #
        # ref_acc = tp/sum(pro_sample_mask_np)
        # results.update({'Pronoun Accuracy':ref_acc})

        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

        logger.info("***** Eval results *****")
        for key in sorted(results.keys()):
            logger.info("  %s_%s = %s", mode, key, str(results[key]))
        
        #self.store_pred(slot_preds_list,intent_token_preds_list)
        self.slot_preds_list = slot_preds_list
        self.intent_token_preds_list = intent_token_preds_list
        return results

    def save_model(self):
        # Save model checkpoint (Overwrite)
        if not os.path.exists(self.args.model_dir):
            os.makedirs(self.args.model_dir)
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
        model_to_save.save_pretrained(self.args.model_dir)

        # Save training arguments together with the trained model
        torch.save(self.args, os.path.join(self.args.model_dir, 'training_args.bin'))
        logger.info("Saving model checkpoint to %s", self.args.model_dir)

    def load_model(self):
        # Check whether model exists
        if not os.path.exists(self.args.model_dir):
            raise Exception("Model doesn't exists! Train first!")

        try:
            self.model = self.model_class.from_pretrained(self.args.model_dir,
                                                          args=self.args,
                                                          intent_label_lst=self.intent_label_lst,
                                                          slot_label_lst=self.slot_label_lst)
            self.model.to(self.device)
            logger.info("***** Model Loaded *****")
        except:
            raise Exception("Some model files might be missing...")

In [3]:
FLAG = False
logger = logging.getLogger(__name__)

In [4]:
import random
import time
import argparse
from datetime import datetime
import logging

#from trainer import Trainer, Trainer_multi, Trainer_woISeq
from utils import init_logger, load_tokenizer, read_prediction_text, set_seed, MODEL_CLASSES, MODEL_PATH_MAP
from data_loader import load_and_cache_examples, processors

def init_logger():
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)

def main(args):
#     init_logger(args)
#     init_logger()

    set_seed(args)
    tokenizer = load_tokenizer(args)
    train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
    dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
    test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
    
    if args.multi_intent == 1:
        trainer = Trainer_multi(args, train_dataset, dev_dataset, test_dataset)
    else:
        trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
    if args.do_train:
        trainer.train()
    if args.do_eval:
        trainer.load_model()
        trainer.evaluate("test")
    return train_dataset



if __name__ == '__main__':
    time_wait = random.uniform(0, 10)
    time.sleep(time_wait)
    parser = argparse.ArgumentParser()
#     parser.add_argument("--task", default='mixsnips', required=True, type=str, help="The name of the task to train")
    parser.add_argument("--task", default='gpsr_pro', type=str, help="The name of the task to train")

#     parser.add_argument("--model_dir", default='./gpsr_model', required=True, type=str, help="Path to save, load model")
    parser.add_argument("--model_dir", default='./gpsr_pro_model', type=str, help="Path to save, load model")

    parser.add_argument("--data_dir", default="./data", type=str, help="The input data dir")
    parser.add_argument("--intent_label_file", default="intent_label.txt", type=str, help="Intent Label file")
    parser.add_argument("--slot_label_file", default="slot_label.txt", type=str, help="Slot Label file")
    parser.add_argument("--model_type", default="multibert", type=str, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
#     parser.add_argument("--intent_seq", type=int, default=0, help="whether we use intent seq setting")
    parser.add_argument("--intent_seq", type=int, default=1, help="whether we use intent seq setting")


    parser.add_argument("--pro", type=int, default=1, help="support pronoun disambiguition")#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


    parser.add_argument("--multi_intent", type=int, default=1, help="whether we use multi intent setting")
    parser.add_argument("--tag_intent", type=int, default=1, help="whether we can use tag to predict intent")
    
    parser.add_argument("--BI_tag", type=int, default=1, help='use BI sum or just B')
    parser.add_argument("--cls_token_cat", type=int, default=1, help='whether we cat the cls to the slot output of bert')
    parser.add_argument("--intent_attn", type=int, default=1, help='whether we use attention mechanism on the CLS intent output')
    parser.add_argument("--num_mask", type=int, default=7, help="assumptive number of slot in one sentence")
                                           #max slot num = 7
    
    
    parser.add_argument('--seed', type=int, default=25, help="random seed for initialization")
    parser.add_argument("--train_batch_size", default=64, type=int, help="Batch size for training.")
#     parser.add_argument("--train_batch_size", default=64, type=int, help="Batch size for training.")

    parser.add_argument("--eval_batch_size", default=128, type=int, help="Batch size for evaluation.")
    parser.add_argument("--max_seq_len", default=32, type=int, help="The maximum total input sequence length after tokenization.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
#     parser.add_argument("--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform.")
    parser.add_argument("--num_train_epochs", default=2.0, type=float, help="Total number of training epochs to perform.")
                                            #####
    
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1, type=float, help="Max gradient norm.")
    parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument("--dropout_rate", default=0.1, type=float, help="Dropout for fully-connected layers")
    parser.add_argument('--logging_steps', type=int, default=500, help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=300, help="Save checkpoint every X updates steps.")
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the test set.")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument("--ignore_index", default=0, type=int,
                        help='Specifies a target value that is ignored and does not contribute to the input gradient')
    parser.add_argument('--slot_loss_coef', type=float, default=2.0, help='Coefficient for the slot loss.')


    parser.add_argument('--pro_loss_coef', type=float, default=10.0, help='Coefficient for the pronoun loss.')



    parser.add_argument('--tag_intent_coef', type=float, default=1.0, help='Coefficient for the tag intent loss')

    # CRF option
    parser.add_argument("--use_crf", action="store_true", help="Whether to use CRF")
    parser.add_argument("--slot_pad_label", default="PAD", type=str, help="Pad token for slot label pad (to be ignore when calculate loss)")
    parser.add_argument("--patience", default=0, type=int, help="The initial learning rate for Adam.")
    
    parser.add_argument('-f')#########################
    
    args = parser.parse_args()
    
    now = datetime.now()
    args.model_dir = args.model_dir + '_' + now.strftime('%m-%d-%H:%M:%S')
    args.model_name_or_path = MODEL_PATH_MAP[args.model_type]

tokenizer = load_tokenizer(args)


In [5]:
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
test_dataset = load_and_cache_examples(args, tokenizer, mode="test")
trainer = Trainer_multi(args, train_dataset, dev_dataset, test_dataset)
trainer.train()
# train_dataset[1]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing JointBERTMultiIntent: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing JointBERTMultiIntent from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing JointBERTMultiIntent from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of JointBERTMultiIntent were not initialized from the model checkpoint at bert-base-uncased and are newly ini

preds:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9921414538310412 , correct:  505 , total:  509




Evaluating:   0%|          | 0/69 [00:00<?, ?it/s][A[A

Evaluating:   1%|▏         | 1/69 [00:00<00:06,  9.89it/s][A[A

Evaluating:   3%|▎         | 2/69 [00:00<00:07,  9.30it/s][A[A

Evaluating:   4%|▍         | 3/69 [00:00<00:07,  9.26it/s][A[A

Evaluating:   6%|▌         | 4/69 [00:00<00:07,  9.23it/s][A[A

Evaluating:   7%|▋         | 5/69 [00:00<00:06,  9.18it/s][A[A

Evaluating:   9%|▊         | 6/69 [00:00<00:06,  9.18it/s][A[A

Evaluating:  10%|█         | 7/69 [00:00<00:06,  9.18it/s][A[A

Evaluating:  12%|█▏        | 8/69 [00:00<00:06,  9.15it/s][A[A

Evaluating:  13%|█▎        | 9/69 [00:00<00:06,  9.06it/s][A[A

Evaluating:  14%|█▍        | 10/69 [00:01<00:06,  9.10it/s][A[A

Evaluating:  16%|█▌        | 11/69 [00:01<00:06,  9.10it/s][A[A

Evaluating:  17%|█▋        | 12/69 [00:01<00:06,  9.10it/s][A[A

Evaluating:  19%|█▉        | 13/69 [00:01<00:06,  9.08it/s][A[A

Evaluating:  20%|██        | 14/69 [00:01<00:06,  9.09it/s][A[A

Evaluating:

preds:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9975973089860644 , correct:  2076 , total:  2081



Iteration:  50%|█████     | 257/514 [01:06<27:42,  6.47s/it][A
Iteration:  50%|█████     | 258/514 [01:06<19:32,  4.58s/it][A
Iteration:  50%|█████     | 259/514 [01:06<13:50,  3.26s/it][A
Iteration:  51%|█████     | 260/514 [01:07<09:52,  2.33s/it][A
Iteration:  51%|█████     | 261/514 [01:07<07:06,  1.69s/it][A
Iteration:  51%|█████     | 262/514 [01:07<05:10,  1.23s/it][A
Iteration:  51%|█████     | 263/514 [01:07<03:50,  1.09it/s][A
Iteration:  51%|█████▏    | 264/514 [01:07<02:53,  1.44it/s][A
Iteration:  52%|█████▏    | 265/514 [01:07<02:14,  1.85it/s][A
Iteration:  52%|█████▏    | 266/514 [01:08<01:46,  2.32it/s][A
Iteration:  52%|█████▏    | 267/514 [01:08<01:27,  2.83it/s][A
Iteration:  52%|█████▏    | 268/514 [01:08<01:14,  3.32it/s][A
Iteration:  52%|█████▏    | 269/514 [01:08<01:04,  3.79it/s][A
Iteration:  53%|█████▎    | 270/514 [01:08<00:57,  4.21it/s][A
Iteration:  53%|█████▎    | 271/514 [01:09<00:53,  4.56it/s][A
Iteration:  53%|█████▎    | 272/514 [01

referee_token_loss:  tensor(0.0002, device='cuda:0', grad_fn=<NllLossBackward0>)



Iteration:  98%|█████████▊| 502/514 [01:49<00:02,  5.62it/s][A
Iteration:  98%|█████████▊| 503/514 [01:50<00:01,  5.62it/s][A
Iteration:  98%|█████████▊| 504/514 [01:50<00:01,  5.62it/s][A
Iteration:  98%|█████████▊| 505/514 [01:50<00:01,  5.63it/s][A
Iteration:  98%|█████████▊| 506/514 [01:50<00:01,  5.63it/s][A
Iteration:  99%|█████████▊| 507/514 [01:50<00:01,  5.63it/s][A
Iteration:  99%|█████████▉| 508/514 [01:51<00:01,  5.63it/s][A
Iteration:  99%|█████████▉| 509/514 [01:51<00:00,  5.62it/s][A
Iteration:  99%|█████████▉| 510/514 [01:51<00:00,  5.62it/s][A
Iteration:  99%|█████████▉| 511/514 [01:51<00:00,  5.62it/s][A
Iteration: 100%|█████████▉| 512/514 [01:51<00:00,  5.62it/s][A
Iteration: 100%|█████████▉| 513/514 [01:51<00:00,  5.64it/s][A

Evaluating:   0%|          | 0/18 [00:00<?, ?it/s][A[A

Evaluating:   6%|▌         | 1/18 [00:00<00:01,  9.57it/s][A[A

Evaluating:  11%|█         | 2/18 [00:00<00:01,  9.06it/s][A[A

Evaluating:  17%|█▋        | 3/18 [00:00

preds:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9941060903732809 , correct:  506 , total:  509




Evaluating:   0%|          | 0/69 [00:00<?, ?it/s][A[A

Evaluating:   1%|▏         | 1/69 [00:00<00:06,  9.73it/s][A[A

Evaluating:   3%|▎         | 2/69 [00:00<00:07,  9.19it/s][A[A

Evaluating:   4%|▍         | 3/69 [00:00<00:07,  9.13it/s][A[A

Evaluating:   6%|▌         | 4/69 [00:00<00:07,  9.14it/s][A[A

Evaluating:   7%|▋         | 5/69 [00:00<00:07,  9.13it/s][A[A

Evaluating:   9%|▊         | 6/69 [00:00<00:06,  9.12it/s][A[A

Evaluating:  10%|█         | 7/69 [00:00<00:06,  9.13it/s][A[A

Evaluating:  12%|█▏        | 8/69 [00:00<00:06,  9.10it/s][A[A

Evaluating:  13%|█▎        | 9/69 [00:00<00:06,  9.05it/s][A[A

Evaluating:  14%|█▍        | 10/69 [00:01<00:06,  9.09it/s][A[A

Evaluating:  16%|█▌        | 11/69 [00:01<00:06,  9.08it/s][A[A

Evaluating:  17%|█▋        | 12/69 [00:01<00:06,  9.08it/s][A[A

Evaluating:  19%|█▉        | 13/69 [00:01<00:06,  9.07it/s][A[A

Evaluating:  20%|██        | 14/69 [00:01<00:06,  9.07it/s][A[A

Evaluating:

preds:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9985583853916387 , correct:  2078 , total:  2081



Iteration: 100%|██████████| 514/514 [02:13<00:00,  3.85it/s][A
Epoch:  50%|█████     | 1/2 [02:13<02:13, 133.48s/it]
Iteration:   0%|          | 0/514 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/514 [00:00<01:30,  5.65it/s][A
Iteration:   0%|          | 2/514 [00:00<01:30,  5.68it/s][A
Iteration:   1%|          | 3/514 [00:00<01:30,  5.66it/s][A
Iteration:   1%|          | 4/514 [00:00<01:30,  5.66it/s][A
Iteration:   1%|          | 5/514 [00:00<01:30,  5.63it/s][A
Iteration:   1%|          | 6/514 [00:01<01:30,  5.62it/s][A
Iteration:   1%|▏         | 7/514 [00:01<01:30,  5.62it/s][A
Iteration:   2%|▏         | 8/514 [00:01<01:30,  5.62it/s][A
Iteration:   2%|▏         | 9/514 [00:01<01:29,  5.63it/s][A
Iteration:   2%|▏         | 10/514 [00:01<01:29,  5.63it/s][A
Iteration:   2%|▏         | 11/514 [00:01<01:29,  5.64it/s][A
Iteration:   2%|▏         | 12/514 [00:02<01:29,  5.63it/s][A
Iteration:   3%|▎         | 13/514 [00:02<01:28,  5.63it/s][A
Iteration:   3%|▎

preds:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9921414538310412 , correct:  505 , total:  509




Evaluating:   0%|          | 0/69 [00:00<?, ?it/s][A[A

Evaluating:   1%|▏         | 1/69 [00:00<00:07,  9.66it/s][A[A

Evaluating:   3%|▎         | 2/69 [00:00<00:07,  9.20it/s][A[A

Evaluating:   4%|▍         | 3/69 [00:00<00:07,  9.16it/s][A[A

Evaluating:   6%|▌         | 4/69 [00:00<00:07,  9.14it/s][A[A

Evaluating:   7%|▋         | 5/69 [00:00<00:07,  9.11it/s][A[A

Evaluating:   9%|▊         | 6/69 [00:00<00:06,  9.06it/s][A[A

Evaluating:  10%|█         | 7/69 [00:00<00:06,  9.04it/s][A[A

Evaluating:  12%|█▏        | 8/69 [00:00<00:06,  9.01it/s][A[A

Evaluating:  13%|█▎        | 9/69 [00:00<00:06,  8.96it/s][A[A

Evaluating:  14%|█▍        | 10/69 [00:01<00:06,  8.99it/s][A[A

Evaluating:  16%|█▌        | 11/69 [00:01<00:06,  9.02it/s][A[A

Evaluating:  17%|█▋        | 12/69 [00:01<00:06,  9.02it/s][A[A

Evaluating:  19%|█▉        | 13/69 [00:01<00:06,  9.04it/s][A[A

Evaluating:  20%|██        | 14/69 [00:01<00:06,  9.05it/s][A[A

Evaluating:

preds:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9985583853916387 , correct:  2078 , total:  2081



Iteration:  50%|█████     | 257/514 [01:07<28:27,  6.65s/it][A
Iteration:  50%|█████     | 258/514 [01:07<20:04,  4.70s/it][A
Iteration:  50%|█████     | 259/514 [01:07<14:13,  3.35s/it][A
Iteration:  51%|█████     | 260/514 [01:07<10:08,  2.39s/it][A
Iteration:  51%|█████     | 261/514 [01:08<07:17,  1.73s/it][A
Iteration:  51%|█████     | 262/514 [01:08<05:18,  1.26s/it][A
Iteration:  51%|█████     | 263/514 [01:08<03:55,  1.07it/s][A
Iteration:  51%|█████▏    | 264/514 [01:08<02:57,  1.41it/s][A
Iteration:  52%|█████▏    | 265/514 [01:08<02:17,  1.82it/s][A
Iteration:  52%|█████▏    | 266/514 [01:08<01:48,  2.28it/s][A
Iteration:  52%|█████▏    | 267/514 [01:09<01:28,  2.78it/s][A
Iteration:  52%|█████▏    | 268/514 [01:09<01:15,  3.27it/s][A
Iteration:  52%|█████▏    | 269/514 [01:09<01:05,  3.74it/s][A
Iteration:  53%|█████▎    | 270/514 [01:09<00:58,  4.16it/s][A
Iteration:  53%|█████▎    | 271/514 [01:09<00:53,  4.51it/s][A
Iteration:  53%|█████▎    | 272/514 [01

referee_token_loss:  tensor(2.7099e-05, device='cuda:0', grad_fn=<NllLossBackward0>)



Iteration:  98%|█████████▊| 502/514 [01:50<00:02,  5.57it/s][A
Iteration:  98%|█████████▊| 503/514 [01:51<00:01,  5.58it/s][A
Iteration:  98%|█████████▊| 504/514 [01:51<00:01,  5.58it/s][A
Iteration:  98%|█████████▊| 505/514 [01:51<00:01,  5.59it/s][A
Iteration:  98%|█████████▊| 506/514 [01:51<00:01,  5.59it/s][A
Iteration:  99%|█████████▊| 507/514 [01:51<00:01,  5.58it/s][A
Iteration:  99%|█████████▉| 508/514 [01:52<00:01,  5.58it/s][A
Iteration:  99%|█████████▉| 509/514 [01:52<00:00,  5.59it/s][A
Iteration:  99%|█████████▉| 510/514 [01:52<00:00,  5.61it/s][A
Iteration:  99%|█████████▉| 511/514 [01:52<00:00,  5.62it/s][A
Iteration: 100%|█████████▉| 512/514 [01:52<00:00,  5.62it/s][A
Iteration: 100%|█████████▉| 513/514 [01:52<00:00,  5.60it/s][A

Evaluating:   0%|          | 0/18 [00:00<?, ?it/s][A[A

Evaluating:   6%|▌         | 1/18 [00:00<00:01,  9.49it/s][A[A

Evaluating:  11%|█         | 2/18 [00:00<00:01,  9.13it/s][A[A

Evaluating:  17%|█▋        | 3/18 [00:00

preds:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9941060903732809 , correct:  506 , total:  509




Evaluating:   0%|          | 0/69 [00:00<?, ?it/s][A[A

Evaluating:   1%|▏         | 1/69 [00:00<00:06,  9.82it/s][A[A

Evaluating:   3%|▎         | 2/69 [00:00<00:07,  9.19it/s][A[A

Evaluating:   4%|▍         | 3/69 [00:00<00:07,  9.06it/s][A[A

Evaluating:   6%|▌         | 4/69 [00:00<00:07,  9.05it/s][A[A

Evaluating:   7%|▋         | 5/69 [00:00<00:07,  9.05it/s][A[A

Evaluating:   9%|▊         | 6/69 [00:00<00:06,  9.06it/s][A[A

Evaluating:  10%|█         | 7/69 [00:00<00:06,  9.06it/s][A[A

Evaluating:  12%|█▏        | 8/69 [00:00<00:06,  8.95it/s][A[A

Evaluating:  13%|█▎        | 9/69 [00:00<00:06,  8.87it/s][A[A

Evaluating:  14%|█▍        | 10/69 [00:01<00:06,  8.89it/s][A[A

Evaluating:  16%|█▌        | 11/69 [00:01<00:06,  8.90it/s][A[A

Evaluating:  17%|█▋        | 12/69 [00:01<00:06,  8.93it/s][A[A

Evaluating:  19%|█▉        | 13/69 [00:01<00:06,  8.92it/s][A[A

Evaluating:  20%|██        | 14/69 [00:01<00:06,  8.95it/s][A[A

Evaluating:

preds:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9985583853916387 , correct:  2078 , total:  2081



Iteration: 100%|██████████| 514/514 [02:14<00:00,  3.81it/s][A
Epoch: 100%|██████████| 2/2 [04:28<00:00, 134.11s/it]


(1028, 0.3613247651142949)

In [6]:
trainer.evaluate("test")

Evaluating: 100%|██████████| 69/69 [00:07<00:00,  9.06it/s]
Evaluating: 69it [00:07,  9.06it/s]


preds:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
labels:  ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-referee', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Pronoun Accurac:  0.9985583853916387 , correct:  2078 , total:  2081


{'loss': 0.1239446930155374,
 'intent_acc': 1.0,
 'slot_precision': 1.0,
 'slot_recall': 1.0,
 'slot_f1': 1.0,
 'Pro_precision': 0.9985604606525912,
 'Pro_recall': 1.0,
 'Pro_f1': 0.999279711884754,
 'intent_token_precision': 1.0,
 'intent_token_recall': 1.0,
 'intent_token_f1': 1.0,
 'tag_intent_acc': 0.9833333333333333,
 'sementic_frame_acc': 0.9489225857940942,
 'intent_slot_acc': 1.0,
 'Pronoun Accuracy': 0.9985583853916387}