In [None]:
%tensorflow_version 2.x
!pip3 install --upgrade pip
!pip install t5==0.9.2


import functools
import os
import gin
import tensorflow_gcs_config
from google.colab import auth
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
from contextlib import contextmanager
import logging as py_logging
import t5

#Set the base dir(Google cloud bucket)
BASE_DIR = "" #@param { type: "string" }

# Set credentials for GCS reading/writing from Colab and TPU.
TPU_TOPOLOGY = "2x2"
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  TPU_ADDRESS = tpu.get_master()
  print('Running on TPU:', TPU_ADDRESS)
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
auth.authenticate_user()
tf.config.experimental_connect_to_host(TPU_ADDRESS)
tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()


#LOGGING
tf.get_logger().propagate = False
py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

In [2]:
path_finetuning = '' #@param { type: "string" }
path_eval = '' #@param { type: "string" }
path_test = '' #@param { type: "string" }


#### Dataset sizes ####

#path_finetuning --> 1122864
#path_eval --> 521779
#path_test --> 437384



nq_tsv_path = {
    "train":      path_finetuning,
    "validation": path_test
}

num_nq_examples = dict(train=106382, validation=12020)

In [3]:
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary


vocab_model_path = '' #@param { type: "string" }
vocab_path = '' #@param { type: "string" }


TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask


def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True, required=False),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True)
}

In [None]:
def nq_dataset_task(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=True),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_dataset_task("train").take(5)):
  print(ex)

In [5]:
def preprocessing(ds):
  
  def to_inputs_and_targets(ex):
        x_input = tf.strings.strip(ex['input'])
        y_label = tf.strings.strip(ex['output']) 
        inputs = tf.strings.join([x_input], separator=' ')
        class_label = tf.strings.join([y_label], separator=' ')
        return {'inputs': inputs, 'targets': class_label}
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('log_injection')
t5.data.TaskRegistry.add(
    "log_injection",
    dataset_fn=nq_dataset_task,
    splits=["train","validation"],
    text_preprocessor=[preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples
)

In [None]:
nq_task = t5.data.TaskRegistry.get("log_injection")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 512, "targets": 512})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

In [None]:
t5.data.MixtureRegistry.remove("task")
t5.data.MixtureRegistry.add(
    "task",
    ["log_injection"],
    default_rate=1.0
)

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular 
from mesh_tensorflow.transformer.learning_rate_schedules import truncated_rsqrt
from tensorflow.keras.optimizers.schedules import PolynomialDecay
from t5 import models

starter_learning_rate = 0.01
end_learning_rate = 0.001
decay_steps = 10000

learning_rate_fn = PolynomialDecay(
     starter_learning_rate,
     decay_steps,
     end_learning_rate,
     power=0.5)

MODEL_SIZE = "small" 

MODEL_DIR = 'gs://'#@param { type: "string" }

PRETRAINED_DIR='gs://'#@param { type: "string" }


model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 128, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = slanted_triangular, #pick the correct scheduler, according to the model you want to train
    sequence_length={"inputs": 512, "targets": 512},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max,
    iterations_per_loop=100,
)

In [None]:
PATH_GIN_FILE_NO_PT = '/content/no_pretraining_operative_config.gin' 
PATH_GIN_FILE_MT = '/content/multi-task_operative_config.gin' 
PATH_GIN_FILE_DENOISE = '/content/denoise_only_operative_config.gin'
PATH_GIN_FILE_LOG_STMT = '/content/log_stmt_only_operative_config.gin'
 
with gin.unlock_config():
    gin.parse_config_file(PATH_GIN_FILE)
    #RUN FINE-TUNING
    TRAIN_STEPS = 200000
    model.finetune(mixture_or_task_name="task",
                   finetune_steps=TRAIN_STEPS,
                   pretrained_model_dir=PRETRAINED_DIR)
    
    # If the no-pretraining experiment is the one you want to run, then, uncomment the following and comment model.finetune
    # Also, make sure to upload the slanted_operative.gin
    #model.train("task", TRAIN_STEPS)




In [None]:
# %%capture
model.bach_size=32
model.eval(
    mixture_or_task_name="task",
    checkpoint_steps=-1 
)



# model.batch_size = 256
# input_file = 'gs://'#@param { type: "string" }
# output_file = 'gs://'#@param { type: "string" }
# model.predict(input_file, output_file, checkpoint_steps=-1, vocabulary=get_default_vocabulary())


