#Finetuning Evaluation and Prediction Script

This notebook evlauates and performs predictions on test data using finetuned models.

Note: If using a TPU from Google Cloud (not the Colab TPU), make sure to run this notebook on a VM with access to all GCP APIs, and make sure TPUs are enabled for the GCP project

#Downgrade Python and Tensorflow 

(the default python version in Colab does not support Tensorflow 1.15)

* **Note** that because the Python used in this notebook is not the default path, syntax highlighting most likely will not function.

####1. First, download and install Python version 3.7:

In [None]:
!wget -O mini.sh https://repo.anaconda.com/miniconda/Miniconda3-py37_22.11.1-1-Linux-x86_64.sh
!chmod +x mini.sh
!bash ./mini.sh -b -f -p /usr/local
!conda install -q -y jupyter
!conda install -q -y google-colab -c conda-forge
!python -m ipykernel install --name "py37" --user

--2023-04-04 21:20:52--  https://repo.anaconda.com/miniconda/Miniconda3-py37_22.11.1-1-Linux-x86_64.sh
Resolving repo.anaconda.com (repo.anaconda.com)... 104.16.131.3, 104.16.130.3, 2606:4700::6810:8203, ...
Connecting to repo.anaconda.com (repo.anaconda.com)|104.16.131.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 86308321 (82M) [application/x-sh]
Saving to: ‘mini.sh’


2023-04-04 21:20:53 (127 MB/s) - ‘mini.sh’ saved [86308321/86308321]

PREFIX=/usr/local
Unpacking payload ...
                                                                                     
Installing base environment...


Downloading and Extracting Packages


Downloading and Extracting Packages

Preparing transaction: - \ | / done
Executing transaction: \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | done
installation finished.
    You currently have a PYTHONPATH environment variable set. This may cause
    unexpected b

####2. Then, reload the webpage (not restart runtime) to allow Colab to recognize the newly installed python
####3. Finally, run the following commands to install tensorflow 1.15:

In [1]:
!python3 -m pip install tensorflow==1.15
!python3 -m pip install numpy==1.19.5
!python3 -m pip install protobuf==3.20.1

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow==1.15
  Downloading tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl (412.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m412.3/412.3 MB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting astor>=0.6.0
  Downloading astor-0.8.1-py2.py3-none-any.whl (27 kB)
Collecting keras-preprocessing>=1.0.5
  Downloading Keras_Preprocessing-1.1.2-py2.py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.6/42.6 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tensorflow-estimator==1.15.1
  Downloading tensorflow_estimator-1.15.1-py2.py3-none-any.whl (503 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m503.4/503.4 kB[0m [31m41.6 MB/s[0m eta [36m0:00:00[0m
Collecting tensorboard<1.16.0,>=1.15.0
  Downloading tensorboard-1.15.0-py3-none-any.whl (3.8 MB)
[2K     [90m━━━━━━━━━━

# Configure settings

In [2]:
#@markdown ## General Config
#@markdown In the case that an inference database is large and a long duration of continuous runtime is required, a GCP TPU/runtime to run this notebook may be desirable. If that's the case, specify here:
GCP_RUNTIME = False #@param {type:"boolean"}
#@markdown How many TPU scores the TPU has: if using colab, NUM_TPU_CORES is 8.
NUM_TPU_CORES = 8 #@param {type:"number"}
#@markdown Which mode to use (a different mode means a different finetuning task): options are:
#@markdown * "MRPC" - paired sequence method
#@markdown * "MRPC_w_ex_data" - paired sequence method with external data
#@markdown * "RE" - single sequence method
#@markdown * "NER" - single sequence per residue prediction 
#@markdown 
#@markdown You can add more modes by creating a new processor and/or a new model_fn inside of the "mutformer_model_code" folder downloaded from github, then changing the corresponding code snippets in the code segment named "Authorize for GCS, Imports, and General Setup" (also edit the dropdown below).
MODE = "MRPC_w_ex_data" #@param   ["MRPC_w_ex_data", "MRPC", "RE", "NER"]   {type:"string"} 
                        ####      ^^^^^ dropdown list for all modes ^^^^^
#@markdown Name of the GCS bucket to use (Make sure to set this to the name of your own GCS  bucket):
BUCKET_NAME = "" #@param {type:"string"}
BUCKET_PATH = "gs://"+BUCKET_NAME
#@markdown Where the processed data was stored in GCS:
PROCESSED_DATA_DIR = "all_snp_prediction_data_loaded" #@param {type:"string"}
#@markdown What folder to write predictions into (location of this folder will either be GCS or google drive) (the PREDICTIONS_FOLDER variable can be the same across all finetuning notebooks):
PREDICTIONS_FOLDER = "full_database_prediction" #@param {type:"string"}
#@markdown What folder to write evaluation results into (location of this folder will either be GCS or google drive) EVALUATIONS_FOLDER variable can be the same across all finetuning notebooks):
EVALUATIONS_FOLDER = "" #@param {type:"string"}

#If running on a GCP runtime, follow these instructions to set it up

###1) Create a VM from the GCP website
###2) Open a command prompt on your computer and perform the following steps"
To ssh into the VM, run:

```
gcloud beta compute ssh --zone <COMPUTE ZONE> <VM NAME> --project <PROJECT NAME> -- -L 8888:localhost:8888
```

Note: Make sure the port above matches the port below (in this case it's 8888)
\
\
In the new command prompt that popped out, either run each of the commands below individually, or copy and paste the one liner below:
```
sudo apt-get update
sudo apt-get -y install python3 python3-pip
sudo apt-get install pkg-config
sudo apt-get install libhdf5-serial-dev
sudo apt-get install libffi6 libffi-dev
sudo -H pip3 install jupyter tensorflow==1.14 google-api-python-client tqdm
sudo -H pip3 install jupyter_http_over_ws
jupyter serverextension enable --py jupyter_http_over_ws
jupyter notebook   --NotebookApp.allow_origin='https://colab.research.google.com'   --port=8888   --NotebookApp.port_retries=0   --no-browser
```
One command:
```
sudo apt-get update ; sudo apt-get -y install python3 python3-pip ; sudo apt-get install pkg-config ; sudo apt-get -y install libhdf5-serial-dev ; sudo apt-get install libffi6 libffi-dev; sudo -H pip3 install jupyter tensorflow==1.14 google-api-python-client tqdm ; sudo -H pip3 install jupyter_http_over_ws ; jupyter serverextension enable --py jupyter_http_over_ws ; jupyter notebook   --NotebookApp.allow_origin='https://colab.research.google.com'   --port=8888   --NotebookApp.port_retries=0   --no-browser
```
###3) In this notebook, click the "connect to local runtime" option under the connect button, and copy and paste the link outputted by command prompt with "locahost: ..."
###4) Finally, run this code segment, which creates a TPU


In [None]:
GCE_PROJECT_NAME = "" #@param {type:"string"}
TPU_ZONE = "us-central1-f" #@param {type:"string"}
TPU_NAME = "mutformer-tpu" #@param {type:"string"}

!gcloud alpha compute tpus create $TPU_NAME --accelerator-type=tpu-v2 --version=1.15.5 --zone=$TPU_ZONE ##create new TPU

!gsutil iam ch serviceAccount:`gcloud alpha compute tpus describe $TPU_NAME | grep serviceAccount | cut -d' ' -f2`:admin $BUCKET_PATH && echo 'Successfully set permissions!' ##give TPU access to GCS

#Clone the MutFormer repo

In [3]:
if GCP_RUNTIME:
  !sudo apt-get -y install git
#@markdown Where to clone the repo into:
REPO_DESTINATION_PATH = "mutformer" #@param {type:"string"}
import os,shutil
if not os.path.exists(REPO_DESTINATION_PATH):
  os.makedirs(REPO_DESTINATION_PATH)
else:
  shutil.rmtree(REPO_DESTINATION_PATH)
  os.makedirs(REPO_DESTINATION_PATH)
cmd = "git clone https://github.com/WGLab/mutformer.git \"" + REPO_DESTINATION_PATH + "\""
!{cmd}

Cloning into 'mutformer'...
remote: Enumerating objects: 1574, done.[K
remote: Counting objects: 100% (454/454), done.[K
remote: Compressing objects: 100% (192/192), done.[K
remote: Total 1574 (delta 313), reused 364 (delta 256), pack-reused 1120[K
Receiving objects: 100% (1574/1574), 5.93 MiB | 21.22 MiB/s, done.
Resolving deltas: 100% (1102/1102), done.


#Authorize for GCS, Imports, and General Setup

In [4]:
import sys
import json
import random
import logging
import tensorflow as tf
import time
import importlib
import os
import shutil
import re

if not GCP_RUNTIME:
  print("Authorize for GCS:")
  def authenticate_user(): ##authentication function that uses link authentication instead of popup
    if os.path.exists("/content/.config/application_default_credentials.json"): 
      return
    !gcloud auth application-default login  --no-launch-browser
    with tf.Session() as sess:
      with open("/content/.config/application_default_credentials.json", 'r') as f:
            auth_info = json.load(f)
      tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)
  authenticate_user()
  print("Authorize done")
  
if REPO_DESTINATION_PATH == "mutformer":
  if os.path.exists("mutformer_code"):
    shutil.rmtree("mutformer_code")
  shutil.copytree(REPO_DESTINATION_PATH,"mutformer_code")
  REPO_DESTINATION_PATH = "mutformer_code"
if not os.path.exists("mutformer"):
  shutil.copytree(REPO_DESTINATION_PATH+"/mutformer_model_code","mutformer")
else:
  shutil.rmtree("mutformer")
  shutil.copytree(REPO_DESTINATION_PATH+"/mutformer_model_code","mutformer")
if "mutformer" in sys.path:
  sys.path.remove("mutformer")
sys.path.append("mutformer")

from mutformer import modeling, optimization, tokenization,run_classifier,run_ner_for_pathogenic  #### <<<<< if you added more modes, change these imports to import the correct processors, 
from mutformer.modeling import BertModel,BertModelModified                                        #### <<<<< correct training scripts (i.e. run_classifier and run_ner_for_pathogenic), and
from mutformer.run_classifier import MrpcProcessor,REProcessor,MrpcWithExDataProcessor            #### <<<<< correct model classes
from mutformer.run_ner_for_pathogenic import NERProcessor  

##reload modules in case that's needed
modules2reload = [modeling, 
                  optimization, 
                  tokenization,
                  run_classifier,
                  run_ner_for_pathogenic]
for module in modules2reload:
    importlib.reload(module)

# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

log.handlers = []

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

#@markdown Whether or not to write logs to a file
DO_FILE_LOGGING = True #@param {type:"boolean"}
if DO_FILE_LOGGING:
  #@markdown * If using file logging, what path to write logs to
  FILE_LOGGING_PATH = 'file_logging/spam.log' #@param {type:"string"}
  if not os.path.exists("/".join(FILE_LOGGING_PATH.split("/")[:-1])):
    os.makedirs("/".join(FILE_LOGGING_PATH.split("/")[:-1]))
  fh = logging.FileHandler(FILE_LOGGING_PATH)
  fh.setLevel(logging.INFO)
  fh.setFormatter(formatter)
  log.addHandler(fh)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
log.addHandler(ch)


if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    ##upload credentials to TPU.
    with open(f"/content/.config/application_default_credentials.json", 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    
else:
  log.warning('Not connected to TPU runtime')


if MODE=="MRPC":      ####       vvvvv if you added more modes, change this part to set the processors and training scripts correctly vvvvv
  processor = MrpcProcessor()
  script = run_classifier
  USING_EX_DATA = False
elif MODE=="MRPC_w_ex_data":
  processor = MrpcWithExDataProcessor()
  script = run_classifier
  USING_EX_DATA = True
elif MODE=="RE":
  processor = REProcessor()
  script = run_classifier
  USING_EX_DATA = False
elif MODE=="NER":
  processor = NERProcessor()
  script = run_ner_for_pathogenic
  USING_EX_DATA = False
else:
  raise Exception("The mode specified was not one of the available modes: [\"MRPC\",\"MRPC_w_ex_data\" \"RE\",\"NER\"].")
label_list = processor.get_labels()
                      ####       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


Authorize for GCS:
Go to the following link in your browser:

    https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com&redirect_uri=https%3A%2F%2Fsdk.cloud.google.com%2Fapplicationdefaultauthcode.html&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fsqlservice.login+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Faccounts.reauth&state=iWIhEfnRcyvXqK8cwcvOSWYMl3zQ5c&prompt=consent&access_type=offline&code_challenge=uRQFdPMf1oE27z27HWUjdNyO4VX72sfa_3m-cmIu2dk&code_challenge_method=S256

Enter authorization code: 4/0AVHEtk78S6d9vEeh7WQj_A_JrJtb2lwm8qSa2fh45lcYhumGfGYJou67A0DNM2at4S6z-g

Credentials saved to file: [/content/.config/application_default_credentials.json]

These credentials will be used by any library that requests Application Default Credentials (ADC).
Cannot find a quota p


2023-04-04 21:25:49.877882: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1
2023-04-04 21:25:49.887689: E tensorflow/stream_executor/cuda/cuda_driver.cc:318] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2023-04-04 21:25:49.887724: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (44f1cc70897b): /proc/driver/nvidia/version does not exist
2023-04-04 21:25:49.889100: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2023-04-04 21:25:49.897478: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2299995000 Hz
2023-04-04 21:25:49.899306: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x3fe93f0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2023-04-04 21:25

Authorize done


#Specify location preferences for google drive vs GCS/Mount Drive if needed



In [5]:
#@markdown ###### Note: For all of these, if using GCP_RUNTIME, all of these parameters must use GCS, because a GCP TPU can't access google drive
#@markdown \
#@markdown If original data was stored in drive and data was not generated into more than one shard, full drive path to the original data (for detecting the # of steps per epoch) (this variable should match up with the "INPUT_DATA_DIR" variable in the data generation script) (this is used to limit interaction with GCS; it can also be left blank and steps will be automatically detected from tfrecords stored in GCS):
#@markdown * If GCP_RUNTIME, drive paths will not work, so steps detection will automatically default to tfrecords
ORIG_DATA_FOLDER = "" #@param {type: "string"}
DRIVE_PATH = "/content/drive/My Drive"
#@markdown Whether to use GCS for writing predictions, if not, defaults to drive
GCS_PREDICTIONS = True #@param {type:"boolean"}
#@markdown Whether to use GCS for writing eval results, if not, defaults to drive
GCS_EVAL = True #@param {type:"boolean"}

PREDS_PATH = BUCKET_PATH if GCS_PREDICTIONS else DRIVE_PATH
EVALS_PATH = BUCKET_PATH if GCS_EVAL else DRIVE_PATH

if GCP_RUNTIME:
  FILES_PATH = BUCKET_PATH

if ("/content/drive" in ORIG_DATA_FOLDER and not GCP_RUNTIME) or not GCS_PREDICTIONS or not GCS_EVAL:
  def mount_drive(): ##mount drive function which uses link mounting instead of popup mounting
    if not os.path.exists("/content/drive/MyDrive"):
      os.makedirs("/content/drive/MyDrive")
      !sudo add-apt-repository -y ppa:alessandro-strada/ppa &> /dev/null ##install google-drive-ocamlfuse
      !sudo apt-get update -qq &> /dev/null
      !sudo apt -y install -qq google-drive-ocamlfuse &> /dev/null
    if len(os.listdir("/content/drive/MyDrive")) >0:
      print("Drive already mounted.")
      return

    if not os.path.exists("/content/driveauthlink.txt") or not open("/content/driveauthlink.txt").read(): ##if the auth link has not been generated, generate it
      !google-drive-ocamlfuse &> /content/driveauthlink.txt
    !sudo apt-get install -qq w3m &> /dev/null
    !xdg-settings set default-web-browser w3m.desktop &> /dev/null
    import re
    link = re.findall("https://.+",[x for x in open("/content/driveauthlink.txt").read().split("\n") if x][-1])[0].split("\"")[0]
    print(f"Click this link to authenticate for mounting drive: {link}") ##print auth link
    print("Waiting for valid athentication...")
    error = None
    while True: ##while the google-drive-ocamlfuse mounting doesn't work (user hasn't athenticated yet), keep trying    
      if os.path.exists("/content/drivemounterror.txt"):
        os.remove("/content/drivemounterror.txt")
      !google-drive-ocamlfuse /content/drive/MyDrive 2> "/content/drivemounterror.txt" 1> /dev/null
      if error and open("/content/drivemounterror.txt").read()!=error:
        raise Exception(f"Drive mount failed. Error: \n\n {open('/content/drivemounterror.txt').read()}")
      error = open("/content/drivemounterror.txt").read()
      no_error = not len(error) >0
      if no_error:
        if len(os.listdir("/content/drive/MyDrive")) >0:
          print("Drive mounted successfully!")
        else:
          raise Exception(f"Drive mount failed. Error: Unknown (likely Keyboard Interrupt)")
        break
  mount_drive()
  






# Run Eval/prediction

This following section will perform evaluation and prediction on either the eval dataset or the test dataset.

###General Setup and definitions

In [6]:
#@markdown When performing prediction, whether or not to ensure all datapoints are predicted via a trailing test dataset: (if so, make sure this option was also specified as True during data generation)
PRECISE_TESTING = True #@param {type:"boolean"}
#@markdown Maximum batch size the runtime can handle during prediction without OOM for all models being evaluated/tested: note that this value should match up with the variable "MAX_BATCH_SIZE" in the data generation script.
MAX_BATCH_SIZE =   512#@param {type:"integer"}

def latest_checkpoint(dir):
  cmd = "gsutil ls "+dir
  files = !{cmd}
  for file in files:
    if "model.ckpt" in file:
      return file.replace("."+file.split(".")[-1],"")

def write_metrics(metrics,dir):
  tf.logging.info("writing metrics to "+dir)
  if os.path.exists(dir):
    shutil.rmtree(dir)
  os.makedirs(dir)
  gs = metrics["global_step"]
  tf.logging.info("global step "+str(gs))

  tf.compat.v1.disable_eager_execution()
  tf.reset_default_graph()
  for key,value in metrics.items():
    tf.logging.info(str(key)+":"+str(value))
    x_scalar = tf.constant(value)
    first_summary = tf.summary.scalar(name=key, tensor=x_scalar)

    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        writer = tf.summary.FileWriter(dir)
        sess.run(init)
        summary = sess.run(first_summary)
        writer.add_summary(summary, gs)
        writer.flush()
        tf.logging.info("Done with writing the scalar summary")
    time.sleep(1)

  if GCS_EVAL:
    cmd = "gsutil -m cp -r \""+dir+"/.\" \""+EVALS_PATH+"/"+dir+"\""
    !{cmd}  
  else:
    if not os.path.exists(EVALS_PATH+"/"+dir):
      os.makedirs(EVALS_PATH+"/"+dir)
    shutil.copytree(dir,EVALS_PATH+"/"+dir)
  

def write_predictions(PREDICTIONS_DIR,
                      result,
                      result_trailing,
                      shard_id=""):
  if not os.path.exists(PREDS_PATH+"/"+PREDICTIONS_DIR):
    os.makedirs(PREDS_PATH+"/"+PREDICTIONS_DIR)
  with tf.gfile.Open(PREDS_PATH+"/"+PREDICTIONS_DIR+"/predictions"+shard_id+".txt", "w") as writer:
    tf.logging.info("***** Predict results *****")
    if result:
      for (i, prediction) in enumerate(result):
        output_line = "\t".join([str(k)+":"+str(v) for k,v in prediction.items()]) + "\n"
        writer.write(output_line)
    if result_trailing:
      for (i, prediction) in enumerate(result_trailing):
        output_line = "\t".join([str(k)+":"+str(v) for k,v in prediction.items()]) + "\n"
        writer.write(output_line)


def evaluation_loop(RUN_EVAL,
                    RUN_PREDICTION,
                    EVALUATE_WHILE_PREDICT,
                    test_or_dev,
                    MODEL,
                    total_metrics,
                    MAX_SEQ_LENGTH,
                    current_ORIG_DATA_FOLDER,
                    BERT_GCS_DIR,
                    USE_LATEST,
                    CHECKPOINT_STEP,
                    DATA_GCS_DIR,
                    USING_SHARDS,
                    START_SHARD,
                    END_SHARD,
                    USING_EX_DATA,
                    PRED_NUM,
                    EVAL_WHILE_PREDICT_PREDICTIONS_DIR,
                    PREDICTIONS_DIR,
                    EVALUATIONS_DIR,
                    CONFIG_FILE):

  try: ##wrap everything in a giant try except so that any 
       ##glitches won't completely stop evaluation in the middle
    current_ckpt = ""

    tf.logging.info("Using data from: "+DATA_GCS_DIR)
    tf.logging.info("Loading model from: "+BERT_GCS_DIR)

    
    def steps_getter(input_files):
      tot_sequences = []
      for input_file in input_files:
        tf.logging.info("reading:"+input_file+" for steps")

        d = tf.data.TFRecordDataset(input_file)

        with tf.Session() as sess:
          tot_sequences.append(sess.run(d.reduce(0, lambda x,_: x+1)))

      return tot_sequences

    test_datasets = [re.findall("test_(\w+).tf_record",file)[0] \
                for file in tf.io.gfile.listdir(DATA_GCS_DIR) \
                if re.findall("test_(\w+).tf_record",file) and "trailing" not in file]
    if not test_datasets or test_or_dev!="test":
      test_datasets = [None]
    for dataset in test_datasets:
      evaluating_file = f"{test_or_dev}_{dataset}.tf_record" if dataset else f"{test_or_dev}.tf_record"
      eval_file = os.path.join(DATA_GCS_DIR, evaluating_file)
      PREDICTIONS_DIR_for_dataset = f"{PREDICTIONS_DIR}/{dataset}"  if dataset else PREDICTIONS_DIR
      EVALUATIONS_DIR_for_dataset = f"{EVALUATIONS_DIR}/{dataset}" if dataset else EVALUATIONS_DIR
      if USING_SHARDS:
        shards_folder = DATA_GCS_DIR
        input_file = os.path.join(DATA_GCS_DIR, evaluating_file)
        file_name = input_file.split("/")[-1]
        all_shards = [[int(re.match(f"{file_name}_(\d+)", file).groups()[0]), shards_folder + "/" + file] for file in tf.io.gfile.listdir(shards_folder) if
                  re.match(f"{file_name}_\d+", file)]
        all_shards = sorted(all_shards,key=lambda x:x[0])
        shards_and_inds = [[shard_ind,shard] for shard_ind,shard in all_shards if START_SHARD<=shard_ind and ((shard_ind<END_SHARD) if END_SHARD!=-1 else True)]

        shards = [shard for shard_ind,shard in shards_and_inds]
        shard_inds = [shard_ind for shard_ind,shard in shards_and_inds]
      else:
        all_shards = [[0,eval_file]]
        shards = [eval_file]

      if USING_SHARDS:
        tf.logging.info("\nUSING SHARDs:")
        for n,shard in enumerate(shards):
          if n==END_SHARD: break
          tf.logging.info(shard)
        tf.logging.info("\n")

      if RUN_EVAL:
        try:
          if len(shards)>1:
            raise Exception("more than one shard needs detection of steps from tfrecords. Reverting to tfrecord steps detection...")
          if dataset=="dev":
            data_path = "/content/drive/My Drive/"+current_ORIG_DATA_FOLDER+"/dev.tsv"
          else:
            data_path = "/content/drive/My Drive/"+current_ORIG_DATA_FOLDER+(f"/test_{dataset}.tsv" if dataset else "test.tsv")
          lines = open(data_path).read().split("\n")
          EVAL_STEPSs = [int(len(lines)/EVAL_BATCH_SIZE)]
        except Exception:
          SEQUENCES_PER_EPOCHs = steps_getter(shards)
          EVAL_STEPSs = [int(SEQUENCES_PER_EPOCH/EVAL_BATCH_SIZE) for SEQUENCES_PER_EPOCH in SEQUENCES_PER_EPOCHs]

      
      if EVALUATE_WHILE_PREDICT:
        cmd = "gsutil -m rm -r "+EVAL_WHILE_PREDICT_PREDICTIONS_DIR
        !{cmd}
      def rewrite_ckpt_file2_restore_ckpt():
        if USE_LATEST:
          try:
            latest_ckpt = tf.train.latest_checkpoint(BERT_GCS_DIR).split("/")[-1]
            max_step = max([int(ckpt.split(".")[-2].split("-")[-1]) for ckpt in tf.io.gfile.listdir(BERT_GCS_DIR)])
            RESTORE_CHECKPOINT = [".".join(ckpt.split(".")[:-1]) 
                                  for ckpt in tf.io.gfile.listdir(BERT_GCS_DIR) 
                                  if len(ckpt.split("."))==3 and str(max_step) == ckpt.split(".")[-2].split("-")[-1]][0]
            old_file_lines = tf.gfile.Open(BERT_GCS_DIR+"/checkpoint").read().split("\n")
            new_file_lines = old_file_lines.copy()
            new_file_lines[0] = new_file_lines[0].replace(latest_ckpt,RESTORE_CHECKPOINT)
            RESTORE_CHECKPOINT = BERT_GCS_DIR+"/"+RESTORE_CHECKPOINT

            tf.gfile.Open(BERT_GCS_DIR+"/checkpoint","w+").write("\n".join(new_file_lines))
            
          except Exception:
            try:
              RESTORE_CHECKPOINT = latest_checkpoint(BERT_GCS_DIR)
            except Exception:
              raise Exception("No checkpoints were found in the given location")
        else:
          try:
            latest_ckpt = tf.train.latest_checkpoint(BERT_GCS_DIR).split("/")[-1]
            RESTORE_CHECKPOINT = [".".join(ckpt.split(".")[:-1]) 
                                  for ckpt in tf.io.gfile.listdir(BERT_GCS_DIR) 
                                  if len(ckpt.split("."))==3 and str(CHECKPOINT_STEP) == ckpt.split(".")[-2].split("-")[-1]][0]
            old_file_lines = tf.gfile.Open(BERT_GCS_DIR+"/checkpoint").read().split("\n")
            new_file_lines = old_file_lines.copy()
            new_file_lines[0] = new_file_lines[0].replace(latest_ckpt,RESTORE_CHECKPOINT)
            RESTORE_CHECKPOINT = BERT_GCS_DIR+"/"+RESTORE_CHECKPOINT

            tf.gfile.Open(BERT_GCS_DIR+"/checkpoint","w+").write("\n".join(new_file_lines))
          except Exception as e:
            tf.logging.info("\n\nCould not find the checkpoint specified. Error:"+str(e)+". Skipping...\n\n")
            return False,total_metrics,current_ckpt
        return RESTORE_CHECKPOINT

      RESTORE_CHECKPOINT = rewrite_ckpt_file2_restore_ckpt()
      current_ckpt=RESTORE_CHECKPOINT
      tf.logging.info("USING CHECKPOINT:"+RESTORE_CHECKPOINT)
        
      config = modeling.BertConfig.from_json_file(CONFIG_FILE)

      model_fn = script.model_fn_builder(
          bert_config=config,
          num_labels=len(label_list),
          init_checkpoint=None,
          restore_checkpoint=RESTORE_CHECKPOINT,
          init_learning_rate=0,
          decay_per_step=0,
          num_warmup_steps=10,
          use_tpu=True,
          use_one_hot_embeddings=True,
          bert=MODEL,
          test_results_dir=EVAL_WHILE_PREDICT_PREDICTIONS_DIR,
          yield_predictions=EVALUATE_WHILE_PREDICT,
          using_ex_data=USING_EX_DATA)

      
      tf.logging.info("USING FILE:"+eval_file)

      def load_stuff(batch_size,file):
          tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
          run_config = tf.contrib.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            model_dir=BERT_GCS_DIR,
            tpu_config=tf.contrib.tpu.TPUConfig(
                num_shards=min(NUM_TPU_CORES,batch_size),
                per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

          estimator = tf.contrib.tpu.TPUEstimator(
              use_tpu=True,
              model_fn=model_fn,
              config=run_config,
              train_batch_size=1,
              eval_batch_size=batch_size,
              predict_batch_size=batch_size)
          input_fn = script.file_based_input_fn_builder(
              input_file=file,
              seq_length=MAX_SEQ_LENGTH,
              is_training=False,
              drop_remainder=True,
              pred_num=PRED_NUM if USING_EX_DATA else None)
          return estimator, input_fn

      all_eval_metrics = []
      tf.logging.info("***** Running evaluation/prediction *****")
      tf.logging.info(" Eval Batch size = "+str(EVAL_BATCH_SIZE))
      tf.logging.info(" Predict Batch size = "+str(MAX_BATCH_SIZE))
      for n,shard in enumerate(shards):
        tf.logging.info(f"\n\nUSING SHARD: {shard}...\n\n")
        if RUN_EVAL:
          estimator,input_fn = load_stuff(EVAL_BATCH_SIZE,shard)
          if EVAL_STEPSs[n] > 0:
            RESTORE_CHECKPOINT = rewrite_ckpt_file2_restore_ckpt()
            eval_metrics = estimator.evaluate(input_fn=input_fn, steps=EVAL_STEPSs[n])
            all_eval_metrics.append([eval_metrics,EVAL_STEPSs[n]*EVAL_BATCH_SIZE])
          if PRECISE_TESTING and shard == all_shards[-1][1] and test_or_dev=="test":
            trailing_test_file = os.path.join(DATA_GCS_DIR, (f"test_trailing_{dataset}.tf_record" if dataset else "test_trailing.tf_record"))
            if tf.gfile.Open(trailing_test_file).size() > 0:
              steps = EVAL_STEPSs[n]
              estimator_trailing,eval_input_fn_trailing = load_stuff(1,trailing_test_file)
              RESTORE_CHECKPOINT = rewrite_ckpt_file2_restore_ckpt()
              eval_metrics=estimator_trailing.evaluate(input_fn=eval_input_fn_trailing, steps=steps)
              all_eval_metrics.append([eval_metrics,steps])
        if RUN_PREDICTION:
          result, result_trailing = [None, None]
          if tf.gfile.Open(shard).size() > 0: ##sometimes, if test dataset is really small, the test dataset doesn't have data; the data is all in the trailing dataset
            estimator,input_fn = load_stuff(MAX_BATCH_SIZE,shard)
            RESTORE_CHECKPOINT = rewrite_ckpt_file2_restore_ckpt()
            result=estimator.predict(input_fn=input_fn)
          if PRECISE_TESTING and shard == all_shards[-1][1] and test_or_dev=="test":
            trailing_test_file = f"{DATA_GCS_DIR}/{'test_trailing_'+dataset+'.tf_record' if dataset else 'test_trailing.tf_record'}" 
            if tf.gfile.Open(trailing_test_file).size() > 0:
              tf.logging.info(f"\n\nUSING TRAILING DATASET: {trailing_test_file}...\n\n")
              estimator_trailing,test_input_fn_trailing = load_stuff(1,trailing_test_file)
              RESTORE_CHECKPOINT = rewrite_ckpt_file2_restore_ckpt()
              result_trailing=estimator_trailing.predict(input_fn=test_input_fn_trailing)
          write_predictions(PREDICTIONS_DIR_for_dataset,
                            result,
                            result_trailing,
                            shard_id=str(shard_inds[n]) if USING_SHARDS else "")

      if RUN_EVAL:
        combined_metrics = {}
        weight_divisor = sum([v for k,v in all_eval_metrics])
        for metric_set,weight in all_eval_metrics:
          for k,v in metric_set.items():
            try:
              combined_metrics[k]+=v*weight/weight_divisor
            except:
              combined_metrics[k]=v*weight/weight_divisor
        
        ##write out evaluation metrics data
        write_metrics(combined_metrics,EVALUATIONS_DIR_for_dataset)
        tf.logging.info(f"\n\n\n\n\n\nEVAL METRICS ({EVALUATIONS_DIR_for_dataset}):")
        print(all_eval_metrics)
        for k,v in combined_metrics.items():
          tf.logging.info(k+":"+str(v))
        tf.logging.info("\n\n\n\n\n\n\n")

        if not REPEAT_LOOP:
            total_metrics[EVALUATIONS_DIR_for_dataset] = combined_metrics
    return True,total_metrics,current_ckpt
  except Exception as e:
      tf.logging.info("\n\nFAILED-error:"+str(e)+". Skipping...\n\n")
      return False,total_metrics,current_ckpt
  

###Eval/prediction loops

Following are two code segments for runnign the finetuning evaluation/prediction loops:
1. Model/sequence length: perform evaluation/prediction for the train loop with the same name from the "mutformer_finetuning_benchmark" file
1. Freezing/batch size: perform evaluation/prediction for the train loop with the same name from the "mutformer_finetuning_benchmark" file

Choose a desired code segment to run, enter the desired options for evaluating/predicting, and run that code segment

Note: One may write more evaluation/prediction loops for more tests based on a similar format to these two example evaluation/prediction loops below, i.e. batch size/sequence length.
\
\
Note: All evaluation results will be written into the previously specified logging directory either under google drive or GCS, depending on the values of GCS_COMS, GCS_PREDICTIONS, and GCS_EVAL specified before. To view the results, use the colab notebook titled "mutformer processing and viewing finetuning..._results," which can also be used to view prediction results.

###Model/Sequence Length

In [None]:
#@markdown ### IO config
#@markdown Folder for where to load the finetuned model from
FINETUNED_MODEL_DIR = "" #@param {type:"string"}
#@markdown Which folder inside of PREDICTIONS_DIR and EVALUATIONS_DIR to write predictions and evaluations, respectively, into:
RUN_NAME = "" #@param {type:"string"}
#@markdown \
#@markdown 
#@markdown 
#@markdown ### Evaluation/prediction procedure config
#@markdown The evaluation loop will loop through a list of models and a list of sequence lengths, attempting to evaluate a finetuned model for each combination of pretrained model and sequence length (failed combinations will be skipped).
#@markdown * List of pretrained models that were used for finetuning (should indicate the names of the model folders inside INIT_MODEL_DIR from the finetuning training script):
MODELS = ["MutFormer_em_adap8L"] #@param
#@markdown * List of model architectures for each model in the "MODELS" list defined in the entry above: each position in this list must correctly indicate the model architecture of its corresponding model folder in the list "MODELS" (BertModel indicates the original BERT, BertModelModified indicates MutFormer's architecture).
MODEL_ARCHITECTURES = ["MutFormer_embedded_convs"] #@param
#@markdown * List of sequence lengthed models to test
MAX_SEQ_LENGTHS = [1024] #@param
#@markdown Whether to evaluate on the test set or the dev set ("test" or "dev")
dataset = "test" #@param{type:"string"}
#@markdown Whether or not to run evaluation
RUN_EVAL = False #@param {type:"boolean"}
#@markdown Whether or not to run prediction (in a seperate loop from evaluation; EVALUATE_WHILE_PREDICT will override this value to False)
RUN_PREDICTION = True #@param {type:"boolean"}
#@markdown Whether or not to repeat this operation in a loop (if performing parallel evaluation operation, set to True, False otherwise)
#@markdown * If using REPEAT_LOOP, to prevent the script from evaluating every single model trained on every single combination of batch size and sequence length every loop, the script will only evaluate models that are being currently trained (the script will only evaluate on the model folders that have seen a new latest checkpoint since the script started running).
REPEAT_LOOP = False #@param {type:"boolean"}
#@markdown When using REPEAT_LOOP, how long to wait in between each loop before checking again for updated train progress:
CHECK_MODEL_EVERY_N_SECS =  150#@param {type:"integer"}
#@markdown If evaluating, whether or not to evaluate and predict results in the same loop; useful when amount of test data is very small and the time it takes to restart a loop is significant (if yes, prediction results will be written in the form of tfevent files into GCS that need to be viewed using the notebook titled "mutformer processing and viewing finetuning results")
#@markdown 
#@markdown Note: If using EVALUATE_WHILE_PREDICT, prediction results must be read using the previously mentioned colab notebook, otherwise, predictions will be written directly as txts and will be directly accessible from google drive under the folder specified above
EVALUATE_WHILE_PREDICT =  False #@param {type:"boolean"}
#@markdown What batch size to use during evaluation (larger batch size will increase evaluation speed but may skip more datapoints)
EVAL_BATCH_SIZE = 1 #@param {type:"integer"}
#@markdown Whether or not testing/evaluating data was generated in shards
USING_SHARDS = False #@param {type:"boolean"}
#@markdown * If using shards, set this value to indicate which shard index to start at (defualt 0 for first shard)
START_SHARD = 0 #@param {type:"integer"}
#@markdown * If using shards, set this value to indicate which shard index to evaluate until (not inclusive) (defualt -1 for last shard)
END_SHARD = 0 #@param {type:"integer"}
#@markdown Whether to use the latest checkpoint in the folder (set to false if an intermediate checkpoint should be used)
USE_LATEST = False #@param {type:"boolean"}
#@markdown * If not using latest checkpoint, which step's checkpoint to use
CHECKPOINT_STEP =  None#@param {type:"integer"}

total_metrics = {}  ## a dictionary for all metrics to  
                    ## print at the end during testing, 
                    ## not necessary during evaluation   
if dataset=="test":
  evaluating_file = "test.tf_record"
elif dataset=="dev":
  evaluating_file = "eval.tf_record"
else:
  raise Exception("only datasets supported are dev and test")

DATA_INFOS = [["N/A" for MODEL_NAME in MODELS]            ##create an empty 2D list to store all
              for MAX_SEQ_LENGTH in MAX_SEQ_LENGTHS]      ##the data info dictionaries

current_ckpts = [["N/A" for MODEL_NAME in MODELS]
                 for MAX_SEQ_LENGTH in MAX_SEQ_LENGTHS]
for M,MAX_SEQ_LENGTH in enumerate(MAX_SEQ_LENGTHS):
  for m,MODEL_NAME in enumerate(MODELS):
        BERT_GCS_DIR = BUCKET_PATH+"/"+FINETUNED_MODEL_DIR+"/mn_"+MODEL_NAME+"_sl_"+str(MAX_SEQ_LENGTH)
        try:
          current_ckpts[M][m] = tf.train.latest_checkpoint(BERT_GCS_DIR)
        except:
          try:
            current_ckpts[M][m] = latest_checkpoint(BERT_GCS_DIR)
          except:
            raise Exception(f"could not find any checkpoints in the model dir specified:{BERT_GCS_DIR}")

def get_new_ckpts(current_ckpts):
  new_ckpts = []
  for M,MAX_SEQ_LENGTH in enumerate(MAX_SEQ_LENGTHS):
    for m,MODEL_NAME in enumerate(MODELS):
          BERT_GCS_DIR = BUCKET_PATH+"/"+FINETUNED_MODEL_DIR+"/mn_"+MODEL_NAME+"_sl_"+str(MAX_SEQ_LENGTH)
          try:
            current_ckpt = tf.train.latest_checkpoint(BERT_GCS_DIR)
            if current_ckpts[M][m]!=current_ckpt:
              new_ckpts.append([M,m])
          except:
            try:
              current_ckpt = latest_checkpoint(BERT_GCS_DIR)
              if current_ckpts[M][m]!=current_ckpt:
                new_ckpts.append([M,m])
            except:
              raise Exception(f"could not find any checkpoints in the model dir specified:{BERT_GCS_DIR}")
  return new_ckpts

while True:
  sleeping = True   ##to prevent excessive interaction with GCS, 
                    ##if an eval/pred loop fails, the script 
                    ##will wait for a while before trying again

  if REPEAT_LOOP:                             ##if using REPEAT_LOOP, only evaluate on new checkpoints
    new_ckpts = get_new_ckpts(current_ckpts)
    if len(new_ckpts) == 0:
      print("No new checkpoints have been written since script start/last evaluation. Trying again in another",CHECK_MODEL_EVERY_N_SECS,"seconds.")

  for M,MAX_SEQ_LENGTH in enumerate(MAX_SEQ_LENGTHS):
    for m,MODEL_NAME in enumerate(MODELS):

      if REPEAT_LOOP:
        if [M,m] not in new_ckpts:
          continue

      print("\n\n\nMODEL NAME:",MODEL_NAME,
            "\nINPUT MAX SEQ LENGTH:",MAX_SEQ_LENGTH)
      
      MODEL = getattr(modeling, MODEL_ARCHITECTURES[m])
      current_ORIG_DATA_FOLDER= ORIG_DATA_FOLDER+"/"+str(MAX_SEQ_LENGTH)

      BERT_GCS_DIR = BUCKET_PATH+"/"+FINETUNED_MODEL_DIR+"/mn_"+MODEL_NAME+"_sl_"+str(MAX_SEQ_LENGTH)
      DATA_GCS_DIR = BUCKET_PATH+"/"+PROCESSED_DATA_DIR+"/"+str(MAX_SEQ_LENGTH)
          
      EVAL_WHILE_PREDICT_PREDICTIONS_DIR = BUCKET_PATH+"/"+PREDICTIONS_FOLDER+"/"+RUN_NAME+"/mn_"+MODEL_NAME+"_sl_"+str(MAX_SEQ_LENGTH)
      EVALUATIONS_DIR = EVALUATIONS_FOLDER+"/"+RUN_NAME+"/mn_"+MODEL_NAME+"_sl_"+str(MAX_SEQ_LENGTH)
      PREDICTIONS_DIR = PREDICTIONS_FOLDER+"/"+RUN_NAME+"/mn_"+MODEL_NAME+"_sl_"+str(MAX_SEQ_LENGTH)
      CONFIG_FILE = BUCKET_PATH+"/"+FINETUNED_MODEL_DIR+"/mn_"+MODEL_NAME+"_sl_"+str(MAX_SEQ_LENGTH)+"/config.json"
      
      if DATA_INFOS[M][m] == "N/A":
        DATA_INFOS[M][m] = json.load(tf.gfile.Open(DATA_GCS_DIR+"/info.json"))
      
      EX_DATA_NUM = DATA_INFOS[M][m]["ex_data_num"] if USING_EX_DATA else 0


      ##run the evaluation/prediction loop
      sucess,total_metrics,current_ckpt = \
          evaluation_loop(RUN_EVAL,
                          RUN_PREDICTION,
                          EVALUATE_WHILE_PREDICT,
                          dataset,
                          MODEL,
                          total_metrics,
                          MAX_SEQ_LENGTH,
                          current_ORIG_DATA_FOLDER,
                          BERT_GCS_DIR,
                          USE_LATEST,
                          CHECKPOINT_STEP,
                          DATA_GCS_DIR,
                          USING_SHARDS,
                          START_SHARD,
                          END_SHARD,
                          USING_EX_DATA,
                          EX_DATA_NUM,
                          EVAL_WHILE_PREDICT_PREDICTIONS_DIR,
                          PREDICTIONS_DIR,
                          EVALUATIONS_DIR,
                          CONFIG_FILE)
              
      current_ckpts[M][m] = current_ckpt
      if sucess:
        sleeping = False
    break
  time.sleep(CHECK_MODEL_EVERY_N_SECS if sleeping else 0)
  if not REPEAT_LOOP:
    break
if not REPEAT_LOOP:
  for evals_dir,metrics in total_metrics.items():
    print("Printing metrics for:",evals_dir,"\n")
    for key,metric in metrics.items():
      print(key+":",metric)
    print("\n")



###Freezing/Batch Size

In [None]:
#@markdown ### IO config
#@markdown Folder for where to load the finetuned model from
FINETUNED_MODEL_DIR = "" #@param {type:"string"}
#@markdown Which folder inside of PREDICTIONS_DIR and EVALUATIONS_DIR to write predictions and evaluations, respectively, into:
RUN_NAME = "" #@param {type:"string"}
#@markdown \
#@markdown 
#@markdown 
#@markdown ### Evaluation/prediction procedure config
#@markdown The evaluation loop will loop through a list of models and a list of sequence lengths, attempting to evaluate a finetuned model for each combination of pretrained model and sequence length (failed combinations will be skipped).
#@markdown * List of pretrained models that were used for finetuning (should indicate the names of the model folders inside INIT_MODEL_DIR from the finetuning training script):
FREEZINGS = [0] #@param
#@markdown Batch size to use
BATCH_SIZES =  [32] #@param
#@markdown The training loop will loop through a list of pretrained models and a list of sequence lengths, training a model for each combination of pretrained model and sequence length
#@markdown * Model Name to use (should indicate the name of a model folder inside the specified INIT_MODEL_DIR
MODEL_NAME =  "MutFormer_em_adap8L"#@param
#@markdown * Model architecture to use. Must correctly correspond to the model indicated by the model folder specified by the above "MODEL_NAME" parameter (BertModel indicates the original BERT, BertModelModified indicates MutFormer's architecture without integrated convs, MutFormer_embedded_convs indicates MutFormer with integrated convolutions).
MODEL_ARCHITECTURE = "MutFormer_embedded_convs" #@param
#@markdown * List of sequence lengths to test
MAX_SEQ_LENGTH = 1024 #@param
#@markdown What dataset to evaluate/predict (either "dev" or "test"):
dataset = "train" #@param{type:"string"}
#@markdown Whether or not to run evaluation
RUN_EVAL = False #@param {type:"boolean"}
#@markdown Whether or not to run prediction (in a seperate loop from evaluation; EVALUATE_WHILE_PREDICT will override this value to False)
RUN_PREDICTION = True #@param {type:"boolean"}
#@markdown Whether or not to repeat this operation in a loop (if performing parallel evaluation operation, set to True, False otherwise)
#@markdown * If using REPEAT_LOOP, to prevent the script from evaluating every single model trained on every single combination of batch size and sequence length every loop, the script will only evaluate models that are being currently trained (the script will only evaluate on the model folders that have seen a new latest checkpoint since the script started running).
REPEAT_LOOP = False #@param {type:"boolean"}
#@markdown When using REPEAT_LOOP, how long to wait in between each loop before checking again for updated train progress:
CHECK_MODEL_EVERY_N_SECS =  150#@param {type:"integer"}
#@markdown If evaluating, whether or not to evaluate and predict results in the same loop; useful when amount of test data is very small and the time it takes to restart a loop is significant (if yes, prediction results will be written in the form of tfevent files into GCS that need to be viewed using the notebook titled "mutformer processing and viewing finetuning results")
#@markdown 
#@markdown Note: If using EVALUATE_WHILE_PREDICT, prediction results must be read using the previously mentioned colab notebook, otherwise, predictions will be written directly as txts and will be directly accessible from google drive under the folder specified above
EVALUATE_WHILE_PREDICT =  False #@param {type:"boolean"}
#@markdown What batch size to use during evaluation (larger batch size will increase evaluation speed but may skip more datapoints)
EVAL_BATCH_SIZE =  2#@param {type:"integer"}
#@markdown Whether or not testing/evaluating data was generated in shards
USING_SHARDS = False #@param {type:"boolean"}
#@markdown * If using shards, set this value to indicate which shard index to start at (defualt 0 for first shard)
START_SHARD = 0 #@param {type:"integer"}
#@markdown * If using shards, set this value to indicate which shard index to evaluate until (not inclusive) (defualt -1 for last shard)
END_SHARD = 0 #@param {type:"integer"}
#@markdown Whether to use the latest checkpoint in the folder (set to false if an intermediate checkpoint should be used)
USE_LATEST = False #@param {type:"boolean"}
#@markdown * If not using latest checkpoint, which step's checkpoint to use
CHECKPOINT_STEP =  None#@param {type:"integer"}

total_metrics = {}  ## a dictionary for all metrics to  
                    ## print at the end during testing, 
                    ## not necessary during evaluation   

  

DATA_INFOS = [["N/A" for BATCH_SIZE in BATCH_SIZES]            ##create an empty 2D list to store all
              for FREEZING in FREEZINGS]      ##the data info dictionaries

current_ckpts = [["N/A" for BATCH_SIZE in BATCH_SIZES] for FREEZING in FREEZINGS]
for M,FREEZING in enumerate(FREEZINGS):
    for m,BATCH_SIZE in enumerate(BATCH_SIZES):
        BERT_GCS_DIR = BUCKET_PATH+"/"+FINETUNED_MODEL_DIR+f"/fl_{FREEZING}_bs_{BATCH_SIZE}"
        try:
          current_ckpts[M][m] = tf.train.latest_checkpoint(BERT_GCS_DIR)
        except:
          try:
            current_ckpts[M][m] = latest_checkpoint(BERT_GCS_DIR)
          except:
            pass

def get_new_ckpts(current_ckpts):
  new_ckpts = []
  for M,FREEZING in enumerate(FREEZINGS):
    for m,BATCH_SIZE in enumerate(BATCH_SIZES):
          BERT_GCS_DIR = BUCKET_PATH+"/"+FINETUNED_MODEL_DIR+f"/fl_{FREEZING}_bs_{BATCH_SIZE}"
          try:
            current_ckpt = tf.train.latest_checkpoint(BERT_GCS_DIR)
            if current_ckpts[M][m]!=current_ckpt:
              new_ckpts.append([M,m])
          except:
            try:
              current_ckpt = latest_checkpoint(BERT_GCS_DIR)
              if current_ckpts[M][m]!=current_ckpt:
                new_ckpts.append([M,m])
            except:
              pass
  return new_ckpts

while True:
  sleeping = True   ##to prevent excessive interaction with GCS, 
                    ##if an eval/pred loop fails, the script 
                    ##will wait for a while before trying again

  if REPEAT_LOOP:                             ##if using REPEAT_LOOP, only evaluate on new checkpoints
    new_ckpts = get_new_ckpts(current_ckpts)
    if len(new_ckpts) == 0:
      print("No new checkpoints have been written since script start/last evaluation. Trying again in another",CHECK_MODEL_EVERY_N_SECS,"seconds.")

  for M,FREEZING in enumerate(FREEZINGS):
    for m,BATCH_SIZE in enumerate(BATCH_SIZES):

      if REPEAT_LOOP:
        if [M,m] not in new_ckpts:
          continue

      print("\n\n\nFreezing layers:",FREEZING,
            "\nBATCH SIZE:",BATCH_SIZE)
      
      MODEL = getattr(modeling, MODEL_ARCHITECTURE)
      current_ORIG_DATA_FOLDER= ORIG_DATA_FOLDER+"/"+str(MAX_SEQ_LENGTH)

      BERT_GCS_DIR = BUCKET_PATH+"/"+FINETUNED_MODEL_DIR+f"/fl_{FREEZING}_bs_{BATCH_SIZE}"
      DATA_GCS_DIR = BUCKET_PATH+"/"+PROCESSED_DATA_DIR+"/"+str(MAX_SEQ_LENGTH)
          
      EVAL_WHILE_PREDICT_PREDICTIONS_DIR = BUCKET_PATH+"/"+PREDICTIONS_FOLDER+"/"+RUN_NAME+f"/fl_{FREEZING}_bs_{BATCH_SIZE}"
      EVALUATIONS_DIR = EVALUATIONS_FOLDER+"/"+RUN_NAME+f"/fl_{FREEZING}_bs_{BATCH_SIZE}"
      PREDICTIONS_DIR = PREDICTIONS_FOLDER+"/"+RUN_NAME+f"/fl_{FREEZING}_bs_{BATCH_SIZE}"
      CONFIG_FILE = BUCKET_PATH+"/"+FINETUNED_MODEL_DIR+f"/fl_{FREEZING}_bs_{BATCH_SIZE}"+"/config.json"
      
      if DATA_INFOS[M][m] == "N/A":
        DATA_INFOS[M][m] = json.load(tf.gfile.Open(DATA_GCS_DIR+"/info.json"))
      
      EX_DATA_NUM = DATA_INFOS[M][m]["ex_data_num"] if USING_EX_DATA else 0


      ##run the evaluation/prediction loop
      sucess,total_metrics,current_ckpt = \
          evaluation_loop(RUN_EVAL,
                          RUN_PREDICTION,
                          EVALUATE_WHILE_PREDICT,
                          dataset,
                          MODEL,
                          total_metrics,
                          MAX_SEQ_LENGTH,
                          current_ORIG_DATA_FOLDER,
                          BERT_GCS_DIR,
                          USE_LATEST,
                          CHECKPOINT_STEP,
                          DATA_GCS_DIR,
                          USING_SHARDS,
                          START_SHARD,
                          END_SHARD,
                          USING_EX_DATA,
                          EX_DATA_NUM,
                          EVAL_WHILE_PREDICT_PREDICTIONS_DIR,
                          PREDICTIONS_DIR,
                          EVALUATIONS_DIR,
                          CONFIG_FILE)
              
      
      current_ckpts[M][m] = current_ckpt
      if sucess:
        sleeping = False
  time.sleep(CHECK_MODEL_EVERY_N_SECS if sleeping else 0)
  if not REPEAT_LOOP:
    break
if not REPEAT_LOOP:
  for evals_dir,metrics in total_metrics.items():
    print("Printing metrics for:",evals_dir,"\n")
    for key,metric in metrics.items():
      print(key+":",metric)
    print("\n")


  