In [1]:
from typing import List, Tuple, Dict, Callable, Any

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Suppress TF info/warning messages

import sys
sys.path.append(".")

import json

import numpy as np
import pandas as pd

import tensorflow as tf
from pytorch_transformers import RobertaTokenizer

from src.datasets import create_test_dataset_for_prediction
from src.model_qbert import QBERT
from src.model_utils import create_masks

# Select intent prediction or type prediction
class_type = 'intent'
checkpoints_path = f"models/{class_type}"
restore_epoch = 3 if class_type == 'intent' else 4

# From argmax indices to human-readable labels
mappings = {
    'type': {
        'Ask about antecedent':  0,
        'Ask about consequence': 1,
        'Ask for confirmation':  2,
        'Irony':                 3,
        'Negative rhetoric':     4,
        'Positive rhetoric':     5,
        'Request information':   6,
        'Suggest a reason':      7,
        'Suggest a solution':    8,
    },
    'intent': {
        'Amplify excitement': 0,
        'Amplify joy':        1,
        'Amplify pride':      2,
        'De-escalate':        3,
        'Express concern':    4,
        'Express interest':   5,
        'Moralize speaker':   6,
        'Motivate':           7,
        'Offer relief':       8,
        'Pass judgement':     9,
        'Support':           10,
        'Sympathize':        11,
    }
}

# Load model checkpoint; the parameters don't matter for inference
#   - Mostly copied from original code repository
num_layers         = 12
d_model            = 768
num_heads          = 12
dff                = d_model * 4
hidden_act         = "gelu"
dropout_rate       = 0.1
layer_norm_eps     = 1e-5
max_position_embed = 514

tokenizer  = RobertaTokenizer.from_pretrained("roberta-base")
vocab_size = tokenizer.vocab_size

lab_mapping  = mappings[class_type]
pred_mapping = {v: k for k, v in lab_mapping.items()}
num_classes  = len(pred_mapping.keys())

adam_beta_1  = 0.9
adam_beta_2  = 0.98
adam_epsilon = 1e-6

qbert = QBERT(num_layers, 
              d_model, 
              num_heads, 
              dff, 
              hidden_act, 
              dropout_rate,
              layer_norm_eps, 
              max_position_embed, 
              vocab_size, 
              num_classes)
optimizer = tf.keras.optimizers.legacy.Adam(2e-5, 
                                            beta_1=adam_beta_1, 
                                            beta_2=adam_beta_2,
                                            epsilon=adam_epsilon)
ckpt = tf.train.Checkpoint(model=qbert, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoints_path, max_to_keep=None)
ckpt.restore(ckpt_manager.checkpoints[restore_epoch - 1]).expect_partial()
print('Checkpoint {} restored!!!'.format(ckpt_manager.checkpoints[restore_epoch - 1]))

Checkpoint models/intent/ckpt-3 restored!!!


In [2]:
def predict(test_dataset: tf.data.Dataset) -> Tuple[List[str], List[int]]:
    """Main function for inference. Directly copied from 
        original code repository."""

    import tqdm
    
    y_pred = []
    pred_ids = []
    for inputs in tqdm.tqdm(test_dataset):
        inp, weights, ids = inputs
        enc_padding_mask = create_masks(inp)
        pred_class = qbert(inp, weights, False, enc_padding_mask)  # Inference
        pred_class = np.argmax(pred_class.numpy(), axis=1)
        y_pred += pred_class.tolist()
        y_pred_lab = [pred_mapping[pred] for pred in y_pred]
        pred_labels = np.array(y_pred_lab)
        pred_ids += ids.numpy().squeeze().tolist()

    return pd.DataFrame({'id': pred_ids, f'predicted_{class_type}_label': pred_labels})


def format_dialogue(
    dialogue: List[Dict[str, str]], 
    sep_char: str = "\n",
    key_content: str = "text"
) -> str:
    """Simply concatenate the utterances of dialogue turns with a 
        `sep` character, which defaults to the newline character 
        to conform the EQT provided API `create_test_dataset_for_prediction()`.
        
        In the data-preparation function `create_test_dataset_for_prediction()`,
        the original authors will split the dialogue by `\n` and REVERSE THE ORDER.
        So, the utterances of LAST SPEAKER WILL BE THE FIRST. They claim that this 
        makes the model pay more attention to the sentence that we actually want to predict."""
    
    ret = sep_char.join(d[key_content] for d in dialogue)
    return ret


def generate_samples_from_dialogue(
    dialogue: List[Dict[str, str]],
    q_predictor: Callable[[str], bool] = None,
    key_content: str = "text",
    key_role: str = "role",
    role_bot: str = "bot",
) -> List[str]:
    """For each input dialogue turns in the form of `List[Dict[str, str]]`,
        generate samples that comply with the EQT inference API.
        
        If multiple questions present in a single dialogue turn, e.g., the bot
        asks two questions in a role, then we generate two samples."""

    from nltk.tokenize import sent_tokenize
    
    # Can supply a custom question predictor
    #   - If not supplied, assume that the bot's utterance 
    #     is a question if it ends with a question mark
    q_predictor = q_predictor or (lambda x: x.endswith("?"))

    ret = []
    for i, d in enumerate(dialogue, start=1):
        if d[key_role] == role_bot:
            dialogue_truncated = dialogue[:i]

            # Generate one sample for each question in the bot's utterance
            #   - If a question is not detected, generate one sample by 
            #     truncating the bot's utterance at the question sentence
            sentences_bot = sent_tokenize(d[key_content])
            for j, s in enumerate(sentences_bot, start=1):
                if q_predictor(s):
                    d_cp = d.copy()
                    d_cp[key_content] = " ".join(sentences_bot[:j])
                    sample = dialogue_truncated[:-1] + [d_cp]
                    ret.append(format_dialogue(sample, key_content=key_content))
    return ret

In [3]:
# Load the example dialogues. Each dialogue turn is represented as 
#   a dictionary with keys "text" and "role".
with open('data/two_dialogues.json', 'r') as f:
    source_dialogues: List[List[Dict[str, str]]] = json.load(f)
    print(json.dumps(source_dialogues[0], indent=4))

[
    {
        "role": "human",
        "text": "Hal, I'm here because I need help in understanding my daughter's relationship with her girlfriend, and how to support her. I really care about my daughter and I want her to be happy, but I'm also worried about how my relatives and friends will perceive her and the relationship. I know it's important to support her and to break away from family norms, but I'm struggling with how to do that. Can you help me?"
    },
    {
        "role": "bot",
        "text": "I'm sorry to hear that this is such a difficult situation. It sounds like you have a lot of conflicting feelings and pressures. Can you tell me what your biggest concerns are in this situation, and how you think I can help?"
    },
    {
        "role": "human",
        "text": "My biggest concern is how to reconcile my traditional beliefs with supporting my daughter's relationship. I want her to be happy, but I'm also afraid of how her relationship might be seen within my communit

In [6]:
ls_samples = generate_samples_from_dialogue(source_dialogues[0]) + generate_samples_from_dialogue(source_dialogues[1])
df_samples = pd.DataFrame({'utterance_truncated': ls_samples}).reset_index().rename(columns={'index': 'id'})  # Needs to use these column names to work with the EQT API
ds_samples = create_test_dataset_for_prediction(tokenizer, df_samples, 32, 256, lab_mapping)
df_samples

Vocabulary size is 50265.


100%|██████████| 9/9 [00:00<00:00, 711.77it/s]

Created dataset with 9 examples.





Unnamed: 0,id,utterance_truncated
0,0,"Hal, I'm here because I need help in understan..."
1,1,"Hal, I'm here because I need help in understan..."
2,2,"Hal, I'm here because I need help in understan..."
3,3,"Hi Hal, I'm in a really tough spot right now. ..."
4,4,"Hi Hal, I'm in a really tough spot right now. ..."
5,5,"Hi Hal, I'm in a really tough spot right now. ..."
6,6,"Hi Hal, I'm in a really tough spot right now. ..."
7,7,"Hi Hal, I'm in a really tough spot right now. ..."
8,8,"Hi Hal, I'm in a really tough spot right now. ..."


In [7]:
df_predictions = predict(ds_samples)
df_predictions = pd.merge(df_predictions, df_samples, on='id', how='left')
df_predictions

100%|██████████| 1/1 [00:00<00:00,  2.25it/s]


Unnamed: 0,id,predicted_intent_label,utterance_truncated
0,0,Offer relief,"Hal, I'm here because I need help in understan..."
1,1,Express interest,"Hal, I'm here because I need help in understan..."
2,2,Offer relief,"Hal, I'm here because I need help in understan..."
3,3,Offer relief,"Hi Hal, I'm in a really tough spot right now. ..."
4,4,Offer relief,"Hi Hal, I'm in a really tough spot right now. ..."
5,5,Offer relief,"Hi Hal, I'm in a really tough spot right now. ..."
6,6,Offer relief,"Hi Hal, I'm in a really tough spot right now. ..."
7,7,Offer relief,"Hi Hal, I'm in a really tough spot right now. ..."
8,8,Express interest,"Hi Hal, I'm in a really tough spot right now. ..."
