In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"

In [None]:
import tensorflow as tf
import time
import numpy as np

In [None]:
import tensorflow_datasets as tfds

In [None]:
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True,
                               as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']

In [None]:
tokenizer_path = '../wmt14/tokenizers/pt-en/'
dataset = 'ted_pt-en'

tokenizer_en = tfds.deprecated.text.SubwordTextEncoder.load_from_file(tokenizer_path+dataset+'_targets_tokenizer')
tokenizer_pt = tfds.deprecated.text.SubwordTextEncoder.load_from_file(tokenizer_path+dataset+'_inputs_tokenizer')

In [None]:
if len(os.environ["CUDA_VISIBLE_DEVICES"]) > 1:
  strategy = tf.distribute.MirroredStrategy()
  distributed_flag = True
  print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
else:
  distributed_flag = False


In [None]:
BUFFER_SIZE = 20000
BATCH_SIZE = 64

if distributed_flag:
  print('batch size per replica ', BATCH_SIZE // strategy.num_replicas_in_sync)


MAX_LENGTH = 40
ENC_MEM_SIZE = 0 
DEC_MEM_SIZE = 0

In [None]:
#change english start and mem
def encode(lang1, lang2, mem_size1, mem_size2):
  lang1 = [tokenizer_pt.vocab_size+2] * tf.constant(mem_size1).numpy() + [tokenizer_pt.vocab_size] + tokenizer_pt.encode(
      lang1.numpy()) + [tokenizer_pt.vocab_size+1]
 
  lang2 = [tokenizer_en.vocab_size+2] * tf.constant(mem_size2).numpy() + [tokenizer_en.vocab_size] + tokenizer_en.encode(
      lang2.numpy()) + [tokenizer_en.vocab_size+1]
  
  return lang1, lang2


def tf_encode(pt, en, mem_size_pt, mem_size_en):
  result_pt, result_en = tf.py_function(encode, [pt, en, mem_size_pt, mem_size_en], [tf.int64, tf.int64])
  result_pt.set_shape([None])
  result_en.set_shape([None])
 
  return result_pt, result_en


def filter_max_length(x, y, max_length=MAX_LENGTH, mem_size_x=ENC_MEM_SIZE, mem_size_y=DEC_MEM_SIZE):
  return tf.logical_and(tf.size(x) - mem_size_x <= max_length,
                        tf.size(y) - mem_size_y <= max_length)


train_dataset = train_examples.map(lambda x, y: tf_encode(x, y, ENC_MEM_SIZE, DEC_MEM_SIZE))
train_dataset = train_dataset.filter(filter_max_length)
train_dataset = train_dataset.shuffle(BUFFER_SIZE).padded_batch(BATCH_SIZE)
 
val_dataset = val_examples.map(lambda x, y: tf_encode(x, y, ENC_MEM_SIZE, DEC_MEM_SIZE))
val_dataset = val_dataset.filter(filter_max_length).padded_batch(BATCH_SIZE)

In [None]:
if distributed_flag:
  train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
  val_dist_dataset = strategy.experimental_distribute_dataset(val_dataset)

In [None]:
pt_batch, en_batch = next(iter(val_dataset))
pt_batch,tokenizer_pt.vocab_size,  en_batch, tokenizer_en.vocab_size

In [None]:
def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates


def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)


def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)


def scaled_dot_product_attention(q, k, v, mask, mask_special_tokens=False):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.
  depth
  Args:
    q: query shape == (..., seq_len_q, depth) == (batch_size, num_heads, seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v) depthv==depth
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  #add mask for start token and first mem token when predicting target sequence
  if mask_special_tokens:
    b_size = tf.shape(q)[0] 
    len_q = tf.shape(q)[-2]
    len_k = tf.shape(k)[-2]
    indices_ax0 = tf.cast(tf.range(DEC_MEM_SIZE +1,len_q),dtype=tf.int64)
    #indices to mask start token
    indices_ax1 = tf.ones(tf.shape(indices_ax0)[0],dtype=tf.int64) * DEC_MEM_SIZE
    #indices to mask first mem token
    indices_ax11 = tf.zeros(tf.shape(indices_ax0)[0],dtype=tf.int64)

    _mask = tf.sparse.to_dense(tf.sparse.reorder(tf.sparse.SparseTensor(
            indices=tf.concat([tf.stack([indices_ax0,indices_ax1],axis=1), 
                               tf.stack([indices_ax0,indices_ax11],axis=1)],axis=0), #[[i,DEC_MEM_SIZE] for i in range(DEC_MEM_SIZE +1,len_q)],
            values=tf.ones(2 * (len_q - DEC_MEM_SIZE - 1)),
            dense_shape=[len_q,len_k])))
    _mask = _mask[tf.newaxis, :, :]
    mask = tf.maximum(tf.zeros((b_size,1,1,len_k)), _mask)

    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights


class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """
    Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask, decoder_masking=False):
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)

    mask_special_tokens = False
    if decoder_masking and tf.shape(k)[1] > DEC_MEM_SIZE + 1:
      mask_special_tokens = True

    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask, mask_special_tokens=mask_special_tokens)

    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

    return output, attention_weights


def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])


class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model, num_heads)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    attn_output, attn_w = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

    return out2, attn_w


class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1, masking=False):
    super(DecoderLayer, self).__init__()

    self.mha1 = MultiHeadAttention(d_model, num_heads)
    self.mha2 = MultiHeadAttention(d_model, num_heads)

    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)

    self.masking = masking

  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):
    attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask, 
                                           decoder_masking=self.masking)  # (batch_size, target_seq_len, d_model)
    attn1 = self.dropout1(attn1, training=training)
    out1 = self.layernorm1(attn1 + x)

    attn2, attn_weights_block2 = self.mha2(
        enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
    attn2 = self.dropout2(attn2, training=training)
    out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)

    ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)

    return out3, attn_weights_block1, attn_weights_block2


class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                            self.d_model)


    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):
    seq_len = tf.shape(x)[1]
    attention_weights = {}

    # adding embedding and position encoding.
    x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x, attn = self.enc_layers[i](x, training, mask)
      attention_weights['encoder_layer{}'.format(i+1)] = attn

    return x, attention_weights


class Decoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
               maximum_position_encoding, rate=0.1, masking=False):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
    self.token_type_embedding = tf.keras.layers.Embedding(2, d_model)

    self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate, masking) 
                       for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)


  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask, token_type_inp=None):

    seq_len = tf.shape(x)[1]
    attention_weights = {}

    x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]
    
    if token_type_inp is not None:
      x += self.token_type_embedding(token_type_inp)
    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                             look_ahead_mask, padding_mask)

      attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
      attention_weights['decoder_layer{}_block2'.format(i+1)] = block2

    return x, attention_weights





class Transformer(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
               target_vocab_size, pe_input, pe_target, rate=0.1, 
               masking=False, token_type=False):
    super(Transformer, self).__init__()

    self.encoder = Encoder(num_layers, d_model, num_heads, dff, 
                           input_vocab_size, pe_input, rate)

    self.decoder = Decoder(num_layers, d_model, num_heads, dff, 
                           target_vocab_size, pe_target, rate, masking)

    self.token_type = token_type
    if token_type:
      self.final_layer = tf.keras.layers.Dense(target_vocab_size+2)
      print('token type, ',str(target_vocab_size+2))
    else:
      self.final_layer = tf.keras.layers.Dense(target_vocab_size)
      print('no token type, ', str(target_vocab_size))

  def call(self, inp, tar, training, enc_padding_mask, 
           look_ahead_mask, dec_padding_mask, token_type_inp=None):

    enc_output, enc_attn = self.encoder(inp, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)

    dec_output, attention_weights = self.decoder(
        tar, enc_output, training, look_ahead_mask, dec_padding_mask, token_type_inp)

    final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size) or target_vocab_size+2
    return final_output, enc_attn, attention_weights

In [None]:
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
if ENC_MEM_SIZE > 0:
  input_vocab_size = tokenizer_pt.vocab_size + 3
else:
  input_vocab_size = tokenizer_pt.vocab_size + 2
if DEC_MEM_SIZE > 0:
  target_vocab_size = tokenizer_en.vocab_size + 3
else:
  target_vocab_size = tokenizer_en.vocab_size + 2
  
dropout_rate = 0.1

masking = False
token_type = True

print(input_vocab_size, target_vocab_size)

In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps

  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred, global_batch_num_elems):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)
  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask
  #loss per example in global batch
  return tf.reduce_sum(loss_)/(global_batch_num_elems if distributed_flag else tf.reduce_sum(mask))


train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')


transformer = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate,
                          masking=masking,
                          token_type=token_type)


def create_masks(inp, tar):
  # Encoder padding mask
  enc_padding_mask = create_padding_mask(inp)

  # Used in the 2nd attention block in the decoder.
  # This padding mask is used to mask the encoder outputs.
  dec_padding_mask = create_padding_mask(inp)

  # Used in the 1st attention block in the decoder.
  # It is used to pad and mask future tokens in the input received by 
  # the decoder.
  look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
  dec_target_padding_mask = create_padding_mask(tar)
  combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

  return enc_padding_mask, combined_mask, dec_padding_mask

## Baseline pre-training

In [None]:
MEM_TOKENS_NUM = 0
NUCLEUS_P = 0.0

In [None]:
# code from https://github.com/ShenakhtPajouh/GPT-language-model-tf.keras/blob/master/utils.py#L126
def top_k_sampling(x):
  logits = x[0] 
  k = tf.cast(x[1],dtype=tf.int32) #25, 
  temperature = x[2] #0.8
  #'k must be greater than 0'
  values, _ = tf.math.top_k(logits, k=k)
  min_value = tf.reduce_min(values)
  logits = tf.where(
      logits < min_value,
      tf.ones_like(logits, dtype=logits.dtype) * -1e9,
      logits)
  logits = logits / temperature
  sample = tf.random.categorical(tf.expand_dims(logits, 0), 1) #tf.multinomial(tf.expand_dims(logits, 0), num_samples=1, output_dtype=tf.int32)
  return sample[0] #tf.reduce_sum(sample)

def argmax(logits):
  return tf.argmax(logits)


def nucleus_sampling(x):
  logits = x[0] 
  p = x[1] #0.9
  # code from https://github.com/royRLL/Zen-NLG-using-Tensorflow-and-Nucleus-Sampling/blob/main/ProjectCharacterPredictionHP60.ipynb
  sortedLogits = tf.sort(logits, direction='DESCENDING')
  #Softmax to get the probabilities
  sortedProbs = tf.nn.softmax(sortedLogits)
  #cumulative sum of the probabilities
  probsSum = tf.cumsum(sortedProbs, exclusive=True)
  maskedLogits = tf.where(probsSum < p, sortedLogits, tf.ones_like(sortedLogits, dtype=tf.float32)*1e9)  
  minLogits= tf.reduce_min(maskedLogits, keepdims=True)  
  res_logits = tf.where(
      logits < minLogits,
      tf.ones_like(logits, dtype=tf.float32) * -1e9,
      logits,
  )

  sample = tf.random.categorical(tf.expand_dims(res_logits, 0), 1) #tf.multinomial(tf.expand_dims(logits, 0), num_samples=1, output_dtype=tf.int32)
  return sample[0] #tf.reduce_sum(sample)


def sampling(x):
  logits = x[0]
  temperature = x[1] #0.8
  logits = logits / temperature
  sample = tf.random.categorical(tf.expand_dims(logits, 0), 1) #tf.multinomial(tf.expand_dims(logits, 0), num_samples=1, output_dtype=tf.int32)
  return sample[0] #tf.reduce_sum(sample)

In [None]:
def sample_different_tokens(x):
  tar_inp = x[0]
  preds = x[1]
  token_type = x[2]

  def mask_existing_mem_tokens(tar_inp,preds,token_type):
    tar_inp_mem = tf.boolean_mask(tar_inp, 1-token_type[:-1])
    indices = tf.expand_dims(tf.cast(tar_inp_mem,dtype=tf.int64), 1)
    tens = tf.sparse.SparseTensor(
      indices=indices,
      values=tf.ones(tf.shape(indices)[0]),
      dense_shape=[target_vocab_size])
    sparse = tf.sparse.reorder(tens)
    tokens_mask = tf.sparse.to_dense(sparse)
    return preds + (tokens_mask * -1e9)  
  
  return tf.cond(tf.equal(token_type[-1], 0),
                 true_fn=lambda: mask_existing_mem_tokens(tar_inp,preds,token_type),
                 false_fn=lambda: preds)


def sample_tokens(x):
  # tar_inp = x[0]
  preds = x[0]
  token_type = x[1]
  nucleus_p = x[2]
  
  return tf.cond(tf.equal(token_type, 0),
                 true_fn=lambda: nucleus_sampling((tf.squeeze(preds,[0]),nucleus_p)),
                 false_fn=lambda: tf.argmax(preds, axis=-1))

In [None]:
def sample_teacher_forcing(x):
  predicted_id_sample = x[0]
  token_type_sample = x[1] 
  tar_inp_sample = x[2] 
  tar_real_sample = x[3]
  #tar_inp starts from the beginning of tar
  #token_type sample contains currently predicted token type
  ##seq_token_index = tf.math.reduce_sum(token_type_sample[:-1])
  #we also take care of current number of seq tokens in token_type_sample:
  #  if the number of seq tokens(except start)(which gives index of tar_real element to concatenate)
  #  is less than len of tar_real then ok
  #  else just return predicted_id_sample

  #idx discards mem token types and start token type
  idx = tf.math.reduce_sum(token_type_sample[DEC_MEM_SIZE+1:-1])

  #if tf.shape(token_type_sample[DEC_MEM_SIZE+1:])[0] - tf.math.reduce_sum(token_type_sample[DEC_MEM_SIZE+1:]) > MEM_TOKENS_NUM:
  #token_type_sample = tf.concat([token_type_sample[:-1], tf.constant([1],dtype=tf.int64)],0)
  token_type_sample = tf.cond(tf.math.greater(tf.math.subtract(tf.cast(tf.shape(token_type_sample[DEC_MEM_SIZE+1:])[0],
                                                                       dtype=tf.int64), 
                                                               tf.math.reduce_sum(token_type_sample[DEC_MEM_SIZE+1:])),
                                              MEM_TOKENS_NUM),
                             true_fn=lambda : tf.concat([token_type_sample[:-1], 
                                                         tf.constant([1],dtype=tf.int64)
                                                        ],0),
                             false_fn=lambda : token_type_sample)
  tar_inp_sample = tf.cond(tf.math.logical_and(tf.equal(token_type_sample[-1], 1), 
                                               tf.less(tf.cast(idx,dtype=tf.int32), 
                                                       tf.shape(tar_real_sample)[0])),
                          true_fn=lambda : tf.concat([tar_inp_sample, 
                                                      [tar_real_sample[idx]]
                                                     ],0),
                          false_fn=lambda : tf.concat([tar_inp_sample, predicted_id_sample],0))
  return tar_inp_sample, token_type_sample

In [None]:
def act_like_body(inp, tar_inp, tar_real, token_type, predictions):
  #double predicting previous tokens but teacher forcing saves prev tokens
  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
  predictions, _, _ = transformer(inp, tar_inp, 
                                  True, 
                                  enc_padding_mask, 
                                  combined_mask, 
                                  dec_padding_mask,
                                  token_type_inp=token_type)

  token_id_preds = predictions[:,:,:-2]

  predicted_token_type = tf.cast(tf.argmax(predictions[: ,-1:, -2:], axis=-1), tf.int64)
  token_type = tf.concat([token_type, predicted_token_type], axis=-1)
  
  #nucleus sampling here
  predicted_id = tf.cast(tf.map_fn(sample_tokens, (
                                                   token_id_preds[:,-1:,:], 
                                                   predicted_token_type,
                                                   tf.ones(tf.shape(token_id_preds)[0]) * NUCLEUS_P
                                                  ),
                          fn_output_signature=tf.TensorSpec(shape=(None), dtype=tf.int64)),
                         tf.int64)                  
  
  tar_inp, token_type = tf.map_fn(sample_teacher_forcing,
                      (predicted_id, token_type, tar_inp, tar_real), 
                      fn_output_signature=(tf.TensorSpec(shape=(None), dtype=tf.int64), 
                                           tf.TensorSpec(shape=(None), dtype=tf.int64)))
  
  return inp, tar_inp, tar_real, token_type, token_id_preds


act_like_condition = lambda inp, tar_inp, tar_real, token_type, predictions: tf.shape(tar_inp)[1] <= tf.shape(tar_real)[1] + MEM_TOKENS_NUM - 1

In [None]:
def filter_sample_seq_tokens(x):
  #here logits and token type are for tokens starting the first after the start token (very first mem tokens and start are discarded)
  predicted_logits_sample = x[0]
  token_type_sample = x[1]
  msg = tf.cond(tf.equal(tf.reduce_sum(token_type_sample), 0),
         true_fn=lambda:'zero token type sample',
         false_fn=lambda:'ok')

  tar_real = x[2]
  tar_seq_len = tf.shape(tar_real)[0]
  curr_token_logits = tf.boolean_mask(predicted_logits_sample, token_type_sample)[:tar_seq_len,:]
  end_token_logits = tf.concat([tf.zeros((tar_seq_len - tf.shape(curr_token_logits)[0],
                                         target_vocab_size-DEC_MEM_SIZE-1),
                                         dtype=tf.float32),
                                tf.ones((tar_seq_len - tf.shape(curr_token_logits)[0], 
                                         1),
                                        dtype=tf.float32),
                                tf.zeros((tar_seq_len - tf.shape(curr_token_logits)[0],
                                          DEC_MEM_SIZE),
                                         dtype=tf.float32)
                               ],
                               -1)
  return tf.concat([curr_token_logits,
                    end_token_logits
                   ], 0)

In [None]:
def process_sample(i, predicted_logits, token_type, seq_logits, tar_real):
  _i = tf.cast(i,dtype=tf.int32)
  _tar_seq_len = tf.shape(tar_real)[1]
  #token_type sample has no element corresponding to the first token
  curr_token_logits = tf.boolean_mask(predicted_logits[_i,:,:], token_type[_i,:])[:_tar_seq_len,:]
  padded_token_logits = tf.concat([curr_token_logits,tf.zeros((_tar_seq_len - tf.shape(curr_token_logits)[0],
                                                        target_vocab_size),dtype=tf.float32)], 0)

  seq_logits = tf.concat([seq_logits,tf.expand_dims(padded_token_logits,0)],0)
  i += 1

  return i, predicted_logits, token_type, seq_logits, tar_real

process_sample_condition = lambda i, predicted_logits, token_type, seq_logits, tar_real: tf.cast(i,dtype=tf.int32) < tf.shape(predicted_logits)[0]

In [None]:
def act_like_train_step(inputs, global_batch_num_elems): #(inp, tar):
  inp, tar = inputs
    
  tar_inp = tar[:, DEC_MEM_SIZE:-1]
  tar_real = tar[:, DEC_MEM_SIZE + 1:]
  
 

  with tf.GradientTape() as tape:
    #last pass before loss to get logits for the whole sequence
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
    predictions, _, _ = transformer(inp, tar_inp, 
                                    True, 
                                    enc_padding_mask, 
                                    combined_mask, 
                                    dec_padding_mask,
                                    token_type_inp=tf.ones(tf.shape(tar_inp),dtype=tf.int64))

    token_id_preds = predictions[:,:,:-2]

    
    
    #discard leading mem tokens from predictions
    predictions = token_id_preds[:, DEC_MEM_SIZE:, :]


    
    loss = loss_function(tar_real, predictions, global_batch_num_elems)

  gradients = tape.gradient(loss, transformer.trainable_variables)    
  optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

  train_loss(loss)
  train_accuracy(tar_real, predictions)
  return loss 


if distributed_flag:
  @tf.function(input_signature=[train_dist_dataset.element_spec])
  def distributed_train_step(dataset_inputs):
    distr_tar = dataset_inputs[1]
    concat_tar = tf.concat(distr_tar.values,axis=0)[:, DEC_MEM_SIZE + 1:]
    global_batch_num_elems = tf.reduce_sum(tf.cast(tf.math.logical_not(tf.math.equal(concat_tar, 0)),
                                                   dtype=tf.float32))

    per_replica_losses = strategy.run(act_like_train_step, args=(dataset_inputs,global_batch_num_elems,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                           axis=None)

@tf.function(input_signature=[train_dataset.element_spec])
def train_step_single_gpu(dataset_inputs):
  return act_like_train_step(dataset_inputs,0)

In [None]:
checkpoint_path = './ckpts_baseline+work_mem_nucleus_0.9'
 
ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)
 
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=None)
 
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print(ckpt_manager.latest_checkpoint)
  print('Latest checkpoint restored!!')

In [None]:
EPOCHS = 5

In [None]:
ENC_MEM_SIZE,DEC_MEM_SIZE,NUCLEUS_P,EPOCHS,MEM_TOKENS_NUM

In [None]:
checkpoint_path

In [None]:
from datetime import datetime

if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)

log_file = checkpoint_path+'/log.txt'

with open(log_file,'a') as f: 
  f.write(f"\n {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} curr mem tokens num = {MEM_TOKENS_NUM} \n")

from tqdm import tqdm
for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in tqdm(train_dist_dataset):
    total_loss += distributed_train_step(x)
    num_batches += 1
    if num_batches % 50 == 0:
      
      template1 = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} Epoch {epoch+1}, Batch {num_batches}, Loss: {total_loss / num_batches}, distr_elem_loss {train_loss.result()},  Accuracy: {train_accuracy.result()}\n"
      print(template1)
      with open(log_file,'a') as f: 
        f.write(template1)
  
  train_loss_averaged = total_loss / num_batches
 
  ckpt_save_path = ckpt_manager.save()
  
  template = (f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} Epoch {epoch+1}, Loss: {train_loss_averaged}, distr_elem_loss {train_loss.result()},  Accuracy: {train_accuracy.result()}, Checkpoint: {ckpt_save_path}\n")
  print(template)

  with open(log_file,'a') as f: 
    f.write(template)

  train_loss.reset_states()
  train_accuracy.reset_states()

## Memory fine-tuning

In [None]:
masking = False
token_type = True

In [None]:
MEM_TOKENS_NUM = 10
NUCLEUS_P = 0.9

In [None]:
def act_like_train_step(inputs, global_batch_num_elems):
  inp, tar = inputs
  tar_inp = tar[:, :DEC_MEM_SIZE+1] #send start token with type 1 (seq) and all the mem tokens with type 0

  tar_real = tar[:, DEC_MEM_SIZE + 1:]

  with tf.GradientTape() as tape:
    #run while loop to obtain predictions logits
    init_token_type = tf.concat([tf.zeros(tf.shape(tar_inp[:, :-1]),dtype=tf.int64), 
                         tf.ones(tf.shape(tar_inp[:, -1:]),dtype=tf.int64)], 
                        axis=-1)

    predictions = tf.zeros((tf.shape(tar_inp)[0], tf.shape(tar_inp)[1], target_vocab_size), dtype=tf.float32)
    inp, tar_inp, tar_real, token_type, predictions = tf.while_loop(act_like_condition, act_like_body, 
                                                      [inp, 
                                                       tar_inp,
                                                       tar_real,
                                                       init_token_type,
                                                       predictions
                                                       ])
    
    
    
    
    #last pass before loss to get logits for the whole sequence
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
    predictions, _, _ = transformer(inp, tar_inp, 
                                    True, 
                                    enc_padding_mask, 
                                    combined_mask, 
                                    dec_padding_mask,
                                    token_type_inp=token_type)

    token_id_preds = predictions[:,:,:-2]

    predicted_token_type = tf.argmax(predictions[: ,-1:, -2:], axis=-1, output_type=tf.int64)
    token_type = tf.concat([token_type, predicted_token_type], axis=-1)

    #nucleus sampling here
    predicted_id = tf.map_fn(sample_tokens, (token_id_preds[:,-1:,:], 
                                             predicted_token_type,
                                             tf.ones(tf.shape(token_id_preds)[0]) * NUCLEUS_P
                                            ),
                            fn_output_signature=tf.TensorSpec(shape=(None), dtype=tf.int64))
    
    tar_inp, token_type = tf.map_fn(sample_teacher_forcing,
                        (predicted_id, token_type, tar_inp, tar_real), 
                        fn_output_signature=(tf.TensorSpec(shape=(None), dtype=tf.int64), 
                                             tf.TensorSpec(shape=(None), dtype=tf.int64)))

    
    #discard leading mem tokens from predictions
    predictions = token_id_preds[:, DEC_MEM_SIZE:, :]


    seq_predictions = tf.map_fn(filter_sample_seq_tokens,
                                (predictions, token_type[:,DEC_MEM_SIZE+1:], tar_real), 
                                fn_output_signature=tf.TensorSpec(shape=(None,None), dtype=tf.float32)
                               )
    
    loss = loss_function(tar_real, seq_predictions, global_batch_num_elems)

  gradients = tape.gradient(loss, transformer.trainable_variables)    
  optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

  train_loss(loss)
  train_accuracy(tar_real, seq_predictions)
  return loss 

if distributed_flag:
  @tf.function(input_signature=[train_dist_dataset.element_spec])
  def distributed_train_step(dataset_inputs):
    distr_tar = dataset_inputs[1]
    concat_tar = tf.concat(distr_tar.values,axis=0)[:, DEC_MEM_SIZE + 1:]
    global_batch_num_elems = tf.reduce_sum(tf.cast(tf.math.logical_not(tf.math.equal(concat_tar, 0)),
                                                   dtype=tf.float32))

    per_replica_losses = strategy.run(act_like_train_step, args=(dataset_inputs,global_batch_num_elems,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                           axis=None)

@tf.function(input_signature=[train_dataset.element_spec])
def train_step_single_gpu(dataset_inputs):
  return act_like_train_step(dataset_inputs,0)

In [None]:
checkpoint_path = './ckpts_baseline+work_mem_nucleus_0.9'
 
ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)
 
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=None)
 
ckpt.restore(checkpoint_path + '/ckpt-5').assert_existing_objects_matched()


In [None]:
EPOCHS = 20

In [None]:
from datetime import datetime

if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)

log_file = checkpoint_path+'/log.txt'

with open(log_file,'a') as f: 
  f.write(f"\n {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} curr mem tokens num = {MEM_TOKENS_NUM} \n")

from tqdm import tqdm
for epoch in range(5, EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in tqdm(train_dist_dataset):
    total_loss += distributed_train_step(x)
    num_batches += 1
    if num_batches % 50 == 0:
      
      template1 = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} Epoch {epoch+1}, Batch {num_batches}, Loss: {total_loss / num_batches}, distr_elem_loss {train_loss.result()},  Accuracy: {train_accuracy.result()}\n"
      print(template1)
      with open(log_file,'a') as f: 
        f.write(template1)

    
  train_loss_averaged = total_loss / num_batches

  
  ckpt_save_path = ckpt_manager.save()
  
  template = (f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} Epoch {epoch+1}, Loss: {train_loss_averaged}, distr_elem_loss {train_loss.result()},  Accuracy: {train_accuracy.result()}, Checkpoint: {ckpt_save_path}\n")
  print(template)

  with open(log_file,'a') as f: 
    f.write(template)

  train_loss.reset_states()
  train_accuracy.reset_states()

## Evaluate

In [None]:
def evaluate(inp_sentence, raw_tokens=False):
  start_token = [tokenizer_pt.vocab_size]
  end_token = [tokenizer_pt.vocab_size + 1]
  
  if not raw_tokens:
    inp_sentence = start_token + tokenizer_pt.encode(inp_sentence) + end_token
    encoder_input = tf.expand_dims(inp_sentence, 0)
  
  if raw_tokens:
    encoder_input = inp_sentence
  if len(encoder_input.shape) == 1:
    encoder_input = tf.expand_dims(encoder_input, 0)

  assert len(encoder_input.shape) == 2  
  
  if DEC_MEM_SIZE > 0:
    decoder_input = [tokenizer_en.vocab_size + 2] * DEC_MEM_SIZE + [tokenizer_en.vocab_size]
  else:
    decoder_input = [tokenizer_en.vocab_size]

  output = tf.expand_dims(decoder_input, 0)
  i = len(decoder_input) #i describes the length of currently generated translation
  
  token_type = tf.concat([tf.zeros(tf.shape(output[:, :-1]),dtype=tf.int64), 
                           tf.ones(tf.shape(output[:, -1:]),dtype=tf.int64)], 
                          axis=-1)
 
    
  while i <= MAX_LENGTH + DEC_MEM_SIZE + MEM_TOKENS_NUM:
    i+=1
    
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input, output)    
    predictions, enc_att_w, attention_weights = transformer(encoder_input, output, 
                                                            False, 
                                                            enc_padding_mask, 
                                                            combined_mask, 
                                                            dec_padding_mask,
                                                            token_type_inp=token_type)

    
    token_id_preds = predictions[:,:,:-2]
    
    predicted_token_type = tf.cast(tf.argmax(predictions[: ,-1:, -2:], axis=-1), tf.int64)
    #bound number of generated mem tokens 
    if tf.cast(tf.shape(token_type[:,DEC_MEM_SIZE+1:])[1],dtype=tf.int64) - tf.math.reduce_sum(token_type[:,DEC_MEM_SIZE+1:]) > MEM_TOKENS_NUM - 1:
      predicted_token_type = tf.constant([[1]],dtype=tf.int64)
    token_type = tf.concat([token_type, predicted_token_type], axis=-1)
    
    #nucleus sampling here
    predicted_id = tf.cast(tf.map_fn(sample_tokens, (
                                                     token_id_preds[:,-1:,:], 
                                                     predicted_token_type,
                                                     tf.ones(tf.shape(token_id_preds)[0]) * NUCLEUS_P
                                                    ),
                            fn_output_signature=tf.TensorSpec(shape=(None), dtype=tf.int64)),
                           tf.int32)                  
    
    # return the result if the predicted_id is equal to the end token and its type is sequence
    if predicted_id == tokenizer_en.vocab_size+1: #and predicted_token_type == 1:
      output = tf.concat([output, predicted_id], axis=-1)
      final_res = tf.squeeze(output, axis=0)
  
      if raw_tokens:
        predicted_sentence = []
        for j,i in enumerate(final_res):
          if token_type[0,j] == 1:
              predicted_sentence.append(i)
        return predicted_sentence, enc_att_w, attention_weights, token_type
      else:
        return final_res, enc_att_w, attention_weights, token_type
      
      
    output = tf.concat([output, predicted_id], axis=-1)
        
  final_res = tf.squeeze(output, axis=0)
  
  if raw_tokens:
    predicted_sentence = []
    for j,i in enumerate(final_res):
      if token_type[0,j] == 1:
          predicted_sentence.append(i)
    return predicted_sentence, enc_att_w, attention_weights, token_type
  else:
    return final_res, enc_att_w, attention_weights, token_type


def translate(sentence, plot=[]):
  result,enc_attn_w, attention_weights, token_type = evaluate(sentence)
  predicted_sequence = []
  predicted_mem = []
  predicted_sentence = []
  for j,i in enumerate(result):
    if i < tokenizer_en.vocab_size:
      if token_type[0,j] == 1:
        predicted_sentence.append(tokenizer_en.decode([i]))
        predicted_sequence.append(tokenizer_en.decode([i]))
      else:
        predicted_mem.append(tokenizer_en.decode([i]))
        predicted_sequence.append(tokenizer_en.decode([i]))
    if i == tokenizer_en.vocab_size:
      if token_type[0,j] == 1:
        predicted_sentence.append('<start>')
        predicted_sequence.append('<start>')
      else:
        predicted_mem.append('<start>')
        predicted_sequence.append('<start>')
    if i == tokenizer_en.vocab_size + 1:
      if token_type[0,j] == 1:
        predicted_sentence.append('<end>')
        predicted_sequence.append('<end>')
      else:
        predicted_mem.append('<end>')
        predicted_sequence.append('<end>')
    if i == tokenizer_en.vocab_size + 2:
      if token_type[0,j] == 1:
        predicted_sentence.append('<mem>')
        predicted_sequence.append('<mem>') 
      else:
        predicted_mem.append('<mem>')
        predicted_sequence.append('<mem>')
  if len(plot) == 0:
    print('Input: {}'.format(sentence))
    print('Predicted sequence: {}'.format(''.join(predicted_sequence)))
    print('Predicted memory : ({})'.format(')('.join(predicted_mem)))
    print('Predicted translation: {}'.format(''.join(predicted_sentence)))

In [None]:
for i in val_examples:
  translate(i[0].numpy().decode())

## BLEU

In [None]:
import collections
import math

def _get_ngrams(segment, max_order):
  """Extracts all n-grams upto a given maximum order from an input segment.
  Args:
    segment: text segment from which n-grams will be extracted.
    max_order: maximum length in tokens of the n-grams returned by this
        methods.
  Returns:
    The Counter containing all n-grams upto max_order in segment
    with a count of how many times each n-gram occurred.
  """
  ngram_counts = collections.Counter()
  for order in range(1, max_order + 1):
    for i in range(0, len(segment) - order + 1):
      ngram = tuple(segment[i:i+order])
      ngram_counts[ngram] += 1
  return ngram_counts


def compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False):
  """Computes BLEU score of translated segments against one or more references.
  Args:
    reference_corpus: list of lists of references for each translation. Each
        reference should be tokenized into a list of tokens.
    translation_corpus: list of translations to score. Each translation
        should be tokenized into a list of tokens.
    max_order: Maximum n-gram order to use when computing BLEU score.
    smooth: Whether or not to apply Lin et al. 2004 smoothing.
  Returns:
    3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
    precisions and brevity penalty.
  """
  matches_by_order = [0] * max_order
  possible_matches_by_order = [0] * max_order
  reference_length = 0
  translation_length = 0
  for (references, translation) in zip(reference_corpus,
                                       translation_corpus):
    reference_length += min(len(r) for r in references)
    translation_length += len(translation)

    merged_ref_ngram_counts = collections.Counter()
    for reference in references:
      merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
    translation_ngram_counts = _get_ngrams(translation, max_order)
    overlap = translation_ngram_counts & merged_ref_ngram_counts
    for ngram in overlap:
      matches_by_order[len(ngram)-1] += overlap[ngram]
    for order in range(1, max_order+1):
      possible_matches = len(translation) - order + 1
      if possible_matches > 0:
        possible_matches_by_order[order-1] += possible_matches

  precisions = [0] * max_order
  for i in range(0, max_order):
    if smooth:
      precisions[i] = ((matches_by_order[i] + 1.) /
                       (possible_matches_by_order[i] + 1.))
    else:
      if possible_matches_by_order[i] > 0:
        precisions[i] = (float(matches_by_order[i]) /
                         possible_matches_by_order[i])
      else:
        precisions[i] = 0.0

  if min(precisions) > 0:
    p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
    geo_mean = math.exp(p_log_sum)
  else:
    geo_mean = 0

  ratio = float(translation_length) / reference_length

  if ratio > 1.0:
    bp = 1.
  else:
    bp = math.exp(1 - 1. / ratio)

  bleu = geo_mean * bp

  return (bleu, precisions, bp, ratio, translation_length, reference_length)

def explain_bleu(bleu_values):
  bleu, precisions, bp, ratio, translation_length, reference_length = bleu_values

  print(f"BLEU score: {bleu:.4}")
  print("----------------")
  print(f"Translated text total length\t {translation_length}")
  print(f"Reference text total length\t {translation_length}")
  print("----------------")
  print(f"n-gram max order was {len(precisions)}")
  print("n-gram precisions: ", end="")
  for val in precisions:
    print(f"{val:.3}", end=" ")
  print()
  print(f"Brevity penalty: {bp}")

In [None]:
val_preprocessed = (val_examples.map(lambda x, y: tf_encode(x, y, ENC_MEM_SIZE, DEC_MEM_SIZE)).filter(filter_max_length))        

In [None]:
evaluate('ola')[0]

In [None]:
evaluate(tf.constant([[2,3,4]]),raw_tokens=True)[0]

In [None]:
from tqdm import tqdm

refs = [np.array(ref[DEC_MEM_SIZE:-1]) for (_, ref) in tqdm(val_preprocessed)]
trans = [np.array(evaluate(inp, raw_tokens=True)[0][DEC_MEM_SIZE:-1]) for (inp, _) in tqdm(val_preprocessed)]
explain_bleu(compute_bleu([[r] for r in refs], trans))

In [None]:
import sys
log_file = checkpoint_path+'/log.txt'

with open(log_file,'a') as f: 
  f.write('\n BLEU {} samples for ckpt # {}\n'.format(len(refs), 20))

orig_stdout = sys.stdout

with open(log_file,'a')as sys.stdout:
  explain_bleu(compute_bleu([[r] for r in refs], trans))
  sys.stdout.close()
  sys.stdout=orig_stdout

## Meteor

In [None]:
def detokenize(s):
  predicted_sequence = []
  predicted_mem = []
  predicted_sentence = []
  for j,i in enumerate(s):
    if i < tokenizer_en.vocab_size:
      predicted_sentence.append(tokenizer_en.decode([i]))
      predicted_sequence.append(tokenizer_en.decode([i]))
    if i == tokenizer_en.vocab_size:
      predicted_sentence.append('<start>')
      predicted_sequence.append('<start>')
    if i == tokenizer_en.vocab_size + 1:
      predicted_sentence.append('<end>')
      predicted_sequence.append('<end>')
    if i == tokenizer_en.vocab_size + 2:
      print('mem found')
      predicted_sentence.append('<mem>')
      predicted_sequence.append('<mem>') 
  return ''.join(predicted_sentence), ''.join(predicted_sequence), ''.join(predicted_mem)

import sys
import nltk
nltk.download("wordnet")

from nltk.translate.meteor_score import single_meteor_score
from tqdm import tqdm
from datetime import datetime
import numpy as np
log_file = checkpoint_path+'/log.txt'

val_preprocessed = (val_examples.map(lambda x, y: tf_encode(x, y, ENC_MEM_SIZE, DEC_MEM_SIZE)).filter(filter_max_length))        

refs = [np.array(ref[DEC_MEM_SIZE:-1]) for (_, ref) in tqdm(val_preprocessed)]
trans = [np.array(evaluate(inp, raw_tokens=True)[0][DEC_MEM_SIZE:-1]) for (inp, _) in tqdm(val_preprocessed)]

# Calculate METEOR for each sentence and save the result to a file
with open(log_file,'a') as output:
    scores = [
            single_meteor_score(detokenize(ref)[0], detokenize(pred)[0], alpha=0.9, beta=3, gamma=0.5)
            for ref, pred in tqdm(zip(refs, trans))
        ]
    output.write(f"\n {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} METEOR = {np.mean(scores)}, {len(refs)} samples, 10 epochs \n")