In [None]:
from IPython.display import clear_output

!pip install seqio==0.0.7
!pip install t5==0.9.3
!pip install tensorflow-text==2.12.0

!pip install -U jax jaxlib
!pip install -U flax


clear_output()

In [None]:
import functools
import os
import gin
import tensorflow_gcs_config
from google.colab import auth
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
from contextlib import contextmanager
import logging as py_logging
import t5

#Set the base dir(Google cloud bucket)
BASE_DIR = "gs://finetuning-ag-row"

# Set credentials for GCS reading/writing from Colab and TPU.
TPU_TOPOLOGY = "2x2"
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  TPU_ADDRESS = tpu.get_master()
  print('Running on TPU:', TPU_ADDRESS)
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

auth.authenticate_service_account()
tf.config.experimental_connect_to_host(TPU_ADDRESS)
tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()


#LOGGING
tf.get_logger().propagate = False
py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

In [None]:
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary


vocab_model_path = 'gs://finetuning-ag-row/dl4se_vocab.model'
vocab_path = 'gs://finetuning-ag-row/dl4se_vocab.vocab'


TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask


def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True, required=False),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True)
}

In [None]:
# Dataset sizes:
#   - training: 138478
#   - eval:     17377
#   - test:     17318

DATA_DIR_1 = os.path.join(BASE_DIR, "T5-Data/2-TS-null")

nq_tsv_path_assert_raw = {
    "train": os.path.join(DATA_DIR_1, "training.tsv"),
    "validation": os.path.join(DATA_DIR_1, "test.tsv"),
}

#num_nq_examples_assert_raw = dict(train=150523, validation=18815)
#num_nq_examples_assert_raw = dict(train=138478, validation=17318)

In [None]:
def nq_dataset_assert_raw(split, shuffle_files=False):
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_assert_raw[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)

  ds = ds.map(lambda *ex: dict(zip(["method", "assert"], ex)))
  return ds

print("A few raw valid examples...")
for idx,ex in enumerate(tfds.as_numpy(nq_dataset_assert_raw("validation").take(5))):
  print(ex)

A few raw valid examples...
{'method': b'"testGetLibrariesDoesDeDuplication ( ) { when ( design . getContentResource ( ) ) . thenReturn ( designContentResource ) ; setLibraries ( designContentResource , PageRegion . HEAD , new java . lang . String [ ] { ""css1"" , ""cssandjs1"" } , new java . lang . String [ ] { ""js1"" , ""cssandjs1"" } ) ; java . lang . String [ ] categories = instance . getLibraries ( design , PageRegion . HEAD ) ; ""<AssertPlaceHolder>"" ; }"', 'assert': b'org . junit . Assert . assertArrayEquals ( new java . lang . Object [ ] { ""css1"" , ""cssandjs1"" , ""js1"" } , categories )'}
{'method': b'"getUsersWaitingNotificationNoWatchExpectEmptyList ( ) { net . jforum . repository . TopicWatchRepository dao = this . newDao ( ) ; net . jforum . entities . Topic topic = new net . jforum . entities . Topic ( ) ; topic . setId ( 13 ) ; java . util . List < net . jforum . entities . User > users = dao . getUsersWaitingNotification ( topic ) ; ""<AssertPlaceHolder>"" ; }"', '

In [None]:
def atlas_preprocessing_raw(ds):

  def to_inputs_and_targets(ex):

        x_input = tf.strings.lower(ex['method'])
        y_label = tf.strings.lower(ex['assert'])
        inputs = tf.strings.join(['generate raw assert:' + x_input], separator=' ')
        class_label = tf.strings.join([y_label], separator=' ')
        return {'inputs': inputs, 'targets': class_label }

  return ds.map(to_inputs_and_targets,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

ASSERT_TYPE='raw'

t5.data.TaskRegistry.remove('assert_raw')
t5.data.TaskRegistry.add(
    "assert_raw",
    dataset_fn=nq_dataset_assert_raw,
    splits=["train", "validation"],
    #sequence_length = [{"inputs": 512, "targets": 512},{"inputs": 512, "targets": 512}],
    text_preprocessor=atlas_preprocessing_raw,
    output_features=DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy]
)


In [None]:
nq_task = t5.data.TaskRegistry.get("assert_raw")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 512, "targets": 512})
ds1 = nq_task.get_dataset(split="validation", sequence_length={"inputs": 512, "targets": 512})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


clear_output()

In [None]:
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('f','','kernel')

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import learning_rate_schedule_noam
from t5 import models

MODEL_SIZE = "small"

MODEL_DIR = 'gs://finetuning-ag-row/pre-trained-null'

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 128, 16),
    "base": (2, 16, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = models.mtf_model.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = learning_rate_schedule_noam,
    sequence_length={"inputs": 512, "targets": 512},
    save_checkpoints_steps=10000,
    keep_checkpoint_max=keep_checkpoint_max,
    iterations_per_loop=100
)

In [None]:
PATH_GIN_FILE = '/content/operative_config.gin'
import gin

with gin.unlock_config():
  gin.parse_config_file(PATH_GIN_FILE)
  FINETUNE_STEPS = 449900

  model.train(
      "assert_raw",
      #pretrained_model_dir=MODEL_DIR,
      FINETUNE_STEPS
  )


# Use a larger batch size for evaluation, which requires less memory.
PATH_GIN_FILE = '/content/operative_config.gin'
import gin
import time

with gin.unlock_config():
  gin.parse_config_file(PATH_GIN_FILE)
  #FINETUNE_STEPS = 470176
  model.batch_size = 128
  #model.sequence_length = {"inputs": 512, "targets": 512}
  s_t = time.time()
  model.eval(
    mixture_or_task_name= "assert_raw",
    checkpoint_steps=-1,
    compute_sequence_length=False)
  e_t = time.time()
  print("infer time: %s s" % (e_t-s_t))


### After the eval phase is completed, we extract only the predictions related to CS and MG task.
### For all the other tasks, we found out that the beam search implemented in HUGGINGFACE works better compared to the one implemented in TF

In [None]:
# Use a larger batch size for evaluation, which requires less memory.
PATH_GIN_FILE = '/content/operative_config.gin'
import gin
import time

with gin.unlock_config():
  gin.parse_config_file(PATH_GIN_FILE)
  #FINETUNE_STEPS = 470176
  model.batch_size = 128
  #model.sequence_length = {"inputs": 512, "targets": 512}
  s_t = time.time()
  model.eval(
    mixture_or_task_name= "assert_raw",
    checkpoint_steps=-1,
    compute_sequence_length=False)
  e_t = time.time()
  print("infer time: %s s" % (e_t-s_t))


### After the eval phase is completed, we extract only the predictions related to CS and MG task.
### For all the other tasks, we found out that the beam search implemented in HUGGINGFACE works better compared to the one implemented in TF

In [None]:
### Use this cell if the eval procedure above fails

from google.cloud import storage
import time

base_validation_path = 'gs://finetuning-ag-row/pre-trained-combine/validation_eval'


# Make sure that in base_validation_path the following are present
input_files = ['assert_raw_inputs']
output_files = ['assert_raw_targets']

for input_file, output_file in zip(input_files, output_files):

  s_t = time.time()
  model.predict(os.path.join(base_validation_path,input_file),
                os.path.join(base_validation_path,output_file),
                checkpoint_steps=-1,
                beam_size=1,
                temperature=1.0,
                vocabulary=SentencePieceVocabulary(
                          vocab_model_path, 100))
  e_t = time.time()
  print("infer time: %s s" % (e_t-s_t))


accuracy_only_task_real = ['assert_raw_targets']
accuracy_only_task_predictions = ['assert_raw_targets-449900']


for target, pred in zip(accuracy_only_task_real, accuracy_only_task_predictions):

  target_list = []
  with tf.io.gfile.GFile(os.path.join(base_validation_path,target)) as preds:
    for item in preds:

      item = item.strip()

      if item[0]=='"':
        item = item[1:]

      if item[-1]=='"':
        item = item[0:-1]

      target_list.append(item)

  # print(len(target_list))


  pred_list = []
  with tf.io.gfile.GFile(os.path.join(base_validation_path,pred)) as preds:
    for item in preds:

      item = item.strip()

      if item[0]=='"':
        item = item[1:]

      if item[-1]=='"':
        item = item[0:-1]

      pred_list.append(item)

  # print(len(pred_list))

  task_name = ' '.join(target.split('_')[0:2])
  print('{} {}'.format(task_name, t5.evaluation.metrics.accuracy(target_list,pred_list)))

In [None]:
import subprocess

def is_path_exists_gsutil(gcs_path):
    # Run gsutil ls command to check if the path exists
    try:
        subprocess.check_output(["gsutil", "ls", gcs_path])
        return True
    except subprocess.CalledProcessError as e:
        # The gsutil ls command returns a non-zero exit code if the path doesn't exist
        return False

gcs_path = 'gs://finetuning-ag-row/pre-trained/operative_config.gin'
if is_path_exists_gsutil(gcs_path):
    print(f'The path {gcs_path} exists in Google Cloud Storage.')
else:
    print(f'The path {gcs_path} does not exist in Google Cloud Storage.')

The path gs://finetuning-ag-row/pre-trained/operative_config.gin exists in Google Cloud Storage.


In [None]:
if ON_CLOUD:
  %reload_ext tensorboard
  import tensorboard as tb
tb.notebook.start("--logdir " + MODEL_DIR)