In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re

from itertools import product
from six.moves import xrange

import numpy as np
import tensorflow as tf

import utils


def read_articles(filename):
    """Read the articles from a given file. Yields (index, article) tuples."""
    with open(filename, mode="r", encoding="utf-8") as input_file:
        txt, index = None, None
        for line in input_file:
            line = line.strip()
            match = re.match(r'<article( id="(.*?)")?( \w+=.*?)*>', line)
            if match:
                txt = ""
                index = match.group(2)
            elif line == "</article>" and txt is not None:
                yield (index, txt)
                txt = None
            elif txt is not None:
                txt += line + "\n"


def inference(sess, data_iterator, probs_op, predicted_class_op, placeholders, batch_size, threshold):
    """Get probability and predicted class of the examples in a data set."""
    x_source, source_seq_length,\
    x_target, target_seq_length,\
    labels, decision_threshold = placeholders

    num_iter = int(np.ceil(data_iterator.size / batch_size))
    probs = []
    predicted_class = []
    for step in xrange(num_iter):
        source, target, label = data_iterator.next_batch(batch_size)
        source_len = utils.sequence_length(source)
        target_len = utils.sequence_length(target)

        feed_dict = {x_source: source, x_target: target, labels: label,
                     source_seq_length: source_len, target_seq_length: target_len,
                     decision_threshold: threshold}
        batch_probs, batch_predicted_class = sess.run([probs_op, predicted_class_op], feed_dict=feed_dict)
        probs.extend(batch_probs.tolist())
        predicted_class.extend(batch_predicted_class.tolist())
    probs = np.array(probs[:data_iterator.size])
    predicted_class = np.array(predicted_class[:data_iterator.size], dtype=np.int)
    return probs, predicted_class


def extract_pairs(sess, alignment_model, source_sentences, target_sentences, 
                  source_sentences_ids, target_sentences_ids,
                  probs_op, predicted_class_op, placeholders, 
                  batch_size, threshold, greedy=False):
    """Extract sentence pairs from articles pairs.
       Returns a list of (source sentence, target sentence, probability score) tuples.
    """
    data = [(source_sentences_ids[i], target_sentences_ids[j])
            for i, j in product(range(len(source_sentences)), range(len(target_sentences)))]
    pairs = [(i, j) for i, j in product(range(len(source_sentences)), range(len(target_sentences)))]
    
    data_iterator = EvalIterator(np.array(data, dtype=object))
    
    y_score, y_label = inference(sess, data_iterator, probs_op, predicted_class_op,
                                 placeholders, batch_size, threshold)

    sentence_pairs = []
    if greedy:
        alignments = [(s, k) for k, s in enumerate(y_score)]
        alignments.sort(reverse=True)
        seen_src = set()
        seen_trg = set()
        for s, k in alignments:
            i, j = pairs[k]
            if s < threshold or i in seen_src or j in seen_trg:
                continue
            if greedy:
                seen_src.add(i)
                seen_trg.add(j)
            sentence_pairs.append((source_sentences[i], target_sentences[j], s))
    else:
        idx = np.where(y_label == 1)[0]
        if len(idx) > 0:
            for k in idx:
                i, j = pairs[k]
                sentence_pairs.append((source_sentences[i], target_sentences[j], y_score[k]))

    return sentence_pairs


class EvalIterator(object):

    def __init__(self, data):
        self.data = data
        self.global_step = 0
        self.epoch_completed = 0
        self._index_in_epoch = 0
        self.size = len(self.data)

    def _sequence_length(self, data):
        length = np.zeros((len(data), 2), dtype=np.int32)
        for i, data_i in enumerate(data):
            source, target = data_i
            length[i] = (len(source), len(target))
        return length

    def _pad_batch(self, data):
        batch_size = len(data)
        batch_sequence_length = self._sequence_length(data)
        max_sequence_length = np.max(batch_sequence_length, axis=0)
        source, target = np.hsplit(data, 2)
        pad_source = np.zeros((batch_size, max_sequence_length[0]), dtype=np.int32)
        pad_target = np.zeros((batch_size, max_sequence_length[1]), dtype=np.int32)
        for i in xrange(batch_size):
            pad_source[i, :batch_sequence_length[i, 0]] = source[i, 0]
            pad_target[i, :batch_sequence_length[i, 1]] = target[i, 0]
        return pad_source, pad_target, np.ones(batch_size)

    def next_batch(self, batch_size):
        self.global_step += 1
        start = self._index_in_epoch
        if start + batch_size > self.size:
            self.epoch_completed += 1
            size_not_observed = self.size - start
            data_not_observed = self.data[start:self.size]
            start = 0
            self._index_in_epoch = batch_size - size_not_observed
            end = self._index_in_epoch
            batch_data = np.concatenate((data_not_observed, self.data[start:end]), axis=0)
        else:
            self._index_in_epoch += batch_size
            end = self._index_in_epoch
            batch_data = self.data[start:end]
        return self._pad_batch(batch_data)

### Load data and restore trained model

In [None]:
# Restore saved TensorFlow model.
sess = tf.Session()
checkpoint_dir = "data/tflogs"
saver = tf.train.import_meta_graph("data/tflogs/model.ckpt-410156.meta")
saver.restore(sess, "data/tflogs/model.ckpt-410156")
        
# Recover placeholders and ops for extraction.
x_source = sess.graph.get_tensor_by_name("x_source:0")
source_seq_length = sess.graph.get_tensor_by_name("source_seq_length:0")

x_target = sess.graph.get_tensor_by_name("x_target:0")
target_seq_length = sess.graph.get_tensor_by_name("target_seq_length:0")

labels = sess.graph.get_tensor_by_name("labels:0")

decision_threshold = sess.graph.get_tensor_by_name("decision_threshold:0")

placeholders = [x_source, source_seq_length, x_target, target_seq_length, labels, decision_threshold]

probs = sess.graph.get_tensor_by_name("prediction_evaluation/probs:0")
predicted_class = sess.graph.get_tensor_by_name("prediction_evaluation/predicted_class:0")

placeholders = [x_source, source_seq_length, x_target, target_seq_length, labels, decision_threshold]

# Read vocabularies.
source_vocab_path = "data/vocabulary.en"
target_vocab_path = "data/vocabulary.fr"
source_vocab, rev_source_vocab = utils.initialize_vocabulary(source_vocab_path)
target_vocab, rev_target_vocab = utils.initialize_vocabulary(target_vocab_path)

### Extract sentence pairs from article pairs

In [None]:
src_filename = "data/wikipedia.en"
trg_filename = "data/wikipedia.fr"

src_articles = read_articles(src_filename)
trg_articles = read_articles(trg_filename)

src_output = "data/wikipedia_extracted.en"
trg_output = "data/wikipedia_extracted.fr"
s_output = "data/wikipedia_extracted.s"

threshold = 0.99
batch_size = 2000
greedy = False
n_articles = 919000

with open(src_output, mode="w", encoding="utf-8") as src_output_file,\
     open(trg_output, mode="w", encoding="utf-8") as trg_output_file,\
     open(s_output, mode="w", encoding="utf-8") as s_output_file:

    for (index, src_txt), (_, trg_txt) in zip(src_articles, trg_articles):
        
        source_sentences = src_txt.split("\n")
        target_sentences = trg_txt.split("\n")
        
        source_sentences_ids = [utils.sentence_to_token_ids(sent, source_vocab, 100) for sent in source_sentences]
        target_sentences_ids = [utils.sentence_to_token_ids(sent, target_vocab, 100) for sent in target_sentences]
        
        pairs = extract_pairs(sess, alignment_model, source_sentences, target_sentences,
                              source_sentences_ids, target_sentences_ids,
                              probs, predicted_class, placeholders, 
                              batch_size, threshold, filtering, greedy)
        
        if int(index) % 50000 == 0:
            print("{:.2f}% done.".format(100 * int(index) / n_articles))
            
        if not pairs:
            continue

        for src_line, trg_line, s in pairs:
            src_output_file.write(src_line + "\n")
            trg_output_file.write(trg_line + "\n")
            s_output_file.write(str(s).encode("utf-8").decode("utf-8") + "\n")