# PAR Transformer XL Trainer

This notebook is to facilitate training on Google colab, so that you can use a GPU/TPU.

## Colab specific setup

In [None]:
%%capture
!pip install tensorflow_text

In [None]:
!git clone https://github.com/Jmkernes/PAR-Transformer-XL.git
%cd PAR-Transformer-XL/

## Load tensorboard. Re-run this cell before every run to reload tensorboard.

This will setup the metric tracking. It's not required, as the code will print out the loss every 100 steps and print to a log file. But, this will tell you additional things like learning rate, perplexity and validation metrics.

In [None]:
!rm -r logs
!rm -r plots
!mkdir logs
%tensorboard --logdir logs

## Run the model

Adjust the parameters in the base_model script if you want to alter the model.

In [None]:
!./base_model.sh

## (Optional) Save results

The checkpoints file can be a lot of data, so it's advised to not zip the whole thing (which is why it's commented out) but just take which checkpoints you want.

If the code runs to completion (about 37m on a single GPU with default model settings), then you also have the option of downloading a .savedmodel file, which can be loaded into a fully functional model by executing ``` tf.keras.models.load_model('saved_models')```

In [772]:
from google.colab import files
!zip -r logs.zip logs
!zip -r plots.zip plots
# !zip -r checkpoints.zip checkpoints

# files.download('plots.zip')
# files.download('logs.zip')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Interactive session

Since this is a notebook, you can load in different checkpoints of the model (or the final version) and play around with it.

In [None]:
import os
import json
import numpy as np
import tensorflow as tf
import tensorflow_text as tf_text
import matplotlib.pyplot as plt
from data_utils import DataManager
from utils import visualize_pi_weights
from par_model import PARTransformerXL
from par_model import create_lookahead_mask, positional_encoding

The below will work to load in from checkpoint. You have to 

1) recreate an identical model with the same architecture

2) create a checkpoint object with parameter model=model. The key here was decided when the first model was checkpointed, i.e., that the model should always be called model.

3) restore the checkpoint object with a checkpoint path ckpt.restore(PATH). This will automatically change the value of model globally, i.e. ckpt doesn't keep a copy of model, it keeps a reference.

All that is in the load_from_checkpoint path. Have fun!

In [None]:
def load_from_checkpoint(ckpt_path):
    with open(ckpt_path+'/config.json', 'r') as file:
        config = json.loads(file.readline())
    model = PARTransformerXL(**config)
    ckpt = tf.train.Checkpoint(model=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_path, 5)
    ckpt.restore(ckpt_manager.latest_checkpoint)
    return model

def load_from_savedmodel(path):
    return tf.keras.models.load_model(path)