In [220]:
import os
import time
import numpy as np
import tensorflow as tf
import tensorflow_text as tf_text
import matplotlib.pyplot as plt
from data_utils import DataManager
from par_model import PARTransformerXL
from par_model import create_lookahead_mask, positional_encoding

# %load_ext autoreload
# %autoreload 2

In [274]:
import time
def printBar(step, tot, diff, loss):
    num_eq = int(10*(step+1)/tot)
    num_pd = 10-num_eq
    bar = '['+'='*num_eq+'>'+'.'*num_pd+']'
    time_left = (tot-step)*diff
    m = int(time_left)//60
    s = int(time_left)%60
    iter_message = f"Iteration {step+1:02d}/{tot}:"
    time_message = f"{1/diff:.2f} it/s. Est: {m:02d}m {s:02d}s"
    loss_message = f"Loss: {loss:.3f}"
    end = '\r' if step<tot-1 else '\n'
    print(iter_message, bar, time_message, loss_message, end=end)
    
start = time.time()
for i in range(100):
    time.sleep(0.01)
    x = 5/np.sqrt(i+1)
    diff = (time.time()-start)/(i+1)
    printBar(i, 100, diff, x)



## Load the wikitext2 train, validation and test data

In [227]:
config = {'tfrecords_directory':'data/wikitext2_bsz32_seqlen32_tfrecords_train',
                'sp_model_prefix': 'wiki2_12k'}
train_dm = DataManager.initialize_from_tfrecord(config)

config['tfrecords_directory'] = 'data/wikitext2_bsz32_seqlen32_tfrecords_valid'
valid_dm = DataManager.initialize_from_tfrecord(config)

config['tfrecords_directory'] = 'data/wikitext2_bsz32_seqlen32_tfrecords_test'
test_dm = DataManager.initialize_from_tfrecord(config)

Loading tokenizer from wiki2_12k.model...
Loading tfrecords from directory
Loading tokenizer from wiki2_12k.model...
Loading tfrecords from directory
Loading tokenizer from wiki2_12k.model...
Loading tfrecords from directory


## Initialize model architecture

In [228]:
train_dm.

<tf.Tensor: shape=(), dtype=int64, numpy=2827>

In [231]:
tf.keras.backend.clear_session()

config = {
    'd_model':256, 
    'num_heads':4, 
    'max_position':512, 
    'd_ffn':1024,
    'num_layers':12, 
    'mem_len':32, 
    'vocab_size':12000,
    'dropout_rate':0.1, 
    'cutoffs':[250, 2500, 12000], 
    'proj_factor':4, 
    'proj_dims':None,
}

ds_size = train_dm.ds_size
max_position = config['max_position']
pos_enc = positional_encoding(max_position, config['d_model'])
lookahead_mask = create_lookahead_mask(max_position, max_position)
model = PARTransformerXL(**config)

In [252]:
train_ds = train_dm.get_inp_tar_pairs()
x, y = next(iter(train_ds))
model(x, None, labels=y, training=True)
model.summary()

Model: "par_transformer_xl"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        multiple                  3072000   
_________________________________________________________________
adaptive_softmax (AdaptiveSo multiple                  393074    
_________________________________________________________________
stochastic_block (Stochastic multiple                  856068    
_________________________________________________________________
stochastic_block_1 (Stochast multiple                  856068    
_________________________________________________________________
stochastic_block_2 (Stochast multiple                  856068    
_________________________________________________________________
stochastic_block_3 (Stochast multiple                  856068    
_________________________________________________________________
stochastic_block_4 (Stochast multiple           

### Create simulated annealing schedule for gumbel softmax tau. We use exponential decay.

In [235]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = warmup_steps
    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
    
learning_rate = CustomSchedule(config['d_model'], 4000)

In [253]:
train_loss = tf.keras.metrics.Mean()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

# @tf.function(input_signature=[
#     tf.TensorSpec(shape=(32,32), dtype=tf.int32),
#      tf.TensorSpec(shape=(32,32), dtype=tf.int32)
# ])
def train_step(inp, x_mems, labels, tau):
    with tf.GradientTape() as tape:
        loss, mems = model(x, x_mems, labels=labels, training=True, tau=tau)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss(loss)
    return mems

In [254]:
# checkpoint_path = "./checkpoints/train"
# ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
# ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# # if a checkpoint exists, restore the latest checkpoint.
# if ckpt_manager.latest_checkpoint:
#     try:
#         ckpt.restore(ckpt_manager.latest_checkpoint)
#         print ('Latest checkpoint restored!!')
#     except:
#         print("Model may have changed, could not restore checkpoint.")

In [272]:
# make tau untrainable
for layer in model.layers:
    if hasattr(layer, 'tau'):
        layer.tau = tf.cast(tf.constant(1.), tf.float32)

In [205]:
dm.tokenizer.detokenize([5])

<tf.Tensor: shape=(), dtype=string, numpy=b''>

In [273]:
num_epochs = 5

tau_start = 2.0
tau_end = 0.2
decay_steps = ds_size*num_epochs
eta = tf.cast((tau_end/tau_start)**(1/decay_steps), dtype=tf.float32)
tau = tf.constant(tau_start, dtype=tf.float32)
num_batches = ds_size.numpy()

history={'loss':[], 'tau':[]}
for epoch in range(num_epochs):
    print('-'*10,f' Epoch {epoch+1} ', '-'*10)
    start = time.time()
    train_loss.reset_states()
    mems = None
    for step, (inp, lbl) in enumerate(train_ds):
        mems = train_step(inp, mems, lbl, tau)
        diff = (time.time()-start)/(step+1)
        printBar(step, num_batches, diff, train_loss.result().numpy())
        tau *= eta
        
        history['loss'].append(train_loss.result().numpy())
        history['tau'].append(tau.numpy())

----------  Epoch 1  ----------
Iteration 09/2827: [>..........] 0.50 it/s. Est: 94m 49s Loss: 8.150

KeyboardInterrupt: 

In [152]:
for i in range(2,6):
    print(f"Layer {i-1}: {model.layers[i].pi.numpy()}")
    
for i in range(2,6):
    print(f"Layer {i-1}: {model.layers[i].tau.numpy()}")

Layer 1: [0.32119352 0.32047126 0.35673234]
Layer 2: [0.3197162  0.34020415 0.35368592]
Layer 3: [0.35318667 0.28377616 0.34028876]
Layer 4: [0.3391504  0.30977973 0.3718491 ]
Layer 1: 1.0
Layer 2: 1.0
Layer 3: 1.0
Layer 4: 1.0
