# Investigate BAD-tag approaches

In [None]:
LANGUAGE_PAIR = "de-en"

In [None]:
import codecs
import json

def read_file(file_path):
    with codecs.open(file_path, 'r', 'utf-8') as fid:
        return [line.rstrip() for line in fid.readlines()]
   
def read_error_detail(file_path):
    with codecs.open(file_path, 'r', 'utf-8') as fid:
        return [json.loads(line.strip()) for line in fid.readlines()]

def red(string):
    return "\033[31m%s\033[0m" % string

def display(tokens, tags=None):
    """
    Same number of tags as tokens
    """
    nr_tokens = len(tokens)
    display = []
    for word_index in range(nr_tokens):
        if tags and tags[word_index] == 'BAD':
            display.append(red(tokens[word_index]))
        else:
            display.append(tokens[word_index])
    print " ".join(display)

def display_v001(tokens, tags):
    """
    Number of tags is twice the number of tokens, imply gaps
    """
    nr_tokens = len(tokens)
    display = []
    
    # Initial OK/BAD gap
    if tags[0] == 'BAD':
        display = [red('___')]
    else:
        display = []
    
    # Separate word and gap tags
    word_tags = tags[1:][::2]
    gap_tags = tags[1:][1::2]
    
    for word_index in range(nr_tokens):
        # Word tag
        if word_tags[word_index] == 'BAD':
            display.append(red(tokens[word_index]))
        else:
            display.append(tokens[word_index])
        # Gap tag        
        if gap_tags[word_index] == 'BAD':
            display.append(red('___'))
    print " ".join(display)

## WMT 2017 Data

In [None]:
# normal               All BAD tokens are propagated to their aligned words
# ignore-shift-set     if a BAD token apears also in PE do not propagate to source
# missing-only         only propagate for missing words
ERROR_TYPE = 'missing-only'

In [None]:
wmt2017 = '/mnt/data/datasets/WMT2017/WMT2017/task2_%s_training/' % LANGUAGE_PAIR
tags_v001 = '/home/ramon/TMP/redefine_word_qe/DATA/temporal_files/%s/task2_%s_training/' % (ERROR_TYPE, LANGUAGE_PAIR)
# Data
source_tokens = [x.split() for x in read_file("%s/train.src" % wmt2017)]
mt_tokens = [x.split() for x in read_file("%s/train.mt" % wmt2017)]
pe_tokens = [x.split() for x in read_file("%s/train.pe" % wmt2017)]
# Tags v0.0.1
# To generate this data see redefine_word_qe repository
source_tags = [x.split() for x in read_file("%s/train.source_tags" % tags_v001)]
target_tags = [x.split() for x in read_file("%s/train.tags" % tags_v001)]
# Error detail
error_details = read_error_detail("%s/train.json" % tags_v001)

In [None]:
from collections import Counter, defaultdict
indices_by_error = defaultdict(set)
for index, error_detail in enumerate(error_details):
    for error in error_detail:
        indices_by_error[error['type']].add(index)

### Check error types

In [None]:
# en-de
# 11387     Case where shif rule would work well

In [None]:
# Random sample of specific error types
import numpy as np
indices = list(indices_by_error[u'deletion'])
index = indices[int(np.random.choice(len(indices), 1))]
index

In [None]:
display(source_tokens[index][:], source_tags[index][:])
display(pe_tokens[index][:])
display_v001(mt_tokens[index][:], target_tags[index][:])

In [None]:
error_details[index]