In [1]:
%matplotlib inline

import numpy as np
#import html # python 3.4
import os.path
import string
import sys
from random import randint

sys.path.insert(0, '/home/ec2-user/kklab/Projects/lrlp/scripts/oov_translate')
from config import *
from utils import *

In [2]:
tmp_dir = exp_dir+"oov_trans_bound/"
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)


In [5]:
def get_best_hyp(hyp_list, ref_sent, metric):
    '''
    pick the best hyp out of a bunch of hyp based on sentence level bleu score or meteor
    params:
        all_hyp: list of hypotheses
        ref_sent: reference sentence
        metric: bleu or meteor
    return:
        best_hyp: hypothesis with highest bleu
    '''
    hyp_file = tmp_dir+"hyp_tmp"
    ref_file_dup = tmp_dir+"ref_tmp"
    with open(hyp_file, 'w') as fw_hyp, open(ref_file_dup, 'w') as fw_ref:
        for hyp in hyp_list:
            fw_hyp.write(hyp+'\n')
            fw_ref.write(ref_sent+'\n')
    if metric == "bleu":
        stdout, stderr = sh(sent_bleu+" "+ref_file_dup+" < "+hyp_file)
        scores = stdout.strip().split('\n')
        scores = [float(item) for item in scores]
    elif metric == "meteor":
        stdout, stderr = sh("java -Xmx2G -jar "+meteor_bin+" "+hyp_file+" "+ref_file_dup+" -norm -noPunct | grep 'Segment'")
        scores = stdout.strip().split('\n')
        scores = [float(line.split('\t')[1]) for line in scores]
    else:
        raise Exception("metric \""+metric+"\" isn't supported!")
    best_idx = scores.index(max(scores))
    best_hyp = hyp_list[best_idx]
    return best_hyp
    
# def get_best_hyp_bleu(all_hyp, ref_sent, hyp_file, ref_file_dup):
#     '''
#     pick the best hyp out of a bunch of hyp based on sentence level bleu score
#     params:
#         all_hyp: list of hypotheses
#         ref_sent: reference sentence
#         hyp_file: file to write the list of hypotheses to
#         ref_file_dup: file to write the duplicated reference to
#     return:
#         best_hyp: hypothesis with highest bleu
#     '''
#     with open(hyp_file, 'w') as fw_hyp, open(ref_file_dup, 'w') as fw_ref:
#         for hyp in all_hyp:
#             fw_hyp.write(hyp+'\n')
#             fw_ref.write(ref_sent+'\n')
#     stdout, stderr = sh(sent_bleu+" "+ref_file_dup+" < "+hyp_file)
#     bleu_scores = stdout.strip().split('\n')
#     bleu_scores = [float(item) for item in bleu_scores]
#     best_bleu_idx = bleu_scores.index(max(bleu_scores))
#     best_hyp = all_hyp[best_bleu_idx]
#     return best_hyp



# def get_best_hyp_meteor(all_hyp, ref_sent, hyp_file, ref_file_dup):
#     '''
#     pick the best hyp out of a bunch of hyp based on meteor score (alignment based)
#     params:
#         all_hyp: list of hypotheses
#         ref_sent: reference sentence
#         hyp_file: file to write the list of hypotheses to
#         ref_file_dup: file to write the duplicated reference to
#     return: 
#         best_hyp: hypothesis with highest meteor
#     '''
#     with open(hyp_file, 'w') as fw_hyp, open(ref_file_dup, 'w') as fw_ref:
#         for hyp in all_hyp:
#             fw_hyp.write(hyp+'\n')
#             fw_ref.write(ref_sent+'\n')
#     stdout, stderr = sh("java -Xmx2G -jar "+meteor_bin+" "+hyp_file+" "+ref_file_dup+" -norm -noPunct | grep 'Segment'")
#     meteor_scores = stdout.strip().split('\n')
#     meteor_scores = [float(line.split('\t')[1]) for line in meteor_scores]
#     best_meteor_idx = meteor_scores.index(max(meteor_scores))
#     best_hyp = all_hyp[best_meteor_idx]
#     return best_hyp



def get_best_hyp_est(tra_tok, oov_candidates, ref_sent, metric):
    ''' 
    for sentences with lots of oovs, we translate oovs one by one, only for ug_dict
    params:
        tra_tok: tokenized translation
        oov_candidates: in ug_dict format
        ref_sent: reference sentence
        metric: "bleu" or "meteor"
    return:
        hypothesis with highest metric
    '''
    ### ./sentence-bleu ref < hyp (result is from 0 to 1)
    tra_tok_new = list(tra_tok)
    for i in range(len(tra_tok)):
        if tra_tok[i] in oov_candidates and tra_tok[i] not in oov_candidates[tra_tok[i]]:
            candidates = list(oov_candidates[tra_tok[i]].keys())
#             hyp_file = tmp_dir+"hyp_long"
#             ref_file_dup = tmp_dir+"ref_long"
            hyp_list = []
            for candidate in candidates:
                sent = list(tra_tok_new)
                sent[i] = candidate
                hyp_list.append(' '.join(sent))
#             if metric == "bleu":
#                 best_candidate = get_best_hyp_bleu(all_hyp, ref_sent, hyp_file, ref_file_dup)
#             elif metric == "meteor":
#                 best_candidate = get_best_hyp_meteor(all_hyp, ref_sent, hyp_file, ref_file_dup)
            best_candidate = get_best_hyp(hyp_list, ref_sent, metric)
            tra_tok_new[i] = candidates[hyp_list.index(best_candidate)]

    return ' '.join(tra_tok_new)



def translate_oov_by_alignment(tra_tok, oov_pos, ref_tok, pairs):
    '''
    given alignment with the reference output by fast-align, 
    output the translation without oov, 
    and mapping from oov words to words in reference
    params:
        tra_tok: tokenized translation
        oov_pos: list of positions of oov words in target lang
        ref_tok: tokenized reference
        pairs: positions pairs representing alignment between src & tgt
    return:
        res: list of translated tokens by the alignment
        oov_trans: the translation for the oov words: {oov word: set({translations})}
    '''
    ### positions in translation
    tra_lang = []
    ### positions in reference
    ref_lang = []
    for pair in pairs:
        lr = pair.split('-')
        tra_lang.append(int(lr[0]))
        ref_lang.append(int(lr[1]))
    
    ### oov translation {oov word: set({translations})}
    oov_trans = {}

    res = list(tra_tok)
    
    for idx_tra in range(len(tra_tok)):
        ### identify an oov word 
        if idx_tra in oov_pos:
            ### check if the oov word appears 
            ### in the alignment pairs on the translation side
            if idx_tra in tra_lang:
                ### replace the oov word 
                ### with the aligned word from reference
                res[idx_tra] = ref_tok[ref_lang[tra_lang.index(idx_tra)]]
            else:
                idx_ref = 0
                ### iterate through reference tokens,
                ### replace the oov word 
                ### with a word from reference 
                ### that doesn't appear in the translation
                while idx_ref < len(ref_tok) and ref_tok[idx_ref] in res:
                    idx_ref += 1
                
                ### make sure that every oov word is replaced by a word in ref
                if idx_ref != len(ref_tok):
                    res[idx_tra] = ref_tok[idx_ref]
                else:
                    ### randint is inclusive on both sides
                    res[idx_tra] = ref_tok[randint(0,len(ref_tok)-1)]

            ### add the oov:translation to the dictionary
            if tra_tok[idx_tra] not in oov_trans:
                oov_trans[tra_tok[idx_tra]] = {res[idx_tra]}
            else:
                oov_trans[tra_tok[idx_tra]].add(res[idx_tra])

    return res, oov_trans


def align_oov(in_file, out_file):
    '''
    params:
        in_file: path to input to fast-align
        out_file: path to output generated by fast-align
    return:
        None
    '''
    ### parallel text fed to fast-align
    with open(tra_file) as ft, open(ref_file) as fr, open(in_file, 'w') as fp:
        for l_tra in ft:
            l_ref = fr.readline().strip()
            fp.write(l_tra.strip())
            fp.write(" ||| ")
            fp.write(l_ref)
            fp.write('\n')

    ### source: translation with oov
    ### target: reference
    sh(fast_align+" -i "+in_file+" -d -o -v > "+out_file)

    
def oov_trans_bound(method, \
                    metric, \
                    res_file):
    '''
    get the maximum gain on bleu ir meteor
    params:
        method:
            1. align
            2. lattice (sentence level bleu, or meteor)
            3. lattice-align (sentence level bleu, or meteor)
        metric:
            1. bleu
            2. meteor
        res_file: path to oov translation result
    return:
        None
    '''

    ### maximum possible gain based on alignment between reference and translation
    if method == "align":
        ### parallel text fed to fast-align
        triple_pipe = tmp_dir+"triple_pipe"
        ### alignment result produced by fast-align
        forward_align = tmp_dir+"forward_align"
        ### produce alignment
        align_oov(triple_pipe, forward_align)

        ctr = 0
        with open(tra_file) as ft, \
        open(oov_file) as fo, \
        open(ref_file) as fr, \
        open(forward_align) as ff, \
        open(res_file, 'w') as fres, \
        open(oov_aligned_file, 'w') as foa:
            for l_tra in ft:
                l_oov = fo.readline()
                l_ref = fr.readline().strip()
                l_align = ff.readline().strip()
                
                if ctr >= 0:
                    print(ctr)
                    ### html unescaping not happening
                    tra_tok, oov_pos, _ = get_context_oov_pos(l_tra, l_oov)
                    ref_tok = l_ref.split(' ')     
                    pairs = l_align.split(' ')

                    res, oov_trans = translate_oov_by_alignment(tra_tok, oov_pos, ref_tok, pairs)
                    for oov in oov_trans:
                        ### write reference-aligned oov translation to file 
                        ### for other programs to use
                        foa.write('\t'.join([oov]+list(oov_trans[oov]))+'\n')

                    res = ' '.join(res)
                    
                    print(res)
                    fres.write(res+'\n')

                ctr += 1
    
    ### oov-candidate-word-list-based metric bound, with reference words added to the candidate word list
    elif method == "lattice-align":
        ### parallel text fed to fast-align
        triple_pipe = tmp_dir+"triple_pipe_mix"
        ### alignment result produced by fast-align
        forward_align = tmp_dir+"forward_align_mix"
        ### produce alignment
        align_oov(triple_pipe, forward_align)
        
        ctr = 0
        with open(tra_file) as ft, \
        open(oov_file) as fo, \
        open(ref_file) as fr, \
        open(forward_align) as ff, \
        open(res_file, 'w') as fres:
            for l_tra in ft:
                l_oov = fo.readline()
                l_ref = fr.readline().strip()
                l_align = ff.readline().strip()      
                
                if ctr >= 0:
                    
                    ### html unescaping not happening
                    tra_tok, oov_pos, _ = get_context_oov_pos(l_tra, l_oov)
#                     print(tra_tok)
#                     print(oov_pos)
                    ref_tok = l_ref.split(' ')              
                    pairs = l_align.split(' ')
                    
                    ### {oov:{candidate}}
                    _, oov_trans = translate_oov_by_alignment(tra_tok, oov_pos, ref_tok, pairs)
                    
                    ### {oov}
                    oov_words_set = set([tra_tok[i] for i in oov_pos])
                    ### {oov:{candidate:score}}
                    oov_candidates = get_oov_candidates(ug_dict, oov_words_set)
                    
                    ### merge the oracle oov candidates and the actual oov candidates
                    for oov in oov_trans:
                        add_ref_to_oov_candidates(oov_candidates, oov, oov_trans[oov])
                     
                    ### for test set: sentence 316 is super long
                    ### for dev set: sentence 305, 452 is super long
#                     s = [305, 452, 686]
#                     if dataset == "dev":
#                         s = [305, 452]
#                     elif dataset == "test":
#                         s = [316]
                    num_hyp = num_hyp = get_num_hyp(oov_candidates, tra_tok, oov_pos)

                    ### consider all hypotheses at once
#                     if ctr not in s:
                    if num_hyp <= 50**2.8:
                        #print("enumerate all combinations.")
                        ### recursively get all possible combination
                        all_sentences = get_all_sentences(tra_tok, oov_candidates)
                        #print("number of hypotheses: "+str(len(all_sentences)))
                        
#                         hyp_file = tmp_dir+"hyp_mix_"+str(ctr)
#                         ref_file_dup = tmp_dir+"ref_mix_"+str(ctr)
#                         if metric == "bleu":
#                             best_trans = get_best_hyp_bleu(all_sentences, l_ref, hyp_file, ref_file_dup)
#                         elif metric == "meteor":
#                             best_trans = get_best_hyp_meteor(all_sentences, l_ref, hyp_file, ref_file_dup)
                        best_trans = get_best_hyp(all_sentences, l_ref, metric)
                    ### decode the oovs one by one since too many hypotheses to consider all at once
                    else:
                        #print("progress oov by oov")
                        best_trans = get_best_hyp_est(tra_tok, oov_candidates, l_ref, metric)

                    print(ctr)
                    print(best_trans)
                    fres.write(best_trans+'\n')
                
                ctr += 1                   
        
    ### oov candidate word list based
    elif method == "lattice":
        ctr = 0
        with open(tra_file) as ft, \
        open(oov_file) as fo, \
        open(ref_file) as fr, \
        open(res_file, 'w') as fres:
            for l_tra in ft:
                l_oov = fo.readline()
                l_ref = fr.readline().strip()
                
                if ctr >= 0:
                    
                    ### html unescaping not happening
                    tra_tok, oov_pos, _ = get_context_oov_pos(l_tra, l_oov)
                    ### {oov}
                    oov_words_set = set([tra_tok[i] for i in oov_pos])
                    ### {oov:{candidate:score}}
                    oov_candidates = get_oov_candidates(ug_dict, oov_words_set)

                    
                    ### for test set: sentence 316 is super long
                    ### for dev set: sentence 305, 452 is super long
#                     s = [305, 452, 686]
#                     if dataset == "dev":
#                         s = [305, 452]
#                     elif dataset == "test":
#                         s = [316]
                    num_hyp = get_num_hyp(oov_candidates, tra_tok, oov_pos)

#                     if ctr not in s:
                    if num_hyp <= 50**2.8:
                        #print("enumerate all combinations.")
                        ### recursively get all possible combination
                        all_sentences = get_all_sentences(tra_tok, oov_candidates)
#                         hyp_file = tmp_dir+"hyp_"+str(ctr)
#                         ref_file_dup = tmp_dir+"ref_"+str(ctr)
#                         if metric == "bleu":
#                             best_trans = get_best_hyp_bleu(all_sentences, l_ref, hyp_file, ref_file_dup)
#                         elif metric == "meteor":
#                             best_trans = get_best_hyp_meteor(all_sentences, l_ref, hyp_file, ref_file_dup)
                        best_trans = get_best_hyp(all_sentences, l_ref, metric)
                    else:
                        #print("progress oov by oov")
                        #best_trans = ' '.join(tra_tok)                
                        #best_trans = l_ref
                        best_trans = get_best_hyp_est(tra_tok, oov_candidates, l_ref, metric)

                    print(ctr)
                    print(best_trans)
                    fres.write(best_trans+'\n')                

                ctr += 1

In [4]:
# -------- hyperparameters specific to this method --------
### align or lattice-align or lattice
method = "lattice-align"
### bleu or meteor; make a choice only when method == lattice or lattice-align
metric = "bleu"
    
# -------- write --------
res_file = res_dir+"_".join(["best",metric,method])
if method == "align":
    res_file = res_dir+"_".join(["best","ref",method])

# -------- translate --------
### method:
# 1. lattice
# 2. align
# 3. lattice-align
### metric (applied only when method == lattice, because if method == aligned, then metric isn't used.):
# 1. meteor
# 2. bleu
### result file
    
oov_trans_bound(method, metric, res_file)

number of hypotheses: 1
0
highlight colour ,
number of hypotheses: 1
1
plan
number of hypotheses: 1
2
Interactive Console for manipulating currently selected accessible
number of hypotheses: 1
3
_ C lear L .
number of hypotheses: 1
4
source
number of hypotheses: 1
5
Plugin with methods of various selecting accessibles quickly .
number of hypotheses: 1
6
Dogtail
number of hypotheses: 1
7
current Script be .
number of hypotheses: 1
8
_ Schema M ,
number of hypotheses: 1
9
in vain .
number of hypotheses: 1
10
role
number of hypotheses: 1
11
EXCEPT
number of hypotheses: 1
12
DEBUG
number of hypotheses: 1
13
basic
number of hypotheses: 1
14
but than focused WIDGET
number of hypotheses: 1
15
% s &apos; s brothers , or by label
number of hypotheses: 1
16
% ( rolename ) s index % ( parent num1 ) d does not match row and column ( num2 index % d
number of hypotheses: 1
17
_ Contents
number of hypotheses: 1
18
He . do not turn away .
number of hypotheses: 1
19
Japanese know .
number of hypotheses

In [6]:
print(dataset)

dev


In [7]:
print(s)

il3
