<a href="https://colab.research.google.com/github/ImpactPretraining/impact_pre-training/blob/main/code/pre_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# from google.colab import auth
# auth.authenticate_user()

import os
os.environ['USE_AUTH_EPHEM'] = '0'

from google.colab import auth
auth.authenticate_user()

#@title ## Set Your GCS credential
project_id = 'literaturereview-358312' #@param {type:"string"}
bucket_name = 'literature_review' #@param {type:"string"}

!gcloud config set project {project_id}

!pip3 install --upgrade pip
!pip install -qU t5==0.9.2
!pip install -q tensorflow-text==2.8.0rc0
!pip3 install keras==2.7.0
!pip3 install gin-config

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
import t5

tf.flags.DEFINE_string('f','','')

#Set the base dir(Google cloud bucket)
BASE_DIR = "gs://" + bucket_name 

if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
ON_CLOUD = True


if ON_CLOUD:
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.compat.v1.enable_eager_execution(config=None, device_policy=None, execution_mode=None)
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)


In [None]:
#@title ## Set pre-training dataset path
masked_pretraining_dataset_path = "gs://literature_review/data/pre-training/MLM/MLM.tsv" #@param {type:"string"}

nq_tsv_path = {
    "train": masked_pretraining_dataset_path
}

In [None]:
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary

#@title ## Set tokenizer's model and vocab paths
vocab_model_path = 'gs://literature_review/tokenizer/BPE_Model.model' #@param {type:"string"} 
vocab_path = 'gs://literature_review/tokenizer/BPE_Model.vocab' #@param {type:"string"}

TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)


In [None]:
DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=False, required=True),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=False)
}

def nq_dataset_fn(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

In [3]:
def preprocessing(ds):
  def to_inputs_and_targets(ex):
        inputs = tf.strings.join([ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
  return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

#Create a new training task
t5.data.TaskRegistry.remove('pretraining')
t5.data.TaskRegistry.add(
    "pretraining",
    t5.data.Task,
    dataset_fn=nq_dataset_fn,
    splits=["train", "validation"],
    text_preprocessor=[preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
)

nq_task = t5.data.TaskRegistry.get("pretraining")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 512, "targets": 512})

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import learning_rate_schedule_noam

MODEL_SIZE = "small"  

#@title ## Set the output model folder path
MODEL_DIR = 'gs://literature_review/models/pre-training/MLM/' #@param {type:"string"} 

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)
from t5 import models

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 512, "targets": 512},
    learning_rate_schedule = learning_rate_schedule_noam,
    save_checkpoints_steps=10000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None
)

In [1]:
# We used 156250 TRAIN_STEPS (40 epochs) 
#@title ## Set the gin file path
PATH_GIN_FILE = "gs://literature_review/utils/operative_config.gin" #@param {type:"string"}
import gin
with gin.unlock_config():    
    gin.parse_config_file(PATH_GIN_FILE)
    TRAIN_STEPS = 156250
    model.train("pretraining", steps=TRAIN_STEPS)