-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate_with_ROUGE.py
105 lines (87 loc) · 3.56 KB
/
evaluate_with_ROUGE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from pyrouge import Rouge155
from pyrouge.utils import log
import re
import os
from os.path import join
import logging
import tempfile
import subprocess as sp
from cytoolz import curry
import codecs
import shutil
_ROUGE_PATH = '/mnt/e/Work/Ahmed/Summarization/pyrouge/tools/ROUGE-1.5.5'
def read_summaries(file_path):
lines = []
data_reader = open(file_path, 'r')
for line in data_reader:
line = line.replace('\n', '').replace('\r', '').strip()
lines.append(line)
return lines
def write_in_files(output_dir, data, extension):
for index, summary in enumerate(data):
if len(summary) == 1:
data_writer = codecs.open(output_dir + '/{}.{}'.format(index + 1, extension), 'w', encoding='utf8')
data_writer.write(summary[0])
data_writer.close()
elif len(summary) == 2:
data_writer = codecs.open(output_dir + '/{}_1.{}'.format(index + 1, extension), 'w', encoding='utf8')
data_writer.write(summary[0])
data_writer.close()
data_writer = codecs.open(output_dir + '/{}_2.{}'.format(index + 1, extension), 'w', encoding='utf8')
data_writer.write(summary[1])
data_writer.close()
def eval_rouge(dec_pattern, dec_dir, ref_pattern, ref_dir, dir_name,
cmd='-c 95 -r 1000 -n 2 -m', system_id=1):
""" evaluate by original Perl implementation"""
# silence pyrouge logging
assert _ROUGE_PATH is not None
log.get_global_console_logger().setLevel(logging.WARNING)
# with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = '/mnt/e/Work/Ahmed/Summarization/SummRuNNer/output/{}/temp/'.format(dir_name)
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.mkdir(tmp_dir)
Rouge155.convert_summaries_to_rouge_format(
dec_dir, join(tmp_dir, 'dec'))
Rouge155.convert_summaries_to_rouge_format(
ref_dir, join(tmp_dir, 'ref'))
Rouge155.write_config_static(
join(tmp_dir, 'dec'), dec_pattern,
join(tmp_dir, 'ref'), ref_pattern,
join(tmp_dir, 'settings.xml'), system_id
)
cmd = ('sudo perl ' + _ROUGE_PATH + '/ROUGE-1.5.5.pl'
+ ' -e {} '.format(join(_ROUGE_PATH, 'data'))
+ cmd
+ ' -a {}'.format(join(tmp_dir, 'settings.xml')))
output = sp.check_output(cmd, universal_newlines=True, shell=True)
return output
def main():
max_rouge_index = 0
max_rouge = 0
max_output = ''
dir_name = '30_75_forum_keywords2_bert_coatt_keywords'
print('Evaluatiig {} ....'.format(dir_name))
for i in range(50):
dec_dir = '/mnt/e/Work/Ahmed/Summarization/SummRuNNer/output/{}/test_{}/dec/'.format(dir_name, i)
ref_dir = '/mnt/e/Work/Ahmed/Summarization/SummRuNNer/output/{}/test_{}/ref_abs/'.format(dir_name, i)
if not os.path.exists('/mnt/e/Work/Ahmed/Summarization/SummRuNNer/output/{}/test_{}/dec/'.format(dir_name, i)):
continue
dec_pattern = r'(\d+).dec'
ref_pattern = '#ID#.ref'
print('test_{}'.format(i))
output = eval_rouge(dec_pattern, dec_dir, ref_pattern, ref_dir, dir_name)
print(output)
lines = output.split('\n')
current_rouge = 0
for line in lines:
if 'Average_F' in line:
val = float(line.split('Average_F:')[1].split('(')[0].strip())
current_rouge += val
if current_rouge > max_rouge:
max_output = output
max_rouge = current_rouge
max_rouge_index = i
print('Best Model...... step {}'.format(max_rouge_index))
print(max_output)
main()