In [16]:
from rouge import Rouge
from rouge.rouge_score import *
import nltk
import ssl

def read_list_asline(path):
    data = []
    with open(path,'r',encoding='utf-8')  as file:
        for line in file:
            data.append(line.strip())
    return data


def download_nltk():
    try:
        _create_unverified_https_context = ssl._create_unverified_context
    except AttributeError:
        pass
    else:
        ssl._create_default_https_context = _create_unverified_https_context

    nltk.download("punkt")


def _text_to_ngrams(text, n=1):
    ngrams = list(nltk.ngrams(nltk.word_tokenize(text), n))
    return Ngrams(ngrams)

def _get_rouge_from_ngram(reference_ngrams: Ngrams, evaluated_ngrams: Ngrams)-> dict:
    reference_count = len(reference_ngrams)
    evaluated_count = len(evaluated_ngrams)

    # Gets the overlapping ngrams between evaluated and reference
    overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
    overlapping_count = len(overlapping_ngrams)
    return f_r_p_rouge_n(evaluated_count, reference_count, overlapping_count)


In [14]:
download_nltk()

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/silas.rudolf/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


NameError: name 'download_nltk' is not defined

In [3]:
data = read_list_asline('/Users/silas.rudolf/projects/School/MA/experiments/data/stage_1/test.source')
labels = read_list_asline('/Users/silas.rudolf/projects/School/MA/experiments/data/stage_1/test_duplicated.target')

In [17]:
MAX_LENGTH = 100

# This function is faster than seg_based_on_rouge because it uses the ngrams to computer rouge rather than text.
def fast_rouge(sou, tar, name=None, verbose=False):
    cur_new = ''
    cur_ngram = Ngrams()
    best_score = 0
    best_sents = []

    # use ngram to represent each text
    sou = _text_to_ngrams(sou)
    seg = [(x, _text_to_ngrams(x), i) for i, x in enumerate(nltk.sent_tokenize(tar))]

    tot_len = len(seg)
    for i in range(min(MAX_LENGTH, tot_len)):
        scores = [(x, _get_rouge_from_ngram(cur_ngram.union(seg_ngram), sou), i)
                    for x, seg_ngram, i in seg]
        best_seg = max(scores, key=lambda x: x[1]['f'])
        seg = [x for x in seg if x[2] != best_seg[2]]  # remove dup
        cur_new += ' ' + best_seg[0]
        cur_ngram = _text_to_ngrams(cur_new)
        cur_score = _get_rouge_from_ngram(cur_ngram, sou)['f']
        if cur_score > best_score:
            best_score = cur_score
            best_sents.append(best_seg)
        else:
            break

    if verbose:
        print("id:", name, "input/output:", tot_len, len(best_sents), "best:", best_score)
    best_string = list(set((x[0], x[2]) for x in best_sents))
    best_string.sort(key=lambda x: x[1])
    best_string = ' '.join([x[0] for x in best_string])

    return best_sents, best_string

In [20]:
fast_rouge(data[0],labels[0])

([('This meeting was the eleventh evidence session on the Children Abolition of Defense of Reasonable Publishment Wales Bill.',
   {'f': 0.07738095142024519,
    'p': 0.04075235109717868,
    'r': 0.7647058823529411},
   0),
  ('Barry Hughes then further explained our-of-court disposals and responded to a specific infrastructure for these cases.',
   {'f': 0.11931818011896307,
    'p': 0.06583072100313479,
    'r': 0.6363636363636364},
   6),
  ('The first one was how the Bill protected the children in terms of prosecutions.',
   {'f': 0.1495844854784724,
    'p': 0.08463949843260188,
    'r': 0.6428571428571429},
   3),
  ('Barry Hughes was there to answer questions related to the Bill.',
   {'f': 0.1643835594409458,
    'p': 0.09404388714733543,
    'r': 0.6521739130434783},
   1),
  ('In the last part, the meeting turned to discuss a number of unintended consequences related to the Bill.',
   {'f': 0.17647058572664362, 'p': 0.10344827586206896, 'r': 0.6},
   7),
  ("The second part 