In [None]:
def viterbi_topk(test_file, output_file, transition_params, emission_params, labels, top_k, i_th):
    sentences = []
    with open(test_file, encoding ='utf-8') as ifile, codecs.open(output_file, 'w', 'utf-8-sig') as ofile:
        sentence = []
        for line in ifile:
            if len(line.split())!=0:
                sentence.append(line.split()[0])
            else:
                sentences.append(sentence)
                sentence = []
        
        for s in sentences:
            nodes = node_scores(s,transition_params, emission_params, labels, top_k)
            labelled_sentence = backtracking(s,nodes, i_th)
            for word in labelled_sentence:
                ofile.write(word+'\n')
            ofile.write("\n")


def node_scores(s, transition_params, emission_params, labels, top_k):
    nodes = {}
    #base case
    nodes[0] = {'START':[[1,'nil',0]]}
    #recursive
    for k in range (1, len(s)+1): #for each word
        X_node = s[k-1]
        for V_node in labels.keys(): #for each node
            prev_nodes_dict = nodes[k-1] #access prev nodes
            #emission params
            if X_node in emission_params.keys():
                emission_labels = emission_params[X_node]

                if V_node in emission_labels:
                    b_node = emission_labels[V]
                else:
                    b_node = 0
            else:
                b_node = emission_params['#UNK#'][V_node]  
            scores = []
            for U_node in prev_nodes_dict.keys():
                #transitionparams
                prev_states = transition_params[V_node]
                if U_node in prev_states.keys():
                    a_node = prev_states[U]
                else:
                    a_node = 0
                index = 0
                for prev_k_nodes in prev_nodes_dict[U_node]:
                    #prev node score
                    score = prev_k_nodes[0]*a_node*b_node
                    scores.append([score, U_node, index])
                    index += 1
            
            #take top k scores
            scores.sort(key=lambda x: x[0],reverse=True)
            topk_scores = scores[:top_k]
            if k in nodes.keys():
                nodes[k][V_node] = topk_scores
            else:
                new_dict = {V_node:topk_scores}
                nodes[k] = new_dict
            
    #end case
    prev_nodes_dict = nodes[len(s)]
    scores = []
    for U_node in prev_nodes_dict.keys():
        #transition
        prev_states = transition_params['STOP']
        if U_node in prev_states.keys():
            a_node = prev_states[U_node]
        else:
            a_node = 0
        #prev node score
        index = 0
        for prev_k_nodes in prev_nodes_dict[U_node]:
            score = prev_k_nodes[0]*a
            scores.append([score, U_node, index])
            index += 1
    scores.sort(key=lambda x: x[0], reverse=True)
    topk_scores = scores[:top_k]
    indiv_node = {'STOP': topk_scores}
    nodes[len(s)+1]=indiv_node
    
    return nodes


def backtracking_topk(s, nodes, i_th):
    prev_state = 'STOP'
    prev_index = 0
    for i in range(len(s)+1, 1,-1):
        if i==len(s)+1:
            prev_node = nodes[i][prev_state][i_th-1]
        else:
            prev_node = nodes[i][prev_state][prev_index]
        prev_state = prev_node[1]
        prev_index = prev_node[2]
        s[i-2] += " "+prev_state
    return s

viterbi_topk(EN_test, EN_maxmin, transition_params_EN, emission_params_EN, label_count_EN, 7, 1)
viterbi_topk(FR_test, FR_maxmin, transition_params_FR, emission_params_FR, label_count_FR, 7, 1)


def forward_backward(observations, states, start_prob, trans_prob, emm_prob, end_st):
    # forward part of the algorithm
    forward = []
    f_prev = {}
    for i, observation_i in enumerate(observations):
        f_curr = {}
        for st in states:
            if i == 0:
                # base case for the forward part
                prev_f_sum = start_prob[st]
            else:
                prev_f_sum = sum(f_prev[k]*trans_prob[k][st] for k in states)

            f_curr[st] = emm_prob[st][observation_i] * prev_f_sum

        forward.append(f_curr)
        f_prev = f_curr

    p_forward = sum(f_curr[k] * trans_prob[k][end_st] for k in states)

    # backward part of the algorithm
    backward = []
    b_prev = {}
    for i, observation_i_plus in enumerate(reversed(observations[1:]+(None,))):
        b_curr = {}
        for st in states:
            if i == 0:
                # base case for backward part
                b_curr[st] = trans_prob[st][end_st]
            else:
                b_curr[st] = sum(trans_prob[st][l] * emm_prob[l][observation_i_plus] * b_prev[l] for l in states)

        backward.insert(0,b_curr)
        b_prev = b_curr

    p_backward = sum(start_prob[l] * emm_prob[l][observations[0]] * b_curr[l] for l in states)

    # merging the two parts
    posterior = []
    for i in range(len(observations)):
        posterior.append({st: forward[i][st] * backward[i][st] / p_forward for st in states})

    assert p_forward == p_backward
    return forward, backward, posterior