#### local run command
`blaze run -c opt learning/brain/research/babelfish/colab:colab_notebook --define=babelfish_task=multimodal`

In [None]:
import lingvo.compat as tf
import matplotlib.pyplot as plt
import numpy as np
import pprint
import os
import pandas as pd

from lingvo.core import py_utils
from google3.learning.brain.research.babelfish import tokenizers
from google3.learning.brain.research.babelfish.multimodal.params.experimental import image_text_baselines as it_params

# from google3.pyglib import gfiler

from google3.perftools.accelerators.xprof.api.colab import xprof

tf.disable_eager_execution()

In [None]:
import nltk

In [None]:
mdl_it2t = it_params.ImageText2TextLMSmall()
mdl_t2t = it_params.Text2TextLMSmall()
mdl_t2t_twin = it_params.Text2TextLMSmall()

p_it2t = mdl_it2t.Task()
p_t2t = mdl_t2t.Task()
p_t2t_twin = mdl_t2t_twin.Task()

# Note: We use the name as part of var/name scopes, you need to ensure that
# the name here matches for checkpoints to load successfully.

p_it2t.name = 'ImageText2TextLMTask'
p_t2t.name = 'Text2TextLM'
p_t2t_twin.name = 'Text2TextLM_Twin'
p_it2t.input = mdl_it2t.Train()
p_t2t.input = mdl_t2t.Train()
p_t2t_twin.input = mdl_t2t_twin.Train()

In [None]:
# We are going to use the global graph for this entire colab.
tf.reset_default_graph()

# Instantiate the Task.
task_it2t = p_it2t.Instantiate()
task_t2t = p_t2t.Instantiate()
task_t2t_twin = p_t2t_twin.Instantiate()

# Create variables by running FProp.
_ = task_it2t.FPropDefaultTheta()
_ = task_t2t.FPropDefaultTheta()
_ = task_t2t_twin.FPropDefaultTheta()

In [None]:
# Create a new session and initialize all the variables.
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
# Setup the checkpoint loading rules for OverrideVarsFromCheckpoints.
loading_rules = [
    (
        "(.*)",  # Regexp match all variables in the ckpt.
        "%s"     # Format string to use the saved var name as is.
    )
]
loading_rules_twin = [
    (
        "Text2TextLM_Twin/(.*/var:0$)",  
        "Text2TextLM/%s"    
    )
]

ignore_rules = []  # No ignore rules, parse all saved vars.

ckpts_loading_rules = lambda x, y:{
    x: (y, ignore_rules)
}

# ckpt_path_it2t = '/cns/jn-d/home/ziruiw/brain/rs=6.3/ImageText2TextLM.small.ibz4096.tbz512.PrefixLM.Res50.Trans2.BatchMajor.RelPos.LR5e4.WD1e2/train/ckpt-01000000'
ckpt_path_t2t = '/cns/tp-d/home/runzheyang/brain/rs=6.3/text2textlm.small.fixedtranspose.1/train/ckpt-01000000'
ckpt_path_t2t_twin = '/cns/tp-d/home/runzheyang/brain/rs=6.3/text2textlm.small.fixedtranspose.1.twin/train/ckpt-01000000'
ckpt_path_it2t = '/cns/mb-d/home/yuancao/brain/rs=6.3/mm_it2t_10_0.5/train/ckpt-00759000'


# Load the saved checkpoint into the session.
py_utils.OverrideVarsFromCheckpoints(
    tf.all_variables(p_it2t.name+"//*"), ckpts_loading_rules(ckpt_path_it2t, loading_rules))(sess)
py_utils.OverrideVarsFromCheckpoints(
    tf.all_variables(p_t2t.name+"//*"), ckpts_loading_rules(ckpt_path_t2t, loading_rules))(sess)
py_utils.OverrideVarsFromCheckpoints(
    tf.all_variables(p_t2t_twin.name+"//*"), ckpts_loading_rules(ckpt_path_t2t_twin, loading_rules_twin))(sess)

### Load top 5000 frequent words

In [None]:
from google3.pyglib import gfile
with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/5000-words.txt', 'r') as f:
  freq_words = f.read()

In [None]:
input_p = mdl_t2t.Train()
input_gen = input_p.Instantiate()

In [None]:
freq_ids = input_gen._vocabulary._encode(freq_words)

In [None]:
freq_ids = np.unique(freq_ids)
freq_ids = [int(i) for i in freq_ids]

In [None]:
input_gen._vocabulary._decode(freq_ids)

In [None]:
len(freq_ids)

In [None]:
# get token embedding (w/o positional embedding), assuming share_emd=True.
it2t_token_embeddings = task_it2t.encoder.softmax.EmbLookup(
    task_it2t.theta.encoder.softmax, freq_ids)

t2t_token_embeddings = task_t2t.encoder.softmax.EmbLookup(
    task_t2t.theta.encoder.softmax, freq_ids)

t2t_twin_token_embeddings = task_t2t_twin.encoder.softmax.EmbLookup(
    task_t2t_twin.theta.encoder.softmax, freq_ids)

fetches = py_utils.NestedMap({
    "it2t_emb": it2t_token_embeddings,
    "t2t_emb": t2t_token_embeddings,
    "t2t_twin_emb": t2t_twin_token_embeddings
})

print(fetches)

In [None]:
emb_output = sess.run(fetches)

In [None]:
it2t_emb = emb_output["it2t_emb"]
t2t_emb = emb_output["t2t_emb"]
t2t_twin_emb = emb_output["t2t_twin_emb"]

In [None]:
def cos_similarity(vecs):
  dotp = vecs.dot(vecs.T)
  norm = np.sqrt((vecs ** 2).sum(1))
  length = np.outer(norm, norm)
  return dotp / length - np.eye(len(vecs))

def k_nn(sim_matrix, k=-1):
  return np.argsort(sim_matrix, -1)[:,::-1][:, :k]

In [None]:
# obtain top 10 similar words
sim_matrix_it2t = cos_similarity(it2t_emb)
knn_it2t_10 = k_nn(sim_matrix_it2t, 10)

In [None]:
# obtain top 10 similar words
sim_matrix_t2t = cos_similarity(t2t_emb)
knn_t2t_10 = k_nn(sim_matrix_t2t, 10)

In [None]:
# obtain top 10 similar words
sim_matrix_t2t_twin = cos_similarity(t2t_twin_emb)
knn_t2t_twin_10 = k_nn(sim_matrix_t2t_twin, 10)

In [None]:


# obtain top 5 similar words
knn_it2t_5 = k_nn(sim_matrix_it2t, 5)
knn_t2t_5 = k_nn(sim_matrix_t2t, 5)
knn_t2t_twin_5 = k_nn(sim_matrix_t2t_twin, 5)

In [None]:
knn_it2t_10.shape

In [None]:
knn_t2t_twin_10

### Check nearest words

In [None]:
def check_id(rid, freq_ids, knn):
  print("query:", input_gen._vocabulary._decode([int(freq_ids[rid])]))
  print("similar words:", [input_gen._vocabulary._decode([int(freq_ids[i])])  for i in knn[rid]])

def check_word(word, freq_ids, knn):
  wids = input_gen._vocabulary._encode(word)
  for wid in wids:
    # skip the empty token..
    if wid == 3: continue
    rid = freq_ids.index(wid)
    check_id(rid, freq_ids, knn)

In [None]:
query = 'cat'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)
check_word(query, freq_ids, knn_t2t_twin_10)

In [None]:
query = 'dream'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)
check_word(query, freq_ids, knn_t2t_twin_10)

In [None]:
np.intersect1d(['fantasy', 'imagine', 'passion', 'vision', 'desire', 'delight', 'ambition', 'imagination', 'miracle', 'wish'], 
                ['nightmare', 'vision', 'imagine', 'envision', 'imagination', 'fantasy', 'desire', 'wish', 'joy', 'wake'])

In [None]:
query = 'throw'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)
check_word(query, freq_ids, knn_t2t_twin_10)

In [None]:
query = 'sing'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)
check_word(query, freq_ids, knn_t2t_twin_10)

In [None]:
query = 'compromise'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)
check_word(query, freq_ids, knn_t2t_twin_10)

### Quantitive comparison

In [None]:
def diff_scores(knn1, knn2, k):
  return [len(np.intersect1d(knn1[i], knn2[i]))/k for i in range(len(knn1))] 

In [None]:
top_k = 10
it2t_vs_t2t = diff_scores(k_nn(sim_matrix_it2t, top_k), k_nn(sim_matrix_t2t, top_k), top_k)
np.mean(it2t_vs_t2t)

In [None]:
t2t_vs_t2t = diff_scores(k_nn(sim_matrix_t2t, top_k), k_nn(sim_matrix_t2t_twin, top_k), top_k)
np.mean(t2t_vs_t2t)

In [None]:
import seaborn as sns
sns.set_context('talk')

distplot = lambda x, c, l: sns.distplot(x, kde=False,
                            hist_kws={"histtype": "step", "linewidth": 3,
                            "alpha": 0.8, "color": c, "label": l})

distplot(it2t_vs_t2t, 'g', 'it2t vs t2t')
distplot(t2t_vs_t2t, 'orange', 't2t vs t2t')
plt.xlabel("top 10 nearest neighbor coherence")
plt.legend(loc='upper left')

In [None]:
def id2word(rid):
  return(input_gen._vocabulary._decode([int(freq_ids[rid])]))

In [None]:
# most dissimilar words (it2t vs t2t)
np.vectorize(id2word)(np.argsort(it2t_vs_t2t))[:100]

In [None]:
# most similar words (it2t vs t2t)
np.vectorize(id2word)(np.argsort(it2t_vs_t2t))[-30:]

In [None]:
query = 'abstract'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)
check_word(query, freq_ids, knn_t2t_twin_10)

In [None]:
# most dissimilar words (t2t vs t2t)
np.vectorize(id2word)(np.argsort(t2t_vs_t2t))[:30]

In [None]:
from scipy import stats
def r2(x, y):
    return stats.pearsonr(x, y)[0] ** 2

sns.regplot(t2t_vs_t2t, it2t_vs_t2t)
plt.xlabel("t2t vs t2t")
plt.ylabel("it2t vs t2t")

In [None]:
from scipy import stats
def r2(x, y):
    return stats.pearsonr(x, y)[0] ** 2

r2(it2t_vs_t2t, t2t_vs_t2t)

In [None]:
hard_words = np.arange(len(freq_ids))[(np.array(it2t_vs_t2t) < 0.3) & (np.array(t2t_vs_t2t) < 0.3)]

In [None]:
np.vectorize(id2word)(hard_words)

In [None]:
easy_words = np.arange(len(freq_ids))[(np.array(it2t_vs_t2t) > 0.8) & (np.array(t2t_vs_t2t) > 0.8)]

In [None]:
np.vectorize(id2word)(easy_words)

In [None]:
query = 'decrease'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)
check_word(query, freq_ids, knn_t2t_twin_10)

In [None]:
query = 'traditional'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)
check_word(query, freq_ids, knn_t2t_twin_10)

## Compare Concreteness Scores

In [None]:
from google3.pyglib import gfile

with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/concreteness.xlsx', 'rb') as fh:  
  concrete_scores = pd.read_excel(fh)

In [None]:
concrete_scores

In [None]:
bool_wordpiece = []
for i, w in enumerate(list(concrete_scores["Word"])):
  ids = input_gen._vocabulary.encode(str(w))
  bool_wordpiece.append(len(ids) == 1)

In [None]:
concrete_scores['is_wordpiece'] = bool_wordpiece

In [None]:
concrete_scores

In [None]:
cr_wid = [input_gen._vocabulary.encode(w) for w in concrete_scores["Word"][concrete_scores["is_wordpiece"]]]

In [None]:
cr_wid = np.array(cr_wid).reshape(-1)

In [None]:
len(np.intersect1d(cr_wid, freq_ids))

In [None]:
len(np.union1d(cr_wid, freq_ids))

In [None]:
len(np.unique(cr_wid))

## REDO ALL PREVIOUS ANALYSIS w/ MORE TOKENS

In [None]:
freq_ids = np.intersect1d(cr_wid, freq_ids)

In [None]:
freq_ids = list(freq_ids)

In [None]:
len(freq_ids)

In [None]:
# get token embedding (w/o positional embedding), assuming share_emd=True.
it2t_token_embeddings = task_it2t.encoder.softmax.EmbLookup(
    task_it2t.theta.encoder.softmax, freq_ids)

t2t_token_embeddings = task_t2t.encoder.softmax.EmbLookup(
    task_t2t.theta.encoder.softmax, freq_ids)

t2t_twin_token_embeddings = task_t2t_twin.encoder.softmax.EmbLookup(
    task_t2t_twin.theta.encoder.softmax, freq_ids)

fetches = py_utils.NestedMap({
    "it2t_emb": it2t_token_embeddings,
    "t2t_emb": t2t_token_embeddings,
    "t2t_twin_emb": t2t_twin_token_embeddings
})

print(fetches)

In [None]:
emb_output = sess.run(fetches)

In [None]:
it2t_emb = emb_output["it2t_emb"]
t2t_emb = emb_output["t2t_emb"]
t2t_twin_emb = emb_output["t2t_twin_emb"]

In [None]:
def cos_similarity(vecs):
  dotp = vecs.dot(vecs.T)
  norm = np.sqrt((vecs ** 2).sum(1))
  length = np.outer(norm, norm)
  return dotp / length - np.eye(len(vecs))

def k_nn(sim_matrix, k=-1):
  return np.argsort(sim_matrix, -1)[:,::-1][:, :k]

In [None]:
# obtain top 10 similar words
sim_matrix_it2t = cos_similarity(it2t_emb)
knn_it2t_10 = k_nn(sim_matrix_it2t, 10)

In [None]:
# obtain top 10 similar words
sim_matrix_t2t = cos_similarity(t2t_emb)
knn_t2t_10 = k_nn(sim_matrix_t2t, 10)

In [None]:
# obtain top 10 similar words
sim_matrix_t2t_twin = cos_similarity(t2t_twin_emb)
knn_t2t_twin_10 = k_nn(sim_matrix_t2t_twin, 10)

In [None]:
# obtain top 5 similar words
knn_it2t_5 = k_nn(sim_matrix_it2t, 5)
knn_t2t_5 = k_nn(sim_matrix_t2t, 5)
knn_t2t_twin_5 = k_nn(sim_matrix_t2t_twin, 5)

In [None]:
knn_it2t_10.shape

In [None]:
knn_t2t_twin_10

### Check nearest words

In [None]:
def check_id(rid, freq_ids, knn):
  print("query:", input_gen._vocabulary._decode([int(freq_ids[rid])]))
  print("similar words:", [input_gen._vocabulary._decode([int(freq_ids[i])])  for i in knn[rid]])

def check_word(word, freq_ids, knn):
  wids = input_gen._vocabulary._encode(word)
  for wid in wids:
    # skip the empty token..
    if wid == 3: continue
    rid = freq_ids.index(wid)
    check_id(rid, freq_ids, knn)

### Quantitive comparison

In [None]:
def diff_scores(knn1, knn2, k):
  return [len(np.intersect1d(knn1[i], knn2[i]))/k for i in range(len(knn1))] 

In [None]:
top_k = 100
it2t_vs_t2t = np.array(diff_scores(k_nn(sim_matrix_it2t, top_k), k_nn(sim_matrix_t2t, top_k), top_k))
it2t_vs_t2t += np.array(diff_scores(k_nn(sim_matrix_it2t, top_k), k_nn(sim_matrix_t2t_twin, top_k), top_k))
it2t_vs_t2t = it2t_vs_t2t/2
np.mean(it2t_vs_t2t)

In [None]:
t2t_vs_t2t = diff_scores(k_nn(sim_matrix_t2t, top_k), k_nn(sim_matrix_t2t_twin, top_k), top_k)
np.mean(t2t_vs_t2t)

In [None]:
import seaborn as sns
sns.set_context('talk')

distplot = lambda x, c, l: sns.distplot(x, kde=True,
                            kde_kws={"linewidth": 3,
                            "alpha": 0.8, "color": c, "label": l},
                            hist_kws={"histtype": "step", "linewidth": 3,
                            "alpha": 0.5, "color": c})

distplot(it2t_vs_t2t, 'g', 'it2t vs t2t')
distplot(t2t_vs_t2t, 'orange', 't2t vs t2t')
plt.xlabel("top 10 nearest neighbor coherence")
plt.legend(loc='upper left')

In [None]:
top_ks = [1, 5, 10, 20, 50, 100, 200, 500]
it2t_vs_t2t_k = {k:diff_scores(k_nn(sim_matrix_it2t, k), k_nn(sim_matrix_t2t, k), k) for k in top_ks}

In [None]:
[np.mean(it2t_vs_t2t_k[k]) for k in top_ks]

In [None]:
plt.plot(top_ks, [np.mean(it2t_vs_t2t_k[k]) for k in top_ks])
plt.xlabel("Number of neighbors (k)")
plt.ylabel("Avg. coherence score")
plt.show()

In [None]:
def id2word(rid):
  return(input_gen._vocabulary._decode([int(freq_ids[rid])]))

In [None]:
query = 'abstract'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)
check_word(query, freq_ids, knn_t2t_twin_10)

In [None]:
# most dissimilar words (t2t vs t2t)
np.vectorize(id2word)(np.argsort(t2t_vs_t2t))[:30]

In [None]:
from scipy import stats
def r2(x, y):
    return stats.pearsonr(x, y)[0] ** 2

sns.regplot(t2t_vs_t2t, it2t_vs_t2t, color='green', scatter_kws={'alpha':0.1})
plt.xlabel("t2t vs t2t coherence score")
plt.ylabel("it2t vs t2t coherence score")

In [None]:
from scipy import stats
def r2(x, y):
    return stats.pearsonr(x, y)[0] ** 2

np.sqrt(r2(it2t_vs_t2t, t2t_vs_t2t))

In [None]:
concrete_scores["C"]

In [None]:
conc_m = concrete_scores["Conc.M"][concrete_scores["is_wordpiece"]]

In [None]:
len(conc_m)

In [None]:
len(freq_ids)

In [None]:
cr_it2t_vs_t2t = [it2t_vs_t2t[freq_ids.index(ids)] for ids in cr_wid if ids in freq_ids]

In [None]:
conc_m = conc_m[[(w in freq_ids) for w in cr_wid]]

In [None]:
# nltk.download('averaged_perceptron_tagger')
# nltk.data.path.append('/usr/local/google/home/runzheyang/nltk_data')
get_pos = lambda x: nltk.pos_tag(nltk.word_tokenize(input_gen._vocabulary.decode([int(x)])))[0][1]

In [None]:
POS = [get_pos(ids) for ids in cr_wid if ids in freq_ids]

In [None]:
np.unique(POS, return_counts=True)

In [None]:
sns.histplot(POS)
plt.xticks(rotation=70)
plt.show()

In [None]:
cr_it2t_vs_t2t = np.array(cr_it2t_vs_t2t)
conc_m = np.array(conc_m)
POS = np.array(POS)

In [None]:
is_in = lambda x, y: [x_ in y for x_ in x] 

In [None]:
sns.regplot(conc_m, cr_it2t_vs_t2t, scatter_kws={'alpha':0.1})
pos_set = ['VB', 'VBD', 'VBG', 'VBN']
sns.regplot(conc_m[is_in(POS, pos_set)], cr_it2t_vs_t2t[is_in(POS, pos_set)], color='red', scatter_kws={'alpha':0.1})
plt.ylabel("it2t vs t2t coherence score")
plt.xlabel("concreteness")

In [None]:
sns.regplot(conc_m, cr_it2t_vs_t2t, scatter_kws={'alpha':0.1})
pos_set = ['NN', 'NNS']
sns.regplot(conc_m[is_in(POS, pos_set)], cr_it2t_vs_t2t[is_in(POS, pos_set)], color='orange', scatter_kws={'alpha':0.1})
plt.ylabel("it2t vs t2t coherence score")
plt.xlabel("concreteness")

In [None]:
sns.regplot(conc_m, cr_it2t_vs_t2t, scatter_kws={'alpha':0.1})
pos_set = ['JJ', 'JJR', 'JJS']
sns.regplot(conc_m[is_in(POS, pos_set)], cr_it2t_vs_t2t[is_in(POS, pos_set)], color='pink', scatter_kws={'alpha':0.1})
plt.ylabel("it2t vs t2t coherence score")
plt.xlabel("concreteness")

In [None]:
sns.regplot(conc_m, cr_it2t_vs_t2t, scatter_kws={'alpha':0.1})
plt.ylabel("it2t vs t2t coherence score")
plt.xlabel("concreteness")

In [None]:
r2(cr_it2t_vs_t2t[POS=='NN'], conc_m[POS=='NN'])

In [None]:
r2(cr_it2t_vs_t2t, conc_m)

In [None]:
check_up_right = lambda x, y: [input_gen._vocabulary.decode([int(ids)]) for ids in np.array(freq_ids)[(np.array(cr_it2t_vs_t2t) > x) & (np.array(conc_m) > y)]]
check_up_left = lambda x, y: [input_gen._vocabulary.decode([int(ids)]) for ids in np.array(freq_ids)[(np.array(cr_it2t_vs_t2t) < x) & (np.array(conc_m) > y)]]
check_bt_right = lambda x, y: [input_gen._vocabulary.decode([int(ids)]) for ids in np.array(freq_ids)[(np.array(cr_it2t_vs_t2t) > x) & (np.array(conc_m) < y)]]
check_bt_left = lambda x, y: [input_gen._vocabulary.decode([int(ids)]) for ids in np.array(freq_ids)[(np.array(cr_it2t_vs_t2t) < x) & (np.array(conc_m) < y)]]

In [None]:
len(cr_wid)

In [None]:
len(freq_ids)

In [None]:
np.array(check_up_right(0.65, 4))

In [None]:
np.array(check_up_left(0.25, 4))

In [None]:
np.array(check_bt_right(0.5, 2))

In [None]:
np.array(check_bt_left(0.2, 2))

In [None]:
query = 'sodium'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)
check_word(query, freq_ids, knn_t2t_twin_10)

In [None]:
from google3.learning.brain.research.babelfish.multimodal import datasets

In [None]:
datasets.open_image_text_train()

In [None]:
1