<a href="https://colab.research.google.com/github/NiloyPurkait/GSoC-2020/blob/master/adverserial_trainingl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Adverserial training script

Loads in pretrained discriminator and generator, and trains them in an adverserial fashion.

- Adapted from : https://www.tensorflow.org/tutorials/generative/dcgan 

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
import time
import numpy as np
import matplotlib.pyplot as plt
import io
import unicodedata
import re
from re import finditer

In [None]:
from pretraining import *
from transformer_generator import *
from transformer_discriminator import *

## Setup input pipeline

In [None]:

from google.colab import drive
drive.mount('/content/gdrive')

file_path = "/content/gdrive/My Drive/data/processed_graphs/eng/gat/f_data.txt"
test_path = "/content/gdrive/My Drive/data/processed_graphs/eng/gat/test_data.txt"

## Set hyperparameters



The values used in the base model of transformer were; *num_layers=6*, *d_model = 512*, *dff = 2048*. See the [paper](https://arxiv.org/abs/1706.03762) for all the other versions of the transformer.



In [None]:
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
DATA_MAX_LEN = 250
target_vocab_size = tokenizer_txt.vocab_size + 2
input_vocab_size = target_vocab_size
dropout_rate = 0.1

In [None]:

learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                         epsilon=1e-9)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')


## Training and checkpointing

In [None]:
generator = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate)

generator.load_weights('./generator_weights.h5')

In [None]:

discriminator = TransformerDiscriminator(tokenizer_txt.vocab_size+2, maxlen=DATA_MAX_LEN)
discriminator.load_weights('./discriminator_weights.h5')


In [None]:
def timer(func):
  def wrapper(*args, **kwargs):
    t = time.time()
    rv = func(*args, **kwargs)
    print('Took ', time.time()-t, 'secs')
    return rv
  return wrapper

In [None]:

@tf.function
def evaluate_batch(inp_batch):

  output = tf.ones((tf.shape(inp_batch)[0], 1), tf.int32) * tokenizer_txt.vocab_size#inp_batch.numpy().shape



  for i in range(MAX_LENGTH):
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        inp_batch, output)
  
    # predictions.shape == (batch_size, seq_len, vocab_size)
    
    predictions, attention_weights = generator(inp_batch, 
                                                 output,
                                                 True,
                                                 enc_padding_mask,
                                                 combined_mask,
                                                 dec_padding_mask)
    
    # select the last word from the seq_len dimension
    predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)


    predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

    # concatentate the predicted_id to the output which is given to the decoder
    # as its input.
    output = tf.concat([output, predicted_id], axis=-1)

  return output#tf.squeeze(output, axis=0)

In [None]:
def parse_batch(preds, max_len):
  disc_data = []
  for sent in preds:
    unparsed = decode_text(sent, tokenizer_txt)
    retokenized = tokenizer_txt.encode(unparsed.split('<end>')[0]+'<end>')
    padded = np.pad(np.array(retokenized), (0, max_len - len(retokenized)), 'constant')

    disc_data.append(padded)

  return np.array(disc_data)

def pad_sequences(data, max_len):
  stack = []
  for i in data:
    i = np.pad(np.array(i), (0, max_len - len(i)), 'constant')
    if len(stack)==0:
      stack = i
    else:
      stack = np.vstack((stack, i))
  return stack


def prepare_generated_data(rdf_batch, predicted_text_batch, max_len):
  gen_data = np.concatenate((rdf_batch, predicted_text_batch), axis=1)
  return pad_sequences(gen_data, max_len)

def prepare_true_data(rdf_batch, text_batch, max_len):
  
  true_data = np.concatenate((rdf_batch, text_batch), axis=1)
  return pad_sequences(true_data, max_len)

def get_disc_batch(rdf_batch, txt_batch):
  predicted_batch = evaluate_batch(rdf_batch)
  predicted_batch = parse_batch(predicted_batch, GEN_DATA_MAX_LEN )
  gen = prepare_generated_data(rdf_batch, predicted_batch, DATA_MAX_LEN)
  true = prepare_true_data(rdf_batch, txt_batch, DATA_MAX_LEN)
  return gen, true





In [None]:

GEN_DATA_MAX_LEN = 200#600
DATA_MAX_LEN = 200#800

## Loss functions
Seperate loss functions for discriminator and generator
- Source : https://www.tensorflow.org/tutorials/generative/dcgan

### Discriminator loss
This method quantifies how well the discriminator is able to distinguish real sequences from fakes. It compares the discriminator's predictions on real sequences to an array of 1s, and the discriminator's predictions on fake (generated) sequences to an array of 0s.

In [None]:
def discriminator_loss(real_output, fake_output):
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    real_loss = loss_object(tf.ones_like(real_output), real_output)
    fake_loss = loss_object(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

### Generator loss
The generator's loss quantifies how well it was able to trick the discriminator. Intuitively, if the generator is performing well, the discriminator will classify the generated sequenceq as real (or 1). Here, we will compare the discriminators decisions on the generated sequence to an array of 1s.

In [None]:
#Primary loss for plain adverserial training
def generator_loss(fake_output):
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
  
    return loss_object(tf.ones_like(fake_output), fake_output)

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)


In [None]:

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

#@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated, real = get_disc_batch(inp, tar)

      real_output = discriminator(real, training=True)
      fake_output = discriminator(generated, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    train_loss(gen_loss)
    train_loss(disc_loss)
    

In [None]:
train_dataset, tokenizer_txt = create_generator_dataset(file_path, BATCH_SIZE=1)

In [None]:
train=True
EPOCHS = 10

In [None]:

if train:

  for epoch in range(EPOCHS):

    start = time.time()
    
    train_loss.reset_states()
    train_accuracy.reset_states()
    
    for (batch, (inp, tar)) in enumerate(train_dataset):

      train_step(inp, tar)
      
      if batch % 1 == 0:

        print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
            epoch + 1, batch, train_loss.result(), train_accuracy.result()))
        
    if (epoch + 1) % 5 == 0:

      ckpt_save_path = ckpt_manager.save()
      
      print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                          ckpt_save_path))
      
    print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                  train_loss.result(), 
                                                  train_accuracy.result()))

    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))