In [None]:
from IPython.display import clear_output

# %env USE_AUTH_EPHEM=0
!pip install -q t5
!pip install -U jax jaxlib
!pip install -U flax

clear_output()

In [None]:
from IPython.display import clear_output

!pip install -U tensorflow-gcs-config==2.12.0
!pip install tensorflow==2.12.0
!pip install tensorflow-text==2.12.0
# !pip install -U jax jaxlib

clear_output()

In [None]:
# %tensorflow_version 2.x
# !pip3 install --upgrade pip
# !pip install t5==0.9.2
# !pip install -U jax jaxlib

%env USE_AUTH_EPHEM=0
import functools
import os
import gin
import tensorflow_gcs_config
from google.colab import auth
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
from contextlib import contextmanager
import logging as py_logging
import t5
tf.app.flags.DEFINE_string('f', '', 'kernel')

#Set the base dir(Google cloud bucket)
BASE_DIR = "" #@param { type: "string" }


# Set credentials for GCS reading/writing from Colab and TPU.
TPU_TOPOLOGY = "2x2"
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  TPU_ADDRESS = tpu.get_master()
  print('Running on TPU:', TPU_ADDRESS)
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
auth.authenticate_user()
# auth.authenticate_service_account()
tf.config.experimental_connect_to_host(TPU_ADDRESS)
tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()


#LOGGING
tf.get_logger().propagate = False
py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

In [None]:
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary


vocab_model_path = '' #@param { type: "string" }
vocab_path = '' #@param { type: "string" }


TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask


def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True, required=False),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True)
}

In [None]:
#Skip this cell for running the pre-training on the second task only

#6755884
path_pretraining_task1 = ''#@param { type: "string" }

nq_tsv_path = {
    "train":      path_pretraining_task1,
}

num_nq_examples_task1 = dict(train=6755884)

def nq_dataset_task1(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=True),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)

  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_dataset_task1("train").take(5)):
    print(ex)


def preprocessing_task1(ds):

  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['DENOISE: ' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }

  return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

t5.data.TaskRegistry.remove('masking_task')
t5.data.TaskRegistry.add(
    "masking_task",
    dataset_fn=nq_dataset_task1,
    splits=["train"],
    text_preprocessor=preprocessing_task1,
    output_features = DEFAULT_OUTPUT_FEATURES,
    num_input_examples=num_nq_examples_task1
)

In [None]:
#Skip this cell for running the pre-training on the first task only

#133082
path_pretraining_task2 = ''#@param { type: "string" }

nq_tsv_path = {
    "train":      path_pretraining_task2,

}
num_nq_examples_task2 = dict(train=133082)

def nq_dataset_task2(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=True),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)

  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_dataset_task1("train").take(5)):
    print(ex)


def preprocessing_task2(ds):

      def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['LOG_STMT: ' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }


      return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)

t5.data.TaskRegistry.remove('log_stmt_task')
t5.data.TaskRegistry.add(
    "log_stmt_task",
    dataset_fn=nq_dataset_task2,
    splits=["train"],
    text_preprocessor=preprocessing_task2,
    output_features = DEFAULT_OUTPUT_FEATURES,
    num_input_examples=num_nq_examples_task2
)

In [None]:
def _rate_num_input_examples(task):
  if "train" in task.splits:
    return float(task.num_input_examples("train"))
  elif "validation" in task.splits:
    return float(task.num_input_examples("validation"))
  else:
    raise ValueError("Task %s does not have a train or validation split." % (task.name))


### Adjsut the mixture according, to the selected experiment

#For denoising only task
# t5.data.MixtureRegistry.remove("pretraining")
# t5.data.MixtureRegistry.add(
#     "pretraining",
#     ["masking_task"],
#     default_rate=_rate_num_input_examples
# )


#For log stmt only task
# t5.data.MixtureRegistry.remove("pretraining")
# t5.data.MixtureRegistry.add(
#     "pretraining",
#     ["log_stmt_task"],
#     default_rate=_rate_num_input_examples
# )

#MT mixture
t5.data.MixtureRegistry.remove("pretraining")
t5.data.MixtureRegistry.add(
    "pretraining",
    ["masking_task","log_stmt_task"],
    default_rate=_rate_num_input_examples
)

<seqio.dataset_providers.Mixture at 0x7f6bc009a650>

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import learning_rate_schedule_noam
from t5 import models


MODEL_SIZE = "small"

MODEL_DIR = ''#@param { type: "string" }

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 128, 16),
    "base": (2, 16, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

ON_CLOUD = True
model = models.mtf_model.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = learning_rate_schedule_noam,
    sequence_length={"inputs": 512, "targets": 512},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

In [None]:
PATH_GIN_FILE = ''#@param { type: "string" }
import gin

with gin.unlock_config():
    gin.parse_config_file(PATH_GIN_FILE)
    #RUN FINE-TUNING
    TRAIN_STEPS = 250000
    model.train("pretraining", TRAIN_STEPS)
