In [47]:
import editdistance
import json
import numpy as np
import pandas as pd
from pytorch_pretrained_bert.tokenization import BertTokenizer
from sklearn.metrics import confusion_matrix
from spacy.tokenizer import Tokenizer
from spacy.lang.en import English
from spacy.lang.en.stop_words import STOP_WORDS
import revtok
import string
import tempfile
from tqdm import tqdm

from evaluator import evaluate

In [48]:
test_file = 'sharc1-official/json/sharc_dev.json'
# test_file = '../adverse-datasets/sharc_dev_shuffled.json'
# test_file = '../adverse-datasets/sharc_dev_regular.json'
# test_file = '../adverse-datasets/sharc_dev_augmented.json'


with open(test_file) as file:
    dev_dataset = json.load(file)

In [49]:
BERT_VOCAB = 'bert-base-uncased-vocab.txt'
LOWERCASE = True
tokenizer = BertTokenizer.from_pretrained(BERT_VOCAB, do_lower_case=LOWERCASE, cache_dir=None)
MATCH_IGNORE = {'do', 'have', '?'}
SPAN_IGNORE = set(string.punctuation)

def tokenize(doc):
    if not doc.strip():
        return []
    tokens = []
    for i, t in enumerate(revtok.tokenize(doc)):
        subtokens = tokenizer.tokenize(t.strip())
        for st in subtokens:
            tokens.append({
                'orig': t,
                'sub': st,
                'orig_id': i,
            })
    return tokens

def filter_answer(answer):
    return detokenize([a for a in answer if a['orig'] not in MATCH_IGNORE])

def filter_chunk(answer):
    return detokenize([a for a in answer if a['orig'] not in MATCH_IGNORE])

def detokenize(tokens):
    words = []
    for i, t in enumerate(tokens):
        if t['orig_id'] is None or (i and t['orig_id'] == tokens[i-1]['orig_id']):
            continue
        else:
            words.append(t['orig'])
    return revtok.detokenize(words)

def get_span(context, answer):
    answer = filter_answer(answer)
    best, best_score = None, float('inf')
    stop = False
    for i in range(len(context)):
        if stop:
            break
        for j in range(i, len(context)):
            chunk = filter_chunk(context[i:j+1])
            if '\n' in chunk or '*' in chunk:
                continue
            score = editdistance.eval(answer, chunk)
            if score < best_score or (score == best_score and j-i < best[1]-best[0]):
                best, best_score = (i, j), score
            if chunk == answer:
                stop = True
                break
    s, e = best
    while not context[s]['orig'].strip() or context[s]['orig'] in SPAN_IGNORE:
        s += 1
    while not context[e]['orig'].strip() or context[s]['orig'] in SPAN_IGNORE:
        e -= 1
    return s, e

def get_bullets(context):
    indices = [i for i, c in enumerate(context) if c == '*']
    pairs = list(zip(indices, indices[1:] + [len(context)]))
    cleaned = []
    for s, e in pairs:
        while not context[e-1].strip():
            e -= 1
        while not context[s].strip() or context[s] == '*':
            s += 1
        if e - s > 2 and e - 2 < 45:
            cleaned.append((s, e-1))
    return cleaned

def extract_clauses(data, tokenizer):
    snippet = data['snippet']
    t_snippet = tokenize(snippet)
    questions = data['questions']
    t_questions = [tokenize(q) for q in questions]

    spans = [get_span(t_snippet, q) for q in t_questions]
    bullets = get_bullets(t_snippet)
    all_spans = spans + bullets
    coverage = [False] * len(t_snippet)
    sorted_by_len = sorted(all_spans,  key=lambda tup: tup[1] - tup[0], reverse=True)

    ok = []
    for s, e in sorted_by_len:
        if not all(coverage[s:e+1]):
            for i in range(s, e+1):
                coverage[i] = True
            ok.append((s, e))
    ok.sort(key=lambda tup: tup[0])

    match = {}
    match_text = {}
    clauses = [None] * len(ok)
    for q, tq in zip(questions, t_questions):
        best_score = float('inf')
        best = None
        for i, (s, e) in enumerate(ok):
            score = editdistance.eval(detokenize(tq), detokenize(t_snippet[s:e+1]))
            if score < best_score:
                best_score, best = score, i
                clauses[i] = tq
        match[q] = best
        s, e = ok[best]
        match_text[q] = detokenize(t_snippet[s:e+1])

    return {'questions': {q: tq for q, tq in zip(questions, t_questions)}, 'snippet': snippet, 't_snippet': t_snippet, 'spans': ok, 'match': match, 'match_text': match_text, 'clauses': clauses}

with open(test_file) as f:
    data = json.load(f)
    tasks = {}
    for ex in data:
        for h in ex['evidence']:
            if 'followup_question' in h:
                h['follow_up_question'] = h['followup_question']
                h['follow_up_answer'] = h['followup_answer']
                del h['followup_question']
                del h['followup_answer']
        if ex['tree_id'] in tasks:
            task = tasks[ex['tree_id']]
        else:
            task = tasks[ex['tree_id']] = {'snippet': ex['snippet'], 'questions': set()}
        for h in ex['history'] + ex['evidence']:
            task['questions'].add(h['follow_up_question'])
    keys = sorted(list(tasks.keys()))
    vals = [extract_clauses(tasks[k], tokenizer) for k in tqdm(keys)]
    trees_dev = {k: v for k, v in zip(keys, vals)}

100%|██████████| 69/69 [00:10<00:00,  6.67it/s]


In [50]:
nlp = English()
nlp.add_pipe(nlp.create_pipe('sentencizer')) # updated
tokenizer = nlp.Defaults.create_tokenizer(nlp)

# Classification

## Miscelleaneous functions

In [51]:
def tokenize(text):
    return [token.text for token in tokenizer(text)]

def prettify_utterance(utterance, predicted_answer=None):
    output = 'RULE TEXT: ' + utterance['snippet'] + '\n'
    output += 'SCENARIO: ' + utterance['scenario'] + '\n'          
    output += 'QUESTION: ' + utterance['question'] + '\n'
    output += 'HISTORY: ' + history_to_string(utterance['history']) + '\n'
    output += 'GOLD ANSWER: ' + utterance['answer'] + '\n'
    if predicted_answer:
        output += 'PREDICTED ANSWER: ' + str(predicted_answer)
    return output

def history_to_string(history):
    output = ''
    first_qa = True
    for qa in history:
        if not first_qa:
            output += '\n'
        output += 'Q: ' + qa['follow_up_question'] + '\n'
        output += 'A: ' + qa['follow_up_answer']
        first_qa = False
    return output

def get_action(answer):
    return answer if answer in ['Yes', 'No', 'Irrelevant'] else 'More'

def evaluate_model(model_fn, dataset, print_confusion_matrix_turn_wise=False):
    prediction_json = []
    gold_json = []
    
    for utterance in dataset:
        prediction_json.append({'utterance_id': utterance['utterance_id'], 'answer': model_fn(utterance)})
        gold_json.append({'utterance_id': utterance['utterance_id'], 'answer': utterance['answer']})
        

    if print_confusion_matrix_turn_wise:
        for turn_number in range(1, 6):
            predicted_actions = [get_action(x['answer']) for x in prediction_json 
                                 if len(id_map[x['utterance_id']]['history']) == turn_number - 1]
            gold_actions = [get_action(x['answer']) for x in gold_json
                            if len(id_map[x['utterance_id']]['history']) == turn_number - 1]

            print(f"Turn number: {turn_number}")
            print(confusion_matrix(gold_actions, predicted_actions, labels=['Irrelevant', 'More', 'No', 'Yes']))
            print('\n\n')
            
    
        
    gold_file = tempfile.NamedTemporaryFile('w')
    json.dump(gold_json, gold_file)
    gold_file.seek(0)

    prediction_file = tempfile.NamedTemporaryFile('w')
    json.dump(prediction_json, prediction_file)
    prediction_file.seek(0)
    
    return evaluate(gold_file.name, prediction_file.name, mode='combined')

## Analysis

### Turn wise class distribution

In [52]:
for turn_number in range(1, 6):
    actions = [get_action(utterance['answer']) for utterance in dev_dataset if len(utterance['history']) == turn_number - 1]
    print(f"Turn number: {turn_number}")
    print(pd.Series(actions).value_counts())
    print("\n\n")

Turn number: 1
More          319
Yes           156
No            148
Irrelevant    138
dtype: int64



Turn number: 2
Yes     302
No      290
More    165
dtype: int64



Turn number: 3
Yes     209
No      200
More     60
dtype: int64



Turn number: 4
Yes     104
No       87
More     18
dtype: int64



Turn number: 5
No     41
Yes    33
dtype: int64





### Last follow up answer

In [53]:
for turn_number in range(2, 6):
    truth = [utterance['answer'] == utterance['history'][-1]['follow_up_answer'] for utterance in dev_dataset 
             if utterance['answer'] in ['Yes', 'No'] and len(utterance['history']) == turn_number - 1]
    print(f"Turn number: {turn_number}")
    print(pd.Series(truth).value_counts())
    print("\n\n")

Turn number: 2
True     422
False    170
dtype: int64



Turn number: 3
True     317
False     92
dtype: int64



Turn number: 4
True     170
False     21
dtype: int64



Turn number: 5
True    74
dtype: int64





### Scenario and History empty

In [54]:
truth = [utterance['answer'] == 'Irrelevant' for utterance in dev_dataset 
         if utterance['scenario'] == '' and utterance['history'] == []]
print(pd.Series(truth).value_counts())

True     138
False     69
dtype: int64


## Distribution model

In [55]:
def distribution_model(utterance):
    rule = utterance['snippet']
    history = utterance['history']
    scenario = utterance['scenario']
    question = utterance['question']
    
    
    turn_number = len(history) + 1
    
    if turn_number == 1:
        if history == [] and scenario == '':
            answer = 'Irrelevant'
        else:
            answer = rule
    else:
        answer = history[-1]['follow_up_answer']
        
    return answer

In [56]:
evaluate_model(distribution_model, dev_dataset)

{'micro_accuracy': 0.604,
 'macro_accuracy': 0.6744,
 'bleu_1': 0.1406,
 'bleu_2': 0.1171,
 'bleu_3': 0.1017,
 'bleu_4': 0.0892}

## Smart model

In [57]:
def relevant_query(text, query, threshold=0.5):
    query_tokens = tokenize(query.lower())
    text_tokens = set(tokenize(text.lower()))
    
    relevant_tokens = 0
    total_tokens = 0
    
    for token in query_tokens:
        if token in STOP_WORDS or token in string.punctuation:
            continue
        elif token in text_tokens:
            relevant_tokens += 1
        total_tokens += 1
    
    return (relevant_tokens / total_tokens) >= threshold

In [58]:
def number_rules(rule):
    if '*' in rule: # bullet points
        return rule.count('*')
    else:
        return 1

In [59]:
def next_follow_up(utterance):
    previous_questions = set([x['follow_up_question'] for x in utterance['history']])
    
    tree = trees_dev[utterance['tree_id']]
    dic = {}
    for k, v in tree['match'].items():
        if v in dic:
            dic[v].add(k)
        else:
            dic[v] = {k}
    match = {tuple(v): tree['match_text'][list(v)[0]] for k, v in sorted(dic.items())}
    
    for questions_set, clause in match.items():
        if not any(question in previous_questions for question in questions_set):
            return 'Are you ' + clause + '?'
    return utterance['snippet']

In [60]:
def model(utterance):
    rule = utterance['snippet']
    history = utterance['history']
    scenario = utterance['scenario']
    question = utterance['question']
    turn_number = len(history) + 1
    
    if turn_number == 1:
        if not scenario and not relevant_query(rule, question):
            return 'Irrelevant'
        else:
            return next_follow_up(utterance)
    elif turn_number == 2:
        if (not scenario and number_rules(rule) >= turn_number) or (scenario and number_rules(rule) - 1 >= turn_number):
            return next_follow_up(utterance)
        else:
            return history[-1]['follow_up_answer']
    else:
        return history[-1]['follow_up_answer']

In [61]:
evaluate_model(model, dev_dataset, print_confusion_matrix_turn_wise=False)

{'micro_accuracy': 0.6374,
 'macro_accuracy': 0.7125,
 'bleu_1': 0.6397,
 'bleu_2': 0.5624,
 'bleu_3': 0.5117,
 'bleu_4': 0.4778}