<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Implementation" data-toc-modified-id="Implementation-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Implementation</a></span><ul class="toc-item"><li><span><a href="#N-gram-Precision" data-toc-modified-id="N-gram-Precision-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>N-gram Precision</a></span></li><li><span><a href="#Modified-N-gram-Precision" data-toc-modified-id="Modified-N-gram-Precision-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>Modified N-gram Precision</a></span></li><li><span><a href="#Brevity-Penalty" data-toc-modified-id="Brevity-Penalty-1.3"><span class="toc-item-num">1.3&nbsp;&nbsp;</span>Brevity Penalty</a></span></li></ul></li><li><span><a href="#Using-corpus_bleu()" data-toc-modified-id="Using-corpus_bleu()-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Using <code>corpus_bleu()</code></a></span></li><li><span><a href="#Using-sentence_bleu()" data-toc-modified-id="Using-sentence_bleu()-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Using <code>sentence_bleu()</code></a></span><ul class="toc-item"><li><span><a href="#NLTK의-BLEU-Vs.-구현한-BLEU-함수" data-toc-modified-id="NLTK의-BLEU-Vs.-구현한-BLEU-함수-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>NLTK의 BLEU Vs. 구현한 BLEU 함수</a></span></li></ul></li></ul></div>

In [6]:
import nltk
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.bleu_score import sentence_bleu
import numpy as np
from collections import Counter

# Implementation

In [4]:
def get_ngram2cnt(cand, n):
    return Counter(nltk.ngrams(cand, n))

In [5]:
get_ngram2cnt(cand1, 3)

NameError: name 'cand1' is not defined

## N-gram Precision

In [None]:
def get_ngram_precision(cand, refs, n):
    ngram2cnt_refs = Counter()
    for ref in refs:
        ngram2cnt_refs += get_ngram2cnt(ref, n)
    ngrams_in_refs = 0
    len_cand = 0
    for ngram, cnt in get_ngram2cnt(cand, n).items():
        if ngram in ngram2cnt_refs:
            ngrams_in_refs += cnt 
        len_cand += cnt
    return ngrams_in_refs/len_cand

In [None]:
print(get_ngram_precision(cand1, refs, 1))
print(get_ngram_precision(cand2, refs, 1))

0.9444444444444444
0.5714285714285714


## Modified N-gram Precision

In [None]:
def get_modified_ngram_precision(cand, refs, n):
    def get_count_clip(ngram, cand, refs, n):
        def get_max_ref_count(ngram, refs, n):
            temp = list()
            for ref in refs:
                ngram2cnt_ref = get_ngram2cnt(ref, n)
                temp.append(ngram2cnt_ref[ngram])
            return max(temp)    

        def get_count(ngram, cand, n):
            return get_ngram2cnt(cand, 1)[ngram]

        return min(get_count(ngram, cand, n), get_max_ref_count(ngram, refs, n))
    
    sum_countclip = 0
    len_cand = 0
    for ngram, cnt in get_ngram2cnt(cand, n).items():
        sum_countclip += get_count_clip(ngram, cand, refs, n)
        len_cand += cnt
    return sum_countclip/len_cand

In [None]:
print(get_modified_ngram_precision(cand1, refs, 1))
print(get_modified_ngram_precision(cand2, refs, 1))

0.9444444444444444
0.5714285714285714


In [None]:
cand = "the the the the the the the"
print(get_ngram_precision(cand.split(" "), refs, 1))
print(get_modified_ngram_precision(cand.split(" "), refs, 1))

1.0
0.5714285714285714


## Brevity Penalty

In [None]:
def closest_ref_length(cand, ref_list): # Ca 길이와 가장 근접한 Ref의 길이를 리턴하는 함수
    ca_len = len(cand) # ca 길이
    ref_lens = (len(ref) for ref in ref_list) # Ref들의 길이
    closest_ref_len = min(ref_lens, key=lambda ref_len: (abs(ref_len - ca_len), ref_len))
    # 길이 차이를 최소화하는 Ref를 찾아서 Ref의 길이를 리턴
    return closest_ref_len

In [None]:
def brevity_penalty(cand, ref_list):
    ca_len = len(cand)
    ref_len = closest_ref_length(cand, ref_list)

    if ca_len > ref_len:
        return 1
    elif ca_len == 0 :
    # cand가 비어있다면 BP = 0 → BLEU = 0.0
        return 0
    else:
        return np.exp(1 - ref_len/ca_len)

In [None]:
def bleu_score(cand, ref_list, weights=[0.25, 0.25, 0.25, 0.25]):
    bp = brevity_penalty(cand, ref_list) # 브레버티 패널티, BP

    p_n = [modified_precision(cand, ref_list, n=n) for n, _ in enumerate(weights,start=1)] 
    #p1, p2, p3, ..., pn
    score = np.sum([w_i * np.log(p_i) if p_i != 0 else 0 for w_i, p_i in zip(weights, p_n)])
    return bp * np.exp(score)

위 함수가 동작하기 위해서는 앞서 구현한 get_ngram2cnt, count_clip, modified_precision, brevity_penalty 4개의 함수 또한 모두 구현되어져 있어야 합니다. 지금까지 구현한 BLEU 코드로 계산된 점수와 NLTK 패키지에 이미 구현되어져 있는 BLEU 코드로 계산된 점수를 비교해봅시다.

# Using `corpus_bleu()`

In [7]:
references = [[['this', 'is', 'a', 'test'], ['this', 'is' 'test']]]
candidates = [['this', 'is', 'a', 'test']]
score = corpus_bleu(references, candidates)
print(score)

1.0


# Using `sentence_bleu()`

In [3]:
refs = [["this", "is", "a", "test"], ["this", "is" "test"]]
cands = ["this", "is", "a", "test"]
score = sentence_bleu(refs, cands)
print(score)

1.0


## NLTK의 BLEU Vs. 구현한 BLEU 함수

In [None]:
import nltk.translate.bleu_score as bleu


cand = "It is a guide to action which ensures that the military always obeys the commands of the party"
refs = ["It is a guide to action that ensures that the military will forever heed Party commands", "It is the guiding principle which guarantees the military forces always being under the command of the Party", "It is the practical guide for the army always to heed the directions of the party"]

# 이번 챕터에서 구현한 코드로 계산한 BLEU 점수
print(bleu_score(cand.split(),list(map(lambda ref: ref.split(), refs))))
# NLTK 패키지 구현되어져 있는 코드로 계산한 BLEU 점수
print(bleu.sentence_bleu(list(map(lambda ref: ref.split(), refs)),cand.split()))

0.5045666840058485
0.5045666840058485
