# Pre-training ALBERT from scratch

## Set up environnement

In [None]:
%tensorflow_version 1.x

In [None]:
import os
import sys
import json
# import nltk
import random
import logging
import tensorflow as tf

# from glob import glob
from google.colab import auth
# from tensorflow.keras.utils import Progbar

In [None]:
auth.authenticate_user()

# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s :  %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]

if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  USE_TPU = True
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    
else:
  log.warning('Not connected to TPU runtime')
  USE_TPU = False

## Import code and data

In [None]:
!pip -q install sentencepiece
!test -d albert || git clone https://github.com/google-research/albert

if not 'albert' in sys.path:
  sys.path += ['albert']

In [None]:
import sentencepiece as spm

In [None]:
#@title GCS configuration

FROM_TF_HUB = False #@param {type:"boolean"}

PROJECT_ID = '' #@param {type:"string"}
os.environ['GOOGLE_CLOUD_PROJECT'] = PROJECT_ID

!gcloud config set project $PROJECT_ID
PRETRAINING_DIR = "" #@param {type:"string"}
BUCKET_NAME = "" #@param {type:"string"}
MODEL_DIR = "" #@param {type:"string"}
tf.gfile.MkDir(MODEL_DIR)


if not BUCKET_NAME:
  log.warning("WARNING: BUCKET_NAME is not set. "
              "You will not be able to train the model.")

if BUCKET_NAME:
  BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
  BUCKET_PATH = "."

BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)

#@markdown ---
#@markdown ### TF hub model (fill if FROM_TF_HUB)

if FROM_TF_HUB:
  ALBERT_MODEL_HUB = 'https://tfhub.dev/google/albert_' + ALBERT_MODEL + '/' + VERSION
  ALBERT_MODEL = 'base' #@param ["base", "large", "xlarge", "xxlarge"]
  VERSION = "1" #@param ["1", "2", "3"]
else:
  ALBERT_MODEL_HUB = None
  ALBERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
  DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)
  ALBERT_CONFIG_FILE = os.path.join(ALBERT_GCS_DIR, "albert_config.json")

if not FROM_TF_HUB and (not BUCKET_PATH or BUCKET_PATH == "gs://"):
  raise ValueError("You must configure at least one of"
                   "`TF_HUB` and `BUCKET_NAME`")

In [None]:
#@title Vocabulary

VOC_SIZE = 30000 #@param {type: "integer"}
VOCAB_FILE = "30k-clean-v2.vocab" #@param {type:"string"}
SPM_MODEL_FILE = "30k-clean-v2.model" #@param {type:"string"}
DO_LOWER_CASE = True #@param {type:"boolean"}
# DO_LOWER_CASE = "True" #@param {type:"string"}
# DO_LOWER_CASE = bool(DO_LOWER_CASE)

VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOCAB_FILE)
SPM_MODEL_FILE = os.path.join(BERT_GCS_DIR, SPM_MODEL_FILE)

VOCAB_UPOS = "deprel.conll.encoder" #@param {type:"string"}
VOCAB_DEPS = "upos.conll.encoder" #@param {type:"string"}

!gsutil cp gs://$BUCKET_NAME/$MODEL_DIR/$VOCAB_UPOS . >nul 2>&1
!gsutil cp gs://$BUCKET_NAME/$MODEL_DIR/$VOCAB_DEPS . >nul 2>&1

VOCAB_UPOS = os.path.join('.', VOCAB_UPOS)
VOCAB_DEPS = os.path.join('.', VOCAB_DEPS)

## Create Model

In [None]:
from albert import (
    modeling, 
    optimization, 
    tokenization
)
from albert.run_pretraining import (
    input_fn_builder,
    model_fn_builder
)

tf.logging.set_verbosity(tf.logging.INFO)

In [None]:
attention_probs_dropout_prob = 0 #@param  {type:"number"}
hidden_act = 'gelu' #@param ["gelu", "relu", "tanh", "sigmoid", "linear"]
hidden_dropout_prob = 0 #@param  {type:"number"}
embedding_size = 128 #@param {type:"integer"}
hidden_size = 768 #@param {type:"integer"}
initializer_range = 0.02 #@param  {type:"number"}
intermediate_size = 3072 #@param {type:"integer"}
max_position_embeddings = 512 #@param {type:"integer"}
num_attention_heads = 12 #@param {type:"integer"}
num_hidden_layers =  12#@param {type:"integer"}

albert_base_config = {
  "attention_probs_dropout_prob": attention_probs_dropout_prob,
  "hidden_act": hidden_act,
  "hidden_dropout_prob": hidden_dropout_prob,
  "embedding_size": embedding_size,
  "hidden_size": hidden_size,
  "initializer_range": initializer_range,
  "intermediate_size": intermediate_size,
  "max_position_embeddings": max_position_embeddings,
  "num_attention_heads": num_attention_heads,
  "num_hidden_layers": num_hidden_layers,
  "num_hidden_groups": 1,
  "net_structure_type": 0,
  "gap_size": 0,
  "num_memory_blocks": 0,
  "inner_group_num": 1,
  "down_scale_factor": 1,
  "type_vocab_size": 2,
  "vocab_size": VOC_SIZE
}

In [None]:
if not FROM_TF_HUB:
  tokenizer = tokenization.FullTokenizer(
        vocab_file=VOCAB_FILE, 
        do_lower_case=DO_LOWER_CASE,
        spm_model_file=SPM_MODEL_FILE)
else:
  tokenizer = fine_tuning_utils.create_vocab(
        hub_module=ALBERT_MODEL_HUB)

MASK_TOKEN_ID = tokenizer.vocab['[MASK]']
CLS_TOKEN_ID = tokenizer.vocab['[CLS]']
SEP_TOKEN_ID = tokenizer.vocab['[SEP]']
print("MASK token ID is {}.".format(MASK_TOKEN_ID))
print("CLS token ID is {}.".format(CLS_TOKEN_ID))
print("SEP token ID is {}.".format(SEP_TOKEN_ID))

In [None]:
with open("{}/albert_config.json".format(MODEL_DIR), "w") as fo:
  json.dump(albert_base_config, fo, indent=2)

In [None]:
# !gsutil -m cp $MODEL_DIR/albert_config.json gs://$BUCKET_NAME/$MODEL_DIR/

## Training

Part 1: Train the model on 90% of training steps using a sentence length of 128
Part 2: Train the model on 10% of training steps using a sentence length of 512

In [None]:
# Input data pipeline config
TRAIN_BATCH_SIZE_PHASE_1 = 64 #@param {type:"integer"}
NUM_ACCUMULATION_STEPS_PHASE_1 =  64#@param {type:"integer"}
MAX_SEQ_LENGTH_PHASE_1 = 128 #@param {type:"integer"}
TRAIN_STEPS_PHASE_1 = 112500 #@param {type:"integer"}
LEARNING_RATE_PHASE_1 = 0.00176 #@param {type:"number"}
WARMUP_STEPS_PHASE_1 = 3125 #@param {type:"integer"}



In [None]:
TRAIN_BATCH_SIZE_PHASE_2 = 128 #@param {type:"integer"}
NUM_ACCUMULATION_STEPS_PHASE_2 = 32 #@param {type:"integer"}
MAX_SEQ_LENGTH_PHASE_2 = 512 #@param {type:"integer"}
TRAIN_STEPS_PHASE_2 = 12500 #@param {type:"integer"}
LEARNING_RATE_PHASE_2 = 0.000275 #@param {type:"number"}
WARMUP_STEPS_PHASE_2 = 312 #@param {type:"integer"}

In [None]:
TAU = 5e-4 #@param {type:"number"}
# TAU = float(TAU)
MAX_PREDICTIONS = 20 #@param {type:"integer"}
EVAL_BATCH_SIZE = 64 #@param {type:"integer"}
OPTIMIZER = "lamb" #@param ["adamw", "lamb"]
SAVE_CHECKPOINTS_STEPS = 5000 #@param {type:"integer"}
KEEP_CHECKPOINTS_MAX = 10  #@param {type: "slider", min: 1, max: 15, step: 1}
NUM_TPU_CORES = 8 #@param {type:"integer"}
POLY_POWER = 1.0  #@param {type:"number"}
# POLY_POWER = float(POLY_POWER)
ITERATIONS_PER_LOOP = 1000 #@param {type:"integer"}
EVAL_EVERY_N_SECONDS = 3600 #@param {type:"integer"}

In [None]:
phase = 1

if FROM_TF_HUB:
  CONFIG_FILE = None
else:
  CONFIG_FILE = os.path.join(BERT_GCS_DIR, "albert_config.json")

# Check which phase is running
INIT_CHECKPOINT = tf.train.latest_checkpoint("{}/{}/phase_{}".format(BUCKET_PATH, MODEL_DIR, str(phase)))

if INIT_CHECKPOINT:
  ckpt_reader = tf.train.NewCheckpointReader(INIT_CHECKPOINT)
  global_step = ckpt_reader.get_tensor('global_step')
  print('global step from phrase 1 in {:,}.'.format(global_step))
  if global_step >= TRAIN_STEPS_PHASE_1:
    phase = 2
    INIT_CHECKPOINT_2 = tf.train.latest_checkpoint("{}/{}/phase_{}".format(BUCKET_PATH, MODEL_DIR, "2"))
    if INIT_CHECKPOINT_2:
      print("loading latest checkpoint from phase 2.")
      INIT_CHECKPOINT = INIT_CHECKPOINT_2
      ckpt_reader = tf.train.NewCheckpointReader(INIT_CHECKPOINT)
      global_step = ckpt_reader.get_tensor('global_step')
      print('global step from phrase 2 in {:,}.'.format(global_step))
    else:
      # else, this is the first checkpoint in phase 2
      print("no checkpoint from phase 2 yet.")
  else:
    print("max step not reached for phase 1 yet.")
    phase = 1
else:
  print("no checkpoint from phase 1 yet.")
  phase = 1

print("Forcing phase 1.")
phase = 1


if phase == 1:
  TRAIN_BATCH_SIZE = TRAIN_BATCH_SIZE_PHASE_1
  NUM_ACCUMULATION_STEPS = NUM_ACCUMULATION_STEPS_PHASE_1
  MAX_SEQ_LENGTH = MAX_SEQ_LENGTH_PHASE_1
  TRAIN_STEPS = TRAIN_STEPS_PHASE_1
  LEARNING_RATE = LEARNING_RATE_PHASE_1
  WARMUP_STEPS = WARMUP_STEPS_PHASE_1
else:
  TRAIN_BATCH_SIZE = TRAIN_BATCH_SIZE_PHASE_2
  NUM_ACCUMULATION_STEPS = NUM_ACCUMULATION_STEPS_PHASE_2
  MAX_SEQ_LENGTH = MAX_SEQ_LENGTH_PHASE_2
  TRAIN_STEPS = TRAIN_STEPS_PHASE_2
  LEARNING_RATE = LEARNING_RATE_PHASE_2
  WARMUP_STEPS = WARMUP_STEPS_PHASE_2

LEARNING_RATE = float(LEARNING_RATE)
MODEL_DIR_CKPT = "{}/{}/phase_{}".format(BUCKET_PATH, MODEL_DIR, str(phase))
log.info("Total train batch size "
         "(train_batch_size * num_accumulation_steps): {:,}"\
         .format(TRAIN_BATCH_SIZE * NUM_ACCUMULATION_STEPS))
log.info("Total training steps "
         "(train_steps / num_accumulation_steps): {:,}"\
        .format(int(TRAIN_STEPS))) #  / NUM_ACCUMULATION_STEPS
log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Saving model to: {}".format(MODEL_DIR_CKPT))
albert_config = modeling.AlbertConfig.from_json_file(CONFIG_FILE)

In [None]:
input_files = !gsutil ls $DATA_GCS_DIR/$MAX_SEQ_LENGTH/shard*
dev_files = !gsutil ls $DATA_GCS_DIR/ptb/*.tfrecord

log.info("For train, Using {} data shards from {}".format(len(input_files), DATA_GCS_DIR+ "/" + str(MAX_SEQ_LENGTH)))

train_input_files = input_files[:-1]
eval_input_files = [input_files[-1]]
# eval_input_files = []
# eval_input_files.extend(dev_files)

log.info("For dev, using {} data shards from {}".format(len(eval_input_files), DATA_GCS_DIR+ "/ptb"))

num_eval_examples = 0
for f in eval_input_files:
  n_f = sum(1 for _ in tf.python_io.tf_record_iterator(f))
  print('{} contains {:,} examples'.format(f, n_f))
  num_eval_examples += n_f
print("Using {:,} dev examples".format(num_eval_examples))

In [None]:
model_fn = model_fn_builder(
      albert_config=albert_config,
      init_checkpoint=INIT_CHECKPOINT,
      init_lr=LEARNING_RATE,
      num_train_steps=TRAIN_STEPS,
      num_warmup_steps=WARMUP_STEPS,
      use_tpu=USE_TPU,
      use_one_hot_embeddings=USE_TPU,
      optimizer_name=OPTIMIZER,
      poly_power=POLY_POWER,
      start_warmup_step=0,
      num_accumulation_steps=NUM_ACCUMULATION_STEPS,
      model_dir=MODEL_DIR_CKPT,
      tau=TAU,
      vocab_upos=VOCAB_UPOS, 
      vocab_deps=VOCAB_DEPS, 
      mask_token_id=MASK_TOKEN_ID,
      cls_token_id=CLS_TOKEN_ID,
      sep_token_id=SEP_TOKEN_ID)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
run_config = tf.contrib.tpu.RunConfig(
    # cluster=tpu_cluster_resolver,
    master=TPU_ADDRESS,
    model_dir=MODEL_DIR_CKPT,
    # log_step_count_steps=1000,
    # save_summary_steps=1000,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    keep_checkpoint_max=KEEP_CHECKPOINTS_MAX,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=ITERATIONS_PER_LOOP,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=USE_TPU,
    eval_on_tpu=False,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE)

train_input_fn = input_fn_builder(
        input_files=input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=True)

eval_input_fn = input_fn_builder(
        input_files=eval_input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=False,
        add_dep_and_pos=True)

In [None]:
estimator.evaluate(input_fn=eval_input_fn, steps=int(num_eval_examples / EVAL_BATCH_SIZE))

In [None]:
# estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)
# setup train spec
train_spec = tf.estimator.TrainSpec(
    input_fn=train_input_fn,
    max_steps=TRAIN_STEPS
  )

# setup eval spec evaluating ever n seconds
eval_spec = tf.estimator.EvalSpec(
    input_fn=eval_input_fn,
    steps=int(num_eval_examples / EVAL_BATCH_SIZE),
    throttle_secs=EVAL_EVERY_N_SECONDS)

# run train and evaluate
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)