In [2]:
import os
import numpy as np
import pandas as pd
import pickle
from rouge import Rouge
import string

In [3]:
from nltk.translate.bleu_score import sentence_bleu
from nltk import word_tokenize

In [9]:
with open('bart_input/special_tokens_map_reddit_dial.pkl', 'rb') as f:
    special_tokens_dict = pickle.load(f)

In [10]:
def remove_special_tokens(s):
    for t in special_tokens_dict['additional_special_tokens']:
        s = s.replace(t, '')
    return s.strip()

In [11]:
def calc_accuracy(preds):
    y_true = np.array([p[0] for p in preds])
    y_pred = np.array([p[1] for p in preds])
    return np.mean(y_true == y_pred)

In [12]:
test_data_full = pd.read_csv("bart_input/val_reddit_dial_df_multi_extented_filt.csv", sep='\t')

In [28]:
def info_from_fn(x):
    if 'base' in x:
        return 'Base model'
    if 'history-title_' in x:
        return 'GroundHog: history, title'
    if 'history#title#grounding' in x:
        return '+grounding'
    if 'to:response_disco' in x:
        return '+discourse planning'
    if 'to:response_aug' in x:
        return '+sentiment&discourse planning'
    if 'history_aug#' in x:
        return '+sentiment'
    if 'from:history_aug_disco' in x and 'amr' not in x:
        return '+discourse'
    return '+AMR'

In [29]:
results_paths = [
 'base_bart_bs8_4ep_lr3e-05__from:history-title-grounding___to:response_nbeams1.pkl',
 'multiencoder_bart_v2_bs8_4ep_lr3e-05_enclen256__from:history#title-history-title___to:response.pkl',
 'multiencoder_bart_v2_bs8_4ep_lr3e-05_enclen256__from:history#title#grounding-history-title-grounding___to:response.pkl',
 'multiencoder_bart_v2_bs8_4ep_lr3e-05_enclen256__from:history_aug_disco#title#grounding-history_aug_disco-title-grounding-history_discourse___to:response.pkl',
 'multiencoder_bart_v2_bs8_4ep_lr3e-05_enclen256__from:history_aug_disco#title#grounding-history_aug_disco-title-grounding-history_discourse-history_amr-addr_amr___to:response.pkl',
 'multiencoder_bart_v2_bs8_4ep_lr3e-05_enclen256__from:history_aug#title#grounding-history_aug-title-grounding-history_discourse-history_amr-addr_amr___to:response.pkl',
 'multiencoder_bart_v2_bs8_4ep_lr3e-05_enclen256__from:history_aug#title#grounding-history_aug-title-grounding-history_discourse-history_amr-addr_amr___to:response_disco.pkl',
 'multiencoder_bart_v2_bs8_4ep_lr3e-05_enclen256__from:history_aug#title#grounding-history_aug-title-grounding-history_discourse-history_amr-addr_amr___to:response_aug.pkl',]

In [54]:
test_data = test_data_full
for res_path in results_paths:
    if 'base' in res_path or ('enclen256' in res_path and '_v2_' in res_path) and \
            ('10ep' not in res_path and 'multiencoder_bart_v2_bs8_7ep' not in res_path):
        #try:
        print(res_path)
        print(info_from_fn(res_path))
        with open('predictions/' + res_path, 'rb') as f:
            preds = pickle.load(f)
            preds = np.array(preds)
            
        target_col = res_path.split('__')[-1][4:].replace('.pkl', '').replace('_nbeams1', '')
        y_test = test_data[target_col].values
            
        if '_aug' in target_col:
            relations = []
            sentiment = []
            for i in range(len(preds)):
                if preds[i] != '' and len(preds[i].split()) > 2:
                    pred_rel, pred_sentim, _ = preds[i].split(' ', 2)
                    rel, sentim, _ = y_test[i].split(' ', 2)
                    relations.append([rel, pred_rel])
                    sentiment.append([sentim, pred_sentim])

            print('Accuracy discourse:', round(calc_accuracy(relations), 3))
            print('Accuracy sentiment:', round(calc_accuracy(sentiment), 3))

        rouge = Rouge()
        hyps, refs = [], []
        notcalc = 0
        for i in range(len(preds)):
            pred_text = remove_special_tokens(preds[i])
            gt_text = remove_special_tokens(y_test[i])
            #if len(pred_text.split(' ', 2)) > 1 and len(gt_text.split(' ', 2)) > 1:
            hyps.append(pred_text)
            refs.append(gt_text)

            #if len(pred_text.split(' ', 2)) <= 1 and len(gt_text.split(' ', 2)) > 1:
            #    notcalc += 1
        #print('\nShort preds:', notcalc)

        gen_ref = zip(hyps, refs)
        gen_ref = [_ for _ in gen_ref if not all(j in string.punctuation for j in _[1]) and not all(j in string.punctuation for j in _[0])]
        gens, refs  = zip(*gen_ref)

        rouge_res = rouge.get_scores(gens, refs, avg=True, ignore_empty=False)
        print()
        print('ROUGE-1:', round(100 * rouge_res['rouge-1']['f'], 2))
        print('ROUGE-2:', round(100 * rouge_res['rouge-2']['f'], 2))
        print('ROUGE-L:', round(100 * rouge_res['rouge-l']['f'], 2))

        for j in range(1, 5):
            weights=[0,0,0,0]
            for k in range(j):
                weights[k] = 1
            mean_bleu = 0
            for gen, ref in zip(gens, refs):
                mean_bleu += sentence_bleu([word_tokenize(ref)], word_tokenize(gen), weights=weights)
            mean_bleu /= len(gens)
            print(f'BLEU-{j}:', round(100 * mean_bleu, 2)) 

        print('\n' + '-'*50 + '\n')

base_bart_bs8_4ep_lr3e-05__from:history-title-grounding___to:response_nbeams1.pkl
Base model

ROUGE-1: 17.71
ROUGE-2: 3.69
ROUGE-L: 15.95


The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


BLEU-1: 17.2
BLEU-2: 2.86
BLEU-3: 2.44
BLEU-4: 2.37

--------------------------------------------------

multiencoder_bart_v2_bs8_4ep_lr3e-05_enclen256__from:history#title-history-title___to:response.pkl
GroundHog: history, title

ROUGE-1: 17.79
ROUGE-2: 3.71
ROUGE-L: 16.05
BLEU-1: 16.99
BLEU-2: 2.88
BLEU-3: 2.49
BLEU-4: 2.44

--------------------------------------------------

multiencoder_bart_v2_bs8_4ep_lr3e-05_enclen256__from:history#title#grounding-history-title-grounding___to:response.pkl
+grounding

ROUGE-1: 17.86
ROUGE-2: 3.85
ROUGE-L: 16.08
BLEU-1: 17.32
BLEU-2: 3.0
BLEU-3: 2.6
BLEU-4: 2.54

--------------------------------------------------

multiencoder_bart_v2_bs8_4ep_lr3e-05_enclen256__from:history_aug_disco#title#grounding-history_aug_disco-title-grounding-history_discourse___to:response.pkl
+discourse

ROUGE-1: 17.88
ROUGE-2: 3.87
ROUGE-L: 16.09
BLEU-1: 17.15
BLEU-2: 3.04
BLEU-3: 2.63
BLEU-4: 2.57

--------------------------------------------------

multiencoder_bart_v2_

In [None]:
short_size = int(15716 / 0.25 * 0.1) # size for real proportions
test_data_short = pd.read_csv("bart_input/val_reddit_dial_df_multi_extented_short.csv", sep='\t').head(short_size)

In [None]:
test_data = test_data_full
for res_path in results_paths:
    if 'base' in res_path or ('enclen256' in res_path and '_v2_' in res_path) and \
            ('10ep' not in res_path and 'multiencoder_bart_v2_bs8_7ep' not in res_path):
        #try:
        print(res_path)
        print(info_from_fn(res_path))
        with open('predictions/' + res_path, 'rb') as f:
            preds = pickle.load(f)
            preds = np.array(preds)
            
        with open('predictions/' + res_path.replace('.pkl', 'short.pkl').replace('nbeams1short.pkl', 'short.pkl'), 'rb') as f:
            preds_short = pickle.load(f)
            preds = list(preds) + list(preds_short[:short_size])
            
        target_col = res_path.split('__')[-1][4:].replace('.pkl', '').replace('_nbeams1', '')
        y_test = list(test_data_full['response'].values) + list(test_data_short['response'].values)
            
        if '_aug' in target_col:
            relations = []
            sentiment = []
            for i in range(len(preds)):
                if preds[i] != '' and len(preds[i].split()) > 2:
                    pred_rel, pred_sentim, _ = preds[i].split(' ', 2)
                    rel, sentim, _ = y_test[i].split(' ', 2)
                    relations.append([rel, pred_rel])
                    sentiment.append([sentim, pred_sentim])

            print('Accuracy discourse:', round(calc_accuracy(relations), 3))
            print('Accuracy sentiment:', round(calc_accuracy(sentiment), 3))

        rouge = Rouge()
        hyps, refs = [], []
        notcalc = 0
        for i in range(len(preds)):
            pred_text = remove_special_tokens(preds[i])
            gt_text = remove_special_tokens(y_test[i])
            #if len(pred_text.split(' ', 2)) > 1 and len(gt_text.split(' ', 2)) > 1:
            hyps.append(pred_text)
            refs.append(gt_text)

            #if len(pred_text.split(' ', 2)) <= 1 and len(gt_text.split(' ', 2)) > 1:
            #    notcalc += 1
        #print('\nShort preds:', notcalc)

        gen_ref = zip(hyps, refs)
        gen_ref = [_ for _ in gen_ref if not all(j in string.punctuation for j in _[1]) and not all(j in string.punctuation for j in _[0])]
        gens, refs  = zip(*gen_ref)

        rouge_res = rouge.get_scores(gens, refs, avg=True, ignore_empty=False)
        print()
        print('ROUGE-1:', round(100 * rouge_res['rouge-1']['f'], 2))
        print('ROUGE-2:', round(100 * rouge_res['rouge-2']['f'], 2))
        print('ROUGE-L:', round(100 * rouge_res['rouge-l']['f'], 2))

        for j in range(1, 5):
            weights=[0,0,0,0]
            for k in range(j):
                weights[k] = 1
            mean_bleu = 0
            for gen, ref in zip(gens, refs):
                mean_bleu += sentence_bleu([word_tokenize(ref)], word_tokenize(gen), weights=weights)
            mean_bleu /= len(gens)
            print(f'BLEU-{j}:', round(100 * mean_bleu, 2)) 

        print('\n' + '-'*50 + '\n')