In [1]:
# bilingual bpe2char model in pytorch
# based on theano code https://github.com/nyu-dl/dl4mt-c2c/blob/master/bpe2char/char_base.py
import torch
import numpy as np
import torch
import os
from collections import OrderedDict
from torch.autograd import Variable
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.utils.data as data

In [4]:
# import cPickle
import warnings
import sys
import time

from collections import OrderedDict
# from mixer import *
# mixer contains the functions for model building blocks



In [5]:
import unicodedata
import string
import re
import random
import time
import math

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F

In [None]:
# In the paper, they mention basing their model on Banhadu et al. 
# Implementation of this original model in pytorch 
# https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb
# USE_CUDA = False

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1):
        super(EncoderRNN, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, bidirectional=True)
        
    def forward(self, word_inputs, hidden):
        # Note: we run this all at once (over the whole input sequence)
        seq_len = len(word_inputs)
        embedded = self.embedding(word_inputs).view(seq_len, 1, -1)
        output, hidden = self.gru(embedded, hidden)
        return output, hidden

    def init_hidden(self):
        hidden = Variable(torch.zeros(self.n_layers, 1, self.hidden_size))
        if USE_CUDA: hidden = hidden.cuda()
        return hidden

# BahdanauAttnDecoderRNN
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, n_layers=1, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        
        # Define parameters
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        self.max_length = max_length
        
        # Define layers
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.dropout = nn.Dropout(dropout_p)
        self.attn = Attn("concat", hidden_size)
        self.gru = nn.GRU(hidden_size * 2, hidden_size, n_layers, dropout=dropout_p)
        self.out = nn.Linear(hidden_size, output_size)
    
    def forward(self, word_input, last_hidden, encoder_outputs):
        # Note that we will only be running forward for a single decoder time step, but will use all encoder outputs
        
        # Get the embedding of the current input word (last output word)
        word_embedded = self.embedding(word_input).view(1, 1, -1) # S=1 x B x N
        word_embedded = self.dropout(word_embedded)
        
        # Calculate attention weights and apply to encoder outputs
        attn_weights = self.attn(last_hidden[-1], encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # B x 1 x N
        
        # Combine embedded input word and attended context, run through RNN
        rnn_input = torch.cat((word_embedded, context), 2)
        output, hidden = self.gru(rnn_input, last_hidden)
        
        # Final output layer
        output = output.squeeze(0) # B x N
        output = F.log_softmax(self.out(torch.cat((output, context), 1)))
        
        # Return final output, hidden state, and attention weights (for visualization)
        return output, hidden, attn_weights

class Attn(nn.Module):
    def __init__(self, method, hidden_size, max_length=MAX_LENGTH):
        super(Attn, self).__init__()
        
        self.method = method
        self.hidden_size = hidden_size
        
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)

        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.other = nn.Parameter(torch.FloatTensor(1, hidden_size))

    def forward(self, hidden, encoder_outputs):
        seq_len = len(encoder_outputs)

        # Create variable to store attention energies
        attn_energies = Variable(torch.zeros(seq_len)) # B x 1 x S
        if USE_CUDA: attn_energies = attn_energies.cuda()

        # Calculate energies for each encoder output
        for i in range(seq_len):
            attn_energies[i] = self.score(hidden, encoder_outputs[i])

        # Normalize energies to weights in range 0 to 1, resize to 1 x 1 x seq_len
        return F.softmax(attn_energies).unsqueeze(0).unsqueeze(0)
    
    def score(self, hidden, encoder_output):
        
        if self.method == 'dot':
            energy = hidden.dot(encoder_output)
            return energy
        
        elif self.method == 'general':
            energy = self.attn(encoder_output)
            energy = hidden.dot(energy)
            return energy
        
        elif self.method == 'concat':
            energy = self.attn(torch.cat((hidden, encoder_output), 1))
            energy = self.other.dot(energy)
            return energy
    

In [None]:
# base encoder decoder models from pytorch docs 



# # embedding
# params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word_src'])
# params['Wemb_dec'] = norm_weight(options['n_words'], options['dim_word'])

# # encoder
# params = get_layer('gru')[0](options, params,
#                              prefix='encoder',
#                              nin=options['dim_word_src'],
#                              dim=options['enc_dim'])
# params = get_layer('gru')[0](options, params,
#                              prefix='encoderr',
#                              nin=options['dim_word_src'],
#                              dim=options['enc_dim'])
# ctxdim = 2 * options['enc_dim']

# slower and faster layer only used in "biscale encoder model" ?
# use bidirectional RNN

#     # context
#     ctx = concatenate([proj, projr[::-1]], axis=proj.ndim-1)

#     # context mean
#     ctx_mean = (ctx * x_mask[:, :, None]).sum(0) / x_mask.sum(0)[:, None]

# def fflayer(tparams, state_below, options, prefix='rconv',
#             activ='lambda x: tensor.tanh(x)', **kwargs):
#     return eval(activ)(
#         tensor.dot(state_below, tparams[_p(prefix, 'W')]) +
#         tparams[_p(prefix, 'b')])
# def ffflayer(tparams, state_below1, state_below2, options, prefix='rconv',
#              activ='lambda x: tensor.tanh(x)', **kwargs):
#     return eval(activ)(
#         tensor.dot(state_below1, tparams[_p(prefix, 'W')]) +
#         tensor.dot(state_below2, tparams[_p(prefix, 'U')]) +
#         tparams[_p(prefix, 'b')])


class EncoderB2C(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1):
        super(EncoderB2C, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        
        self.gru1 = nn.GRU(hidden_size, hidden_size) #, bidirectional=True)
        self.gru2 = nn.GRU(hidden_size, hidden_size)
        self.tanh1 = nn.Tanh()
        self.tanh2 = nn.Tanh()
        self.context=None
        
    def forward(self, in_data, hidden1, hidden2):
        embedded = self.embedding(in_data).view(1, 1, -1)
        output = embedded
        # can potentially accomplish this with bidirectional=True argument
        # would need to translate that to the context
        output1, hidden1 = self.gru1(output, hidden1)
        output2, hidden2 = self.gru2(output, hidden2)
        context = torch.concatenate([hidden1, hidden2], axis=hidden1.ndim-1)
        context_mean = torch.mean(context)
        return output1, output2, hidden1, hidden2, context

    def initHidden(self):
        result = Variable(torch.zeros(1, 1, self.hidden_size))
        if use_cuda:
            return result.cuda()
        else:
            return result


        
        
# params['Wemb_dec'] = norm_weight(options['n_words'], options['dim_word'])
        
# # init_state of decoder
# params = get_layer('ff')[0](options, params,
#                             prefix='ff_init_state_char',
#                             nin=ctxdim,
#                             nout=options['dec_dim'])
# params = get_layer('ff')[0](options, params,
#                             prefix='ff_init_state_word',
#                             nin=ctxdim,
#                             nout=options['dec_dim'])

# print "target dictionary size: %d" % options['n_words'] 

# # decoder
# params = get_layer('two_layer_gru_decoder')[0](options, params,
#                                                prefix='decoder',
#                                                nin=options['dim_word'],
#                                                dim_char=options['dec_dim'],
#                                                dim_word=options['dec_dim'],
#                                                dimctx=ctxdim)
  
#           'two_layer_gru_decoder': ('param_init_two_layer_gru_decoder',
#                                     'two_layer_gru_decoder'),
 
class DecoderB2C(nn.Module):
    def __init__(self, hidden_size, output_size, n_layers=1):
        super(DecoderB2C, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax()

    def forward(self, in_data, hidden):
        output = self.embedding(in_data).view(1, 1, -1)
        for i in range(self.n_layers):
            output = F.relu(output)
            output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        result = Variable(torch.zeros(1, 1, self.hidden_size))
        if use_cuda:
            return result.cuda()
        else:
            return result

## comment
# # readout
# params = get_layer('fff')[0](options, params, prefix='ff_logit_rnn',
#                              nin1=options['dec_dim'], nin2=options['dec_dim'],
#                              nout=options['dim_word'], ortho=False)
# params = get_layer('ff')[0](options, params, prefix='ff_logit_prev',
#                             nin=options['dim_word'],
#                             nout=options['dim_word'],
#                             ortho=False)
# params = get_layer('ff')[0](options, params, prefix='ff_logit_ctx',
#                             nin=ctxdim,
#                             nout=options['dim_word'],
#                             ortho=False)
# params = get_layer('ff')[0](options, params, prefix='ff_logit',
#                             nin=options['dim_word'],
#                             nout=options['n_words'])

class SoftAlignment(nn.Module):
    def __init__(self):
        super(SoftAlignment, self).__init__()
        self.C = None
        
    def forward(self):
        return None
    def init_weights(self):
        return None
    
# loss function: negative log likelihood 
nn.NLLLoss()
# optimizer: SGD


In [None]:
def train_bi_b2c_model():
    #     model = 
    #     optimizer = 
    return



### Functions defined below (in original code)
- def init_params(options):
 - initializes layers of model for encoder and decoder using "get_layer" function from mixer.py

- def build_model(tparams, options):
 - formats data inputs     # description string: #words x #samples
 - generates embeddings
 - passes data through GRU
 
 
- def build_sampler(tparams, options, trng, use_noise):
 - not sure what samples are for 
- def gen_sample(tparams, f_init, f_next, x, options, trng=None,k=1, maxlen=500, stochastic=True, argmax=False):
 - not sure what samples are for
 

In [None]:
##### ORIGINAL CODE #####
'''
Build a simple neural language model using GRU units
'''
# import theano
# from theano import tensor
# from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams

import cPickle
import numpy
import copy

import os
import warnings
import sys
import time

from collections import OrderedDict
from mixer import *


def init_params(options):
    params = OrderedDict()

    print "source dictionary size: %d" % options['n_words_src']
    # embedding
    params['Wemb'] = norm_weight(options['n_words_src'], options['dim_word_src'])
    params['Wemb_dec'] = norm_weight(options['n_words'], options['dim_word'])

    # encoder
    params = get_layer('gru')[0](options, params,
                                 prefix='encoder',
                                 nin=options['dim_word_src'],
                                 dim=options['enc_dim'])
    params = get_layer('gru')[0](options, params,
                                 prefix='encoderr',
                                 nin=options['dim_word_src'],
                                 dim=options['enc_dim'])
    ctxdim = 2 * options['enc_dim']

    # init_state of decoder
    params = get_layer('ff')[0](options, params,
                                prefix='ff_init_state_char',
                                nin=ctxdim,
                                nout=options['dec_dim'])
    params = get_layer('ff')[0](options, params,
                                prefix='ff_init_state_word',
                                nin=ctxdim,
                                nout=options['dec_dim'])

    print "target dictionary size: %d" % options['n_words']
    
    # decoder
    params = get_layer('two_layer_gru_decoder')[0](options, params,
                                                   prefix='decoder',
                                                   nin=options['dim_word'],
                                                   dim_char=options['dec_dim'],
                                                   dim_word=options['dec_dim'],
                                                   dimctx=ctxdim)

    # readout
    params = get_layer('fff')[0](options, params, prefix='ff_logit_rnn',
                                 nin1=options['dec_dim'], nin2=options['dec_dim'],
                                 nout=options['dim_word'], ortho=False)
    params = get_layer('ff')[0](options, params, prefix='ff_logit_prev',
                                nin=options['dim_word'],
                                nout=options['dim_word'],
                                ortho=False)
    params = get_layer('ff')[0](options, params, prefix='ff_logit_ctx',
                                nin=ctxdim,
                                nout=options['dim_word'],
                                ortho=False)
    params = get_layer('ff')[0](options, params, prefix='ff_logit',
                                nin=options['dim_word'],
                                nout=options['n_words'])

    return params


def build_model(tparams, options):
    opt_ret = OrderedDict()

    trng = RandomStreams(numpy.random.RandomState(numpy.random.randint(1024)).randint(numpy.iinfo(numpy.int32).max))
    use_noise = theano.shared(numpy.float32(0.))

    # description string: #words x #samples
    x = tensor.matrix('x', dtype='int64')
    x_mask = tensor.matrix('x_mask', dtype='float32')
    y = tensor.matrix('y', dtype='int64')
    y_mask = tensor.matrix('y_mask', dtype='float32')
    x.tag.test_value = numpy.zeros((5, 63), dtype='int64')
    x_mask.tag.test_value = numpy.ones((5, 63), dtype='float32')
    y.tag.test_value = numpy.zeros((7, 63), dtype='int64')
    y_mask.tag.test_value = numpy.ones((7, 63), dtype='float32')

    xr = x[::-1]
    xr_mask = x_mask[::-1]

    n_samples = x.shape[1]
    n_timesteps = x.shape[0]
    n_timesteps_trg = y.shape[0]

    # word embedding for forward RNN (source)
    emb = tparams['Wemb'][x.flatten()]
    emb = emb.reshape([n_timesteps, n_samples, options['dim_word_src']])

    # word embedding for backward RNN (source)
    embr = tparams['Wemb'][xr.flatten()]
    embr = embr.reshape([n_timesteps, n_samples, options['dim_word_src']])

    # pass through gru layer, recurrence here
    proj = get_layer('gru')[1](tparams, emb, options,
                               prefix='encoder', mask=x_mask)
    projr = get_layer('gru')[1](tparams, embr, options,
                                prefix='encoderr', mask=xr_mask)

    # context
    ctx = concatenate([proj, projr[::-1]], axis=proj.ndim-1)

    # context mean
    ctx_mean = (ctx * x_mask[:, :, None]).sum(0) / x_mask.sum(0)[:, None]

    # initial decoder state
    init_state_char = get_layer('ff')[1](tparams, ctx_mean, options,
                                         prefix='ff_init_state_char', activ='tanh')
    init_state_word = get_layer('ff')[1](tparams, ctx_mean, options,
                                         prefix='ff_init_state_word', activ='tanh')

    # word embedding and shifting for targets
    yemb = tparams['Wemb_dec'][y.flatten()]
    yemb = yemb.reshape([n_timesteps_trg, n_samples, options['dim_word']])
    yemb_shited = tensor.zeros_like(yemb)
    yemb_shited = tensor.set_subtensor(yemb_shited[1:], yemb[:-1])
    yemb = yemb_shited

    char_h, word_h, ctxs, alphas = \
            get_layer('two_layer_gru_decoder')[1](tparams, yemb, options,
                                                  prefix='decoder',
                                                  mask=y_mask,
                                                  context=ctx,
                                                  context_mask=x_mask,
                                                  one_step=False,
                                                  init_state_char=init_state_char,
                                                  init_state_word=init_state_word)

    opt_ret['dec_alphas'] = alphas

    # compute word probabilities
    logit_rnn = get_layer('fff')[1](tparams, char_h, word_h, options,
                                    prefix='ff_logit_rnn', activ='linear')
    logit_prev = get_layer('ff')[1](tparams, yemb, options,
                                    prefix='ff_logit_prev', activ='linear')
    logit_ctx = get_layer('ff')[1](tparams, ctxs, options,
                                   prefix='ff_logit_ctx', activ='linear')
    logit = tensor.tanh(logit_rnn + logit_prev + logit_ctx)

    if options['use_dropout']:
        print 'Using dropout'
        logit = dropout_layer(logit, use_noise, trng)

    logit = get_layer('ff')[1](tparams, logit, options,
                               prefix='ff_logit', activ='linear')
    logit_shp = logit.shape
    probs = tensor.nnet.softmax(logit.reshape([logit_shp[0]*logit_shp[1], logit_shp[2]]))

    # cost
    y_flat = y.flatten()
    y_flat_idx = tensor.arange(y_flat.shape[0]) * options['n_words'] + y_flat
    cost = -tensor.log(probs.flatten()[y_flat_idx])
    cost = cost.reshape([y.shape[0], y.shape[1]])
    cost = (cost * y_mask).sum(0)

    return trng, use_noise, x, x_mask, y, y_mask, opt_ret, cost



In [None]:

def build_sampler(tparams, options, trng, use_noise):
    x = tensor.matrix('x', dtype='int64')
    xr = x[::-1]

    n_timesteps = x.shape[0]
    n_samples = x.shape[1]

    emb = tparams['Wemb'][x.flatten()]
    emb = emb.reshape([n_timesteps, n_samples, options['dim_word_src']])
    embr = tparams['Wemb'][xr.flatten()]
    embr = embr.reshape([n_timesteps, n_samples, options['dim_word_src']])

    proj = get_layer('gru')[1](tparams, emb, options, prefix='encoder')
    projr = get_layer('gru')[1](tparams, embr, options, prefix='encoderr')

    ctx = concatenate([proj, projr[::-1]], axis=proj.ndim-1)
    ctx_mean = ctx.mean(0)

    init_state_char = get_layer('ff')[1](tparams, ctx_mean, options,
                                         prefix='ff_init_state_char', activ='tanh')
    init_state_word = get_layer('ff')[1](tparams, ctx_mean, options,
                                         prefix='ff_init_state_word', activ='tanh')

    print 'Building f_init...',
    outs = [init_state_char, init_state_word, ctx]
    f_init = theano.function([x], outs, name='f_init', profile=profile)
    print 'Done'

    y = tensor.vector('y_sampler', dtype='int64')
    init_state_char = tensor.matrix('init_state_char', dtype='float32')
    init_state_word = tensor.matrix('init_state_word', dtype='float32')

    # if it's the first word, emb should be all zero and it is indicated by -1
    yemb = tensor.switch(y[:, None] < 0,
                         tensor.alloc(0., 1, tparams['Wemb_dec'].shape[1]),
                         tparams['Wemb_dec'][y])

    next_state_char, next_state_word, next_ctx, next_alpha = \
            get_layer('two_layer_gru_decoder')[1](tparams, yemb, options,
                                                  prefix='decoder',
                                                  context=ctx,
                                                  mask=None,
                                                  one_step=True,
                                                  init_state_char=init_state_char,
                                                  init_state_word=init_state_word)

    logit_rnn = get_layer('fff')[1](tparams,
                                    next_state_char,
                                    next_state_word,
                                    options,
                                    prefix='ff_logit_rnn',
                                    activ='linear')
    logit_prev = get_layer('ff')[1](tparams,
                                    yemb,
                                    options,
                                    prefix='ff_logit_prev',
                                    activ='linear')
    logit_ctx = get_layer('ff')[1](tparams,
                                   next_ctx,
                                   options,
                                   prefix='ff_logit_ctx',
                                   activ='linear')
    logit = tensor.tanh(logit_rnn + logit_prev + logit_ctx)

    if options['use_dropout']:
        print 'Sampling for dropoutted model'
        logit = dropout_layer(logit, use_noise, trng)

    logit = get_layer('ff')[1](tparams, logit, options,
                               prefix='ff_logit',
                               activ='linear')
    next_probs = tensor.nnet.softmax(logit)
    next_sample = trng.multinomial(pvals=next_probs).argmax(1)

    # next word probability
    print 'Building f_next...',
    inps = [y, ctx, init_state_char, init_state_word]
    outs = [next_probs, next_sample, next_state_char, next_state_word]
    f_next = theano.function(inps, outs, name='f_next', profile=profile)
    print 'Done'

    return f_init, f_next


def gen_sample(tparams, f_init, f_next, x, options, trng=None,
               k=1, maxlen=500, stochastic=True, argmax=False):

    # k is the beam size we have
    if k > 1:
        assert not stochastic, \
            'Beam search does not support stochastic sampling'

    sample = []
    sample_score = []
    if stochastic:
        sample_score = 0

    live_k = 1
    dead_k = 0

    hyp_samples = [[]] * live_k
    hyp_scores = numpy.zeros(live_k).astype('float32')
    hyp_states = []

    # get initial state of decoder rnn and encoder context
    ret = f_init(x)
    next_state_char, next_state_word, ctx0 = ret[0], ret[1], ret[2]
    next_w = -1 * numpy.ones((1,)).astype('int64')  # bos indicator

    for ii in xrange(maxlen):
        ctx = numpy.tile(ctx0, [live_k, 1])
        inps = [next_w, ctx, next_state_char, next_state_word]
        ret = f_next(*inps)
        next_p, next_w, next_state_char, next_state_word = ret[0], ret[1], ret[2], ret[3]
        if stochastic:
            if argmax:
                nw = next_p[0].argmax()
            else:
                nw = next_w[0]
            sample.append(nw)
            sample_score += next_p[0, nw]
            if nw == 0:
                break
        else:
            cand_scores = hyp_scores[:, None] - numpy.log(next_p)
            cand_flat = cand_scores.flatten()
            ranks_flat = cand_flat.argsort()[:(k-dead_k)]

            voc_size = next_p.shape[1]
            trans_indices = ranks_flat / voc_size
            word_indices = ranks_flat % voc_size
            costs = cand_flat[ranks_flat]

            new_hyp_samples = []
            new_hyp_scores = numpy.zeros(k-dead_k).astype('float32')
            new_hyp_states_char = []
            new_hyp_states_word = []

            for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
                new_hyp_samples.append(hyp_samples[ti]+[wi])
                new_hyp_scores[idx] = copy.copy(costs[idx])
                new_hyp_states_char.append(copy.copy(next_state_char[ti]))
                new_hyp_states_word.append(copy.copy(next_state_word[ti]))

            # check the finished samples
            new_live_k = 0
            hyp_samples = []
            hyp_scores = []
            hyp_states_char = []
            hyp_states_word = []

            for idx in xrange(len(new_hyp_samples)):
                if new_hyp_samples[idx][-1] == 0:
                    sample.append(new_hyp_samples[idx])
                    sample_score.append(new_hyp_scores[idx])
                    dead_k += 1
                else:
                    new_live_k += 1
                    hyp_samples.append(new_hyp_samples[idx])
                    hyp_scores.append(new_hyp_scores[idx])
                    hyp_states_char.append(new_hyp_states_char[idx])
                    hyp_states_word.append(new_hyp_states_word[idx])
            hyp_scores = numpy.array(hyp_scores)
            live_k = new_live_k

            if new_live_k < 1:
                break
            if dead_k >= k:
                break

            next_w = numpy.array([w[-1] for w in hyp_samples])
            next_state_char = numpy.array(hyp_states_char)
            next_state_word = numpy.array(hyp_states_word)

    if not stochastic:
        # dump every remaining one
        if live_k > 0:
            for idx in xrange(live_k):
                sample.append(hyp_samples[idx])
                sample_score.append(hyp_scores[idx])

    return sample, sample_score

In [None]:
# train function from nmt file
# https://github.com/nyu-dl/dl4mt-c2c/blob/master/bpe2char/nmt.py


def train(
      dim_word=100,
      dim_word_src=200,
      enc_dim=1000,
      dec_dim=1000,  # the number of LSTM units
      patience=-1,  # early stopping patience
      max_epochs=5000,
      finish_after=-1,  # finish after this many updates
      decay_c=0.,  # L2 regularization penalty
      alpha_c=0.,  # alignment regularization
      clip_c=-1.,  # gradient clipping threshold
      lrate=0.01,  # learning rate
      n_words_src=100000,  # source vocabulary size
      n_words=100000,  # target vocabulary size
      maxlen=100,  # maximum length of the description
      maxlen_trg=None,  # maximum length of the description
      maxlen_sample=1000,
      optimizer='rmsprop',
      batch_size=16,
      valid_batch_size=16,
      sort_size=20,
      save_path=None,
      save_file_name='model',
      save_best_models=0,
      dispFreq=100,
      validFreq=100,
      saveFreq=1000,   # save the parameters after every saveFreq updates
      sampleFreq=-1,
      verboseFreq=10000,
      datasets=[
          'data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok',
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok'],
      valid_datasets=['../data/dev/newstest2011.en.tok',
                      '../data/dev/newstest2011.fr.tok'],
      dictionaries=[
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.en.tok.pkl',
          '/data/lisatmp3/chokyun/europarl/europarl-v7.fr-en.fr.tok.pkl'],
      source_word_level=0,
      target_word_level=0,
      use_dropout=False,
      re_load=False,
      re_load_old_setting=False,
      uidx=None,
      eidx=None,
      cidx=None,
      layers=None,
      save_every_saveFreq=0,
      save_burn_in=20000,
      use_bpe=0,
      gru='gru',
      init_params=None,
      build_model=None,
      build_sampler=None,
      gen_sample=None,
      **kwargs
    ):

    if gru not in "gru lngru".split():
        raise

    print "GRU:", gru

    if maxlen_trg is None:
        maxlen_trg = maxlen * 10
    # Model options
    model_options = locals().copy()
    del model_options['init_params']
    del model_options['build_model']
    del model_options['build_sampler']
    del model_options['gen_sample']

    # load dictionaries and invert them
    worddicts = [None] * len(dictionaries)
    worddicts_r = [None] * len(dictionaries)
    for ii, dd in enumerate(dictionaries):
        with open(dd, 'rb') as f:
            worddicts[ii] = cPickle.load(f)
        worddicts_r[ii] = dict()
        for kk, vv in worddicts[ii].iteritems():
            worddicts_r[ii][vv] = kk

    print 'Building model'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    file_name = '%s%s.npz' % (save_path, save_file_name)
    best_file_name = '%s%s.best.npz' % (save_path, save_file_name)
    opt_file_name = '%s%s%s.npz' % (save_path, save_file_name, '.grads')
    best_opt_file_name = '%s%s%s.best.npz' % (save_path, save_file_name, '.grads')
    model_name = '%s%s.pkl' % (save_path, save_file_name)
    params = init_params(model_options)
    cPickle.dump(model_options, open(model_name, 'wb'))
    history_errs = []

    # reload options
    if re_load and os.path.exists(file_name):
        print 'You are reloading your experiment.. do not panic dude..'
        if re_load_old_setting:
            with open(model_name, 'rb') as f:
                model_options = cPickle.load(f)
        params = load_params(file_name, params)
        # reload history
        model = numpy.load(file_name)
        history_errs = list(model['history_errs'])
        if uidx is None:
            uidx = model['uidx']
        if eidx is None:
            eidx = model['eidx']
        if cidx is None:
            cidx = model['cidx']
    else:
        if uidx is None:
            uidx = 0
        if eidx is None:
            eidx = 0
        if cidx is None:
            cidx = 0

    print 'Loading data'
    train = TextIterator(source=datasets[0],
                         target=datasets[1],
                         source_dict=dictionaries[0],
                         target_dict=dictionaries[1],
                         n_words_source=n_words_src,
                         n_words_target=n_words,
                         source_word_level=source_word_level,
                         target_word_level=target_word_level,
                         batch_size=batch_size,
                         sort_size=sort_size)
    valid = TextIterator(source=valid_datasets[0],
                         target=valid_datasets[1],
                         source_dict=dictionaries[0],
                         target_dict=dictionaries[1],
                         n_words_source=n_words_src,
                         n_words_target=n_words,
                         source_word_level=source_word_level,
                         target_word_level=target_word_level,
                         batch_size=valid_batch_size,
                         sort_size=sort_size)

    # create shared variables for parameters
    tparams = init_tparams(params)

    trng, use_noise, \
        x, x_mask, y, y_mask, \
        opt_ret, \
        cost = \
        build_model(tparams, model_options)
    inps = [x, x_mask, y, y_mask]

    print 'Building sampler...\n',
    f_init, f_next = build_sampler(tparams, model_options, trng, use_noise)
    #print 'Done'

    # before any regularizer
    print 'Building f_log_probs...',
    f_log_probs = theano.function(inps, cost, profile=profile)
    print 'Done'
    if re_load:
        use_noise.set_value(0.)
        valid_errs = pred_probs(f_log_probs, prepare_data,
                                model_options, valid, verboseFreq=verboseFreq)
        valid_err = valid_errs.mean()

        if numpy.isnan(valid_err):
            import ipdb
            ipdb.set_trace()

        print 'Reload sanity check: Valid ', valid_err

    cost = cost.mean()

    # apply L2 regularization on weights
    if decay_c > 0.:
        decay_c = theano.shared(numpy.float32(decay_c), name='decay_c')
        weight_decay = 0.
        for kk, vv in tparams.iteritems():
            weight_decay += (vv ** 2).sum()
        weight_decay *= decay_c
        cost += weight_decay

    # regularize the alpha weights
    if alpha_c > 0. and not model_options['decoder'].endswith('simple'):
        alpha_c = theano.shared(numpy.float32(alpha_c), name='alpha_c')
        alpha_reg = alpha_c * (
            (tensor.cast(y_mask.sum(0) // x_mask.sum(0), 'float32')[:, None] -
             opt_ret['dec_alphas'].sum(0))**2).sum(1).mean()
        cost += alpha_reg

    # after all regularizers - compile the computational graph for cost
    print 'Building f_cost...',
    f_cost = theano.function(inps, cost, profile=profile)
    print 'Done'

    print 'Computing gradient...',
    grads = tensor.grad(cost, wrt=itemlist(tparams))
    print 'Done'

    if clip_c > 0:
        grads, not_finite, clipped = gradient_clipping(grads, tparams, clip_c)
    else:
        not_finite = 0
        clipped = 0

    # compile the optimizer, the actual computational graph is compiled here
    lr = tensor.scalar(name='lr')
    print 'Building optimizers...',
    if re_load and os.path.exists(file_name):
        if clip_c > 0:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                  not_finite=not_finite, clipped=clipped,
                                                                  file_name=opt_file_name)
        else:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                  file_name=opt_file_name)
    else:
        if clip_c > 0:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost,
                                                                  not_finite=not_finite, clipped=clipped)
        else:
            f_grad_shared, f_update, toptparams = eval(optimizer)(lr, tparams, grads, inps, cost=cost)
    print 'Done'

    print 'Optimization'
    best_p = None
    bad_counter = 0

    if validFreq == -1:
        validFreq = len(train[0]) / batch_size
    if saveFreq == -1:
        saveFreq = len(train[0]) / batch_size

    # Training loop
    ud_start = time.time()
    estop = False

    if re_load:
        print "Checkpointed minibatch number: %d" % cidx
        for cc in xrange(cidx):
            if numpy.mod(cc, 1000)==0:
                print "Jumping [%d / %d] examples" % (cc, cidx)
            train.next()

    for epoch in xrange(max_epochs):
        time0 = time.time()
        n_samples = 0
        NaN_grad_cnt = 0
        NaN_cost_cnt = 0
        clipped_cnt = 0
        if re_load:
            re_load = 0
        else:
            cidx = 0

        for x, y in train:
            cidx += 1
            uidx += 1
            use_noise.set_value(1.)

            x, x_mask, y, y_mask, n_x = prepare_data(x, y, maxlen=maxlen,
                                                     maxlen_trg=maxlen_trg,
                                                     n_words_src=n_words_src,
                                                     n_words=n_words)
            if x is None:
                print 'Minibatch with zero sample under length ', maxlen
                uidx -= 1
                uidx = max(uidx, 0)
                continue

            n_samples += n_x

            # compute cost, grads and copy grads to shared variables
            if clip_c > 0:
                cost, not_finite, clipped = f_grad_shared(x, x_mask, y, y_mask)
            else:
                cost = f_grad_shared(x, x_mask, y, y_mask)

            if clipped:
                clipped_cnt += 1

            # check for bad numbers, usually we remove non-finite elements
            # and continue training - but not done here
            if numpy.isnan(cost) or numpy.isinf(cost):
                NaN_cost_cnt += 1

            if not_finite:
                NaN_grad_cnt += 1
                continue

            # do the update on parameters
            f_update(lrate)

            if numpy.isnan(cost) or numpy.isinf(cost):
                continue

            if float(NaN_grad_cnt) > max_epochs * 0.5 or float(NaN_cost_cnt) > max_epochs * 0.5:
                print 'Too many NaNs, abort training'
                return 1., 1., 1.

            # verbose
            if numpy.mod(uidx, dispFreq) == 0:
                ud = time.time() - ud_start
                wps = n_samples / float(time.time() - time0)
                print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost, 'NaN_in_grad', NaN_grad_cnt,\
                      'NaN_in_cost', NaN_cost_cnt, 'Gradient_clipped', clipped_cnt, 'UD ', ud, "%.2f sentence/s" % wps
                ud_start = time.time()

            # generate some samples with the model and display them
            if numpy.mod(uidx, sampleFreq) == 0 and sampleFreq != -1:
                # FIXME: random selection?
                for jj in xrange(numpy.minimum(5, x.shape[1])):
                    stochastic = True
                    use_noise.set_value(0.)
                    sample, score = gen_sample(tparams, f_init, f_next,
                                               x[:, jj][:, None],
                                               model_options, trng=trng, k=1,
                                               maxlen=maxlen_sample,
                                               stochastic=stochastic,
                                               argmax=False)
                    print
                    print 'Source ', jj, ': ',
                    if source_word_level:
                        for vv in x[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[0]:
                                if use_bpe:
                                    print (worddicts_r[0][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[0][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        source_ = []
                        for vv in x[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[0]:
                                source_.append(worddicts_r[0][vv])
                            else:
                                source_.append('UNK')
                        print "".join(source_)
                    print 'Truth ', jj, ' : ',
                    if target_word_level:
                        for vv in y[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                if use_bpe:
                                    print (worddicts_r[1][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[1][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        truth_ = []
                        for vv in y[:, jj]:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                truth_.append(worddicts_r[1][vv])
                            else:
                                truth_.append('UNK')
                        print "".join(truth_)
                    print 'Sample ', jj, ': ',
                    if stochastic:
                        ss = sample
                    else:
                        score = score / numpy.array([len(s) for s in sample])
                        ss = sample[score.argmin()]
                    if target_word_level:
                        for vv in ss:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                if use_bpe:
                                    print (worddicts_r[1][vv]).replace('@@', ''),
                                else:
                                    print worddicts_r[1][vv],
                            else:
                                print 'UNK',
                        print
                    else:
                        sample_ = []
                        for vv in ss:
                            if vv == 0:
                                break
                            if vv in worddicts_r[1]:
                                sample_.append(worddicts_r[1][vv])
                            else:
                                sample_.append('UNK')
                        print "".join(sample_)
                    print

            # validate model on validation set and early stop if necessary
            if numpy.mod(uidx, validFreq) == 0:
                use_noise.set_value(0.)
                valid_errs = pred_probs(f_log_probs, prepare_data,
                                        model_options, valid, verboseFreq=verboseFreq)
                valid_err = valid_errs.mean()
                history_errs.append(valid_err)

                if uidx == 0 or valid_err <= numpy.array(history_errs).min():
                    best_p = unzip(tparams)
                    best_optp = unzip(toptparams)
                    bad_counter = 0

                if saveFreq != validFreq and save_best_models:
                    numpy.savez(best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, **best_p)
                    numpy.savez(best_opt_file_name, **best_optp)

                if len(history_errs) > patience and valid_err >= \
                        numpy.array(history_errs)[:-patience].min() and patience != -1:
                    bad_counter += 1
                    if bad_counter > patience:
                        print 'Early Stop!'
                        estop = True
                        break

                if numpy.isnan(valid_err):
                    import ipdb
                    ipdb.set_trace()

                print 'Valid ', valid_err

            # save the best model so far
            if numpy.mod(uidx, saveFreq) == 0:
                print 'Saving...',

                if not os.path.exists(save_path):
                    os.mkdir(save_path)

                params = unzip(tparams)
                optparams = unzip(toptparams)
                numpy.savez(file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                            cidx=cidx, **params)
                numpy.savez(opt_file_name, **optparams)

                if save_every_saveFreq and (uidx >= save_burn_in):
                    this_file_name = '%s%s.%d.npz' % (save_path, save_file_name, uidx)
                    this_opt_file_name = '%s%s%s.%d.npz' % (save_path, save_file_name, '.grads', uidx)
                    numpy.savez(this_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, **params)
                    numpy.savez(this_opt_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                cidx=cidx, **params)
                    if best_p is not None and saveFreq != validFreq:
                        this_best_file_name = '%s%s.%d.best.npz' % (save_path, save_file_name, uidx)
                        numpy.savez(this_best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx,
                                    cidx=cidx, **best_p)
                print 'Done...',
                print 'Saved to %s' % file_name

            # finish after this many updates
            if uidx >= finish_after and finish_after != -1:
                print 'Finishing after %d iterations!' % uidx
                estop = True
                break

        print 'Seen %d samples' % n_samples
        eidx += 1

        if estop:
            break

    use_noise.set_value(0.)
    valid_err = pred_probs(f_log_probs, prepare_data,
                           model_options, valid).mean()

    print 'Valid ', valid_err

    params = unzip(tparams)
    optparams = unzip(toptparams)
    file_name = '%s%s.%d.npz' % (save_path, save_file_name, uidx)
    opt_file_name = '%s%s%s.%d.npz' % (save_path, save_file_name, '.grads', uidx)
    numpy.savez(file_name, history_errs=history_errs, uidx=uidx, eidx=eidx, cidx=cidx, **params)
    numpy.savez(opt_file_name, **optparams)
    if best_p is not None and saveFreq != validFreq:
        best_file_name = '%s%s.%d.best.npz' % (save_path, save_file_name, uidx)
        best_opt_file_name = '%s%s%s.%d.best.npz' % (save_path, save_file_name, '.grads',uidx)
        numpy.savez(best_file_name, history_errs=history_errs, uidx=uidx, eidx=eidx, cidx=cidx, **best_p)
        numpy.savez(best_opt_file_name, **best_optp)

    return valid_err