In [1]:
import os
import numpy as np
import pandas as pd
import pickle
from rouge import Rouge
import string

In [2]:
from nltk.translate.bleu_score import sentence_bleu
from nltk import word_tokenize

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
results_paths = [
    'structure_custom_t5_convokit_bs_1_2_lr_2e5_ep_5_w_0.pkl',
    'structure_custom_t5_convokit_bs_1_2_lr_2e5_ep_5_w_100.pkl'
]

dataset_paths = [
    'val_structure_convokit.csv',
    'val_structure_convokit.csv'
]

In [6]:
def calc_accuracy(preds):
    y_true = np.array([p[0] for p in preds])
    y_pred = np.array([p[1] for p in preds])
    return np.mean(y_true == y_pred)

In [7]:
for res_path, data_path in zip(results_paths, dataset_paths):
    print(res_path)
    test_data = pd.read_csv("data/" + data_path, sep='\t')
    X_test = test_data['context'].values
    y_test = test_data['structure'].values
    
    with open('predictions/' + res_path, 'rb') as f:
        _, preds = pickle.load(f)
        
    if '_t5_' in res_path:
        for i in range(len(preds)):
            if len(preds[i][1].split()) > 1:
                preds[i][1] = preds[i][1].split(' ', 1)[1]
            else:
                preds[i][1] = ''
    
    for i in range(len(preds)):
        if preds[i][1].startswith('<unk>') and preds[i][1][4] != ' ':
            preds[i][1] = '<unk> ' + preds[i][1][5:]
            
    print('No errors:', len([p for p in preds if p[1] != 'err']))
    
    relations = []
    cnt_err = 0
    for i in range(len(preds)):
        if preds[i][1] != 'err':
            pred_rel = preds[i][1].split(' ', 1)[0]
            relation = y_test[i].split(' ', 1)[0]
            relations.append([relation, pred_rel])
            
    print('Accuracy:', round(calc_accuracy(relations), 3))
    
    rouge = Rouge()
    hyps, refs = [], []
    for i in range(len(preds)):
        #try:
        if len(preds[i][1].split(' ', 1)) > 1:
            hyps.append(preds[i][1].split(' ', 1)[1])
        else:
            hyps.append('')
            
        if len(y_test[i].split(' ', 1)) > 1:
            refs.append(y_test[i].split(' ', 1)[1])
        else:
            refs.append('')
        #except:
        #    continue
    
    gen_ref = zip(hyps, refs)
    gen_ref = [_ for _ in gen_ref if not all(j in string.punctuation for j in _[1]) and not all(j in string.punctuation for j in _[0])]
    gens, refs  = zip(*gen_ref)
    
    rouge_res = rouge.get_scores(gens, refs, avg=True, ignore_empty=False)
    print()
    print('ROUGE-1:', round(100 * rouge_res['rouge-1']['f'], 2))
    print('ROUGE-2:', round(100 * rouge_res['rouge-2']['f'], 2))
    print('ROUGE-L:', round(100 * rouge_res['rouge-l']['f'], 2))
    
    mean_bleu = 0
    for gen, ref in zip(gens, refs):
        mean_bleu += sentence_bleu([word_tokenize(ref)], word_tokenize(gen), weights=[1,0,0,0])
    mean_bleu /= len(gens)
    print()
    print('BLEU-1:', round(100 * mean_bleu, 2))
    
    mean_bleu = 0
    for gen, ref in zip(gens, refs):
        mean_bleu += sentence_bleu([word_tokenize(ref)], word_tokenize(gen), weights=[1,1,0,0])
    mean_bleu /= len(gens)
    print('BLEU-2:', round(100 * mean_bleu, 2))
    
    print('\n' + '-'*50 + '\n')

structure_custom_t5_convokit_bs_1_2_lr_2e5_ep_5_w_0.pkl
No errors: 15020
Accuracy: 0.113

ROUGE-1: 7.31
ROUGE-2: 0.41
ROUGE-L: 6.62

BLEU-1: 6.36
BLEU-2: 0.27

--------------------------------------------------

structure_custom_t5_convokit_bs_1_2_lr_2e5_ep_5_w_100.pkl
No errors: 15020
Accuracy: 0.466

ROUGE-1: 8.81
ROUGE-2: 0.5
ROUGE-L: 7.87

BLEU-1: 8.02
BLEU-2: 0.24

--------------------------------------------------



In [8]:
results_paths = [
    'structure_custom_bart_convokit_bs_1_2_lr_2e5_ep_5_norelut_v2.pkl',
    'structure_custom_bart_convokit_bs_1_2_lr_2e5_ep_5_norels_v2.pkl'
]

In [19]:
results_paths = [
    'structure_custom_t5_convokit_bs_1_2_lr_2e5_ep_5_norelut.pkl',
    'structure_custom_t5_convokit_bs_1_2_lr_2e5_ep_5_norels.pkl'
]

In [20]:
dataset_paths = [
    'val_structure_convokit_norelut.csv',
    'val_structure_convokit_norels.csv'
]

In [21]:
for res_path, data_path in zip(results_paths, dataset_paths):
    print(res_path)
    test_data = pd.read_csv("data/" + data_path, sep='\t')
    X_test = test_data['context'].values
    y_test = test_data['structure'].values
    
    with open('predictions/' + res_path, 'rb') as f:
        _, preds = pickle.load(f)
        
    if '_t5_' in res_path:
        for i in range(len(preds)):
            if len(preds[i][1].split()) > 1:
                preds[i][1] = preds[i][1].split(' ', 1)[1]
            else:
                preds[i][1] = ''
    
    for i in range(len(preds)):
        if preds[i][1].startswith('<unk>') and preds[i][1][4] != ' ':
            preds[i][1] = '<unk> ' + preds[i][1][5:]
            
    print('No errors:', len([p for p in preds if p[1] != 'err']))
    
    rouge = Rouge()
    hyps, refs = [], []
    for i in range(len(preds)):
        #try:
        hyps.append(preds[i][1])
        refs.append(y_test[i])
        #except:
        #    continue
    
    gen_ref = zip(hyps, refs)
    gen_ref = [_ for _ in gen_ref if not all(j in string.punctuation for j in _[1]) and not all(j in string.punctuation for j in _[0])]
    gens, refs  = zip(*gen_ref)
    
    rouge_res = rouge.get_scores(gens, refs, avg=True, ignore_empty=False)
    print()
    print('ROUGE-1:', round(100 * rouge_res['rouge-1']['f'], 2))
    print('ROUGE-2:', round(100 * rouge_res['rouge-2']['f'], 2))
    print('ROUGE-L:', round(100 * rouge_res['rouge-l']['f'], 2))
    
    mean_bleu = 0
    for gen, ref in zip(gens, refs):
        mean_bleu += sentence_bleu([word_tokenize(ref)], word_tokenize(gen), weights=[1,0,0,0])
    mean_bleu /= len(gens)
    print()
    print('BLEU-1:', round(100 * mean_bleu, 2))
    
    mean_bleu = 0
    for gen, ref in zip(gens, refs):
        mean_bleu += sentence_bleu([word_tokenize(ref)], word_tokenize(gen), weights=[1,1,0,0])
    mean_bleu /= len(gens)
    print('BLEU-2:', round(100 * mean_bleu, 2))
    
    print('\n' + '-'*50 + '\n')

structure_custom_t5_convokit_bs_1_2_lr_2e5_ep_5_norelut.pkl
No errors: 14887

ROUGE-1: 7.06
ROUGE-2: 0.39
ROUGE-L: 6.36

BLEU-1: 6.12
BLEU-2: 0.25

--------------------------------------------------

structure_custom_t5_convokit_bs_1_2_lr_2e5_ep_5_norels.pkl
No errors: 14887

ROUGE-1: 6.93
ROUGE-2: 0.4
ROUGE-L: 6.23

BLEU-1: 5.95
BLEU-2: 0.2

--------------------------------------------------

