In [None]:
%tensorflow_version 2.x
#!pip3 install --upgrade pip
#!pip install -qU t5
!pip install -q git+https://github.com/google-research/text-to-text-transfer-transformer.git@1e269e72a981fde4ea64a88a0a0d8cc88871e20a #temporary fix


import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

#Set the base dir(Google cloud bucket)

BASE_DIR = "gs://" 

if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")
ON_CLOUD = True


if ON_CLOUD:
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  log_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

In [None]:
#Set the path of sentencepiece model and vocab files

vocab_model_path = 'gs://........model'
vocab_path = 'gs://.........vocab'

DATA_DIR = os.path.join(BASE_DIR, "data/datasets/UNSUPERVISED/pre-training") 

nq_tsv_path = {
    "train": os.path.join(DATA_DIR, "unsupervised_training.tsv"),
}

In [None]:
from t5.data.utils import Feature
import t5.data.preprocessors
from t5.data import sentencepiece_vocabulary

TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

DEFAULT_EXTRA_IDS = 100

def get_default_vocabulary():
  return sentencepiece_vocabulary.SentencePieceVocabulary(
      vocab_model_path, DEFAULT_EXTRA_IDS)

FEATURES = {
    "inputs": Feature(vocabulary=get_default_vocabulary(), add_eos=True),
    "targets": Feature(vocabulary=get_default_vocabulary(), add_eos=True)
}


In [None]:
def nq_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string"],
                        field_delim="\t", use_quote_delim=True),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["text"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_dataset_fn("train").take(5)):
  print(ex)

In [None]:
#Create a new training task

t5.data.TaskRegistry.remove('unsupervised_training')
t5.data.TaskRegistry.add(
    "unsupervised_training",
    dataset_fn=nq_dataset_fn,
    splits=["train"],
    output_features=FEATURES,
    text_preprocessor=functools.partial(
        t5.data.preprocessors.rekey, key_map={"inputs": None, "targets": "text"}),
    metric_fns=[])

In [None]:
#Get a few preprocessed training examples...

nq_task = t5.data.TaskRegistry.get("unsupervised_training")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 512, "targets": 512})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import learning_rate_schedule_noam

#See https://github.com/google-research/text-to-text-transfer-transformer if you want to scale up the model
MODEL_SIZE = "small"  

MODEL_DIR = 'gs://......../'


model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]


tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 512, "targets": 512},
    learning_rate_schedule = learning_rate_schedule_noam,
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None
)

In [None]:
# Load and parse the configuration file for T5 small network.
# You can find the file in the replication package

PATH_GIN_FILE='gs://.....train_t5_small.gin'
import gin
with gin.unlock_config():
      gin.parse_config_file(PATH_GIN_FILE)
print('\n'.join(gin.config_str().split('\n')))

In [None]:
TRAIN_STEPS = 500000
model.train("unsupervised_training", steps=TRAIN_STEPS)