#### local run command
`blaze run -c opt learning/brain/research/babelfish/colab:colab_notebook --define=babelfish_task=multimodal`

In [None]:
import lingvo.compat as tf
import matplotlib.pyplot as plt
import numpy as np
import pprint
import os

from lingvo.core import py_utils
from google3.learning.brain.research.babelfish import tokenizers
from google3.learning.brain.research.babelfish.multimodal.params.experimental import image_text_finetune as gen_params

# from google3.pyglib import gfiler

from google3.perftools.accelerators.xprof.api.colab import xprof

tf.disable_eager_execution()

In [None]:
mdl_it2t = gen_params.FinetuneCNNDMSmall()
mdl_t2t = gen_params.FinetuneCNNDMSmall()

mdl_it2t.USE_RELATIVE_ATTN = True
mdl_it2t.USE_RELATIVE_ATTN = False

mdl_it2t.DROPOUT= 0.0
mdl_t2t.DROPOUT = 0.0

p_it2t = mdl_it2t.Task()
p_t2t = mdl_t2t.Task()

# Note: We use the name as part of var/name scopes, you need to ensure that
# the name here matches for checkpoints to load successfully.

p_it2t.name = 'CNNDM_IT2T'
p_t2t.name = 'CNNDM_T2T'

# imagetext2text:
p_it2t.decoder.shared_emb.softmax.use_num_classes_major_weight = True
p_it2t.encoder.shared_emb.softmax.use_num_classes_major_weight = True

# text2text:
p_t2t.decoder.shared_emb.softmax.use_num_classes_major_weight = False
p_t2t.encoder.shared_emb.softmax.use_num_classes_major_weight = False

p_it2t.input = mdl_it2t.Train()
p_t2t.input = mdl_t2t.Train()

In [None]:
# We are going to use the global graph for this entire colab.
tf.reset_default_graph()

# Instantiate the Task.
task_it2t = p_it2t.Instantiate()
task_t2t = p_t2t.Instantiate()

# Create variables by running FProp.
_ = task_it2t.FPropDefaultTheta()
_ = task_t2t.FPropDefaultTheta()

In [None]:
# Create a new session and initialize all the variables.
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
# Setup the checkpoint loading rules for OverrideVarsFromCheckpoints.
loading_rules_it2t = [
    (
        "CNNDM_IT2T/(.*/var:0$)",  
        "CNNDMTask/%s"    
    )
]

loading_rules_t2t = [
    (
        "CNNDM_T2T/(.*/var:0$)",  
        "CNNDMTask/%s"    
    )
]

ignore_rules = []  # No ignore rules, parse all saved vars.

ckpts_loading_rules = lambda x, y:{
    x: (y, ignore_rules)
}

ignore_rules = []  # No ignore rules, parse all saved vars.
ckpt_path_it2t = '/cns/tp-d/home/runzheyang/brain/rs=6.3/cnndm.imagetext2textlm.small/train/ckpt-00100000'
ckpt_path_t2t = '/cns/tp-d/home/runzheyang/brain/rs=6.3/cnndm.text2textlm.small/train/ckpt-00100000'

# Load the saved checkpoint into the session.
py_utils.OverrideVarsFromCheckpoints(
    tf.all_variables(p_it2t.name+"//*"), ckpts_loading_rules(ckpt_path_it2t, loading_rules_it2t))(sess)
py_utils.OverrideVarsFromCheckpoints(
    tf.all_variables(p_t2t.name+"//*"), ckpts_loading_rules(ckpt_path_t2t, loading_rules_t2t))(sess)

## Generation

In [None]:
import t5
import tensorflow_datasets as tfds

mdl_it2t.TRAIN_BATCH_SIZE = 128
mdl_it2t.EVAL_BATCH_SIZE = 128
mdl_t2t.TRAIN_BATCH_SIZE = 128
mdl_t2t.EVAL_BATCH_SIZE = 128
input_p = mdl_t2t.Test()

input_gen = input_p.Instantiate()
input_gen.Initialize(sess)

In [None]:
from google3.learning.brain.research.babelfish.multimodal.params.experimental import image_text_baselines as it_params
mdl_vocab = it_params.Text2TextLMSmall()
input_v = mdl_vocab.Test()
input_v = input_v.Instantiate()

In [None]:
input_batch = input_gen.GetPreprocessedInputBatch()
encoder_inputs = input_batch.encoder_inputs
encoder_paddings = input_batch.encoder_paddings

# encoder
sources = py_utils.NestedMap(ids=encoder_inputs, paddings=encoder_paddings)
def BeamsearchDec(task, sources):
  encoder_embeddings = task.encoder.FPropEmbeddings(task.theta.encoder, sources)
  encoder_outputs = task.encoder.FPropTransformerLayers(
      task.theta.encoder, encoder_embeddings)
  encoder_outputs = task.decoder.AddExtraDecodingInfo(encoder_outputs,
                                                      input_batch)
  decoded = task.decoder.BeamSearchDecode(encoder_outputs)

  return task._ProcessBeamSearchDecodeOut(input_batch, decoded)

decode_outs = py_utils.NestedMap({
    "it2t": BeamsearchDec(task_it2t, sources),
    "t2t": BeamsearchDec(task_t2t, sources)
})

In [None]:
decode_outs

In [None]:
test_out = sess.run(decode_outs)

In [None]:
test_out

In [None]:
id = 3

In [None]:
test_out['it2t']['topk_decoded'][id]

In [None]:
test_out['t2t']['topk_decoded'][id]

In [None]:
from google3.pyglib import gfile
import pandas as pd

with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/concreteness.xlsx', 'rb') as fh:  
  concrete_scores = pd.read_excel(fh)

In [None]:
concrete_scores

In [None]:
bool_wordpiece = []
for i, w in enumerate(list(concrete_scores["Word"])):
  ids = input_v._vocabulary.encode(str(w))
  bool_wordpiece.append(len(ids) == 1)

In [None]:
concrete_scores['is_wordpiece'] = bool_wordpiece

In [None]:
concrete_scores

In [None]:
concrete_scores["Conc.M"][concrete_scores["is_wordpiece"]]

In [None]:
cr_wid = [input_v._vocabulary.encode(w)[0] for w in concrete_scores["Word"][concrete_scores["is_wordpiece"]]]

In [None]:
cr_wid = np.array(cr_wid).flatten()

In [None]:
cr_score = np.array(concrete_scores["Conc.M"][concrete_scores["is_wordpiece"]])

In [None]:
cr_dict = dict(zip(cr_wid, cr_score))

In [None]:
it2t_gen_words = [input_v._vocabulary.encode(s) for s in test_out['it2t']['topk_decoded'].flatten()]

In [None]:
cscore_it2t = [[cr_dict[w] for w in s if w in cr_wid] for s in it2t_gen_words]

In [None]:
t2t_gen_words = [input_v._vocabulary.encode(s) for s in test_out['t2t']['topk_decoded'].flatten()]

In [None]:
cscore_t2t = [[cr_dict[w] for w in s if w in cr_wid] for s in t2t_gen_words]

In [None]:
def lflatten(t):
    return [item for sublist in t for item in sublist]

In [None]:
import seaborn as sns

sns.set_context('talk')

In [None]:
sns.distplot(lflatten(cscore_it2t), color='orange', label='it2t')
sns.distplot(lflatten(cscore_t2t), color='green', label='t2t')

plt.xlabel("Concreteness")
plt.legend()