<a href="https://colab.research.google.com/github/JoonYoung-Sohn/practice/blob/master/201205_BLEU_score.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from collections import Counter
import numpy as np
from nltk import ngrams

In [2]:
def simple_count(tokens, n): # 토큰화 된 candidate 문장, n-gram에서의 n 이 두 가지를 인자로 받음.
    return Counter(ngrams(tokens, n)) 

In [3]:
candidate = "It is a guide to action which ensures that the military always obeys the commands of the party."
tokens = candidate.split() 
result = simple_count(tokens, 1) 
print(result)

Counter({('the',): 3, ('It',): 1, ('is',): 1, ('a',): 1, ('guide',): 1, ('to',): 1, ('action',): 1, ('which',): 1, ('ensures',): 1, ('that',): 1, ('military',): 1, ('always',): 1, ('obeys',): 1, ('commands',): 1, ('of',): 1, ('party.',): 1})


In [4]:
candidate = 'the the the the the the the'
tokens = candidate.split() 
result = simple_count(tokens, 1)
print(result)

Counter({('the',): 7})


In [5]:
def count_clip(candidate, reference_list, n):
    cnt_ca = simple_count(candidate, n)
    temp = dict()

    for ref in reference_list: 
        cnt_ref = simple_count(ref, n)
        for n_gram in cnt_ref:
            if n_gram in temp:
                temp[n_gram] = max(cnt_ref[n_gram], temp[n_gram]) # max_ref_count
            else:
                temp[n_gram] = cnt_ref[n_gram]

    return {
        n_gram: min(cnt_ca.get(n_gram, 0), temp.get(n_gram, 0)) for n_gram in cnt_ca
        # count_clip=min(count, max_ref_count)
        # 위의 get은 찾고자 하는 n-gram이 없으면 0을 반환한다.
     }

In [6]:
candidate = 'the the the the the the the'
references = [
    'the cat is on the mat',
    'there is a cat on the mat'
]
result = count_clip(candidate.split(),list(map(lambda ref: ref.split(), references)),1)
print(result)

{('the',): 2}


In [7]:
def modified_precision(candidate, reference_list, n):
    clip = count_clip(candidate, reference_list, n) 
    total_clip = sum(clip.values()) # 분자

    ct = simple_count(candidate, n)
    total_ct = sum(ct.values()) #분모

    if total_ct==0: 
      total_ct=1

    return (total_clip / total_ct) 

In [8]:
result=modified_precision(candidate.split(),list(map(lambda ref: ref.split(), references)),1) # 유니그램이므로 n=1
print(result)

0.2857142857142857


In [9]:
def closest_ref_length(candidate, reference_list): # Ca 길이와 가장 근접한 Ref의 길이를 리턴하는 함수
    ca_len = len(candidate) 
    ref_lens = (len(ref) for ref in reference_list)
    closest_ref_len = min(ref_lens, key=lambda ref_len: (abs(ref_len - ca_len), ref_len))
    return closest_ref_len

In [10]:
def brevity_penalty(candidate, reference_list):
    ca_len = len(candidate)
    ref_len = closest_ref_length(candidate, reference_list)

    if ca_len > ref_len:
        return 1
    elif ca_len == 0 :
        return 0
    else:
        return np.exp(1 - ref_len/ca_len)

In [11]:
def bleu_score(candidate, reference_list, weights=[0.25, 0.25, 0.25, 0.25]):
    bp = brevity_penalty(candidate, reference_list) 
    p_n = [modified_precision(candidate, reference_list, n=n) for n, _ in enumerate(weights,start=1)] 
    score = np.sum([w_i * np.log(p_i) if p_i != 0 else 0 for w_i, p_i in zip(weights, p_n)])
    return bp * np.exp(score)

In [12]:
import nltk.translate.bleu_score as bleu


candidate = 'It is a guide to action which ensures that the military always obeys the commands of the party'
references = [
    'It is a guide to action that ensures that the military will forever heed Party commands',
    'It is the guiding principle which guarantees the military forces always being under the command of the Party',
    'It is the practical guide for the army always to heed the directions of the party'
]

print(bleu_score(candidate.split(),list(map(lambda ref: ref.split(), references))))

print(bleu.sentence_bleu(list(map(lambda ref: ref.split(), references)),candidate.split()))

0.5045666840058485
0.5045666840058485
