In [1]:
import os
import logging
import argparse
from tqdm import tqdm, trange
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from utils import init_logger, load_tokenizer, get_intent_labels, get_slot_labels, MODEL_CLASSES, MODEL_PATH_MAP
# from utils import init_logger, load_tokenizer, read_prediction_text, set_seed, MODEL_CLASSES, MODEL_PATH_MAP, get_intent_labels, get_slot_labels

from seqeval.metrics.sequence_labeling import get_entities

logger = logging.getLogger(__name__)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def get_device(pred_config):
    return "cuda" if torch.cuda.is_available() and not pred_config.no_cuda else "cpu"

In [5]:
def get_args(pred_config):
    return torch.load(os.path.join(pred_config.model_dir, 'training_args.bin'))

In [16]:
def load_model(pred_config, args, device):
    # Check whether model exists
    if not os.path.exists(pred_config.model_dir):
        raise Exception("Model doesn't exists! Train first!")
    try:
        # model = torch.load(os.path.join(args.model_dir, 'pytorch_model.bin'))

        model = MODEL_CLASSES[args.model_type][1].from_pretrained(args.model_dir,
                                                                  args=args,
                                                                  intent_label_lst=get_intent_labels(args),
                                                                  slot_label_lst=get_slot_labels(args))
        model.to(device)
        model.eval()
        logger.info("***** Model Loaded *****")
    except:
        raise Exception("Some model files might be missing...")
    return model

In [7]:
def read_input_file(pred_config):
    lines = []
    with open(pred_config.input_file, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            words = line.split()
            lines.append(words)
    return lines

In [8]:
def convert_input_file_to_tensor_dataset(lines,
                                         pred_config,
                                         args,
                                         tokenizer,
                                         pad_token_label_id,
                                         cls_token_segment_id=0,
                                         pad_token_segment_id=0,
                                         sequence_a_segment_id=0,
                                         mask_padding_with_zero=True):
    # Setting based on the current model type
    cls_token = tokenizer.cls_token
    sep_token = tokenizer.sep_token
    unk_token = tokenizer.unk_token
    pad_token_id = tokenizer.pad_token_id

    all_input_ids = []
    all_attention_mask = []
    all_token_type_ids = []
    all_slot_label_mask = []
    for words in lines:
        tokens = []
        slot_label_mask = []
        for word in words:
            word_tokens = tokenizer.tokenize(word)
            if not word_tokens:
                word_tokens = [unk_token]  # For handling the bad-encoded word
            tokens.extend(word_tokens)
            # Use the real label id for the first token of the word, and padding ids for the remaining tokens
            slot_label_mask.extend([pad_token_label_id + 1] + [pad_token_label_id] * (len(word_tokens) - 1))

        print(tokens)

        # Account for [CLS] and [SEP]
        special_tokens_count = 2
        if len(tokens) > args.max_seq_len - special_tokens_count:
            tokens = tokens[: (args.max_seq_len - special_tokens_count)]
            slot_label_mask = slot_label_mask[:(args.max_seq_len - special_tokens_count)]

        # Add [SEP] token
        tokens += [sep_token]
        token_type_ids = [sequence_a_segment_id] * len(tokens)
        slot_label_mask += [pad_token_label_id]

        # Add [CLS] token
        tokens = [cls_token] + tokens
        token_type_ids = [cls_token_segment_id] + token_type_ids
        slot_label_mask = [pad_token_label_id] + slot_label_mask
        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = args.max_seq_len - len(input_ids)
        input_ids = input_ids + ([pad_token_id] * padding_length)
        attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
        token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
        slot_label_mask = slot_label_mask + ([pad_token_label_id] * padding_length)
        all_input_ids.append(input_ids)
        all_attention_mask.append(attention_mask)
        all_token_type_ids.append(token_type_ids)
        all_slot_label_mask.append(slot_label_mask)

        # print('padding_length: \n',padding_length,'\n')
        # print('input_ids: \n',input_ids,'\n')
        # print('slot_label_mask: \n',slot_label_mask,'\n')
        # print(f'attention_mask: \n{attention_mask}\n')

    # Change to Tensor
    all_input_ids = torch.tensor(all_input_ids, dtype=torch.long)
    all_attention_mask = torch.tensor(all_attention_mask, dtype=torch.long)
    all_token_type_ids = torch.tensor(all_token_type_ids, dtype=torch.long)
    all_slot_label_mask = torch.tensor(all_slot_label_mask, dtype=torch.long)
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_slot_label_mask)
    return dataset

In [10]:
def predict(pred_config, dataset):
    # load model and args
    args = get_args(pred_config)
    device = get_device(pred_config)
    model = load_model(pred_config, args, device)
    logger.info(args)
    intent_label_lst = get_intent_labels(args)
    slot_label_lst = get_slot_labels(args)

    # Convert input file to TensorDataset
    pad_token_label_id = args.ignore_index
    tokenizer = load_tokenizer(args)
    lines = read_input_file(pred_config)
    #dataset = convert_input_file_to_tensor_dataset(lines, pred_config, args, tokenizer, pad_token_label_id)

    # Predict
    sampler = SequentialSampler(dataset)
    data_loader = DataLoader(dataset, sampler=sampler, batch_size=pred_config.batch_size)
    all_slot_label_mask = None

    intent_preds = None
    slot_preds = None
    intent_token_preds = None
    out_intent_label_ids = None
    out_slot_labels_ids = None
    out_intent_token_ids = None

    tag_intent_preds = None
    out_tag_intent_ids = None

    referee_preds = None #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    out_referee_labels_ids = None #!!!!!!!!!!!!!!!!!!!!!!!!!
    all_referee_preds = None #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    all_out_referee_labels_ids = None #!!!!!!!!!!!!!!!!!!!!!!!!!

    for batch in tqdm(data_loader, desc="Predicting"):
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'intent_label_ids': batch[3],
                      'slot_labels_ids': batch[4],
                      'intent_token_ids': batch[5],
                      'B_tag_mask': batch[6],
                      'BI_tag_mask': batch[7],
                      'tag_intent_label': batch[8],
                       'referee_labels_ids' : batch[9],
                       'pro_labels_ids' : batch[10]}

            if args.model_type != "distilbert":
                inputs["token_type_ids"] = batch[2]
            outputs = model(**inputs)
            if args.pro and args.intent_seq and args.tag_intent: #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                    tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits, tag_intent_logits, referee_token_logits,all_referee_token_logits) = outputs[:2] #!!!!!!!!!!!!!
            elif args.intent_seq and args.tag_intent:
                tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits, tag_intent_logits) = outputs[:2]
            elif args.intent_seq:
                tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits) = outputs[:2]
            elif args.tag_intent:
                tmp_eval_loss, (intent_logits, slot_logits, tag_intent_logits) = outputs[:2]
            else:
                tmp_eval_loss, (intent_logits, slot_logits) = outputs[:2]

        # ============================ Intent prediction =============================
        if intent_preds is None:
            intent_preds = intent_logits.detach().cpu().numpy()
            out_intent_label_ids = inputs['intent_label_ids'].detach().cpu().numpy()
        else:
            intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
            out_intent_label_ids = np.append(
                out_intent_label_ids, inputs['intent_label_ids'].detach().cpu().numpy(), axis=0)

        # ============================= Slot prediction ==============================
        if slot_preds is None:
            if args.use_crf:
                # decode() in `torchcrf` returns list with best index directly
                slot_preds = np.array(model.crf.decode(slot_logits))
            else:
                slot_preds = slot_logits.detach().cpu().numpy()

            out_slot_labels_ids = inputs["slot_labels_ids"].detach().cpu().numpy()
        else:
            if args.use_crf:
                slot_preds = np.append(slot_preds, np.array(model.crf.decode(slot_logits)), axis=0)
            else:
                slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)

            out_slot_labels_ids = np.append(out_slot_labels_ids, inputs["slot_labels_ids"].detach().cpu().numpy(),
                                            axis=0)
        # print('slot_preds shape:     ',slot_preds.shape)
        # ============================= Pronoun referee prediction ==============================
        if args.pro:
            if all_referee_preds is None:
                all_referee_preds = all_referee_token_logits.detach().cpu().numpy()
                referee_preds = referee_token_logits.detach().cpu().numpy()

                pro_sample_mask_np = (torch.max(inputs["pro_labels_ids"],dim = 1)[0] ==1 ).detach().cpu().numpy()

                all_out_referee_labels_ids = inputs["referee_labels_ids"].detach().cpu().numpy()
                # out_referee_labels_ids = np.array([ele for i,ele in enumerate(all_out_referee_labels_ids) if pro_sample_mask_np[i] != False])
                out_referee_labels_ids = all_out_referee_labels_ids[pro_sample_mask_np]


            else:
                all_referee_preds = np.append(all_referee_preds,all_referee_token_logits.detach().cpu().numpy(), axis = 0)
                referee_preds = np.append(referee_preds, referee_token_logits.detach().cpu().numpy(), axis = 0)

                pro_sample_mask_np = (torch.max(inputs["pro_labels_ids"],dim = 1)[0] == 1).detach().cpu().numpy()
                new_all_out_referee_labels_ids = inputs["referee_labels_ids"].detach().cpu().numpy()
                all_out_referee_labels_ids = np.append(all_out_referee_labels_ids,new_all_out_referee_labels_ids,axis = 0)
                # small_new_out_referee_labels_ids = np.array([ele for i,ele in enumerate(new_all_out_referee_labels_ids) if pro_sample_mask_np[i] != False])
                small_new_out_referee_labels_ids = new_all_out_referee_labels_ids[pro_sample_mask_np]
                out_referee_labels_ids = np.append(out_referee_labels_ids, small_new_out_referee_labels_ids, axis = 0)

        # print('all_referee_token_logits shape: ',all_referee_token_logits.shape)
        # print('all_referee_preds shape: ',all_referee_preds.shape)
        # ============================== Intent Token Seq =============================
        if args.intent_seq:
            if intent_token_preds is None:
                if args.use_crf:
                    intent_token_preds = np.array(model.crf.decode(intent_token_logits))
                else:
                    intent_token_preds = intent_token_logits.detach().cpu().numpy()

                out_intent_token_ids = inputs["intent_token_ids"].detach().cpu().numpy()
            else:
                if args.use_crf:
                    intent_token_preds = np.append(intent_token_preds,
                                                   np.array(model.crf.decode(intent_token_logits)), axis=0)
                else:
                    intent_token_preds = np.append(intent_token_preds, intent_token_logits.detach().cpu().numpy(),
                                                   axis=0)

                out_intent_token_ids = np.append(out_intent_token_ids,
                                                 inputs["intent_token_ids"].detach().cpu().numpy(), axis=0)

        # print('intent_token_preds shape: ',intent_token_preds.shape)

    # Intent result
    # (batch_size, )
    # intent_preds = np.argmax(intent_preds, axis=1)
    # (batch_size, num_intents)
    # we set the threshold to 0.5
    intent_preds = torch.as_tensor(intent_preds > 0.5, dtype=torch.int32)

    # Slot result
    # (batch_size, seq_len)
    if not args.use_crf:
        slot_preds = np.argmax(slot_preds, axis=2)
    slot_label_map = {i: label for i, label in enumerate(slot_label_lst)}
    out_slot_label_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
    slot_preds_list = [[] for _ in range(out_slot_labels_ids.shape[0])]

    B_tag_mask_pred = []
    BI_tag_mask_pred = []

    # generate mask
    for i in range(out_slot_labels_ids.shape[0]):
        # record the padding position
        pos_offset = [0 for _ in range(out_slot_labels_ids.shape[1])]
        pos_cnt = 0
        padding_recording = [0 for _ in range(out_slot_labels_ids.shape[1])]

        for j in range(out_slot_labels_ids.shape[1]):
            if out_slot_labels_ids[i, j] != pad_token_label_id:
                out_slot_label_list[i].append(slot_label_map[out_slot_labels_ids[i][j]])
                slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
                pos_offset[pos_cnt + 1] = pos_offset[pos_cnt]
                pos_cnt += 1
            else:
                pos_offset[pos_cnt] = pos_offset[pos_cnt] + 1
                padding_recording[j] = 1

        entities = get_entities(slot_preds_list[i])
        entities = [tag for entity_idx, tag in enumerate(entities) if slot_preds_list[i][tag[1]].startswith('B')]

        if len(entities) > args.num_mask:
            entities = entities[:args.num_mask]

        entity_masks = []

        for entity_idx, entity in enumerate(entities):
            entity_mask = [0 for _ in range(out_slot_labels_ids.shape[1])]
            start_idx = entity[1] + pos_offset[entity[1]]
            end_idx = entity[2] + pos_offset[entity[2]] + 1
            if args.BI_tag:
                entity_mask[start_idx:end_idx] = [1] * (end_idx - start_idx)
                for padding_idx in range(start_idx, end_idx):
                    if padding_recording[padding_idx]:
                        entity_mask[padding_idx] = 0
            else:
                entity_mask[start_idx] = 1

            entity_masks.append(entity_mask)

        for extra_idx in range(args.num_mask - len(entity_masks)):
            entity_masks.append([
                0 for _ in range(out_slot_labels_ids.shape[1])
            ])

        if args.BI_tag:
            BI_tag_mask_pred.append(entity_masks)
        else:
            B_tag_mask_pred.append(entity_masks)

    if args.BI_tag:
        BI_tag_mask_pred_tensor = torch.FloatTensor(BI_tag_mask_pred)
    else:
        B_tag_mask_pred_tensor = torch.FloatTensor(B_tag_mask_pred)

    BI_tag_mask_pred_input = None
    B_tag_mask_pred_input = None

    for eval_idx, batch in tqdm(enumerate(data_loader), desc="Predicting", disable=False):
        if args.BI_tag:
            BI_tag_mask_pred_input = BI_tag_mask_pred_tensor[
                                     eval_idx * (pred_config.batch_size):(eval_idx + 1) * pred_config.batch_size]
        else:
            B_tag_mask_pred_input = B_tag_mask_pred_tensor[
                                    eval_idx * (pred_config.batch_size):(eval_idx + 1) * pred_config.batch_size]

        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'intent_label_ids': batch[3],
                      'slot_labels_ids': batch[4],
                      'intent_token_ids': batch[5],
                      'B_tag_mask': B_tag_mask_pred_input,
                      'BI_tag_mask': BI_tag_mask_pred_input,
                      'tag_intent_label': batch[8],
                      'referee_labels_ids' : batch[9],
                      'pro_labels_ids' : batch[10]}

            if args.model_type != 'distilbert':
                inputs['token_type_ids'] = batch[2]
            outputs = model(**inputs)
            if args.pro and args.intent_seq and args.tag_intent: #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                    # print('len: ',len(outputs[:2][1]))
                    tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits, tag_intent_logits, referee_token_logits,all_referee_token_logits) = outputs[:2] #!!!!!!!!!!!!!
            elif args.intent_seq and args.tag_intent:
                tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits, tag_intent_logits) = outputs[:2]
            elif args.intent_seq:
                tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits) = outputs[:2]
            elif args.tag_intent:
                tmp_eval_loss, (intent_logits, slot_logits, tag_intent_logits) = outputs[:2]
            else:
                tmp_eval_loss, (intent_logits, slot_logits) = outputs[:2]

        if args.tag_intent:
            size_1 = inputs['tag_intent_label'].size(0)
            size_2 = inputs['tag_intent_label'].size(1)

            if tag_intent_preds is None:
                tag_intent_preds = tag_intent_logits.view(size_1, size_2, -1).detach().cpu().numpy()
                out_tag_intent_ids = inputs['tag_intent_label'].detach().cpu().numpy()
            else:
                tag_intent_preds = np.append(tag_intent_preds,
                                             tag_intent_logits.view(size_1, size_2, -1).detach().cpu().numpy(),
                                             axis=0)
                out_tag_intent_ids = np.append(
                    out_tag_intent_ids, inputs['tag_intent_label'].detach().cpu().numpy(), axis=0)


    # ============================= Pronoun Referee Prediction ============================ !!!!!!!!!!!!!!!!!!!!!!!!!!!!
    print('referee_preds shape: ',referee_preds.shape)
    print('all_referee_preds shape: ', all_referee_preds.shape)
    print('out_referee_labels_ids shape: ',out_referee_labels_ids.shape)
    print('all_out_referee_labels_ids shape: ', all_out_referee_labels_ids.shape)
    print('out_slot_labels_ids shape: ',out_slot_labels_ids.shape)
    print('out_intent_token_ids shape: ',out_intent_token_ids.shape)

    print('out_referee_labels_ids[0]:     ',out_referee_labels_ids[0])
    print('out_slot_labels_ids[0]:        ',out_slot_labels_ids[0])

    referee_token_map = {0:'PAD', 1:'O' ,2: 'B-referee'} # All referee are just one word in EGPSR


    if args.pro:
        referee_preds = np.argmax(referee_preds, axis=2)
        all_referee_preds = np.argmax(all_referee_preds, axis=2)

        print('referee_preds shape:       ',referee_preds.shape)
        print('referee_preds [0]:         ',referee_preds[0])
        print('all_referee_preds shape: ',all_referee_preds.shape)


        referee_preds_list = [[] for _ in range(out_referee_labels_ids.shape[0])]
        out_referee_label_list = [[] for _ in range(out_referee_labels_ids.shape[0])]
        all_referee_preds_list = [[] for _ in range(all_out_referee_labels_ids.shape[0])]
        all_out_referee_label_list = [[] for _ in range(all_out_referee_labels_ids.shape[0])]


        for i in range(out_referee_labels_ids.shape[0]):
            for j in range(out_referee_labels_ids.shape[1]):
                if out_referee_labels_ids[i, j] != pad_token_label_id: #out_slot_labels_ids,out_referee_labels_ids
                    out_referee_label_list[i].append(referee_token_map[out_referee_labels_ids[i][j]])
                    referee_preds_list[i].append(referee_token_map[referee_preds[i][j]])

        for i in range(all_out_referee_labels_ids.shape[0]):
            for j in range(all_out_referee_labels_ids.shape[1]):
                if all_out_referee_labels_ids[i, j] != pad_token_label_id: #all_out_referee_labels_ids
                    all_out_referee_label_list[i].append(referee_token_map[all_out_referee_labels_ids[i][j]])
                    all_referee_preds_list[i].append(referee_token_map[all_referee_preds[i][j]])




    intent_token_map = {i: label for i, label in enumerate(intent_label_lst)}
    out_intent_token_list = None
    intent_token_preds_list = None
    # ============================= Intent Seq Prediction ============================
    if args.intent_seq:
        if not args.use_crf:
            intent_token_preds = np.argmax(intent_token_preds, axis=2)
        out_intent_token_list = [[] for _ in range(out_intent_token_ids.shape[0])]
        intent_token_preds_list = [[] for _ in range(out_intent_token_ids.shape[0])]

        for i in range(out_intent_token_ids.shape[0]):
            for j in range(out_intent_token_ids.shape[1]):
                if out_intent_token_ids[i, j] != pad_token_label_id:
                    out_intent_token_list[i].append(intent_token_map[out_intent_token_ids[i][j]])
                    intent_token_preds_list[i].append(intent_token_map[intent_token_preds[i][j]])

    out_tag_intent_list = None
    tag_intent_preds_list = None
    # ============================ Tag Intent Prediction ==============================
    if args.tag_intent:
        tag_intent_preds = np.argmax(tag_intent_preds, axis=2)
        out_tag_intent_list = [[] for _ in range(out_tag_intent_ids.shape[0])]
        tag_intent_preds_list = [[] for _ in range(out_tag_intent_ids.shape[0])]

        for i in range(out_tag_intent_ids.shape[0]):
            for j in range(out_tag_intent_ids.shape[1]):
                if out_tag_intent_ids[i, j] != pad_token_label_id:
                    out_tag_intent_list[i].append(intent_token_map[out_tag_intent_ids[i][j]])
                    tag_intent_preds_list[i].append(intent_token_map[tag_intent_preds[i][j]])

    # Write to output file
    pronouns = ['him','her','it','its']
    p_count = 0
    with open(pred_config.output_file, "w", encoding="utf-8") as f:
        for idx,(words, slot_preds, intent_preds,referee_preds,true_referee_preds) in enumerate(zip(lines, slot_preds_list, intent_token_preds_list,all_referee_preds_list,all_out_referee_label_list)):
            if idx <= 10:
                print('words:              ',words, len(words))
                print('slot_preds:         ',slot_preds, len(slot_preds))
                print('referee_preds:      ',referee_preds, len(referee_preds))
                print('true_referee_preds: ',true_referee_preds, len(true_referee_preds))
                print('intent_preds:       ',intent_preds, len(intent_preds))
                print('=====================================')
            line = ""
            if 'B-referee' not in referee_preds:#all([word not in pronouns for word in words]):
                for word, i_pred, s_pred in zip(words, intent_preds, slot_preds):
                    if s_pred == 'O' and i_pred == 'O':
                        line = line + word + " "
                    else:
                        line = line + "[{}:{}:{}] ".format(word, i_pred,s_pred)
                #f.write("<{}> -> {}\n".format(intent_label_lst[intent_pred], line.strip()))
                f.write(line.strip()+'\n')
            else:
                r_idx = referee_preds.index('B-referee')
                for word, i_pred, s_pred, r_pred in zip(words, intent_preds, slot_preds,referee_preds):
                    if s_pred == 'O' and i_pred == 'O':
                        line = line + word + " "
                    else:
                        if word not in pronouns:
                            line = line + "[{}:{}:{}] ".format(word, i_pred,s_pred)
                            if r_pred == 'B-referee':
                                ref = word
                        else:
                            if r_idx >= len(words):
                                print('sample: ',idx)
                                print(words)
                                print('len(words): ',len(words))
                                print('len(intent_preds): ',len(intent_preds))
                                print('len(slot_preds): ',len(slot_preds))
                                print(slot_preds)
                                print('len(referee_preds): ',len(referee_preds))
                                print(referee_preds)

                                print(r_idx)
                                print('--------------------------')


                                line = line + "[{}:{}:{}] ".format(word, i_pred,s_pred)
                            else:
                                line = line + "[{}:{}:{}:{}] ".format(word,words[r_idx], i_pred,s_pred)
                                # if idx <= 100:
                                #     print('good idx: ',idx)
                                #     print(line)
                                #     print('len(slot_preds): ',len(slot_preds))
                                #     print(slot_preds)
                                #     print('len(referee_preds): ',len(referee_preds))
                                #     print(referee_preds)
                                #     print(all_referee_preds_list[idx])

                #f.write("<{}> -> {}\n".format(intent_label_lst[intent_pred], line.strip()))
                f.write('\n')
                f.write('---------------------------------------------------------------------\n')
                f.write('* Pro Case: \n')
                f.write(line.strip()+'\n')
                f.write('---------------------------------------------------------------------\n \n')


    logger.info("Prediction Done!")
    return

In [17]:
from data_loader import load_and_cache_examples
import argparse

# train_dataset = load_and_cache_examples(args, tokenizer, mode="train")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # parser.add_argument("--input_file", default="sample_pred_in.txt", type=str, help="Input file for prediction")
    parser.add_argument("--input_file", default="data/gpsr_pro_instance/test/seq.in", type=str, help="Input file for prediction")

    parser.add_argument("--task", default='gpsr_pro_instance', type=str, help="The name of the task to train")
    parser.add_argument("--model_type", default="multibert", type=str,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--data_dir", default="./data", type=str, help="The input data dir")
    parser.add_argument("--intent_label_file", default="intent_label.txt", type=str, help="Intent Label file")
    parser.add_argument("--slot_label_file", default="slot_label.txt", type=str, help="Slot Label file")
    parser.add_argument("--intent_seq", type=int, default=1, help="whether we use intent seq setting")


    parser.add_argument("--pro", type=int, default=1, help="support pronoun disambiguition")#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!



    parser.add_argument("--num_mask", type=int, default=7, help="assumptive number of slot in one sentence")
    parser.add_argument("--ignore_index", default=0, type=int,
                        help='Specifies a target value that is ignored and does not contribute to the input gradient')

    parser.add_argument("--output_file", default="sample_pred_out_pro_instance.txt", type=str, help="Output file for prediction")
    parser.add_argument("--model_dir", default="./gpsr_pro_instance_model_02-10-02:24:43", type=str, help="Path to save, load model")
    parser.add_argument("--max_seq_len", default=32, type=int,
                        help="The maximum total input sequence length after tokenization.")
    parser.add_argument("--batch_size", default=128, type=int, help="Batch size for prediction")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument('-f')  #########################

    pred_config = parser.parse_args()

    pred_config.model_name_or_path = MODEL_PATH_MAP[pred_config.model_type]
    pred_config.model_name_or_path = MODEL_PATH_MAP[pred_config.model_type]

    tokenizer = load_tokenizer(pred_config)
    dev_dataset = load_and_cache_examples(pred_config, tokenizer, mode="dev")
    test_dataset = load_and_cache_examples(pred_config, tokenizer, mode="test")

    predict(pred_config, test_dataset)

Predicting:   0%|          | 0/69 [00:00<?, ?it/s]


TypeError: forward() got an unexpected keyword argument 'intent_label_ids'

In [9]:
#tokenizer = load_tokenizer(pred_config)
sample = test_dataset[0]
print(tokenizer.decode(sample[0]))
# dev_dataset[0]
print(sample[4])
print(sample[9])
sample[10]

[CLS] could you look for mary, retrieve the orange from the cupboard, and set it on the sink [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
tensor([0, 1, 1, 1, 1, 7, 1, 1, 1, 2, 1, 1, 8, 1, 1, 1, 2, 1, 1, 3, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])


tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

In [10]:
# debug
args = get_args(pred_config)
model = load_model(pred_config, args, 'cuda')
sampler = SequentialSampler(test_dataset)
data_loader = DataLoader(test_dataset, sampler=sampler, batch_size=pred_config.batch_size)
for batch in tqdm(data_loader, desc="Predicting"):
        batch = tuple(t.to('cuda') for t in batch)
        with torch.no_grad():
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'intent_label_ids': batch[3],
                      'slot_labels_ids': batch[4],
                      'intent_token_ids': batch[5],
                      'B_tag_mask': batch[6],
                      'BI_tag_mask': batch[7],
                      'tag_intent_label': batch[8],
                       'referee_labels_ids' : batch[9],
                       'pro_labels_ids' : batch[10]}

            if args.model_type != "distilbert":
                inputs["token_type_ids"] = batch[2]
            outputs = model(**inputs)
            if args.pro and args.intent_seq and args.tag_intent: #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                    tmp_eval_loss, (intent_logits, slot_logits, intent_token_logits, tag_intent_logits, referee_token_logits,all_referee_token_logits) = outputs[:2] #!!!!!!!!!!!!!

Predicting: 100%|██████████| 69/69 [00:07<00:00,  9.27it/s]


In [11]:
torch.max(inputs["pro_labels_ids"],dim = 1)[0] >0

tensor([False, False,  True, False, False, False, False,  True, False, False,
        False, False, False,  True,  True, False, False, False, False, False,
        False, False,  True, False, False, False,  True, False, False,  True,
        False, False, False,  True, False,  True,  True,  True,  True,  True,
         True, False, False,  True, False,  True, False,  True, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False], device='cuda:0')

In [12]:
tokenizer.decode(inputs['input_ids'][7])

'[CLS] leave sponge on the sink, grasp the fruits from the dining table, and deliver to it to me please [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [13]:
print(inputs["pro_labels_ids"].shape)
print(inputs["referee_labels_ids"].shape)
b = torch.max(inputs["pro_labels_ids"],dim = 1)[0] >0
print(b.shape)
print(inputs["referee_labels_ids"][7])

torch.Size([67, 32])
torch.Size([67, 32])
torch.Size([67])
tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')


In [14]:
a = np.array([[1,2,3,4],[2,3,4,5],[4,5,6,7]])
b = np.array([True,False,True])
c = a[b]
c

array([[1, 2, 3, 4],
       [4, 5, 6, 7]])

In [15]:
a = [1,2,3]
b = [1,2,3,4,5]
for c,d in zip(a,b):
    print(c,d)

1 1
2 2
3 3
