#### local run command
`blaze run -c opt learning/brain/research/babelfish/colab:colab_notebook --define=babelfish_task=multimodal`

In [None]:
import lingvo.compat as tf
import matplotlib.pyplot as plt
import numpy as np
import pprint
import os

from lingvo.core import py_utils
from google3.learning.brain.research.babelfish import tokenizers
from google3.learning.brain.research.babelfish.multimodal.params.experimental import image_text_baselines as it_params
from google3.learning.brain.research.babelfish.multimodal.params.experimental import nlu_baselines as nlu_params

# from google3.pyglib import gfiler

from google3.perftools.accelerators.xprof.api.colab import xprof

tf.disable_eager_execution()

In [None]:
# mdl = it_params.ImageText2TextLMSmall()
mdl = nlu_params.QNLIClassification()
p = mdl.Task()

# Note: We use the name as part of var/name scopes, you need to ensure that
# the name here matches for checkpoints to load successfully.


p.name = 'GLUETask'

# if text2text:
p.decoder.shared_emb.softmax.use_num_classes_major_weight = False
p.encoder.shared_emb.softmax.use_num_classes_major_weight = False

p.input = mdl.Train()

In [None]:
# We are going to use the global graph for this entire colab.
tf.reset_default_graph()

# Instantiate the Task.
task = p.Instantiate()

# Create variables by running FProp.
_ = task.FPropDefaultTheta()

In [None]:
# Create a new session and initialize all the variables.
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
# Setup the checkpoint loading rules for OverrideVarsFromCheckpoints.
loading_rules = [
    (
        "(.*)",  # Regexp match all variables in the ckpt.
        "%s"     # Format string to use the saved var name as is.
    )
]
ignore_rules = []  # No ignore rules, parse all saved vars.
# ckpt_path = '/cns/tp-d/home/runzheyang/brain/rs=6.3/qnli.imagetext2textlm/train/ckpt-00010000'
ckpt_path = '/cns/tp-d/home/runzheyang/brain/rs=6.3/qnli.text2textlm/train/ckpt-00010000'

ckpts_loading_rules = {
    ckpt_path: (loading_rules, ignore_rules)
}

# Load the saved checkpoint into the session.
py_utils.OverrideVarsFromCheckpoints(
    tf.all_variables(), ckpts_loading_rules)(sess)

## Check examples from the training set

In [None]:
mdl.TRAIN_BATCH_SIZE = 128
mdl.EVAL_BATCH_SIZE = 128
input_p = mdl.Train()

input_gen = input_p.Instantiate()
input_batch = input_gen.GetPreprocessedInputBatch()

ids = input_batch.ids
paddings = input_batch.paddings

# encoder
sources = py_utils.NestedMap(ids=ids, paddings=paddings)
encoder_embeddings = task.encoder.FPropEmbeddings(task.theta.encoder, sources)
encoder_outputs = task.encoder.FPropTransformerLayers(task.theta.encoder, 
                                                      encoder_embeddings)

# decoder
targets = py_utils.NestedMap(ids=ids, paddings=paddings)
decoder_outputs = task.decoder.ComputePredictions(task.theta.decoder,
                                                  encoder_outputs, targets)

classifier_input = task._extract_classifier_input(paddings, decoder_outputs)

predictions = task._apply_classifier(task.theta, classifier_input)

# Notice that we are calling this with task.theta which ensures that we are
# using the same variables which we have just loaded.
fetches = py_utils.NestedMap(
          {"labels": input_batch.labels,
           "sources": sources,
           "encoder_embeddings": encoder_embeddings,
           "encoder_outputs": encoder_outputs,
           "decoder_outputs": decoder_outputs,
           "classifier_input": classifier_input,
           "predictions":predictions})

print(fetches)

In [None]:
input_gen.Initialize(sess)

In [None]:
test_output = sess.run(fetches)

In [None]:
def pretty_print_examples(input_str, label, prediction):
  index = input_str.find('sentence')
  print(input_str[:index] + "\n" + input_str[index:])
  print("label: " + ("\x1b[32mENTAIL\x1b[0m" if label == 1 else "\x1b[31mNOT ENTAIL\x1b[0m"))
  print("prediction: " + ("\x1b[32mENTAIL\x1b[0m" if prediction == 1 else "\x1b[31mNOT ENTAIL\x1b[0m"))
  print()

In [None]:
# check batch accuracy
((test_output["labels"].reshape(-1) - 
  np.argmax(test_output["predictions"]["probs"], axis=1)) == 0).sum() / mdl.TRAIN_BATCH_SIZE

In [None]:
# check a few examples from training set
for i in range(mdl.TRAIN_BATCH_SIZE):
  pretty_print_examples(input_gen._vocabulary._decode(
      [int(ids) for ids in test_output["sources"]["ids"][i]]),
      test_output["labels"][i],
      np.argmax(test_output["predictions"]["probs"], axis=1)[i])

In [None]:
emb_out = test_output["encoder_embeddings"]["input_embs"][test_output["encoder_embeddings"]["paddings"]==0]
emb_out.shape

In [None]:
enc_out = test_output["encoder_outputs"]["encoded"][test_output["encoder_outputs"]["padding"] == 0]
enc_out.shape

In [None]:
dec_out = test_output["decoder_outputs"].reshape(-1, 512)
dec_out.shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('talk')

In [None]:
def cov(X):
    X = (X - X.mean(0))
    return X.T.dot(X)/X.shape[0]

In [None]:
emb_cov = cov(emb_out)
emb_eigvals, emb_eigvecs = np.linalg.eig(emb_cov)

enc_cov = cov(enc_out)
enc_eigvals, enc_eigvecs = np.linalg.eig(enc_cov)

In [None]:
dec_cov = cov(dec_out)
dec_eigvals, dec_eigvecs = np.linalg.eig(dec_cov)

In [None]:
# check if neural activity lies on a low-dimensional manifold
top_k = 100
plt.plot(np.arange(len(emb_eigvals))[:top_k], 
         np.cumsum(emb_eigvals[:top_k])/emb_eigvals.sum(), label="embedding")
plt.plot(np.arange(len(enc_eigvals))[:top_k], 
         np.cumsum(enc_eigvals[:top_k])/enc_eigvals.sum(), label="encoder output")
plt.plot(np.arange(len(enc_eigvals))[:top_k], 
         np.cumsum(dec_eigvals[:top_k])/dec_eigvals.sum(), label="decoder output")
plt.ylabel("variance explained")
plt.legend()
plt.show()

In [None]:
len(np.unique(test_output["sources"]["ids"].reshape(-1)))

## Check examples from the test dataset

In [None]:
# feeds = {
#     'text': tf.placeholder(tf.string, shape=[1,])
# }
# input_batch = py_utils.NestedMap(encoder_inputs=feeds['text'], 
#                                  decoder_inputs=feeds['text'])

mdl.TRAIN_BATCH_SIZE = 64
mdl.EVAL_BATCH_SIZE = 64
input_p = mdl.Test()

input_gen = input_p.Instantiate()
input_batch = input_gen.GetPreprocessedInputBatch()

ids = input_batch.ids
paddings = input_batch.paddings

# encoder
sources = py_utils.NestedMap(ids=ids, paddings=paddings)
encoder_embeddings = task.encoder.FPropEmbeddings(task.theta.encoder, sources)
encoder_outputs = task.encoder.FPropTransformerLayers(task.theta.encoder, 
                                                      encoder_embeddings)

# decoder
targets = py_utils.NestedMap(ids=ids, paddings=paddings)
decoder_outputs = task.decoder.ComputePredictions(task.theta.decoder,
                                                  encoder_outputs, targets)

classifier_input = task._extract_classifier_input(paddings, decoder_outputs)

predictions = task._apply_classifier(task.theta, classifier_input)

# Notice that we are calling this with task.theta which ensures that we are
# using the same variables which we have just loaded.
fetches = py_utils.NestedMap(
          {"labels": input_batch.labels,
           "sources": sources,
           "encoder_embeddings": encoder_embeddings,
           "encoder_outputs": encoder_outputs,
           "decoder_outputs": decoder_outputs,
           "classifier_input": classifier_input,
           "predictions":predictions})

print(fetches)

In [None]:
input_gen.Initialize(sess)
test_output = sess.run(fetches)

In [None]:
# check batch accuracy
((test_output["labels"].reshape(-1) - 
  np.argmax(test_output["predictions"]["probs"], axis=1)) == 0).sum() / mdl.TRAIN_BATCH_SIZE

In [None]:
# check a few examples from test set
for i in range(mdl.TRAIN_BATCH_SIZE):
  pretty_print_examples(input_gen._vocabulary._decode(
      [int(ids) for ids in test_output["sources"]["ids"][i]]),
      test_output["labels"][i],
      np.argmax(test_output["predictions"]["probs"], axis=1)[i])

In [None]:
emb_out = test_output["encoder_embeddings"]["input_embs"][test_output["encoder_embeddings"]["paddings"]==0]
emb_out.shape

In [None]:
enc_out = test_output["encoder_outputs"]["encoded"][test_output["encoder_outputs"]["padding"] == 0]
enc_out.shape

In [None]:
dec_out = test_output["decoder_outputs"].reshape(-1, 512)
dec_out.shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('talk')

In [None]:
def cov(X):
    X = (X - X.mean(0))
    return X.T.dot(X)/X.shape[0]

In [None]:
emb_cov = cov(emb_out)
emb_eigvals, emb_eigvecs = np.linalg.eig(emb_cov)

enc_cov = cov(enc_out)
enc_eigvals, enc_eigvecs = np.linalg.eig(enc_cov)

In [None]:
dec_cov = cov(dec_out)
dec_eigvals, dec_eigvecs = np.linalg.eig(dec_cov)

In [None]:
# check if neural activity lies on a low-dimensional manifold
top_k = 100
plt.plot(np.arange(len(emb_eigvals))[:top_k], 
         np.cumsum(emb_eigvals[:top_k])/emb_eigvals.sum(), label="embedding")
plt.plot(np.arange(len(enc_eigvals))[:top_k], 
         np.cumsum(enc_eigvals[:top_k])/enc_eigvals.sum(), label="encoder output")
plt.plot(np.arange(len(enc_eigvals))[:top_k], 
         np.cumsum(dec_eigvals[:top_k])/dec_eigvals.sum(), label="decoder output")
plt.ylabel("variance explained")
plt.legend()
plt.show()

In [None]:
len(np.unique(test_output["sources"]["ids"].reshape(-1)))