In [1]:
import sys; sys.path.append('..')

import os
import json
import argparse
from random import choice
from time import time
from datetime import timedelta

import numpy as np
import tensorflow as tf
from tqdm import tqdm
import matplotlib.pyplot as plt
from pandas import ewma

from vocab import Vocab
from src.training_utils import *
from lib.tensor_utils import infer_mask, initialize_uninitialized_variables, all_shapes_equal
from lib.utils import save_score

from models.transformer_fused import Model
from models.transformer_lm import TransformerLM

  return f(*args, **kwds)
Using TensorFlow backend.


In [2]:
from os import path

model_name = 'transformer'

config = {
    'data_path': '../data_small',
    'src_lm_path': '../trained_models/lm1/model.npz',
    'target_lm_path': '../trained_models/lm2/model.npz',
    'hp_file_path': '../hp_files/trans_default.json',
    'use_early_stopping': True,
    'early_stopping_last_n': 10,
    'max_epochs': 1000,
    'max_time_seconds': 600,
    'batch_size_for_inference': 16,
    'max_len': 200,
    'validate_every_epoch': True,
    'warm_up_num_epochs': 10
}

In [3]:
model_path = 'trained_models/{}'.format(model_name)
if not os.path.isdir('trained_models'): os.mkdir('trained_models')
if not os.path.isdir(model_path): os.mkdir(model_path)

src_train_path = '{}/bpe_parallel_train1.txt'.format(config.get('data_path'))
dst_train_path = '{}/bpe_parallel_train2.txt'.format(config.get('data_path'))
src_val_path = '{}/bpe_parallel_val1.txt'.format(config.get('data_path'))
dst_val_path = '{}/bpe_parallel_val2.txt'.format(config.get('data_path'))
src_unlabeled_path = '{}/bpe_corpus1.txt'.format(config.get('data_path'))
dst_unlabeled_path = '{}/bpe_corpus2.txt'.format(config.get('data_path'))

src_train = open(src_train_path, 'r', encoding='utf-8').read().splitlines()
dst_train = open(dst_train_path, 'r', encoding='utf-8').read().splitlines()
src_val = open(src_val_path, 'r', encoding='utf-8').read().splitlines()
dst_val = open(dst_val_path, 'r', encoding='utf-8').read().splitlines()
src_unlabeled = open(src_unlabeled_path, 'r', encoding='utf-8').read().splitlines()
dst_unlabeled = open(dst_unlabeled_path, 'r', encoding='utf-8').read().splitlines()

inp_voc = Vocab.from_file('{}/1.voc'.format(config.get('data_path')))
out_voc = Vocab.from_file('{}/2.voc'.format(config.get('data_path')))
max_len = config.get('max_len', 200)

# Hyperparameters
hp = json.load(open(config.get('hp_file_path'), 'r', encoding='utf-8')) if config.get('hp_file_path') else {}

In [4]:
gpu_options = create_gpu_options(config)
sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))

# Disabling GPU
# tf_config = tf.ConfigProto(device_count = {'GPU': 0})
# sess = tf.InteractiveSession(config=tf_config)

In [5]:
optimizer = create_optimizer(hp)
inp = tf.placeholder(tf.int32, [None, None])
out = tf.placeholder(tf.int32, [None, None])

In [6]:
#############################################
### Initializing our main model
#############################################
lm = TransformerLM('lm2', out_voc, **hp)
if config.get('target_lm_path'):
    lm_weights = np.load(config.get('target_lm_path'))
    ops = []
    for w in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, lm.name):
        if w.name in lm_weights:
            ops.append(tf.assign(w, lm_weights[w.name]))
        else:
            print(w.name, 'not initialized')
    sess.run(ops)
else:
    raise ValueError("Must specify LM path!")

model = Model(model_name, inp_voc, out_voc, lm, **hp)

logprobs = model.symbolic_score(inp, out, is_train=True)[:,:tf.shape(out)[1]]
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logprobs, labels=out)
loss = nll * infer_mask(out, out_voc.eos, dtype=tf.float32)
loss = tf.reduce_sum(loss, axis=1)
loss = tf.reduce_mean(loss)
weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, model_name)

grads = tf.gradients(loss, weights)
grads = tf.clip_by_global_norm(grads, 100)[0]
train_step = optimizer.apply_gradients(zip(grads, weights))
#############################################

In [7]:
#############################################
### Initializing our back-translation model
### TODO(universome): DRY
#############################################
lm_bk = TransformerLM('lm1', inp_voc, **hp)
if config.get('src_lm_path'):
    lm_weights = np.load(config.get('src_lm_path'))
    ops = []
    for w in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, lm_bk.name):
        if w.name in lm_weights:
            ops.append(tf.assign(w, lm_weights[w.name]))
        else:
            print(w.name, 'not initialized')
    sess.run(ops)
else:
    raise ValueError("Must specify src LM path!")
model_bk = Model(model_name + "bk", out_voc, inp_voc, lm_bk, **hp)

logprobs_bk = model_bk.symbolic_score(out, inp, is_train=True)[:, :tf.shape(inp)[1]]
nll_bk = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logprobs_bk, labels=inp)
loss_bk = nll_bk * infer_mask(inp, inp_voc.eos, dtype=tf.float32)
loss_bk = tf.reduce_sum(loss_bk, axis=1)
loss_bk = tf.reduce_mean(loss_bk)
weights_bk = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, model_name + 'bk')

grads_bk = tf.gradients(loss_bk, weights_bk)
grads_bk = tf.clip_by_global_norm(grads_bk, 100)[0]
train_step_bk = optimizer.apply_gradients(zip(grads_bk, weights_bk))
#############################################

In [8]:
all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
non_trainable_vars = list(set(all_vars).difference(set(weights + weights_bk)))

initialize_uninitialized_variables(sess)

In [9]:
def init_with_lms(model, weights, target_lm_path, src_lm_path, session):
    assigns = []
    # init model with LMs
    weights_by_common_name = {w.name[len(model.name) + 1:]: w for w in weights}
    with np.load(target_lm_path) as dic:
        for key in dic:  # decoder_init
            w_lm = dic[key]
            weights_key = '/'.join(key.split('/')[1:]).replace('main/', '').replace("enc", 'dec').replace("inp", "out")
            if "emb_out_bias" in weights_key:  # no such thing
                continue

            w_var = weights_by_common_name[weights_key]

            all_shapes_equal(w_lm, w_var, session=session, mode='assert')

            assigns.append(tf.assign(w_var, w_lm))
    with np.load(src_lm_path) as dic:
        for key in dic:  # encoder_init
            w_lm = dic[key]
            weights_key = '/'.join(key.split('/')[1:]).replace('main/', '')
            if "logits" in weights_key:  # encoder has no 'logits' layer for the logits to be initialised
                continue
            w_var = weights_by_common_name[weights_key]

            all_shapes_equal(w_lm, w_var, session=session, mode='assert')
            assigns.append(tf.assign(w_var, w_lm))
    session.run(assigns)

In [10]:
assert config.get('target_lm_path') and config.get('src_lm_path')
init_with_lms(model, weights, config.get('target_lm_path'), config.get('src_lm_path'), sess)
init_with_lms(model_bk, weights_bk, config.get('src_lm_path'), config.get('target_lm_path'), sess)

In [11]:
batch_size = hp.get('batch_size', 32)
epoch = 0
training_start_time = time()
loss_history = []
loss_history_bk = []
val_scores = []

num_iters_done = 0
should_start_next_epoch = True # We need this var to break outer loop

In [12]:
def save_model():
    save_path = '{}/model.npz'.format(model_path)
    print('Saving the model into %s' %save_path)

    w_values = sess.run(weights)
    weights_dict = {w.name: w_val for w, w_val in zip(weights, w_values)}
    np.savez(save_path, **weights_dict)

def validate(val_scores):
    """
    Returns should_continue flag, which tells us if we should continue or early stop
    """
    should_continue = True

    # if config.get('warm_up_num_epochs') and config.get('warm_up_num_epochs') > epoch:
    #     print('Skipping validation, becaused is not warmed up yet')
    #     return should_continue

    print('Validating')
    val_score = compute_bleu_for_model(model, sess, inp_voc, out_voc, src_val, dst_val,
                                        model_name, config, max_len=max_len)
    val_scores.append(val_score)
    print('Validation BLEU: {:0.3f}'.format(val_score))

    # Save model if this is our best model
    if np.argmax(val_scores) == len(val_scores)-1:
        print('Saving model because it has the highest validation BLEU.')
        save_model()

    if config.get('use_early_stopping') and should_stop_early(val_scores, config.get('early_stopping_last_n')):
        print('Model did not improve for last %s steps. Early stopping.' % config.get('early_stopping_last_n'))
        should_continue = False

    if config.get('warm_up_num_epochs') > epoch:
        return True
    else:
        return should_continue

In [13]:
syntethic_src_trg = []
syntethic_trg_src = []
max_num_synthetic_sents = config.get('max_num_synthetic_sents', 5000)

In [14]:
def translate_sents(model, sess, inp_voc, out_voc, src_val, model_name, config, max_len=200):
    src_val_ix = inp_voc.tokenize_many(src_val)

    inp = tf.placeholder(tf.int32, [None, None])
    translations = []

    if model_name == 'gnmt':
        raise NotImplemented("deprecated model")
    sy_translations = model.symbolic_translate(inp, mode='greedy', max_len=max_len,
                                               back_prop=False, swap_memory=True).best_out

    for batch in iterate_minibatches(src_val_ix, batchsize=config.get('batch_size_for_inference', 64)):
        translations += sess.run([sy_translations], feed_dict={inp: batch[0][:, :max_len]})[0].tolist()

    outputs = out_voc.detokenize_many(translations, unbpe=True, deprocess=True)

    return outputs

In [15]:
should_start_next_epoch = True

while should_start_next_epoch:
    batches = batch_generator_over_dataset(src_train, dst_train, batch_size, batches_per_epoch=None)
    with tqdm(batches) as t:
        for batch_src, batch_dst in t:
            # Note: we don't use voc.tokenize_many(batch, max_len=max_len)
            # cuz it forces batch length to be that long and we often get away with much less
            batch_src_ix = inp_voc.tokenize_many(batch_src)[:, :max_len]
            batch_dst_ix = out_voc.tokenize_many(batch_dst)[:, :max_len]

            feed_dict = {inp: batch_src_ix, out: batch_dst_ix}
            _, loss_t, _, loss_t_bk = sess.run([train_step, loss, train_step_bk, loss_bk], feed_dict)

            loss_history.append(np.mean(loss_t))
            loss_history_bk.append(np.mean(loss_t_bk))

            if len(syntethic_src_trg) > 0:
                # src -> dst back trans results
                syntethic_src_trg_batch = [choice(range(len(syntethic_src_trg)))
                                           for i in range(batch_size // 10 + 1)]

                batch_src = [syntethic_src_trg[i][0] for i in syntethic_src_trg_batch]
                batch_dst = [syntethic_src_trg[i][1] for i in syntethic_src_trg_batch]

                batch_src_ix = inp_voc.tokenize_many(batch_src)[:, :max_len]
                batch_dst_ix = out_voc.tokenize_many(batch_dst)[:, :max_len]

                feed_dict = {inp: batch_src_ix, out: batch_dst_ix}
                _, loss_t_bk = sess.run([train_step_bk, loss_bk], feed_dict)

            if len(syntethic_trg_src) > 0:
                # src <- dst back trans results
                syntethic_trg_src_batch = [choice(range(len(syntethic_trg_src)))
                                           for i in range(batch_size // 10 + 1)]

                batch_src = [syntethic_trg_src[i][1] for i in syntethic_trg_src_batch]
                batch_dst = [syntethic_trg_src[i][0] for i in syntethic_trg_src_batch]

                batch_src_ix = inp_voc.tokenize_many(batch_src)[:, :max_len]
                batch_dst_ix = out_voc.tokenize_many(batch_dst)[:, :max_len]

                feed_dict = {inp: batch_src_ix, out: batch_dst_ix}
                _, loss_t = sess.run([train_step, loss], feed_dict)

            # Note: we don't use voc.tokenize_many(batch, max_len=max_len)
            # cuz it forces batch length to be that long and we often get away with much less
            batch_src_ix = inp_voc.tokenize_many(batch_src)[:, :max_len]
            batch_dst_ix = out_voc.tokenize_many(batch_dst)[:, :max_len]

            feed_dict = {inp: batch_src_ix, out: batch_dst_ix}

            loss_t = sess.run([train_step, loss], feed_dict)[1]
            loss_history.append(np.mean(loss_t))

            loss_to_print = ewma(np.array(loss_history[-50:]), span=50)[-1]
            t.set_description('Iterations done: {}. Loss: {:.2f}'.format(num_iters_done, loss_to_print))

            if not config.get('validate_every_epoch') and (num_iters_done+1) % config.get('validate_every', 500) == 0:
                should_continue = validate(val_scores)
                if not should_continue:
                    should_start_next_epoch = False
                    break

            num_iters_done += 1

            if config.get('max_time_seconds'):
                seconds_elapsed = time()-training_start_time

                if seconds_elapsed > config.get('max_time_seconds'):
                    print('Maximum allowed training time reached. Training took %s. Stopping.' % seconds_elapsed)
                    should_start_next_epoch = False
                    break

        epoch +=1

        if config.get('validate_every_epoch') and should_start_next_epoch:
            should_start_next_epoch = validate(val_scores)

        if config.get('max_epochs') and config.get('max_epochs') == epoch:
            print('Maximum amount of epochs reached. Stopping.')
            break

        src_to_trans = [choice(src_unlabeled) for i in range(config.get('synthetic_per_epoch', 100))]
        dst_to_trans = [choice(dst_unlabeled) for i in range(config.get('synthetic_per_epoch', 100))]
        syntethic_src_trg += zip(src_to_trans, translate_sents(model, sess, inp_voc, out_voc,
                                                            src_to_trans, model_name, config))
        syntethic_trg_src += zip(dst_to_trans, translate_sents(model_bk, sess, out_voc, inp_voc,
                                                            dst_to_trans, model_name + 'bk', config))

        syntethic_trg_src = syntethic_trg_src[-max_num_synthetic_sents:]
        syntethic_src_trg = syntethic_src_trg[-max_num_synthetic_sents:]

Iterations done: 1. Loss: 145.76: : 2it [00:04,  2.32s/it]


Validating
Validation BLEU: 0.000
Saving model because it has the highest validation BLEU.
Saving the model into trained_models/transformer/model.npz


0it [00:00, ?it/s]


NameError: name 'syntethic_src_dst_batch' is not defined