In [1]:
from rouge import Rouge
from rouge.rouge_score import *
import nltk
import ssl

def read_list_asline(path):
    data = []
    with open(path,'r',encoding='utf-8')  as file:
        for line in file:
            data.append(line.strip())
    return data


def download_nltk():
    try:
        _create_unverified_https_context = ssl._create_unverified_context
    except AttributeError:
        pass
    else:
        ssl._create_default_https_context = _create_unverified_https_context

    nltk.download("punkt")


def _text_to_ngrams(text, n=1):
    ngrams = list(nltk.ngrams(nltk.word_tokenize(text), n))
    return Ngrams(ngrams)

def _get_rouge_from_ngram(reference_ngrams: Ngrams, evaluated_ngrams: Ngrams)-> dict:
    reference_count = len(reference_ngrams)
    evaluated_count = len(evaluated_ngrams)

    # Gets the overlapping ngrams between evaluated and reference
    overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
    overlapping_count = len(overlapping_ngrams)
    return f_r_p_rouge_n(evaluated_count, reference_count, overlapping_count)

download_nltk()

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/silas.rudolf/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [7]:
MAX_LENGTH = 10

# This function is faster than seg_based_on_rouge because it uses the ngrams to computer rouge rather than text.
def fast_rouge(sou, tar, name=None, verbose=False):
    cur_new = ''
    cur_ngram = Ngrams()
    best_score = 0
    best_sents = []

    # use ngram to represent each text
    sou = _text_to_ngrams(sou)
    seg = [(x, _text_to_ngrams(x), i) for i, x in enumerate(nltk.sent_tokenize(tar))]

    tot_len = len(seg)
    for i in range(min(MAX_LENGTH, tot_len)):
        scores = [(x, _get_rouge_from_ngram(cur_ngram.union(seg_ngram), sou), i) for x, seg_ngram, i in seg]
        best_seg = max(scores, key=lambda x: x[1]['f'])
        seg = [x for x in seg if x[2] != best_seg[2]]  # remove dup
        cur_new += ' ' + best_seg[0]
        cur_ngram = _text_to_ngrams(cur_new)
        cur_score = _get_rouge_from_ngram(cur_ngram, sou)['f']
        if cur_score > best_score:
            best_score = cur_score
            best_sents.append(best_seg)
        else:
            break

    if verbose:
        print("id:", name, "input/output:", tot_len, len(best_sents), "best:", best_score)
    best_string = list(set((x[0], x[2]) for x in best_sents))
    best_string.sort(key=lambda x: x[1])
    best_string = ' '.join([x[0] for x in best_string])

    return best_sents, best_string

In [33]:
tar = '[CLS] marco. excited about. conversions in. weeks . 21 % of conversions. % of conversions. of conversions. the client. average client is. in the last 25 days. last 25 days. 25 days. problem . . the. showed him. him . the client. ##k ##b. z ##k ##b. k ##lar ##a. ##lar ##a. z ##k ##b. ##k ##b. same problem . problem .'
sou = 'Marco was at a meeting yesterday. The client was excited about the insights they showed him. They lost 21% of conversions in 3 weeks. The average client is losing 41% in the last 25 days. Klara and ZKB have the same problem.'

fast_rouge(sou, tar)

([('in the last 25 days.',
   {'f': 0.2727272703719008, 'p': 0.15789473684210525, 'r': 1.0},
   9),
  ('21 % of conversions.',
   {'f': 0.4166666633680556, 'p': 0.2631578947368421, 'r': 1.0},
   4),
  ('average client is.',
   {'f': 0.5098039177700885, 'p': 0.34210526315789475, 'r': 1.0},
   8),
  ('excited about.',
   {'f': 0.5660377317906728, 'p': 0.39473684210526316, 'r': 1.0},
   1),
  ('showed him.',
   {'f': 0.6181818139107438, 'p': 0.4473684210526316, 'r': 1.0},
   13),
  ('weeks .', {'f': 0.642857138494898, 'p': 0.47368421052631576, 'r': 1.0}, 3),
  ('problem . . the.', {'f': 0.6666666622222223, 'p': 0.5, 'r': 1.0}, 12)],
 'excited about. weeks . 21 % of conversions. average client is. in the last 25 days. problem . . the. showed him.')