<a href="https://colab.research.google.com/github/NiloyPurkait/GSoC-2020/blob/master/V2.0/Transformers/experimental_generator_pretraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Pre-train Generator before adverserial training

The generator is pretrained using teacher-forcing. 

- Adapted from : https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/text/transformer.ipynb

In [None]:
#! pip install tf-nightly-gpu

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
import time
import numpy as np
import matplotlib.pyplot as plt
import io
import unicodedata
import re
from re import finditer

## Setup input pipeline

In [None]:

from google.colab import drive
drive.mount('/content/gdrive')

file_path = "/content/gdrive/My Drive/f_data.txt"
test_path = "/content/gdrive/My Drive/data/processed_graphs/eng/gat/test_data.txt"

In [None]:
from pretraining import *
from transformer_generator import *
from transformer_discriminator import *

In [None]:

train_dataset, tokenizer_txt = create_generator_dataset(file_path, BATCH_SIZE=16)

## Loss and metrics

In [None]:
def discriminator_loss(real_output, fake_output):
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    real_loss = loss_object(tf.ones_like(real_output), real_output)
    fake_loss = loss_object(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss



#Primary loss for plain adverserial training
def generator_loss(real_output, fake_output):
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    loss_ = loss_object(tf.ones_like(fake_output), fake_output)
    return  loss_ #tf.reduce_sum(




## Set hyperparameters

To keep this example small and relatively fast, the values for *num_layers, d_model, and dff* have been reduced. 

The values used in the base model of transformer were; *num_layers=6*, *d_model = 512*, *dff = 2048*. See the [paper](https://arxiv.org/abs/1706.03762) for all the other versions of the transformer.

Note: By changing the values below, you can get the model that achieved state of the art on many tasks.

In [None]:
#Generator params
num_layers = 4
d_model = 128
dff = 512
num_heads = 8


target_vocab_size = tokenizer_txt.vocab_size + 2
input_vocab_size = target_vocab_size
dropout_rate = 0.1

generator_optimizer = tf.keras.optimizers.Adam(1e-4)


In [None]:
learning_rate = CustomSchedule(d_model)


train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_accuracy')

In [None]:
generator = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate)

In [None]:
def pretrain_loss_function(real, pred):
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
  
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask
  
  return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

def pretrain_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]
  
  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
  
  with tf.GradientTape() as tape:
    predictions, _ = generator(inp, tar_inp, 
                                 True, 
                                 enc_padding_mask, 
                                 combined_mask, 
                                 dec_padding_mask)
    
    loss = pretrain_loss_function(tar_real, predictions)

  gradients = tape.gradient(loss, generator.trainable_variables)    
  optimizer.apply_gradients(zip(gradients, generator.trainable_variables))
  
  train_loss(loss)
  train_accuracy(tar_real, predictions)



In [None]:
for (inpt, targ) in train_dataset:
  pretrain_step(inpt, targ)
  break

In [None]:
generator.load_weights('./generator_weights.h5')

In [None]:
DATA_MAX_LEN = 250
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator = TransformerDiscriminator(tokenizer_txt.vocab_size+2, maxlen=DATA_MAX_LEN)
discriminator.load_weights('./discriminator_weights.h5')

Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every `n` epochs.

In [None]:
checkpoint_path = "./content/checkpoints/train"

ckpt = tf.train.Checkpoint(generator=generator,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)


In [None]:

# if a checkpoint exists, restore the latest checkpoint.

if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')


In [None]:
EPOCHS = 10

In [None]:
def render_preds(batch_pred, inp, tar, n=2):
    print(type(batch_pred), type(inp), batch_pred.shape, inp.shape)
    for (ind,i) in enumerate(batch_pred):
      print('\n| Predicted: ', decode_text(i, tokenizer_txt))
      print('| True: ', decode_text(tar[ind], tokenizer_txt))
      print('| Input RDF: ', decode_text(inp[ind], tokenizer_txt))
      print()
      if ind==n:
        break

In [None]:
##max_len global varable

def gen_batch(preds, inp, tar, max_len = 100):

  disc_data = []
  for sent in preds:
    unparsed = decode_text(sent, tokenizer_txt)
    retokenized = tokenizer_txt.encode(unparsed.split('<end>')[0]+'<end>')
    padded = np.pad(np.array(retokenized), (0, max_len - len(retokenized)), 'constant')

    disc_data.append(padded)
  disc_data = tf.convert_to_tensor(disc_data, dtype=inp.dtype)
  gens = tf.concat([inp, disc_data], axis=-1, name='concat')
  real = tf.concat([inp, tar], axis=-1, name='concat')

  return gens, real


In [None]:
 def pad(tensor, maxlen=250):
   return tf.keras.preprocessing.sequence.pad_sequences(tensor,
                                                        padding='post',
                                                        value=0,
                                                        maxlen=maxlen)


In [None]:
#generator.trainable_variables

In [None]:


train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

#@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]
  
  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
  
  with tf.GradientTape() as gen_tape:#, tf.GradientTape() as disc_tape:

    predictions, _ = generator(inp, tar_inp, 
                                 True, 
                                 enc_padding_mask, 
                                 combined_mask, 
                                 dec_padding_mask)
    

    
    batch_pred = tf.argmax(predictions, axis=-1)

    generated, real = gen_batch(batch_pred, inp, tar)
    generated, real = pad(generated), pad(real)

    real_output = discriminator(real, training=True)
    fake_output = discriminator(generated, training=True)
    
    gen_loss = generator_loss(real_output, fake_output)
    disc_loss = discriminator_loss(real_output, fake_output)
    #print(gen_loss)
    #print(disc_loss)

  gen_tape.watch(gen_loss)
  gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
  print(gradients_of_generator)
  gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
  print(gradients_of_discriminator)
  print(len(gradients_of_generator))
  print(len(gradients_of_discriminator))

  generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
  
  train_loss(gen_loss)
  train_loss(disc_loss)
  train_accuracy(tar_real, predictions)







In [None]:
train=True

In [None]:
if train:
  
  for epoch in range(EPOCHS):
    start = time.time()
    
    train_loss.reset_states()
    train_accuracy.reset_states()
    
    for (batch, (inp, tar)) in enumerate(train_dataset):
      train_step(inp, tar)
      
      if batch % 50 == 0:
        print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
            epoch + 1, batch, train_loss.result(), train_accuracy.result()))
        
    if (epoch + 1) % 5 == 0:
      ckpt_save_path = ckpt_manager.save()
      print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                          ckpt_save_path))
      
    print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                  train_loss.result(), 
                                                  train_accuracy.result()))

    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

In [None]:
transformer.save_weights('./generator_weights.h5')

## Evaluate

In [None]:
def evaluate_(inp_sentence):

  encoder_input = tf.expand_dims(inp_sentence, 0)

  decoder_input = [tokenizer_txt.vocab_size]
  output = tf.expand_dims(decoder_input, 0)
    
  for i in range(MAX_LENGTH):
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        encoder_input, output)
  
    # predictions.shape == (batch_size, seq_len, vocab_size)
    predictions, attention_weights = transformer(encoder_input, 
                                                 output,
                                                 False,
                                                 enc_padding_mask,
                                                 combined_mask,
                                                 dec_padding_mask)
    
    # select the last word from the seq_len dimension
    predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)

    predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
    
    # return the result if the predicted_id is equal to the end token
    if predicted_id == tokenizer_txt.vocab_size+1:
      return tf.squeeze(output, axis=0)
    
    # concatentate the predicted_id to the output which is given to the decoder
    # as its input.
    output = tf.concat([output, predicted_id], axis=-1)

  return tf.squeeze(output, axis=0)

In [None]:
MAX_LENGTH=250
rdfb, txtb = next(iter(train_dataset))

In [None]:
predicted_sentence = evaluate_(rdfb[0])

In [None]:
decode_text(predicted_sentence, tokenizer_txt)