In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_text as text
import tensorflow_hub as hub
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import GlobalMaxPooling1D, Dense
from tensorflow.keras import backend as K

In [None]:
def toy_model():
    
    input_src = tf.keras.layers.Input(shape=(128, ), name="input_src_text")
    input_tgt = tf.keras.layers.Input(shape=(128, ), name="input_tgt_text")
    x = tf.keras.layers.concatenate([input_src, input_tgt])
    x = Dense(2048, activation="relu")(x)
    x = Dense(64, activation="relu")(x)
    output = tf.keras.layers.Dense(2, activation="softmax")(x)
    
    model = tf.keras.Model([input_src, input_tgt], output)
    
    return model

In [None]:
with tf.device("/gpu:4"):
    my_model = toy_model()

In [None]:
my_model.summary()

In [None]:
input_src_text = tf.random.uniform((5000, 128))
input_tgt_text = tf.random.uniform((5000, 128))
labels = np.random.randint(2, size=(5000,2))
# train_data = tf.data.Dataset.from_tensor_slices(dict(
#                                                  x = {"input_src_text": input_src_text,
#                                                       "input_tgt_text": input_tgt_text},
#                                                  y = labels)).batch(5)
train_data = tf.data.Dataset.from_tensor_slices(({"input_src_text": input_src_text,
                                                  "input_tgt_text": input_tgt_text},
                                                  labels)).batch(32)

In [None]:
"""Low level train step"""
def train_step(my_model, example, optimizer, loss_fn):
    with tf.GradientTape() as tape:
        output = my_model(example[0], training=True)
        loss = loss_fn(example[1], output)
        
    variables = my_model.trainable_variables
    gradient = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradient, variables))
    
    return loss


def train_and_checkpoint_per_step(model, steps=50):
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restoring from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing training from scratch")
    for _ in range(steps):
        example = next(iterator)
        # print(example)
        loss = train_step(model, example, optimizer, loss_fn)
        ckpt.step.assign_add(1)
        if int(ckpt.step) % 10 == 0:
            save_path = manager.save()
            print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
            print("Loss {:1.2f}".format(loss.numpy()))


def train_and_checkpoint_per_epoch(model, iterator, epochs=10):
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restoring from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing training from scratch")
        
    for epoch in range(epochs):
        print("\nTraining epoch: {}".format(epoch + 1))
        for example in iterator:
        #  print(example)
            loss_value = train_step(model, example, optimizer, loss_fn)
            
        #  ckpt.step.assign_add(1)
        save_path = manager.save()
        print("\tSaved checkpoint for epoch {}: {}".format(epoch + 1, save_path))
        print("\tLoss at final step {:1.2f}".format(loss_value.numpy()))

In [None]:
"""Training for first time or restoring training, remember to re-initiate ckpt and manager."""
with tf.device("/gpu:7"):
    my_model = toy_model()
optimizer = tf.keras.optimizers.Adam(0.1)
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
train_iterator = iter(train_data)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), 
                           optimizer=optimizer,
                           net=my_model, 
                           #  iterator=train_iterator
                          )
manager = tf.train.CheckpointManager(ckpt, "/linguistics/ethan/DL_Prototype/models/example_ckpt", max_to_keep=3)

In [None]:
# train_and_checkpoint_per_step(my_model, steps=50)
train_and_checkpoint_per_epoch(my_model, train_iterator, epochs=10)

In [None]:
# my_model.load_weights(manager.latest_checkpoint)
# tf.keras.models.load_model(manager.latest_checkpoint)
# ckpt.restore(manager.latest_checkpoint)
# for example in train_data:
#     print(example)
my_model.summary()

In [None]:
input_data = [tf.random.uniform((5, 128)), tf.random.uniform((5, 128))]
my_model.predict([tf.random.uniform((2, 128)), tf.random.uniform((2, 128))])

In [None]:
# my_model = toy_model()
# my_model(input_data)
my_model.save("/linguistics/ethan/DL_Prototype/models/example_ckpt/export/toy_model")

In [None]:
# my_model.load_weights("/linguistics/ethan/DL_Prototype/models/example_ckpt/ckpt-10")
# tf.keras.models.load_model("/linguistics/ethan/DL_Prototype/models/example_ckpt/ckpt-10")

In [None]:
# ckpt.step.numpy()
# ckpt.step.assign_add(1)
# ckpt.step.numpy()
# manager.latest_checkpoint

In [None]:
export_dir= "/linguistics/ethan/DL_Prototype/models/example_ckpt/export/toy_model_ckpt2pb"    