In [8]:
import numpy as np
import time
import pickle
import operator
from __future__ import print_function
from IPython.display import display
import matplotlib.pyplot as plt
import plotly.plotly as py
%matplotlib inline  

In [9]:
def load_file(filepath):
    """Reads a file into a list of phrases. Each phrase in the file must be separated with a new line character '\n' 
    Args:
        filepath (str): the relevant filepath of the file
    Returns:
        phrases (list(str)): A list with all the phrases 
    """
    with open(filepath, 'r') as f:
        phrases = f.readlines()
    return phrases

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False


In [49]:
class Decoder:
    def __init__(self, trg_lm_file, tm_file, rm_file):
        
        self.trg_LM = self.parse_lm_file(trg_lm_file)
        self.TM = self.parse_tm_file(tm_file)
        self.RM = self.parse_rm_file(rm_file)
        
        self.distortion_limit = 3
        
    
    def parse_lm_file(self, filepath):
#         lm_file = load_file(filepath)

#         LM = {}

#         for line in lm_file:
#             tokens = line.split()
#             if len(tokens) > 0:
#                 if is_number(tokens[0]):

#                     p1 = float(tokens[0])

#                     s = tokens[1:]

#                     if is_number(s[-1]):
#                         p2 = float(s[-1])
#                         del s[-1]
#                         LM[" ".join(s)] = [p1,p2]
#                     else:    
#                         LM[" ".join(s)] = [p1]

#         return LM
        lm_file = load_file(filepath)

        LM = {}

        for line in lm_file:
            tokens = line.split('\t')
            if len(tokens) > 0:
                
                if is_number(tokens[0]):

                    if(len(tokens) == 3):
                        LM[tokens[1]] = [float(tokens[0]), float(tokens[2])]
                    else:
                        LM[tokens[1]] = [float(tokens[0])]

        return LM


    def parse_tm_file(self, filepath):
        tm_file = load_file(filepath)

        TM = {}

        for line in tm_file:
            tokens = line.split("|||")

            f = tokens[0].strip()
            e = tokens[1].strip()

            tokens = tokens[2].split()


            p_fe = float(tokens[0])
            l_fe = float(tokens[1])
            p_ef = float(tokens[2])
            l_ef = float(tokens[3])
            wp = float(tokens[4])

            TM[(f,e)] = [p_fe, l_fe, p_ef, l_ef, wp]

        return TM


    def parse_rm_file(self, filepath):
        re_file = load_file(filepath)

        reorderings = {}

        for line in re_file:
            tokens = line.split("|||")

            f = tokens[0].strip()
            e = tokens[1].strip()

            tokens = tokens[2].split()

            m_rl = float(tokens[0])
            s_rl = float(tokens[1])
            d_rl = float(tokens[2])

            m_lr = float(tokens[3])
            s_lr = float(tokens[4])
            d_lr = float(tokens[5])

            reorderings[(f,e)] = [m_rl, s_rl, d_rl, m_lr, s_lr, d_lr]

        return reorderings
    
    
    def parse_trace(self, trace_line):
        
        tokens = trace_line.split("|||")
        
        trace = []
        for token in tokens: 
            [src_positions, translation] = token.split(':',1) 
            
            src_positions = [int(pos) for pos in src_positions.split("-")]
            
            trace.append([src_positions[0], src_positions[1], translation.strip()])
        
        return trace
    
    
    def get_translation_cost(self, src_phrase, trace):
        
        trace = self.parse_trace(trace)
        src_words = src_phrase.split()
    
        src_phrases = []
        for p in trace:
            src_phrase = " ".join(src_words[p[0]:p[1] + 1])
            src_phrases.append(src_phrase)
            
        
        return cost
    

In [50]:
src_file = load_file('data/file.test.de')
trg_file = load_file('data/file.test.en') 
traces =  load_file('data/testresults.trans.txt.trace')

decoder = Decoder('data/file.en.lm', 'data/phrase-table', 'data/dm_fe_0.75')

In [68]:
def ngram_prob(ngram_words, LM):
    
    if not ngram_words:
        return -4.0
    
    
    ngram_str = " ".join(ngram_words)
    
    if ngram_str in LM:
        
        return LM[ngram_str][0]
    
    else:
        
        backoff = 0
        backoff_str = " ".join(ngram_words[:-1])
        if backoff_str in LM:
            if len(LM[backoff_str]) == 2:
                backoff = LM[backoff_str][1]
            
            
        return ngram_prob(ngram_words[1:], LM) + backoff 
    
    
    


def get_lm_cost(phrase, hist, LM):
    lm_cost = 0
    
    hist_words = hist.split()
    phrase_words = phrase.split()
 
    
    for word in phrase_words:

        lm_cost += ngram_prob(hist_words + [word], LM)
        
        if len(hist_words) == 4:
            del hist_words[0]
            
        hist_words.append(word)
        
        
        
            
    return lm_cost






def get_translation_cost(src_phrase, trace, decoder):

    MISS_COST = np.log10(0.001)
    PHRASE_PENALTY = -1
    HIST_SIZE = 4

    total_cost = 0
    
    
    trace = decoder.parse_trace(trace)
    src_words = src_phrase.split()

    src_phrases = [] 
    for p in trace:
        src_phrase = " ".join(src_words[p[0]:p[1] + 1])
        src_phrases.append(src_phrase)  

    
    
    
    for i in range(len(trace)):
        
        #-----------------------------TM COST-------------------------------------
        tm_cost = 0
        
        pair = (src_phrases[i], trace[i][2])
        
        if pair in decoder.TM:   
            [p_fe, lex_fe, p_ef, lex_ef, wp] = decoder.TM[pair]
                
            #h_TM = p(f|e) + p(e|f) (phrase translation)
            tm_cost += 1.0 * (np.log10(p_fe) + np.log10(p_ef))
                
            #h_TM = l(f|e) + l(e|f) (lexical phrase translation)
            tm_cost += 1.0 * (np.log10(lex_fe) + np.log10(lex_ef))

            #h_wp (word penalty)
            tm_cost += 1.0 * wp
        else: 
            tm_cost += MISS_COST
    
        #---------------------------------------------------------------------
        
        
        
        #-----------------------------LM COST-------------------------------------
        
        #Find the history of the target phrase, up to (n-1) words ('n' for maximum n-gram)
        trg_hist = []
        hist_count = 0
        j = i-1
        while j >= 0:
            prev_trg_words = trace[j][2].split()
            for trg_word in reversed(prev_trg_words):
                trg_hist.insert(0, trg_word)
                hist_count += 1
                
                if hist_count == HIST_SIZE:
                    break;
                
            if hist_count == HIST_SIZE:
                break;
            j -= 1
            
        if hist_count < HIST_SIZE:
            trg_hist.insert(0, '<s>')
            
        trg_hist = " ".join(trg_hist)
            
        lm_cost = get_lm_cost(trace[i][2], trg_hist, decoder.trg_LM)
        
        #------------------------------------------------------------------------------
        
        #Distorion>
        
#         #Reordering pentaly
        rm_cost = 0
        if i > 0:
            
            if pair in decoder.RM:
            
                #continuous
                if trace[i][0] == trace[i-1][1] + 1:
                    rm_cost += decoder.RM[pair][3]

                elif trace[i][1] == trace[i-1][0] - 1:
                    rm_cost += decoder.RM[pair][4]

                else:
                    rm_cost += decoder.RM[pair][5]
            
            else:
                rm_cost = 0
            
            
            
            
            
        total_cost += tm_cost + lm_cost + PHRASE_PENALTY
    
         
    return total_cost


for i in range(10):
    cost = get_translation_cost(src_file[i], traces[i], decoder)
    print(i, cost)
    print
    print

0 -73.2276788907
1 -110.043402691
2 -191.224779675
3 -109.614728338
4 -151.9036692
5 -184.737497751
6 -74.7789894865
7 -53.7341135257
8 -57.8084059581
9 -43.1085104777


In [69]:
# for line in traces:
#     trace = decoder.parse_trace(line)
    
#     for pos in range(len(trace) - 1):
        
#         if trace[pos][1] + 1 != trace[pos + 1][0]:
#             print(pos, trace)
#             break
        
     

In [70]:
get_lm_cost('same as', 'for junior doctors', decoder.trg_LM)

-7.2063388

In [61]:
decoder.trg_LM['for junior doctors same']

KeyError: 'for junior doctors same'

In [60]:
# decoder.trg_LM['junior doctors same']
decoder.trg_LM['for junior doctors']

KeyError: 'for junior doctors'

In [59]:
# decoder.trg_LM['doctors same']
decoder.trg_LM['junior doctors']

[-0.06360954, -0.2467188]

In [58]:
decoder.trg_LM['same']
decoder.trg_LM['doctors']

[-4.204124, -0.288043]

In [63]:
-4.349341 - 0.288043 - 0.2467188 

-4.8841028