In [None]:
%tensorflow_version 2.x
!pip3 install --upgrade pip
!pip install -qU t5

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

#Set the base dir(Google cloud bucket)
#Made sure to use a valid GCS Bucket containing the datasets
BASE_DIR = "gs://tse_extension"  #@param { type: "string" }

if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
ON_CLOUD = True


if ON_CLOUD:
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

In [None]:
tsv_path_bf_small = {
    "train":      'gs://tse_extension/data/datasets/fine-tuning/BFsmall/training.tsv',
    "validation": 'gs://tse_extension/data/datasets/fine-tuning/BFsmall/eval.tsv'
} 
examples_bf_small = dict(train=46680, validation=5835) }

In [None]:
tsv_path_bf_medium = {
    "train":      'gs://tse_extension/data/datasets/fine-tuning/BFmedium/training.tsv',
    "validation": 'gs://tse_extension/data/datasets/fine-tuning/BFmedium/eval.tsv'
}

examples_bf_medium = dict(train=52364, validation=6546)

In [None]:
tsv_path_cs = {
    "train":      'gs://tse_extension/data/datasets/fine-tuning/CS/training.tsv',
    "validation": 'gs://tse_extension/data/datasets/fine-tuning/CS/test.tsv'
}

examples_cs = dict(train=1953940, validation=90908)

In [None]:
tsv_path_assert_raw = {
    "train":      'gs://tse_extension/data/datasets/fine-tuning/AGraw/training.tsv',
    "validation": 'gs://tse_extension/data/datasets/fine-tuning/AGraw/eval.tsv'
}

examples_assert_raw = dict(train=150523, validation=18816)

In [None]:
tsv_path_assert_abs = {
    "train":      'gs://tse_extension/data/datasets/fine-tuning/AGabs/training.tsv',
    "validation": 'gs://tse_extension/data/datasets/fine-tuning/AGabs/eval.tsv'
}

examples_assert_abs = dict(train=126477, validation=15809)

In [None]:
tsv_path_mg = {
    "train":      'gs://tse_extension/data/datasets/fine-tuning/MG/training.tsv',
    "validation": 'gs://tse_extension/data/datasets/fine-tuning/MG/test.tsv'
}

examples_mg = dict(train=92476, validation=11559)

In [None]:
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary


# # Set the path of sentencepiece model and vocab files
# # Must be the same used for the pre-trained phase
vocab_model_path = 'gs://tse_extension/data/SP_Model/dl4se.model' #@param { type: "string" }

TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask


def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True, required=False),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True)
}

In [None]:
def nq_dataset_bfp_small(split, shuffle_files=False):
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(tsv_path_bf_small[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["buggy", "fixed"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_dataset_bfp_small("validation").take(5)):
  print(ex)

def bfp_preprocessing_small(ds):
  
  def to_inputs_and_targets(ex):
        #x_input = tf.strings.lower(ex['buggy'])
        #y_label = tf.strings.lower(ex['fixed']) 
        inputs = tf.strings.join(['generate small patch: '  + ex['buggy']], separator=' ')
        class_label = tf.strings.join([ex['fixed']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

t5.data.TaskRegistry.remove('bfp_small')
t5.data.TaskRegistry.add(
    "bfp_small",
    dataset_fn=nq_dataset_bfp_small,
    splits=["train", "validation"],
    text_preprocessor=[bfp_preprocessing_small],
    output_features = DEFAULT_OUTPUT_FEATURES,    
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples = examples_bf_small
)



In [None]:
def nq_dataset_bfp_medium(split, shuffle_files=False):
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(tsv_path_bf_medium[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["buggy", "fixed"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_dataset_bfp_medium("validation").take(5)):
  print(ex)

def bfp_preprocessing_medium(ds):
  
  def to_inputs_and_targets(ex):
        #x_input = tf.strings.lower(ex['buggy'])
        #y_label = tf.strings.lower(ex['fixed']) 
        inputs = tf.strings.join(['generate medium patch: '  + ex['buggy']], separator=' ')
        class_label = tf.strings.join([ex['fixed']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

t5.data.TaskRegistry.remove('bfp_medium')
t5.data.TaskRegistry.add(
    "bfp_medium",
    dataset_fn=nq_dataset_bfp_medium,
    splits=["train", "validation"],
    text_preprocessor=[bfp_preprocessing_medium],
    output_features = DEFAULT_OUTPUT_FEATURES,    
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples = examples_bf_medium
)



In [None]:
def nq_dataset_assert_raw(split, shuffle_files=False):
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(tsv_path_assert_raw[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["method", "assert"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_dataset_assert_raw("validation").take(5)):
  print(ex)

def atlas_preprocessing_raw(ds):
  
  def to_inputs_and_targets(ex):

        #x_input = tf.strings.lower(ex['method'])
        #y_label = tf.strings.lower(ex['assert']) 
        inputs = tf.strings.join(['generate raw assert: ' + ex['method']], separator=' ')
        class_label = tf.strings.join([ex['assert']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

t5.data.TaskRegistry.remove('ag_raw')
t5.data.TaskRegistry.add(
    "ag_raw",
    dataset_fn=nq_dataset_assert_raw,
    splits=["train", "validation"],
    text_preprocessor=[atlas_preprocessing_raw],
    output_features = DEFAULT_OUTPUT_FEATURES,    
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples = examples_assert_raw 
)



In [None]:
def nq_dataset_assert_abs(split, shuffle_files=False):
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(tsv_path_assert_abs[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["method", "assert"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_dataset_assert_abs("validation").take(5)):
  print(ex)

def atlas_preprocessing_abs(ds):
  
  def to_inputs_and_targets(ex):

        #x_input = tf.strings.lower(ex['method'])
        #y_label = tf.strings.lower(ex['assert']) 
        inputs = tf.strings.join(['generate abt assert: ' + ex['method']], separator=' ')
        class_label = tf.strings.join([ex['assert']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

t5.data.TaskRegistry.remove('ag_abs')
t5.data.TaskRegistry.add(
    "ag_abs",
    dataset_fn=nq_dataset_assert_abs,
    splits=["train", "validation"],
    text_preprocessor=[atlas_preprocessing_abs],
    output_features = DEFAULT_OUTPUT_FEATURES,    
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples = examples_assert_abs
)



In [None]:
def nq_dataset_cs(split, shuffle_files=False):
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(tsv_path_cs[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["method", "comment"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_dataset_cs("validation").take(5)):
  print(ex)

def preprocessing_cs(ds):
  
  def to_inputs_and_targets(ex):

    
        inputs = tf.strings.join(['generate comment: ' + ex['method']], separator=' ')
        class_label = tf.strings.join([ex['comment']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

t5.data.TaskRegistry.remove('cs')
t5.data.TaskRegistry.add(
    "cs",
    dataset_fn=nq_dataset_cs,
    splits=["train", "validation"],
    text_preprocessor=[preprocessing_cs],
    output_features = DEFAULT_OUTPUT_FEATURES,    
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples = examples_cs
)



In [None]:
def nq_dataset_mutant(split, shuffle_files=False):
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(tsv_path_mg[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["fixed", "buggy"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_dataset_mutant("validation").take(5)):
  print(ex)

def preprocessing_mg(ds):
  
  def to_inputs_and_targets(ex):

        #x_input = tf.strings.lower(ex['method'])
        #y_label = tf.strings.lower(ex['assert']) 
        inputs = tf.strings.join(['generate mutant: ' + ex['fixed']], separator=' ')
        class_label = tf.strings.join([ex['buggy']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

t5.data.TaskRegistry.remove('mg')
t5.data.TaskRegistry.add(
    "mg",
    dataset_fn=nq_dataset_mutant,
    splits=["train", "validation"],
    text_preprocessor=[preprocessing_mg],
    output_features = DEFAULT_OUTPUT_FEATURES,    
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples = examples_mg
)



In [None]:
#Uncomment the following for the proportional sampling

# def _rate_num_input_examples(task):
#   if "train" in task.splits:
#     return float(task.num_input_examples("train"))
#   elif "validation" in task.splits:
#     return float(task.num_input_examples("validation"))
#   else:
#     raise ValueError("Task %s does not have a train or validation split." % (task.name))

# Balanced training strategy
t5.data.MixtureRegistry.add(
    "all_tasks",
    ["bfp_small", "bfp_medium", 'cs', 'mg', 'ag_abs', 'ag_raw'],
     default_rate=1.0
)

In [None]:
import t5.models
from mesh_tensorflow.transformer.learning_rate_schedules import truncated_rsqrt
 
# from tensorflow.keras.optimizers.schedules import PolynomialDecay

# starter_learning_rate = 0.01
# end_learning_rate = 0.001
# decay_steps = 10000

# learning_rate_fn = PolynomialDecay(
#     starter_learning_rate,
#     decay_steps,
#     end_learning_rate,
#     power=0.5)


MODEL_SIZE = "small"
MODEL_DIR = 'gs://tse_extension/experiments/with-pretraining-new/MT-Balanced' #@param { type: "string" }
PRETRAINED_MODEL = 'gs://tse_extension/models/pre-trained-new' #@param { type: "string" }


model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 128, 200),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = truncated_rsqrt,
    sequence_length={"inputs": 512, "targets": 512},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

In [None]:
PATH_GIN_FILE = '/content/operative_config.gin'
STEP = 1750000
import gin

with gin.unlock_config():
    gin.parse_config_file(PATH_GIN_FILE)
    model.finetune('all_tasks',
                   finetune_steps=STEP,
                   pretrained_model_dir=PRETRAINED_MODEL
    )


In [None]:
# Use a larger batch size for evaluation, which requires less memory.
# For Code Summarization and Mutant Generation we rely on TF's predictions with beam size K=1
%%capture

PATH_GIN_FILE = '/content/operative_config.gin'
import gin

with gin.unlock_config():
  gin.parse_config_file(PATH_GIN_FILE)
    
  task_list = ["mg",'cs']
  model.batch_size = 16
  for task in task_list:

      model.eval(
          mixture_or_task_name=task,
          checkpoint_steps=-1 
      )
