In [1]:
import numpy as np

In [2]:
def get_ngrams(n: int, sentence: list[str]) -> list[str]:
    ngrams = []
    for i in range(len(sentence) - n + 1):
        ngrams.append(' '.join(sentence[i:i+n]))
        
    return ngrams

def closest_r(rs: list[list[str]], c: list[str]) -> int:
    len_rs = [len(sentence) for sentence in rs]
    r_idx = np.argmin(np.abs(np.array(len_rs) - len(c)))
    
    return r_idx

def brevity_penalty(rs: list[list[str]], c: list[str]) -> float:
    r_idx = closest_r(rs=rs, c=c)
    bp = 1. if len(c) >= len(rs[r_idx]) else np.exp(1 - len(rs[r_idx])/len(c))
    
    return bp

def compute_p(c: list[str], rs: list[list[str]], N: int) -> np.ndarray:
    p = []
    for n in range(1, N+1):
        ngrams_rs = [get_ngrams(n=n, sentence=sentence) for sentence in rs] 
        ngrams_c = get_ngrams(n=n, sentence=c)
        numerator = 0
        denominator = 0
        for ngram_c in ngrams_c:
            count_r_ngram = max(ngrams_r_i.count(ngram_c) for ngrams_r_i in ngrams_rs)
            count_c_ngram = ngrams_c.count(ngram_c)
            numerator += min(count_r_ngram, count_c_ngram)
            denominator += count_c_ngram
        p.append(numerator / denominator)
    
    return np.array(p)

def compute_bleu(c: list[str], rs: list[list[str]] ,lambdas: np.ndarray) -> float:
    p = compute_p(c=c, rs=rs, N=len(lambdas))
    bp = brevity_penalty(rs=rs, c=c)
    bleu = bp * np.exp(np.sum(lambdas * np.log(p)))
    
    return bleu

In [3]:
references = [
    "resources have to be sufficient and they have to be predictable".split(' '),
    "adequate and predictable resources are required".split(' ')
]
translations = [
    "here is a need for adequate and predictable resources".split(' '),
    "resources be sufficient and predictable to".split(' ')
]
lambdas = np.array([0.5, 0.5])

In [4]:
for i, c in enumerate(translations):
    bleu = compute_bleu(c=c, rs=references, lambdas=lambdas)
    print(f"BLEU for NMT Translation c_{i} is {bleu:.3f}")

BLEU for NMT Translation c_0 is 0.327
BLEU for NMT Translation c_1 is 0.775


In [5]:
for i, c in enumerate(translations):
    bleu = compute_bleu(c=c, rs=references[1:2], lambdas=lambdas)
    print(f"BLEU for NMT Translation c_{i} is {bleu:.3f}")

BLEU for NMT Translation c_0 is 0.408
BLEU for NMT Translation c_1 is 0.316
