<a href="https://colab.research.google.com/github/ImpactPretraining/impact_pre-training/blob/main/fine_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# from google.colab import auth
# auth.authenticate_user()

import os
os.environ['USE_AUTH_EPHEM'] = '0'

from google.colab import auth
auth.authenticate_user()

#@title ## Set Your GCS credential
project_id = 'literaturereview-358312' #@param {type:"string"}
bucket_name = 'literature_review' #@param {type:"string"}

!gcloud config set project {project_id}

!pip3 install --upgrade pip
!pip install -qU t5==0.9.2
!pip install -q tensorflow-text==2.8.0rc0
!pip3 install keras==2.7.0
!pip3 install gin-config

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
import t5

tf.flags.DEFINE_string('f','','')

#Set the base dir(Google cloud bucket)
BASE_DIR = "gs://" + bucket_name 

if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
ON_CLOUD = True


if ON_CLOUD:
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.compat.v1.enable_eager_execution(config=None, device_policy=None, execution_mode=None)
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

In [2]:
## training and validation sets

#@title ## Set training and valitadion dataset paths
training_set_path = 'gs://literature_review/data/fine-tuning/bug-fix/train.tsv' #@param {type:"string"} 
validation_set_path = 'gs://literature_review/data/fine-tuning/bug-fix/val.tsv' #@param {type:"string"}

nq_tsv_path = {
    "train": training_set_path,
    "validation": validation_set_path
}

!gsutil cp {nq_tsv_path["train"]} ./train.tsv
!gsutil cp {nq_tsv_path["validation"]} ./val.tsv

data_train = len([line for line in open('./train.tsv', 'r')])
data_val = len([line for line in open('./val.tsv', 'r')])

num_nq_examples = dict(train=data_train, validation=data_val)

In [None]:
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary

#@title ## Set tokenizer's model and vocab paths
vocab_model_path = 'gs://literature_review/tokenizer/BPE_Model.model' #@param {type:"string"} 
vocab_path = 'gs://literature_review/tokenizer/BPE_Model.vocab' #@param {type:"string"}

TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask

def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True, required=False),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True)
}

In [3]:
#@title ## Set fine-tuning task
fine_tuning_task = 'bug_fix' #@param ["bug_fix", "code_completion", "code_summarization"]

def nq_dataset(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

def preprocessing(ds):
  def to_inputs_and_targets(ex):
        inputs = tf.strings.join([ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
t5.data.TaskRegistry.remove(fine_tuning_task)
t5.data.TaskRegistry.add(
    fine_tuning_task,
    dataset_fn=nq_dataset_bug_fix,
    splits=["train", "validation"],
    text_preprocessor=[bug_fix_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_bug_fix
)

nq_task = t5.data.TaskRegistry.get(fine_tuning_task)
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 512, "targets": 512})

In [None]:
def _rate_num_input_examples(task):
  if "train" in task.splits:
    return float(task.num_input_examples("train"))
  elif "validation" in task.splits:
    return float(task.num_input_examples("validation"))
  else:
    raise ValueError("Task %s does not have a train or validation split." % (task.name))

In [4]:
## BUG-FIX
t5.data.MixtureRegistry.remove(fine_tuning_task)
t5.data.MixtureRegistry.add(
    fine_tuning_task,
    [fine_tuning_task],
    default_rate=_rate_num_input_examples
)

In [None]:
#@title ## Select fine-tuning with or without pre-training, pre-trained checkpoint path, output model path
fine_tuning = "fine-tuning_with_pre-training/" #@param ["fine-tuning_with_pre-training/", "fine-tuning_without_pre-training/"]

# Specify the pre-trained dir (if needed) which must contain:
#  - the pre-trained models,
#  - the operative_config.gin file 
#  - the checkpoint files as well
PRETRAINED_DIR= 'gs://literature_review/models/pre-training/MLM/' #@param {type:"string"} 

############ output path ############
MODEL_DIR = 'gs://literature_review/models/fine-tuning/bug_fix/MLM/' #@param {type:"string"} 

# our T5 selected architecture
MODEL_SIZE = "small"

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 200),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]


In [5]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular 
from mesh_tensorflow.transformer.learning_rate_schedules import truncated_rsqrt
from tensorflow.keras.optimizers.schedules import PolynomialDecay

starter_learning_rate = 0.05
end_learning_rate = 0.001
decay_steps = 10000

learning_rate_fn = PolynomialDecay(
    starter_learning_rate,
    decay_steps,
    end_learning_rate,
    power=0.5)

# learning rate scheduler
selected_learning_rate_scheduler = slanted_triangular
PATH_GIN_FILE = 'gs://literature_review/utils/operative_config_slanted.gin'

# changed by Sara
#@title Select a learning rate scheduler
number_of_steps = 300000 #@param {type:"integer"}

pretraining_steps = 0
if fine_tuning == "fine-tuning_with_pre-training/":
  pretraining_steps = 156250

tf.io.gfile.makedirs(MODEL_DIR)

from t5 import models

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = selected_learning_rate_scheduler,
    sequence_length={"inputs": 512, "targets": 512},
    save_checkpoints_steps=10000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

!gsutil cp {PATH_GIN_FILE}  ./config.gin
# modify gin file
gin_lines = [line for line in open("./config.gin")]
f = open("./config.gin", "w+")
for i in range(len(gin_lines)):
  if i == 196 and fine_tuning == "fine-tuning_without_pre-training/":
    line = "slanted_triangular.start_step = 0\n"
    f.write(line)
    continue
  if i == 197:
    line = "slanted_triangular.total_train_steps = " + str(number_of_steps + pretraining_steps) + '\n'
    f.write(line)
    continue
  f.write(gin_lines[i])
f.close()

In [None]:
import gin

if fine_tuning == "fine-tuning_without_pre-training/":
  # NON PRETRAINED
  with gin.unlock_config():    
      gin.parse_config_file("./config.gin")
      TRAIN_STEPS = number_of_steps
      model.train(task, steps=number_of_steps)

else:
  # PRETRAINED
  with gin.unlock_config():
      gin.parse_config_file("./config.gin")
      #RUN FINE-TUNING
      model.finetune(
          mixture_or_task_name=task,
          pretrained_model_dir=PRETRAINED_DIR,
            finetune_steps=number_of_steps
      )

# Evaluation

---



In [6]:
checkpoints = [x for x in range(10000, 500000, 10000)]

# Use a larger batch size for evaluation, which requires less memory.
model.batch_size = 1024
model.eval(
    mixture_or_task_name=task,
    # -1 will evaluate the last checkpoint, you can also provide 
    # a list of checkpoints with the following format : [10000, 20000, 30000]
    checkpoint_steps=checkpoints,
    split="validation"
    )

# Predictions

---

In [None]:
#@title ## Set the model checkpoint you want to test
best_checkpoint = 196250 #@param {type:"integer"}

In [7]:
# load test data
import pandas as pd

#@title ## set the test set path
test_set_path = 'gs://literature_review/data/fine-tuning/bug-fix/test.tsv' #@param {type:"string"}
!gsutil cp {test_set_path} ./test.tsv

data = pd.read_csv('./test.tsv', sep='\t', names=['source', 'target'])
source = list(data['source'])
target = list(data['target'])

f_src = open('./test_source.txt', 'w')
f_tgt = open('./test_target.txt', 'w')
for i in range(len(data)):
    f_src.write(source[i] + '\n')
    f_tgt.write(target[i] + '\n')
f_src.close()
f_tgt.close()

In [8]:
# generate predictions
model.predict(input_file='./test_source.txt', output_file='./output.txt', checkpoint_steps=[best_checkpoint],
              beam_size=1, temperature=1.0, keep_top_k=-1, vocabulary=get_default_vocabulary())

In [9]:
# eval predictions
predictions = [line.strip() for line in open('./output.txt-' + str(best_checkpoint), 'r')]
target = [line.strip() for line in open('./test_target.txt', 'r')]

print('num predictions:', len(predictions))
print('num target:', len(target))

n = len(target)
correct_predictions = 0
for i in range(n):
    pred = predictions[i].replace(' ', '')
    tgt = target[i].replace(' ', '')
    if pred == tgt:
        correct_predictions += 1
print('correct predictions: ' + str(correct_predictions)+ '/' + str(n))
percent = round(correct_predictions * 100 / n, 2)
print(str(percent) + '%')