Performs finetuning with varying batch sizes, models, and sequence lengths in order to find the best model. Note that support for running this file on a GCP TPU is not included since this file should not need more memory than google colab provides and does not require constant uptime

# Configure settings

In [None]:
#@markdown ## General Config
#@markdown If preferred, a GCP TPU/runtime can be used to run this notebook
USE_GCP_TPU = False #@param {type:"boolean"}
#@markdown Which task to perform: options are "MRPC" for paired sequence method, "MRPC_w_preds" for paired sequence method with external data, "RE" for single sequence method, or "NER" for single sequance per residue prediction (if you add more modes make sure to change the corresponding code segments)
MODE = "MRPC_w_preds" #@param {type:"string"}
MAX_SEQ_LENGTH =  1024#@param {type:"integer"}
PROCESSES = 2 #@param {type:"integer"}
BUCKET_NAME = "theodore_jiang" #@param {type:"string"}
#@markdown ###### For if multiple models fine tuned: xxx is the placeholder for the individual model identifier (if only one is being evaluated replace xx with the actual name of the model)
#@markdown \
#@markdown folder for where to save the finetuned model
MODEL_DIR_format = "bert_model_mrpc_adding_preds_xxx" #@param {type:"string"}
#@markdown folder for the pretrained model
INIT_MODEL_DIR_format = "bert_model_xxx" #@param {type:"string"}
DATA_DIR_format = "MRPC_adding_preds_xxx" #@param {type:"string"}
RUN_NAME_format = "MRPC_adding_preds_xxx" #@param {type:"string"}

#@markdown ### Training procedure config
INIT_LEARNING_RATE =  1e-5 #@param {type:"number"}
END_LEARNING_RATE = 5e-7 #@param {type:"number"}
SAVE_CHECKPOINTS_STEPS =  1000 #@param {type:"integer"}
#@markdown ###### TPUEstimator will keep this number of checkpoints; older checkpoints will all be deleted
KEEP_N_CHECKPOINTS_AT_A_TIME =  10#@param {type:"integer"}
NUM_TPU_CORES = 8 #@param {type:"number"}
PLANNED_TOTAL_SEQUENCES_SEEN =  2e5 #@param {type:"number"}
#@markdown PLANNED_TOTAL_STEPS will override PLANNED_TOTAL_SEQUENCES_SEEN; if using PLANNED_TOTAL_SEQUENCES_SEEN, set PLANNED_TOTAL_STEPS to -1 (PLANNED TOTAL STEPS will be based on the train batch size used)
PLANNED_TOTAL_STEPS = 8000 #@param {type:"number"}


#If running on a GCP TPU, use these commands prior to running this notebook

To ssh into the VM:

```
gcloud beta compute ssh --zone <COMPUTE ZONE> <VM NAME> --project <PROJECT NAME> -- -L 8888:localhost:8888
```

Make sure the port above matches the port below (in this case it's 8888)

```
sudo apt-get update
sudo apt-get -y install python3 python3-pip
sudo apt-get install pkg-config
sudo apt-get install libhdf5-serial-dev
sudo apt-get install libffi6 libffi-dev
sudo -H pip3 install jupyter tensorflow==1.14 google-api-python-client tqdm
sudo -H pip3 install jupyter_http_over_ws
jupyter serverextension enable --py jupyter_http_over_ws
jupyter notebook   --NotebookApp.allow_origin='https://colab.research.google.com'   --port=8888   --NotebookApp.port_retries=0   --no-browser

(one command):sudo apt-get update ; sudo apt-get -y install python3 python3-pip ; sudo apt-get install pkg-config ; sudo apt-get -y install libhdf5-serial-dev ; sudo apt-get install libffi6 libffi-dev; sudo -H pip3 install jupyter tensorflow==1.14 google-api-python-client tqdm ; sudo -H pip3 install jupyter_http_over_ws ; jupyter serverextension enable --py jupyter_http_over_ws ; jupyter notebook   --NotebookApp.allow_origin='https://colab.research.google.com'   --port=8888   --NotebookApp.port_retries=0   --no-browser
```
And then copy and paste the outputted link with "locahost: ..." into the colab connect to local runtime option


###Also run this code segment, which creates a TPU

In [None]:
GCE_PROJECT_NAME = "genome-project-319100" #@param {type:"string"}
TPU_ZONE = "us-central1-f" #@param {type:"string"}
TPU_NAME = "mutformer-tpu" #@param {type:"string"}

!gcloud alpha compute tpus create $TPU_NAME --accelerator-type=tpu-v2 --version=1.15.5 --zone=$TPU_ZONE ##create new TPU

!gsutil iam ch serviceAccount:`gcloud alpha compute tpus describe $TPU_NAME | grep serviceAccount | cut -d' ' -f2`:admin gs://theodore_jiang && echo 'Successfully set permissions!' ##give TPU access to GCS

#Clone the repo

In [None]:
if USE_GCP_TPU:
  !sudo apt-get -y install git
#@markdown ######where to clone the repo into (only value that it can't be is "mutformer"):
REPO_DESTINATION_PATH = "code/mutformer" #@param {type:"string"}
import os,shutil
if not os.path.exists(REPO_DESTINATION_PATH):
  os.makedirs(REPO_DESTINATION_PATH)
else:
  shutil.rmtree(REPO_DESTINATION_PATH)
  os.makedirs(REPO_DESTINATION_PATH)
cmd = "git clone https://github.com/WGLab/mutformer.git \"" + REPO_DESTINATION_PATH + "\""
!{cmd}

#Imports

In [None]:
if not USE_GCP_TPU:
  %tensorflow_version 1.x
  from google.colab import auth
  print("Authorize for GCS:")
  auth.authenticate_user()
  print("Authorize done")

import sys
import json
import random
import logging
import tensorflow as tf
import time
import importlib

if not os.path.exists("mutformer"):
  shutil.copytree(REPO_DESTINATION_PATH+"/mutformer_model_code","mutformer")
else:
  shutil.rmtree("mutformer")
  shutil.copytree(REPO_DESTINATION_PATH+"/mutformer_model_code","mutformer")
if "mutformer" in sys.path:
  sys.path.remove("mutformer")
sys.path.append("mutformer")

from mutformer import modeling, optimization, tokenization,run_classifier,run_ner_for_pathogenic
from mutformer.modeling import BertModel,BertModelModified
from mutformer.run_classifier import MrpcProcessor,REProcessor,MrpcWithPredsProcessor ##change this part if you add more modes--
from mutformer.run_ner_for_pathogenic import NERProcessor      ##--

##reload modules in case that's needed
modules2reload = [modeling, 
                  optimization, 
                  tokenization,
                  run_classifier,
                  run_ner_for_pathogenic]
for module in modules2reload:
    importlib.reload(module)

# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

log.handlers = []

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

#@markdown ###### Whether or not to write logs to a file
DO_FILE_LOGGING = True #@param {type:"boolean"}
if DO_FILE_LOGGING:
  #@markdown ###### If using file logging, what path to write logs to
  FILE_LOGGING_PATH = 'file_logging/spam.log' #@param {type:"string"}
  if not os.path.exists("/".join(FILE_LOGGING_PATH.split("/")[:-1])):
    os.makedirs("/".join(FILE_LOGGING_PATH.split("/")[:-1]))
  fh = logging.FileHandler(FILE_LOGGING_PATH)
  fh.setLevel(logging.INFO)
  fh.setFormatter(formatter)
  log.addHandler(fh)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
log.addHandler(ch)

# create formatter and add it to the handlers
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
log.addHandler(ch)

log.handlers = [fh,ch]

if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    
else:
  log.warning('Not connected to TPU runtime')


if MODE=="MRPC": ##change this part if you added more modes
  processor = MrpcProcessor()
  script = run_classifier
elif MODE=="MRPC_w_preds":
  processor = MrpcWithPredsProcessor()
  script = run_classifier
elif MODE=="RE":
  processor = REProcessor()
  script = run_classifier
elif MODE=="NER":
  processor = NERProcessor()
  script = run_ner_for_pathogenic
else:
  raise Exception("The mode specified was not one of the available modes: [\"MRPC\", \"RE\",\"NER\"].")
label_list = processor.get_labels()

#Select preference for communication with eval script/Mount drive if necessary

In [None]:
import os
import shutil

#@markdown ###### Note: for all of these, if using USE_GCP_TPU, all of these parameters must use GCS, because a GCP TPU can't access google drive
#@markdown \
DRIVE_PATH = "/content/drive/My Drive"
BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
#@markdown whether to use GCS for communicating with eval script, if not, defaults to drive. Note that by defualt, training logs have to be stored in GCS because the Google TPU requires it
GCS_COMS = False #@param {type:"boolean"}

COMS_PATH = BUCKET_PATH if GCS_COMS else DRIVE_PATH

if not GCS_COMS:
  from google.colab import drive,auth
  !fusermount -u /content/drive
  drive.flush_and_unmount()
  drive.mount('/content/drive', force_remount=True)
  


# Run Training

###General definitions

In [None]:
name2model = {
    "modified_large":BertModelModified,
    "modified_medium":BertModelModified,
    "modified":BertModelModified,
    "orig":BertModel,
    "large":BertModel
}

def latest_checkpoint(dir):
  cmd = "gsutil ls "+dir
  files = !{cmd}
  for file in files:
    if "model.ckpt" in file:
      return file.replace("."+file.split(".")[-1],"")

def training_loop(BATCH_SIZE,
                  RESUMING,
                  PLANNED_TOTAL_STEPS,
                  DECAY_PER_STEP,
                  DATA_SEQ_LENGTH,
                  MODEL_NAME,
                  MODEL,
                  INIT_CHECKPOINT_DIR,
                  BERT_GCS_DIR,
                  DATA_GCS_DIR,
                  USING_SHARDS,
                  START_SHARD,
                  USING_PREDS,
                  PRED_NUM,
                  GCS_LOGGING_DIR,
                  CONFIG_FILE):
  
  RESTORE_CHECKPOINT = None if not RESUMING else tf.train.latest_checkpoint(BERT_GCS_DIR)
  if not RESUMING:
    cmd = "gsutil -m rm -r "+BERT_GCS_DIR
    !{cmd}

  ## if using a directory with only a single checkpoint and no "checkpoint" file, 
  ## tf.train.latest_checkpoint will not work, so get fold name manually via latest_checkpoint(dir)
  try: 
    INIT_CHECKPOINT = tf.train.latest_checkpoint(INIT_CHECKPOINT_DIR)
  except:
    INIT_CHECKPOINT = latest_checkpoint(INIT_CHECKPOINT_DIR)
  print("init checkpoint:",INIT_CHECKPOINT,"restore/save checkpont:",RESTORE_CHECKPOINT)

  config = modeling.BertConfig.from_json_file(CONFIG_FILE)
  config.hidden_dropout_prob = 0.1
  config.attention_probs_dropout_prob = 0.1

  model_fn = script.model_fn_builder(
      bert_config=config,
      logging_dir=GCS_LOGGING_DIR,
      num_labels=len(label_list),
      init_checkpoint=INIT_CHECKPOINT,
      restore_checkpoint=RESTORE_CHECKPOINT,
      init_learning_rate=INIT_LEARNING_RATE,
      decay_per_step=DECAY_PER_STEP,
      num_warmup_steps=10,
      use_tpu=True,
      use_one_hot_embeddings=True,
      bert=MODEL,
      weight_decay=0.01,
      epsilon=1e-6,
      clip_grads=False,
      using_preds=USING_PREDS)

  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=BERT_GCS_DIR,
      save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
      keep_checkpoint_max=KEEP_N_CHECKPOINTS_AT_A_TIME,
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
          num_shards=NUM_TPU_CORES,
          per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=True,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=BATCH_SIZE)
  
  train_file_name = "train.tf_record"
  train_file = os.path.join(DATA_GCS_DIR, train_file_name)

  if USING_SHARDS:
    shards_folder = DATA_GCS_DIR
    input_file = os.path.join(DATA_GCS_DIR, train_file_name)
    import re
    file_name = input_file.split("/")[-1]
    shards = [shards_folder + "/" + file for file in tf.io.gfile.listdir(shards_folder) if
              re.match(file_name + "_\d+", file)]
    shards = sorted(shards,key=lambda shard:int(shard.split("_")[-1]))[START_SHARD:]
  else:
    shards = [train_file]

  if USING_SHARDS:
    print("\nUSING SHARDs:")
    for shard in shards:
      print(shard)
    print("\n")

  tf.logging.info("***** Running training *****")
  tf.logging.info("  Batch size = %d", BATCH_SIZE)
  for n,shard in enumerate(shards):
      train_input_fn = script.file_based_input_fn_builder( ##if using external data with MRPC_w_preds, make sure to specify "pred_num=xxx"
          input_file=shard,
          seq_length=DATA_SEQ_LENGTH,
          is_training=True,
          drop_remainder=True,
          pred_num=PRED_NUM if USING_PREDS else None)

      ##writing data to drive so that the parallel eval script can know which model to evaluate
      try:
        tf.gfile.Open(COMS_PATH+"/finetuning_run_paired_model.txt","w+").write(MODEL_NAME)
        tf.gfile.Open(COMS_PATH+"/finetuning_run_paired_seq_length.txt","w+").write(str(DATA_SEQ_LENGTH))
        tf.gfile.Open(COMS_PATH+"/finetuning_run_paired_batch_size.txt","w+").write(str(BATCH_SIZE))
      except:
        pass
      estimator.train(input_fn=train_input_fn, max_steps=PLANNED_TOTAL_STEPS)



####Model/sequence length

In [None]:
#@markdown train batch size to use
BATCH_SIZE=16 #@param
#@markdown list of models to test
models = ["modified_medium","modified_large"] #@param
#@markdown list of maximum sequence lengths to test
lengths = [256,512,1024] #@param
LOGGING_DIR = "mrpc_loss_spam_model_comparison_final" #@param {type:"string"}
#@markdown whether or not to resume training from a previous finetuned checkpoint; if no, always train from pretrained model
RESUMING = False #@param {type:"boolean"}
#@markdown whether or not external data is being used
USING_PREDS = True #@param {type:"boolean"}
#@markdown if using external data, how many datapoints are included in total
PRED_NUM =   27#@param {type:"integer"}

PLANNED_TOTAL_STEPS = PLANNED_TOTAL_STEPS if PLANNED_TOTAL_STEPS != -1 else PLANNED_TOTAL_SEQUENCES_SEEN//BATCH_SIZE
DECAY_PER_STEP = (END_LEARNING_RATE-INIT_LEARNING_RATE)/(PLANNED_TOTAL_STEPS if PLANNED_TOTAL_STEPS!=-1 else PLANNED_TOTAL_SEQUENCES_SEEN/TRAIN_BATCH_SIZE) 

for DATA_SEQ_LENGTH in lengths:
  for MODEL_NAME in models:
    print("\n\n\nMODEL NAME:",MODEL_NAME,
          "\nINPUT MAX SEQ LENGTH:",DATA_SEQ_LENGTH,
          "\nTRAIN_BATCH_SIZE:",BATCH_SIZE,"\n\n\n")

    MODEL = name2model[MODEL_NAME]
    INIT_CHECKPOINT_DIR = "{}/{}".format(BUCKET_PATH, INIT_MODEL_DIR_format.replace("xxx",MODEL_NAME))
    BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR_format.replace("xxx",MODEL_NAME+"_"+str(DATA_SEQ_LENGTH)))
    DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, DATA_DIR_format.replace("xxx",str(DATA_SEQ_LENGTH)))
    
    GCS_LOGGING_DIR = "{}/{}".format(BUCKET_PATH, LOGGING_DIR+"/"+RUN_NAME_format.replace("xxx",MODEL_NAME+"_"+str(DATA_SEQ_LENGTH)))

    CONFIG_FILE = "{}/config.json".format(BUCKET_PATH+"/"+INIT_MODEL_DIR_format.replace("xxx",MODEL_NAME))

    training_loop(BATCH_SIZE,
                  RESUMING,
                  PLANNED_TOTAL_STEPS,
                  DECAY_PER_STEP,
                  DATA_SEQ_LENGTH,
                  MODEL_NAME,
                  MODEL,
                  INIT_CHECKPOINT_DIR,
                  BERT_GCS_DIR,
                  DATA_GCS_DIR,
                  USING_SHARDS,
                  START_SHARD,
                  USING_PREDS,
                  PRED_NUM,
                  GCS_LOGGING_DIR,
                  CONFIG_FILE)
  
  

####Batch size/sequence length

In [None]:
#@markdown list of batch sizes to test
batch_sizes = [64] #@param
#@markdown list of maximum sequence lengths to test
lengths = [1024] #@param
#@markdown model to use
MODEL_NAME="modified_large" #@param {type:"string"}
LOGGING_DIR = "mrpc_loss_spam_model_comparison_final" #@param {type:"string"}
#@markdown whether or not to resume training from a previous finetuned checkpoint; if no, always train from pretrained model
RESUMING = False #@param {type:"boolean"}
#@markdown whether or not external data is being used
USING_PREDS = True #@param {type:"boolean"}
#@markdown if using external data, how many datapoints are included in total
PRED_NUM =   27#@param {type:"integer"}

BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
PLANNED_TOTAL_STEPS = PLANNED_TOTAL_STEPS if PLANNED_TOTAL_STEPS != -1 else PLANNED_TOTAL_SEQUENCES_SEEN//BATCH_SIZE
DECAY_PER_STEP = (END_LEARNING_RATE-INIT_LEARNING_RATE)/(PLANNED_TOTAL_STEPS if PLANNED_TOTAL_STEPS!=-1 else PLANNED_TOTAL_SEQUENCES_SEEN/TRAIN_BATCH_SIZE) 

for DATA_SEQ_LENGTH in lengths:
    for BATCH_SIZE in batch_sizes:
        print("\n\n\nMODEL NAME:",MODEL_NAME,
              "\nINPUT MAX SEQ LENGTH:",DATA_SEQ_LENGTH,
              "\nTRAIN_BATCH_SIZE:",BATCH_SIZE,"\n\n\n")
       
        MODEL = name2model[MODEL_NAME]
        INIT_CHECKPOINT_DIR = "{}/{}".format(BUCKET_PATH, INIT_MODEL_DIR_format.replace("xxx",MODEL_NAME))
        BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR_format.replace("xxx",MODEL_NAME+"_"+str(DATA_SEQ_LENGTH)+"_"+str(BATCH_SIZE)))
        DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, DATA_DIR_format.replace("xxx",str(DATA_SEQ_LENGTH)))
      
        GCS_LOGGING_DIR = "{}/{}".format(BUCKET_PATH, LOGGING_DIR+"/"+RUN_NAME_format.replace("xxx",MODEL_NAME+"_"+str(DATA_SEQ_LENGTH)+"_"+str(BATCH_SIZE)))
        
        CONFIG_FILE = "{}/config.json".format(BUCKET_PATH+"/"+INIT_MODEL_DIR_format.replace("xxx",MODEL_NAME))

        training_loop(BATCH_SIZE,
                      RESUMING,
                      PLANNED_TOTAL_STEPS,
                      DECAY_PER_STEP,
                      DATA_SEQ_LENGTH,
                      MODEL_NAME,
                      MODEL,
                      INIT_CHECKPOINT_DIR,
                      BERT_GCS_DIR,
                      DATA_GCS_DIR,
                      USING_SHARDS,
                      START_SHARD,
                      USING_PREDS,
                      PRED_NUM,
                      GCS_LOGGING_DIR,
                      CONFIG_FILE)

###Train a single model

In [None]:
#@markdown batch size to use
BATCH_SIZE = 32 #@param
#@markdown maximum sequence length to use
DATA_SEQ_LENGTH = 512 #@param
#@markdown model to use
MODEL_NAME="modified_large" #@param {type:"string"}
LOGGING_DIR = "mrpc_loss_spam_model_comparison_final" #@param {type:"string"}
#@markdown whether or not to resume training from a previous checkpoint; if no, always train from scratch
RESUMING = True #@param {type:"boolean"}
#@markdown ###### identifier for the model to use (replaces "xxx" from the variable "MODEL_DIR_format")
model_name_extension = "added_preds_512_32" #@param {type:"string"}
#@markdown whether or not training data was generated in shards (for really large databases)
USING_SHARDS = True #@param {type:"boolean"}
#@markdown if using shards, which shard index to start at (defualt 0 for first shard)
START_SHARD =   0#@param {type:"integer"}
#@markdown whether or not external data is being used
USING_PREDS = True #@param {type:"boolean"}
#@markdown if using external data, how many datapoints are included in total
PRED_NUM =   27#@param {type:"integer"}


PLANNED_TOTAL_STEPS = PLANNED_TOTAL_STEPS if PLANNED_TOTAL_STEPS != -1 else PLANNED_TOTAL_SEQUENCES_SEEN//BATCH_SIZE
DECAY_PER_STEP = (END_LEARNING_RATE-INIT_LEARNING_RATE)/(PLANNED_TOTAL_STEPS if PLANNED_TOTAL_STEPS!=-1 else PLANNED_TOTAL_SEQUENCES_SEEN/TRAIN_BATCH_SIZE) 


print("\n\n\nMODEL NAME:",MODEL_NAME,
      "\nINPUT MAX SEQ LENGTH:",DATA_SEQ_LENGTH,
      "\nTRAIN_BATCH_SIZE:",BATCH_SIZE,"\n\n\n")

MODEL = name2model[MODEL_NAME]
INIT_CHECKPOINT_DIR = "{}/{}".format(BUCKET_PATH, INIT_MODEL_DIR_format.replace("xxx",MODEL_NAME))
BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR_format.replace("xxx",model_name_extension))
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, DATA_DIR_format.replace("xxx",str(DATA_SEQ_LENGTH)))

GCS_LOGGING_DIR = "{}/{}".format(BUCKET_PATH, LOGGING_DIR+"/"+RUN_NAME_format.replace("xxx",MODEL_NAME+"_"+str(DATA_SEQ_LENGTH)+"_"+str(BATCH_SIZE)))

CONFIG_FILE = "{}/config.json".format(BUCKET_PATH+"/"+INIT_MODEL_DIR_format.replace("xxx",MODEL_NAME))


training_loop(BATCH_SIZE,
              RESUMING,
              PLANNED_TOTAL_STEPS,
              DECAY_PER_STEP,
              DATA_SEQ_LENGTH,
              MODEL_NAME,
              MODEL,
              INIT_CHECKPOINT_DIR,
              BERT_GCS_DIR,
              DATA_GCS_DIR,
              USING_SHARDS,
              START_SHARD,
              USING_PREDS,
              PRED_NUM,
              GCS_LOGGING_DIR,
              CONFIG_FILE)