In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_text as text
import tensorflow_hub as hub
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import GlobalMaxPooling1D, Dense
from tensorflow.keras import backend as K

In [None]:
max_seq_len = 128
preprocessor_dir = "/linguistics/ethan/DL_Prototype/models/universal-sentence-encoder-cmlm_multilingual-preprocess_2"
LaBSE_dir = "/linguistics/ethan/DL_Prototype/models/LaBSE2"

def build_model_with_preprocessor(max_seq_len, preprocessor_dir, LaBSE_dir):
    
    src_texts = tf.keras.layers.Input(shape=(), dtype=tf.string, name="input_src_text")
    tgt_texts = tf.keras.layers.Input(shape=(), dtype=tf.string, name="input_tgt_text")

    preprocessor = hub.KerasLayer(preprocessor_dir, trainable=False)
    encoder = hub.KerasLayer(LaBSE_dir, trainable=False)
    
    src_x = preprocessor(src_texts)
    tgt_x = preprocessor(tgt_texts)
    
    src_x = encoder(src_x)["default"]
    tgt_x = encoder(tgt_x)["default"]
    
    src_x = tf.math.l2_normalize(src_x, axis=1, epsilon=1e-12, name=None)
    tgt_x = tf.math.l2_normalize(tgt_x, axis=1, epsilon=1e-12, name=None)
    
    # np.matmul(english_embeds, np.transpose(italian_embeds))
    x = tf.concat([src_x, tgt_x], axis=1)
    #  x = GlobalMaxPooling1D(x)
    
    x = Dense(512, activation='relu')(x)
    x = Dense(64, activation='relu')(x)
    output = Dense(1, activation='sigmoid')(x)
    
    model = Model([src_texts, tgt_texts], output)
    
    return model

In [None]:
model = build_model_with_preprocessor(max_seq_len, preprocessor_dir, LaBSE_dir)

In [None]:
num_ckpts = 2
ckpt_weights = []

for epoch in range(1, num_ckpts+1):
    ckpt_path = f"/linguistics/ethan/DL_Prototype/models/LaBSE2_ckpts/tqc-000{epoch}.ckpt"
    model.load_weights(ckpt_path)
    weights = model.get_weights()
    ckpt_weights.append(weights)

In [None]:
ckpt_weights[0][-2].shape

In [None]:
# averaged_weights = []
# for weights_list_tuple in zip(*ckpt_weights):

#     averaged_layer = [np.array(weights_).mean(axis=0) for weights_ in zip(weights_list_tuple)]
#     averaged_weights.append(averaged_layer)

# averaged_weights = np.array(averaged_weights)
# model.set_weights(averaged_weights)

In [None]:
it = [[1,2,3], [4,5,6]]
for t in zip(*it):
    print(t)

In [None]:
def toy_model():
    
    input_src = tf.keras.layers.Input(shape=(10, ), name="input_src_text")
    input_tgt = tf.keras.layers.Input(shape=(10, ), name="input_tgt_text")
    x = tf.keras.layers.concatenate([input_src, input_tgt])
    output = tf.keras.layers.Dense(2, activation="softmax")(x)
    
    model = tf.keras.Model([input_src, input_tgt], output)
    
    return model

In [None]:
my_model = toy_model()

In [None]:
model.summary()

In [None]:
input_src_text = tf.random.uniform((30, 10))
input_tgt_text = tf.random.uniform((30, 10))
labels = np.random.randint(2, size=(30,2))
# train_data = tf.data.Dataset.from_tensor_slices(dict(
#                                                  x = {"input_src_text": input_src_text,
#                                                       "input_tgt_text": input_tgt_text},
#                                                  y = labels)).batch(5)
train_data = tf.data.Dataset.from_tensor_slices(({"input_src_text": input_src_text,
                                                  "input_tgt_text": input_tgt_text},
                                                  labels)).batch(7)

In [None]:
for data in train_data:
    print(data[1])
    print("\n")

In [None]:
# high-level training
# model.compile(optimizer="adam",
#               metrics=["accuracy"],
#               loss="binary_crossentropy")
# model.fit(train_data, epochs=2)

# low level training with checkpoint storing and restoring
def train_step(my_model, example, optimizer, loss_fn):
    with tf.GradientTape() as tape:
        output = my_model(example[0], training=True)
        loss = loss_fn(example[1], output)
        
    variables = my_model.trainable_variables
    gradient = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradient, variables))
    
    return loss


def train_and_checkpoint_per_step(model, steps=50):
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restoring from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing training from scratch")
    for _ in range(steps):
        example = next(iterator)
        # print(example)
        loss = train_step(model, example, optimizer, loss_fn)
        ckpt.step.assign_add(1)
        if int(ckpt.step) % 10 == 0:
            save_path = manager.save()
            print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
            print("Loss {:1.2f}".format(loss.numpy()))

def train_and_checkpoint_per_epoch(model, iterator, epochs=10):
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restoring from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing training from scratch")
        
    for epoch in range(epochs):
        print("\nTraining epoch: {}".format(epoch + 1))
        for example in iterator:
#             print(example)
            loss_value = train_step(model, example, optimizer, loss_fn)
            
#         ckpt.step.assign_add(1)
        save_path = manager.save()
        print("\tSaved checkpoint for epoch {}: {}".format(epoch + 1, save_path))
        print("\tLoss at final step {:1.2f}".format(loss_value.numpy()))

In [None]:
"""Training for first time or restoring training, remember to re-initiate ckpt and manager."""
my_model = toy_model()
optimizer = tf.keras.optimizers.Adam(0.1)
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
train_iterator = iter(train_data)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), 
                           optimizer=optimizer,
                           net=my_model, 
                           iterator=train_iterator)
manager = tf.train.CheckpointManager(ckpt, "/linguistics/ethan/DL_Prototype/models/example_ckpt", max_to_keep=3)

In [None]:
# train_and_checkpoint_per_step(my_model, steps=50)
train_and_checkpoint_per_epoch(my_model, train_iterator, epochs=10)
# for e in iterator:
#     print(e)

In [None]:
# my_model.load_weights(manager.latest_checkpoint)
# tf.keras.models.load_model(manager.latest_checkpoint)
ckpt.restore(manager.latest_checkpoint)

In [None]:
input_data = [tf.random.uniform((5, 10)), tf.random.uniform((5, 10))]

In [None]:
my_model = toy_model()
my_model(input_data)

In [None]:
my_model = toy_model()
my_model(input_data)

In [None]:
ckpt = tf.train.Checkpoint(model=my_model)
ckpt.restore("/linguistics/ethan/DL_Prototype/models/example_ckpt/ckpt-20")
my_model(input_data)

In [None]:
ckpt(input_data)

In [None]:
# ckpt.step.numpy()
# ckpt.step.assign_add(1)
# ckpt.step.numpy()
manager.latest_checkpoint

In [None]:
model.predict([tf.random.uniform((2, 10)), tf.random.uniform((2, 10))])