# GCS Auth

In [None]:
print("Setting up GCS access...")
import os
os.environ['USE_AUTH_EPHEM'] = '0'
from google.colab import auth
auth.authenticate_user()

# Set up

In [None]:
!pip install gcsfs
import tensorflow.compat.v1 as tf

# Helper functions 

In [None]:
def check_model_accuracy(model_index, predictions_path, targets_path):
  # Get most recent prediction file by sorting by their step.
  prediction_file = tf.io.gfile.glob(f"{predictions_path}")[0]
  # print(f"Checking accuracy of: {prediction_file}")
  targets = []
  with  tf.io.gfile.GFile(targets_path, "r") as f_targets, \
        tf.io.gfile.GFile(prediction_file, "r") as f_pred:
    # Read targets (real method names) and predictions made by model
    targets = f_targets.readlines()
    predictions = f_pred.readlines()
    
    assert len(targets) == len(predictions), f"{len(targets)} != {len(predictions)}"

    # comapre two sets
    perfect_predictions = 0
    for x,y in zip(targets, predictions):
      x = ''.join(x.split())   # To fix double-space issue
      y = ''.join(y.split())
      if x == y:
        perfect_predictions += 1
    accuracy = perfect_predictions*100.0/len(targets)    
    # print(f"Instances: {len(targets)}\t\tModel Accuracy: {perfect_predictions*100.0/len(targets):.2f}% (pp={perfect_predictions})")
    # print(f"="*50)
    return float(accuracy)

In [None]:
import numpy as np
def _is_improvement(monitor_value, reference_value, delta):
  '''
  Arg 1: monitor_value the accuracy we are checking
  Arg 2: reference_value the accuracy we are checking against
  arg 3: delta the min difference between the two accuracies to be improved
  '''
  delta = abs(delta)
  return np.greater(monitor_value - delta, reference_value)

In [None]:
def get_best_checkpoint(checkpoints, accuracies):
  baseline = 0
  best_acc = 0
  best_check = 0
  delta = 0.01
  patience = 5
  wait = 0
  for current_check, current_acc in zip(checkpoints, accuracies):
    if wait == patience:
      print('stopped')
      break
    wait += 1
    if _is_improvement(current_acc, best_acc, delta):
      best_acc = current_acc
      best_check = current_check
      if _is_improvement(current_acc, baseline, delta):
        wait = 0
    baseline = current_acc
  print(best_check, best_acc)

# Variables and paths

In [None]:
scheduler = "isr" #@param ["polynomial", "constant", "isr", "slanted"]
finetune_task =  "classifier" #@param ['multi-log-injection', 'single-log-injection', 'classifier']
pretrain_task = "masking" #@param ['masking']
multi_task = "one-to-n" #@param ['one-to-n']
representation = "tokens" #@param ['tokens']
split = 'validation_eval' #@param ['validation_eval', 'test_eval']
task = 'classification' #@param ['log_injection', 'classification']

if finetune_task == 'multi-log-injection':
  target_path = f'gs://lance2/finetuned-model/{finetune_task}/{multi_task}/{representation}/{scheduler}/{split}/{task}_targets'
else:
  target_path = f'gs://lance2/finetuned-model/{finetune_task}/{representation}/{scheduler}/{split}/{task}_targets'
print(target_path)


# Compute accuracies and early stopping

In [None]:
checkpoints = []

for i in range(500000, 610000, 10000):
  checkpoints.append(i)
# checkpoints

In [None]:
accuracies = []

for checkpoint in checkpoints:
  
  if finetune_task == 'multi-log-injection':
    prediction_path = f'gs://lance2/finetuned-model/{finetune_task}/{multi_task}/{representation}/{scheduler}/{split}/{task}_{checkpoint}_predictions'
  else:
   prediction_path = f'gs://lance2/finetuned-model/{finetune_task}/{representation}/{scheduler}/{split}/{task}_{checkpoint}_predictions'
  
  accuracies.append(check_model_accuracy(checkpoints, prediction_path, target_path))

In [None]:
get_best_checkpoint(checkpoints, accuracies)

In [None]:
max_acc = max(accuracies[:])
index = accuracies.index(max_acc)
print(checkpoints[index])