#### local run command
`blaze run -c opt learning/brain/research/babelfish/colab:colab_notebook --define=babelfish_task=multimodal`

In [None]:
import lingvo.compat as tf
import matplotlib.pyplot as plt
import numpy as np
import pprint
import os

from lingvo.core import py_utils
from google3.learning.brain.research.babelfish import tokenizers
from google3.learning.brain.research.babelfish.multimodal.params.experimental import image_text_baselines as it_params
from google3.learning.brain.research.babelfish.multimodal.params.experimental import nlu_baselines as nlu_params

# from google3.pyglib import gfiler

from google3.perftools.accelerators.xprof.api.colab import xprof

tf.disable_eager_execution()

## Load IT2T and T2T models

In [None]:
mdl_it2t = nlu_params.QNLIClassification()
mdl_t2t = nlu_params.QNLIClassification()

mdl_it2t.DROPOUT_RATE = 0.0
mdl_t2t.DROPOUT_RATE = 0.0

p_it2t = mdl_it2t.Task()
p_t2t = mdl_t2t.Task()

# Note: We use the name as part of var/name scopes, you need to ensure that
# the name here matches for checkpoints to load successfully.

p_it2t.name = 'GLUETask_IT2T'
p_t2t.name = 'GLUETask_T2T'

# imagetext2text:
p_it2t.decoder.shared_emb.softmax.use_num_classes_major_weight = True
p_it2t.encoder.shared_emb.softmax.use_num_classes_major_weight = True

# text2text:
p_t2t.decoder.shared_emb.softmax.use_num_classes_major_weight = False
p_t2t.encoder.shared_emb.softmax.use_num_classes_major_weight = False

p_it2t.input = mdl_it2t.Train()
p_t2t.input = mdl_t2t.Train()

In [None]:
# We are going to use the global graph for this entire colab.
tf.reset_default_graph()

# Instantiate the Task.
task_it2t = p_it2t.Instantiate()
task_t2t = p_t2t.Instantiate()

# Create variables by running FProp.
_ = task_it2t.FPropDefaultTheta()
_ = task_t2t.FPropDefaultTheta()

In [None]:
# Create a new session and initialize all the variables.
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
# Setup the checkpoint loading rules for OverrideVarsFromCheckpoints.
loading_rules_it2t = [
    (
        "GLUETask_IT2T/(.*/var:0$)",  
        "GLUETask/%s"    
    )
]

loading_rules_t2t = [
    (
        "GLUETask_T2T/(.*/var:0$)",  
        "GLUETask/%s"    
    )
]

ignore_rules = []  # No ignore rules, parse all saved vars.

ckpts_loading_rules = lambda x, y:{
    x: (y, ignore_rules)
}

ignore_rules = []  # No ignore rules, parse all saved vars.
ckpt_path_it2t = '/cns/tp-d/home/runzheyang/brain/rs=6.3/qnli.imagetext2textlm.small.fixedtranspose.1m.lr3e-5.lineardecay.dropout01/train/ckpt-00010000'
ckpt_path_t2t = '/cns/tp-d/home/runzheyang/brain/rs=6.3/qnli.text2textlm.small.fixedtranspose.1m.lr3e-5.lineardecay.dropout01/train/ckpt-00010000'

# Load the saved checkpoint into the session.
py_utils.OverrideVarsFromCheckpoints(
    tf.all_variables(p_it2t.name+"//*"), ckpts_loading_rules(ckpt_path_it2t, loading_rules_it2t))(sess)
py_utils.OverrideVarsFromCheckpoints(
    tf.all_variables(p_t2t.name+"//*"), ckpts_loading_rules(ckpt_path_t2t, loading_rules_t2t))(sess)

## Load Dataset

In [None]:
import t5
import tensorflow_datasets as tfds

# total 5463
mdl_t2t.TRAIN_BATCH_SIZE = 1
mdl_t2t.EVAL_BATCH_SIZE = 1
input_p = mdl_t2t.Test()

input_gen = input_p.Instantiate()
input_gen.Initialize(sess)

In [None]:
#@title Sample Tasks/Mixtures examples
task_name = 'glue_qnli_v002' #@param
split = 'validation' #@param
inputs_length = 256  #@param
targets_length = 32  #@param
num_samples =      5#@param

task = t5.data.TaskRegistry.get(task_name)
for s in task.splits:
    print('%s: %d' % (s, task.num_input_examples(s)))
print()

ds = task.get_dataset(split=split, sequence_length={"inputs": inputs_length, "targets": targets_length})
print("A few preprocessed {} examples...".format(split))

def print_example(example):
  print('===')
  print('input:')
  print(input_gen._vocabulary._decode(
        [int(ids) for ids in example['inputs']]))
  print('output:')
  print(example['targets_pretokenized'])

for ex in tfds.as_numpy(ds.take(num_samples)):
  print_example(ex)
  break

In [None]:
def process_ex(ex):
  ids = np.pad(ex['inputs'], (1, 511-len(ex['inputs'])), 'constant', 
               constant_values=(0, 0)).reshape(1,-1)
  paddings = np.pad(np.zeros(len(ex['inputs'])+1), (0, 511-len(ex['inputs'])), 
                    'constant', constant_values=(1, 1)).reshape(1,-1)
  labels = 1 if ex['targets_pretokenized'] == b'entailment' else 0
  return ids, paddings, labels

## Evaluation on the whole validation test

In [None]:
def get_predictions(task, sources):
  # encoder
  encoder_embeddings = task.encoder.FPropEmbeddings(task.theta.encoder, sources)
  encoder_outputs = task.encoder.FPropTransformerLayers(task.theta.encoder, 
                                                        encoder_embeddings)

  # decoder
  targets = py_utils.NestedMap(ids=sources.ids, paddings=sources.paddings)
  decoder_outputs = task.decoder.ComputePredictions(task.theta.decoder,
                                                    encoder_outputs, targets)

  classifier_input = task._extract_classifier_input(sources.paddings, decoder_outputs)

  predictions = task._apply_classifier(task.theta, classifier_input)

  return predictions

feed_ids =  tf.placeholder(tf.int32, shape=[1,512])
feed_paddings = tf.placeholder(tf.float32, shape=[1,512])

sources = py_utils.NestedMap(ids=feed_ids, paddings=feed_paddings)
predictions_it2t = get_predictions(task_it2t, sources)
predictions_t2t = get_predictions(task_t2t, sources)


# Notice that we are calling this with task.theta which ensures that we are
# using the same variables which we have just loaded.
fetches = py_utils.NestedMap(
          {"sources": sources,
           "predictions_it2t":predictions_it2t,
           "predictions_t2t":predictions_t2t
           })

print(fetches)

In [None]:
max([len(ex['inputs']) for ex in tfds.as_numpy(ds.take(task.num_input_examples("validation")))])

In [None]:
labels = []
test_outputs = []

for ex in tfds.as_numpy(ds.take(task.num_input_examples("validation"))):
  ids, paddings, label = process_ex(ex) 
  labels.append(label)
  test_outputs.append(sess.run(fetches, {feed_ids: ids, feed_paddings: paddings}))

In [None]:
VAL_SIZE = task.num_input_examples("validation")
print((np.array([labels[i] - np.argmax(test_outputs[i]['predictions_it2t']["probs"]) for i in range(VAL_SIZE)]) == 0).sum() / VAL_SIZE)
print((np.array([labels[i] - np.argmax(test_outputs[i]['predictions_t2t']["probs"]) for i in range(VAL_SIZE)]) == 0).sum() / VAL_SIZE)

In [None]:
def pretty_print_examples(input_str, label, prediction_it2t, prediction_t2t):
  print(input_str)
  print("label: " + ("\x1b[32mPOSITIVE\x1b[0m" if label == 1 else "\x1b[31mNEGATIVE\x1b[0m"))
  print("IT2T prediction: " + ("\x1b[32mPOSITIVE\x1b[0m" if prediction_it2t == 1 else "\x1b[31mNEGATIVE\x1b[0m"))
  print("T2T prediction: " + ("\x1b[32mPOSITIVE\x1b[0m" if prediction_t2t == 1 else "\x1b[31mNEGATIVE\x1b[0m"))
  print()

# check a few examples from test set
for i in range(100):
  pred_it2t = np.argmax(test_outputs[i]["predictions_it2t"]["probs"], axis=1)
  pred_t2t = np.argmax(test_outputs[i]["predictions_t2t"]["probs"], axis=1)
  if pred_it2t != pred_t2t:
    pretty_print_examples(input_gen._vocabulary._decode(
        [int(ids) for ids in test_outputs[i]["sources"]["ids"][0]]),
        labels[i],
        pred_it2t,
        pred_t2t)

In [None]:
all_ex = [input_gen._vocabulary._decode(
          [int(ids) for ids in test_outputs[i]["sources"]["ids"][0]]) for i in range(5463)]

In [None]:
it2t_ex, t2t_ex = [], []
for i in range(5463):
  pred_it2t = np.argmax(test_outputs[i]["predictions_it2t"]["probs"], axis=1)
  pred_t2t = np.argmax(test_outputs[i]["predictions_t2t"]["probs"], axis=1)
  ex = input_gen._vocabulary._decode(
          [int(ids) for ids in test_outputs[i]["sources"]["ids"][0]])
  if pred_it2t != labels[i]:
    it2t_ex.append(ex)
  if pred_t2t != labels[i]:
    t2t_ex.append(ex)

In [None]:
len(np.intersect1d(all_ex, it2t_ex))

In [None]:
len(it2t_ex)

In [None]:
all_ex.index(query)

In [None]:
test_outputs[1429]["predictions_it2t"]["probs"]

In [None]:
test_outputs[1429]["predictions_t2t"]["probs"]

In [None]:
query = 'qnli question: What branch is independant of the other branches? sentence: The Judiciary is independent of the executive and the legislature.'
print(query in it2t_ex)
print(query in t2t_ex)

In [None]:
len(np.intersect1d(it2t_ex, t2t_ex))

In [None]:
len(it2t_ex)

In [None]:
len(t2t_ex)

In [None]:
import json
from google3.pyglib import gfile

with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/qnli_validaiton', 'wt') as fh:  
  json.dump(all_ex, fh)

In [None]:
with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/qnli_it2t_failue', 'wt') as fh:  
  json.dump(it2t_ex, fh)

with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/qnli_t2t_failue', 'wt') as fh:  
  json.dump(t2t_ex, fh)

In [None]:
# with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/sst2_validaiton', 'r') as fh:  
#   all_ex_ = json.load(fh)

In [None]:
with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/5000-words.txt', 'r') as f:
  freq_words = f.read()

In [None]:
input_p = mdl_t2t.Train()
input_gen = input_p.Instantiate()

In [None]:
freq_ids = input_gen._vocabulary._encode(freq_words)

In [None]:
freq_ids = np.unique(freq_ids)
freq_ids = [int(i) for i in freq_ids]

In [None]:
input_gen._vocabulary._decode(freq_ids)

In [None]:
len(freq_ids)

In [None]:
it2t_fm_ids = [input_gen._vocabulary._encode(ex) for ex in it2t_ex]

In [None]:
[np.intersect1d(freq_ids, ids) for ids in  it2t_fm_ids]