#### local run command
`blaze run -c opt learning/brain/research/babelfish/colab:colab_notebook --define=babelfish_task=multimodal`

In [None]:
import lingvo.compat as tf
import matplotlib.pyplot as plt
import numpy as np
import pprint
import os

from lingvo.core import py_utils
from google3.learning.brain.research.babelfish import tokenizers
from google3.learning.brain.research.babelfish.multimodal.params.experimental import nlu_baselines as nlu_params

# from google3.pyglib import gfiler

from google3.perftools.accelerators.xprof.api.colab import xprof

tf.disable_eager_execution()

In [None]:
mdl_mixed = nlu_params.SST2ClassificationMixed()

mdl_mixed.DROPOUT_RATE = 0.0

p_mixed = mdl_mixed.Task()

# Note: We use the name as part of var/name scopes, you need to ensure that
# the name here matches for checkpoints to load successfully.

p_mixed.encoder_ex.shared_emb_ex.softmax.use_num_classes_major_weight = True
p_mixed.decoder_ex.shared_emb_ex.softmax.use_num_classes_major_weight = True

p_mixed.name = 'MixedFinetune'

p_mixed.input = mdl_mixed.Train()

In [None]:
# We are going to use the global graph for this entire colab.
tf.reset_default_graph()

# Instantiate the Task.
task_mixed = p_mixed.Instantiate()

# Create variables by running FProp.
_ = task_mixed.FPropDefaultTheta()

In [None]:
# Create a new session and initialize all the variables.
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
# Setup the checkpoint loading rules for OverrideVarsFromCheckpoints.
loading_rules_mixed = [
    (
        "MixedFinetune/(.*/var:0$)",  
        "MixedFinetune/%s"    
    )
]

ignore_rules = []  # No ignore rules, parse all saved vars.

ckpts_loading_rules = lambda x, y:{
    x: (y, ignore_rules)
}

ignore_rules = []  # No ignore rules, parse all saved vars.
ckpt_path_mixed = '/cns/tp-d/home/runzheyang/brain/rs=6.3/sst2.mixed.small.encdec_it2t/train/ckpt-00010000'

# Load the saved checkpoint into the session.
py_utils.OverrideVarsFromCheckpoints(
    tf.all_variables(p_mixed.name+"//*"), ckpts_loading_rules(ckpt_path_mixed, loading_rules_mixed))(sess)

### Load top 5000 frequent words and Task examples

In [None]:
from google3.pyglib import gfile
with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/5000-words.txt', 'r') as f:
  freq_words = f.read()

import json
with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/sst2_validaiton', 'r') as fh:  
  all_ex = json.load(fh)
with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/sst2_it2t_failue', 'r') as fh:  
  it2t_ex = json.load(fh)
with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/sst2_t2t_failue', 'r') as fh:  
  t2t_ex = json.load(fh)

In [None]:
input_p = mdl_mixed.Train()
input_gen = input_p.Instantiate()

In [None]:
freq_ids = input_gen._vocabulary._encode(freq_words)

In [None]:
freq_ids = np.unique(freq_ids)
freq_ids = [int(i) for i in freq_ids]
len(freq_ids)

In [None]:
all_ex_tokens = []
for ex in all_ex:
  all_ex_tokens += input_gen._vocabulary._encode(ex)

all_ex_tokens = [int(ids) for ids in np.unique(all_ex_tokens)]
len(all_ex_tokens)

In [None]:
# check the overlap between top 5000 words and SST tokens
len(np.intersect1d(all_ex_tokens, freq_ids))

In [None]:
freq_ids = [int(ids) for ids in np.union1d(freq_ids, all_ex_tokens)]

In [None]:
input_gen._vocabulary._decode(freq_ids)

In [None]:
len(freq_ids)

In [None]:
# get token embedding (w/o positional embedding), assuming share_emd=True.
it2t_token_embeddings = task_mixed.encoder.softmax.EmbLookup(
    task_mixed.theta.encoder.softmax, freq_ids)

t2t_token_embeddings = task_mixed.encoder_ex.softmax.EmbLookup(
    task_mixed.theta.encoder_ex.softmax, freq_ids)

selector_input = tf.concat([it2t_token_embeddings, t2t_token_embeddings], axis=1)
# Select between IT2T and T2T embeddings
selection = task_mixed.emb_selector.FProp(task_mixed.theta.emb_selector, selector_input)

fetches = py_utils.NestedMap({
    "it2t_emb": it2t_token_embeddings,
    "t2t_emb": t2t_token_embeddings,
    "selection": selection})

print(fetches)

In [None]:
emb_output = sess.run(fetches)

In [None]:
it2t_emb = emb_output["it2t_emb"]
t2t_emb = emb_output["t2t_emb"]
selection = emb_output["selection"]

In [None]:
selection.mean()

In [None]:
import seaborn as sns
sns.set_context("talk")

sns.distplot(selection, kde=False,
            kde_kws={"color": "k", "lw": 3, "label": "KDE", "color":"g"},
            hist_kws={"histtype": "step", "linewidth": 3,
                            "alpha": 1, "color": "g"})
plt.xlabel("selection (t2t)")

In [None]:
def cos_similarity(vecs):
  dotp = vecs.dot(vecs.T)
  norm = np.sqrt((vecs ** 2).sum(1))
  length = np.outer(norm, norm)
  return dotp / length - np.eye(len(vecs))

def k_nn(sim_matrix, k=-1):
  return np.argsort(sim_matrix, -1)[:,::-1][:, :k]

In [None]:
# obtain top 10 similar words
sim_matrix_it2t = cos_similarity(it2t_emb)
knn_it2t_10 = k_nn(sim_matrix_it2t, 10)

In [None]:
# obtain top 10 similar words
sim_matrix_t2t = cos_similarity(t2t_emb)
knn_t2t_10 = k_nn(sim_matrix_t2t, 10)

In [None]:
# obtain top 5 similar words
knn_it2t_5 = k_nn(sim_matrix_it2t, 5)
knn_t2t_5 = k_nn(sim_matrix_t2t, 5)

### Check nearest words

In [None]:
def check_id(rid, freq_ids, knn):
  print("query:", input_gen._vocabulary._decode([int(freq_ids[rid])]))
  print("similar words:", [input_gen._vocabulary._decode([int(freq_ids[i])])  for i in knn[rid]])

def check_word(word, freq_ids, knn):
  wids = input_gen._vocabulary._encode(word)
  for wid in wids:
    # skip the empty token..
    if wid == 3: continue
    rid = freq_ids.index(wid)
    check_id(rid, freq_ids, knn)

In [None]:
query = 'cat'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)

In [None]:
query = 'dream'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)

In [None]:
query = 'throw'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)

In [None]:
query = 'sing'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)

In [None]:
query = 'compromise'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)

### Quantitive comparison

In [None]:
def diff_scores(knn1, knn2, k):
  return [len(np.intersect1d(knn1[i], knn2[i]))/k for i in range(len(knn1))] 

In [None]:
top_k = 10
it2t_vs_t2t = diff_scores(k_nn(sim_matrix_it2t, top_k), k_nn(sim_matrix_t2t, top_k), top_k)
np.mean(it2t_vs_t2t)

In [None]:
top_ks = [1, 5, 10, 20, 50, 100, 200, 500]
it2t_vs_t2t_k = {k:diff_scores(k_nn(sim_matrix_it2t, k), k_nn(sim_matrix_t2t, k), k) for k in top_ks}

In [None]:
[np.mean(it2t_vs_t2t_k[k]) for k in top_ks]

In [None]:
plt.plot(top_ks, [np.mean(it2t_vs_t2t_k[k]) for k in top_ks])
plt.xlabel("Number of neighbors (k)")
plt.ylabel("Avg. coherence Score")
plt.show()

In [None]:
def id2word(rid):
  return(input_gen._vocabulary._decode([int(freq_ids[rid])]))

In [None]:
# most dissimilar words (it2t vs t2t)
np.vectorize(id2word)(np.argsort(it2t_vs_t2t))[:50]

In [None]:
# most similar words (it2t vs t2t)
np.vectorize(id2word)(np.argsort(it2t_vs_t2t))[-30:]

In [None]:
query = 'abstract'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)

In [None]:
query = 'decrease'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)

In [None]:
query = 'traditional'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)

In [None]:
import seaborn as sns
sns.set_context('talk')

distplot = lambda x, c, l: sns.distplot(x, kde=True,
                            kde_kws={"lw": 0,  "alpha": 0.0},
                            hist_kws={"histtype": "step", "linewidth": 3,
                            "alpha": 0.8, "color": c, "label": l})

distplot(it2t_vs_t2t_k[5], 'orange', 'k=5')
distplot(it2t_vs_t2t_k[10], 'g', 'k=10')
distplot(it2t_vs_t2t_k[100], 'b', 'k=100')
distplot(it2t_vs_t2t_k[500], 'r', 'k=500')

plt.xlabel("top 10 nearest neighbor coherence")
plt.legend(loc='upper right')

# Compared with concreteness scores


In [None]:
from google3.pyglib import gfile
import pandas as pd

with gfile.Open('/cns/tp-d/home/runzheyang/brain/rs=6.3/data/concreteness.xlsx', 'rb') as fh:  
  concrete_scores = pd.read_excel(fh)

In [None]:
concrete_scores

In [None]:
bool_wordpiece = []
for i, w in enumerate(list(concrete_scores["Word"])):
  ids = input_gen._vocabulary.encode(str(w))
  bool_wordpiece.append(len(ids) == 1)

In [None]:
concrete_scores['is_wordpiece'] = bool_wordpiece

In [None]:
concrete_scores

In [None]:
cr_wid = [input_gen._vocabulary.encode(w) for w in concrete_scores["Word"][concrete_scores["is_wordpiece"]]]

In [None]:
cr_wid = np.array(cr_wid).reshape(-1)

In [None]:
len(np.intersect1d(cr_wid, freq_ids))

In [None]:
len(np.union1d(cr_wid, freq_ids))

In [None]:
len(np.unique(cr_wid))

In [None]:
from scipy import stats
def r2(x, y):
    return stats.pearsonr(x, y)[0] ** 2

sns.regplot(emb_output["selection"], it2t_vs_t2t, color='green', scatter_kws={'alpha':0.1})
plt.xlabel("selection (t2t)")
plt.ylabel("it2t vs t2t coherence score")

In [None]:
conc_m = concrete_scores["Conc.M"][concrete_scores["is_wordpiece"]]

In [None]:
len(conc_m)

In [None]:
len(freq_ids)

In [None]:
cr_it2t_vs_t2t = [it2t_vs_t2t[freq_ids.index(ids)] for ids in cr_wid if ids in freq_ids]

In [None]:
cr_selection = [float(emb_output["selection"][freq_ids.index(ids)]) for ids in cr_wid if ids in freq_ids]

In [None]:
conc_m = conc_m[[(w in freq_ids) for w in cr_wid]]

In [None]:
import nltk
nltk.download('averaged_perceptron_tagger')
nltk.data.path.append('/usr/local/google/home/runzheyang/nltk_data')
get_pos = lambda x: nltk.pos_tag(nltk.word_tokenize(input_gen._vocabulary.decode([int(x)])))[0][1]

In [None]:
POS = [get_pos(ids) for ids in cr_wid if ids in freq_ids]

In [None]:
np.unique(POS, return_counts=True)

In [None]:
sns.histplot(POS)
plt.xticks(rotation=70)
plt.show()

In [None]:
cr_it2t_vs_t2t = np.array(cr_it2t_vs_t2t)
conc_m = np.array(conc_m)
POS = np.array(POS)

In [None]:
is_in = lambda x, y: [x_ in y for x_ in x] 

In [None]:
sns.regplot(conc_m, cr_it2t_vs_t2t, scatter_kws={'alpha':0.1})
pos_set = ['VB', 'VBD', 'VBG', 'VBN']
sns.regplot(conc_m[is_in(POS, pos_set)], cr_it2t_vs_t2t[is_in(POS, pos_set)], color='red', scatter_kws={'alpha':0.1})
plt.ylabel("it2t vs t2t coherence score")
plt.xlabel("concreteness")

In [None]:
sns.regplot(conc_m, cr_selection, scatter_kws={'alpha':0.05})
plt.ylabel("selection (t2t)")
plt.xlabel("concreteness")

In [None]:
np.sqrt(r2(conc_m, cr_selection))

In [None]:
# sns.regplot(conc_m, cr_selection, scatter_kws={'alpha':0.1})
pos_set = ['VB', 'VBD', 'VBG', 'VBN']
sns.regplot(conc_m[is_in(POS, pos_set)], np.array(cr_selection)[is_in(POS, pos_set)], color='red', scatter_kws={'alpha':0.1})
plt.ylabel("selection (t2t)")
plt.xlabel("concreteness")

In [None]:
np.sqrt(r2(conc_m[is_in(POS, pos_set)], np.array(cr_selection)[is_in(POS, pos_set)]))

In [None]:
# sns.regplot(conc_m, cr_selection, scatter_kws={'alpha':0.1})
pos_set = ['NN', 'NNS']
sns.regplot(conc_m[is_in(POS, pos_set)], np.array(cr_selection)[is_in(POS, pos_set)], color='orange', scatter_kws={'alpha':0.1})
plt.ylabel("selection (t2t)")
plt.xlabel("concreteness")

## Compare SST2 Performance

In [None]:
def seq_score(sentence_ids, freq_ids, score):
  rids = []
  for wid in sentence_ids:
    if wid in freq_ids:
      rids.append(freq_ids.index(wid))
      scores = np.array(score)[rids]
  return scores.mean()

In [None]:
def seq_inconsist_words(sentence_ids, freq_ids, score):
  rids = []
  for wid in sentence_ids:
    if wid in freq_ids:
      rids.append(freq_ids.index(wid))
      scores = np.array(score)[rids]
  return 100*(scores <= 0.2).sum()/(scores <= 1.1).sum()

In [None]:
def exs_scores(exs, k):
  _ids = [input_gen._vocabulary._encode(ex) for ex in exs]
  _scores = [seq_score(_ids[i], freq_ids, it2t_vs_t2t_k[k]) for i in range(len(exs))]
  return _scores

In [None]:
def exs_icwords(exs, k):
  _ids = [input_gen._vocabulary._encode(ex) for ex in exs]
  _scores = [seq_inconsist_words(_ids[i], freq_ids, it2t_vs_t2t_k[k]) for i in range(len(exs))]
  return _scores

In [None]:
top_k = 20

val_scores = exs_scores(all_ex, top_k)
it2t_ex_scores = exs_scores(it2t_ex, top_k)
t2t_ex_scores = exs_scores(t2t_ex, top_k)

In [None]:
correct_ex = [ex for ex in all_ex if ex not in it2t_ex and ex not in t2t_ex]
correct_scores = exs_scores(correct_ex, top_k)

In [None]:
common_ex = [ex for ex in all_ex if ex in it2t_ex and ex in t2t_ex]
common_scores = exs_scores(common_ex, top_k)

In [None]:
it2t_only_ex = [ex for ex in it2t_ex if ex not in t2t_ex]
it2t_only_scores = exs_scores(it2t_only_ex, top_k)

In [None]:
t2t_only_ex = [ex for ex in t2t_ex if ex not in it2t_ex]
t2t_only_scores = exs_scores(t2t_only_ex, top_k)

In [None]:
it2t_ex, t2t_ex

In [None]:
len(np.intersect1d(it2t_ex, t2t_ex)), len(it2t_only_ex), len(t2t_only_ex)

In [None]:
distplot = lambda x, b, c, l: sns.distplot(x, kde=True, bins=b,
                            kde_kws={"color": c, "lw": 3, "alpha": 0.2},
                            hist_kws={"histtype": "step", "linewidth": 3,
                            "alpha": 0.2, "color": c, "label": l})

bins = np.linspace(0.44,0.6,50)
distplot(correct_scores, bins, 'gray', 'correct')
distplot(it2t_ex_scores, bins, 'g', 'it2t failure')
distplot(t2t_ex_scores, bins, 'orange', 't2t failure')
plt.xlabel("embedding coherence score")
plt.legend(loc='upper right')

In [None]:
distplot = lambda x, b, c, l: sns.distplot(x, kde=True, bins=b,
                            kde_kws={"color": c, "lw": 3, "alpha": 0.2},
                            hist_kws={"histtype": "step", "linewidth": 3,
                            "alpha": 0.2, "color": c, "label": l})

bins = np.linspace(0.44,0.6,50)
distplot(common_scores, bins, 'gray', 'common')
distplot(it2t_ex_scores, bins, 'g', 'it2t specific')
distplot(t2t_ex_scores, bins, 'orange', 't2t specific')
plt.xlabel("embedding coherence score")
plt.legend(loc='upper right')

In [None]:
top_k = 20

val_icscores = exs_icwords(all_ex, top_k)
it2t_ex_icscores = exs_icwords(it2t_ex, top_k)
t2t_ex_icscores = exs_icwords(t2t_ex, top_k)
it2t_only_icscores = exs_icwords(it2t_only_ex, top_k)
t2t_only_icscores = exs_icwords(t2t_only_ex, top_k)
common_icscores = exs_icwords(common_ex, top_k)
correct_icscores = exs_icwords(correct_ex, top_k)

In [None]:
distplot = lambda x, b, c, l: sns.distplot(x, kde=True, bins=b,
                            kde_kws={"color": c, "lw": 3, "alpha": 0.2},
                            hist_kws={"histtype": "step", "linewidth": 3,
                            "alpha": 0.2, "color": c, "label": l})

bins = np.linspace(0,20,50)
distplot(correct_icscores, bins, 'gray', 'correct')
distplot(it2t_ex_icscores, bins, 'g', 'it2t failure')
distplot(t2t_ex_icscores, bins, 'orange', 't2t failure')
plt.xlabel("% of incoherence tokens")
plt.legend(loc='upper right')
plt.xlim([-3,20])

In [None]:
distplot = lambda x, b, c, l: sns.distplot(x, kde=True, bins=b,
                            kde_kws={"color": c, "lw": 3, "alpha": 0.2},
                            hist_kws={"histtype": "step", "linewidth": 3,
                            "alpha": 0.2, "color": c, "label": l})

bins = np.linspace(0,20,50)
distplot(common_icscores, bins, 'gray', 'common')
distplot(it2t_only_icscores, bins, 'g', 'it2t specific')
distplot(t2t_only_icscores, bins, 'orange', 't2t specific')
plt.xlabel("% of incoherence tokens")
plt.legend(loc='upper right')
plt.xlim([-3,20])

In [None]:
def colored(r, g, b, text):
    return " \033[38;2;{};{};{}m{}\033[38;2;255;255;255m ".format(r, g, b, text)

def print_diff(ex, score):
  rids = []
  sentence_ids = input_gen._vocabulary.encode(ex)
  for wid in sentence_ids:
    rids.append(freq_ids.index(wid))
    scores = np.array(score)[rids]
  print(' '.join([colored(int(255*scores[i]), int(255*scores[i]), int(255*scores[i]),
                          input_gen._vocabulary.decode([ids])) 
                          for i, ids in enumerate(sentence_ids)]))
  return scores

def obtain_diff(ex, score):
  rids = []
  sentence_ids = input_gen._vocabulary.encode(ex)
  for wid in sentence_ids:
    rids.append(freq_ids.index(wid))
    scores = np.array(score)[rids]
  return scores

In [None]:
top_k = 20

for ex in it2t_only_ex[:10]:
  print_diff(ex, it2t_vs_t2t_k[top_k])

In [None]:
top_k = 20

for ex in t2t_only_ex[:10]:
  print_diff(ex, it2t_vs_t2t_k[top_k])

In [None]:
query = 'disappointing'
check_word(query, freq_ids, k_nn(sim_matrix_it2t, 20))
check_word(query, freq_ids, k_nn(sim_matrix_t2t, 20))

In [None]:
query = 'coincidence'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)

In [None]:
query = 'facile'
check_word(query, freq_ids, knn_it2t_10)
check_word(query, freq_ids, knn_t2t_10)

## Perturbation-based analysis

In [None]:
def get_sensitivity(task, input_batch, max_seq_length, eps=0.5):
    
  p = task.encoder.params
  # [batch, time]
  input_ids = input_batch.ids
  # [batch, time]
  paddings = input_batch.paddings

  batch = py_utils.GetShape(input_ids)[0]
  time = py_utils.GetShape(input_ids)[1]

  # Embedding layer.
  # [batch, time, dim]
  if not p.shared_emb:
    input_embs = task.encoder.token_emb.EmbLookup(task.encoder.theta.token_emb, input_ids)
  else:
    input_embs = task.encoder.softmax.EmbLookup(task.encoder.theta.softmax, input_ids)

  perturbed_embs = []
  for i in range(max_seq_length + 1):
    if i == max_seq_length:
      perturbed_embs.append(input_embs)
    else:
      mask = np.ones((time, p.model_dim))
      mask[i, :] += eps
      tf.expand_dims(mask, 0)
      perturbed_embs.append(mask * input_embs)

  perturbed_embs = tf.stack(perturbed_embs)
  perturbed_embs = tf.reshape(perturbed_embs,
                              [(max_seq_length+1) * batch, time, p.model_dim])

  # [1, time, dim]
  position_embs = tf.expand_dims(
      task.encoder.position_emb.FProp(task.encoder.theta.position_emb, time), 0)

  # [batch, time, dim]
  perturbed_embs += position_embs

  if p.input_dropout_tpl.fprop_dtype:
    perturbed_embs = tf.cast(perturbed_embs, p.input_dropout_tpl.fprop_dtype)
    paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype)

  # [batch, time, dim]
  transformer_input = perturbed_embs
  # Explicitly set the input shape of Transformer layers, to avoid
  # unknown shape error occurred to tf.einsum on nonTPU devices.

  transformer_input = tf.reshape(transformer_input,
                                  [(max_seq_length+1) * batch, time, p.model_dim])

  # Reshape to match with input shapes of other embeddings, e.g. image.
  transformer_input = tf.transpose(transformer_input, [1, 0, 2])
  paddings = tf.tile(paddings, [max_seq_length+1, 1]) 
  # paddings = tf.transpose(paddings)

  encoder_embeddings = py_utils.NestedMap(input_embs=transformer_input, paddings=paddings)

  encoder_outputs = task.encoder.FPropTransformerLayers(task.theta.encoder, 
                                                        encoder_embeddings)

  # decoder
  targets = py_utils.NestedMap(ids=tf.tile(sources.ids,  [max_seq_length+1,1]), 
                               paddings=tf.tile(sources.paddings, [max_seq_length+1,1]))
  decoder_outputs = task.decoder.ComputePredictions(task.theta.decoder,
                                                    encoder_outputs, targets)

  classifier_input = task._extract_classifier_input(
      tf.tile(sources.paddings, [max_seq_length+1,1]), decoder_outputs)

  predictions = task._apply_classifier(task.theta, classifier_input)

  return predictions

feed_ids =  tf.placeholder(tf.int32, shape=[1,512])
feed_paddings = tf.placeholder(tf.float32, shape=[1,512])

sources = py_utils.NestedMap(ids=feed_ids, paddings=feed_paddings)
sensitivity_it2t = get_sensitivity(task_it2t, sources, max_seq_length)
sensitivity_t2t = get_sensitivity(task_t2t, sources, max_seq_length)

sen_it2t = abs(sensitivity_it2t.probs[:,0] - sensitivity_it2t.probs[-1][0])
sen_t2t = abs(sensitivity_t2t.probs[:,0] - sensitivity_t2t.probs[-1][0])

# Notice that we are calling this with task.theta which ensures that we are
# using the same variables which we have just loaded.
fetches = py_utils.NestedMap(
          {"sources": sources,
           "prob_it2t": sensitivity_it2t.probs,
           "prob_t2t": sensitivity_t2t.probs,
           "sensitivity_it2t":sen_it2t,
           "sensitivity_t2t":sen_t2t
           })

print(fetches)

In [None]:
ids.shape

In [None]:
np.tile(ids, [30,1]).shape

In [None]:
def process_ex(ex):
  inputs = input_gen._vocabulary._encode(ex)
  ids = np.pad(inputs, (1, 511-len(inputs)), 'constant', 
               constant_values=(0, 0)).reshape(1,-1)
  paddings = np.pad(np.zeros(len(inputs)+1), (0, 511-len(inputs)), 
                    'constant', constant_values=(1, 1)).reshape(1,-1)
  input_len = len(inputs) + 1

  return ids, paddings, input_len

In [None]:
def print_sensitivity(ex, score):
  rids = []
  sentence_ids = input_gen._vocabulary.encode(ex)
  for wid in sentence_ids:
    rids.append(freq_ids.index(wid))
    scores = 1-np.array(score)
  print(' '.join([colored(int(255*scores[i]), int(255*scores[i]), int(255*scores[i]),
                          input_gen._vocabulary.decode([ids])) 
                          for i, ids in enumerate(sentence_ids)]))

In [None]:
it2t_ex[0] in all_ex

In [None]:
labels = []
test_outputs = []

for ex in it2t_only_ex[:10]:
  ids, paddings, input_len = process_ex(ex)
  max_seq_length = max(max_seq_length, input_len)
  test_outputs.append(sess.run(fetches, {feed_ids: ids, feed_paddings: paddings}))

In [None]:
def print_sensitivity(ex, score):
  rids = []
  sentence_ids = input_gen._vocabulary.encode(ex)
  for wid in sentence_ids:
    rids.append(freq_ids.index(wid))
    scores = 1-np.array(score)
  print(' '.join([colored(int(255*scores[i]), int(255*scores[i]), int(255*scores[i]),
                          input_gen._vocabulary.decode([ids])) 
                          for i, ids in enumerate(sentence_ids)]))

In [None]:
len(it2t_ex)

In [None]:
len(t2t_ex)

### activation based analysis

In [None]:
max_seq_length

In [None]:
def get_sensitivity(task, input_batch, max_seq_length, eps=0.5):
    
  p = task.encoder.params
  # [batch, time]
  input_ids = input_batch.ids
  # [batch, time]
  paddings = input_batch.paddings

  batch = py_utils.GetShape(input_ids)[0]
  time = py_utils.GetShape(input_ids)[1]

  # Embedding layer.
  # [batch, time, dim]
  if not p.shared_emb:
    input_embs = task.encoder.token_emb.EmbLookup(task.encoder.theta.token_emb, input_ids)
  else:
    input_embs = task.encoder.softmax.EmbLookup(task.encoder.theta.softmax, input_ids)

  perturbed_embs = []
  for i in range(max_seq_length + 1):
    if i == max_seq_length:
      perturbed_embs.append(input_embs)
    else:
      mask = np.ones((time, p.model_dim))
      mask[i, :] += eps
      tf.expand_dims(mask, 0)
      perturbed_embs.append(mask * input_embs)

  perturbed_embs = tf.stack(perturbed_embs)
  perturbed_embs = tf.reshape(perturbed_embs,
                              [(max_seq_length+1) * batch, time, p.model_dim])

  # [1, time, dim]
  position_embs = tf.expand_dims(
      task.encoder.position_emb.FProp(task.encoder.theta.position_emb, time), 0)

  # [batch, time, dim]
  perturbed_embs += position_embs

  if p.input_dropout_tpl.fprop_dtype:
    perturbed_embs = tf.cast(perturbed_embs, p.input_dropout_tpl.fprop_dtype)
    paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype)

  # [batch, time, dim]
  transformer_input = perturbed_embs
  # Explicitly set the input shape of Transformer layers, to avoid
  # unknown shape error occurred to tf.einsum on nonTPU devices.

  transformer_input = tf.reshape(transformer_input,
                                  [(max_seq_length+1) * batch, time, p.model_dim])

  # Reshape to match with input shapes of other embeddings, e.g. image.
  transformer_input = tf.transpose(transformer_input, [1, 0, 2])
  paddings = tf.tile(paddings, [max_seq_length+1, 1]) 
  paddings = tf.transpose(paddings)

  encoder_embeddings = py_utils.NestedMap(input_embs=transformer_input, paddings=paddings)

  encoder_outputs = task.encoder.FPropTransformerLayers(task.theta.encoder, 
                                                        encoder_embeddings)

  # decoder
  targets = py_utils.NestedMap(ids=tf.tile(sources.ids,  [max_seq_length+1,1]), 
                               paddings=tf.tile(sources.paddings, [max_seq_length+1,1]))
  decoder_outputs = task.decoder.ComputePredictions(task.theta.decoder,
                                                    encoder_outputs, targets)

  classifier_input = task._extract_classifier_input(
      tf.tile(sources.paddings, [max_seq_length+1,1]), decoder_outputs)

  predictions = task._apply_classifier(task.theta, classifier_input)

  return predictions

feed_ids =  tf.placeholder(tf.int32, shape=[1,512])
feed_paddings = tf.placeholder(tf.float32, shape=[1,512])

sources = py_utils.NestedMap(ids=feed_ids, paddings=feed_paddings)
sensitivity_it2t = get_sensitivity(task_it2t, sources, max_seq_length)
sensitivity_t2t = get_sensitivity(task_t2t, sources, max_seq_length)

sen_it2t = abs(sensitivity_it2t.probs[:,0] - sensitivity_it2t.probs[-1][0])
sen_t2t = abs(sensitivity_t2t.probs[:,0] - sensitivity_t2t.probs[-1][0])

# Notice that we are calling this with task.theta which ensures that we are
# using the same variables which we have just loaded.
fetches = py_utils.NestedMap(
          {"sources": sources,
           "prob_it2t": sensitivity_it2t.probs,
           "prob_t2t": sensitivity_t2t.probs,
           "sensitivity_it2t":sen_it2t,
           "sensitivity_t2t":sen_t2t
           })

print(fetches)

In [None]:
def process_ex(ex):
  inputs = input_gen._vocabulary._encode(ex)
  ids = np.pad(inputs, (1, 511-len(inputs)), 'constant', 
               constant_values=(0, 0)).reshape(1,-1)
  paddings = np.pad(np.zeros(len(inputs)+1), (0, 511-len(inputs)), 
                    'constant', constant_values=(1, 1)).reshape(1,-1)
  input_len = len(inputs) + 1

  return ids, paddings, input_len

In [None]:
correct_test_outputs = []
it2t_only_test_outputs = []
t2t_only_test_outputs = []

# max_seq_length = 0

for ex in it2t_only_ex:
  ids, paddings, input_len = process_ex(ex)
  # max_seq_length = max(max_seq_length, input_len)
  it2t_only_test_outputs.append(sess.run(fetches, {feed_ids: ids, feed_paddings: paddings}))

In [None]:
for ex in t2t_only_ex:
  ids, paddings, input_len = process_ex(ex)
  # max_seq_length = max(max_seq_length, input_len)
  t2t_only_test_outputs.append(sess.run(fetches, {feed_ids: ids, feed_paddings: paddings}))

In [None]:
for ex in correct_ex[:30]:
  ids, paddings, input_len = process_ex(ex)
  # max_seq_length = max(max_seq_length, input_len)
  correct_test_outputs.append(sess.run(fetches, {feed_ids: ids, feed_paddings: paddings}))

In [None]:
correct_sscore_it2t = [correct_test_outputs[ids].sensitivity_it2t[correct_test_outputs[ids].sources.paddings[0][:85] == 0][1:] for ids in range(30)]
correct_sscore_t2t = [correct_test_outputs[ids].sensitivity_t2t[correct_test_outputs[ids].sources.paddings[0][:85] == 0][1:] for ids in range(30)]

In [None]:
it2t_sscore_it2t = [it2t_only_test_outputs[ids].sensitivity_it2t[it2t_only_test_outputs[ids].sources.paddings[0][:85] == 0][1:] for ids in range(len(it2t_only_ex))]
it2t_sscore_t2t = [it2t_only_test_outputs[ids].sensitivity_t2t[it2t_only_test_outputs[ids].sources.paddings[0][:85] == 0][1:] for ids in range(len(it2t_only_ex))]

In [None]:
t2t_sscore_it2t = [t2t_only_test_outputs[ids].sensitivity_it2t[t2t_only_test_outputs[ids].sources.paddings[0][:85] == 0][1:] for ids in range(len(t2t_only_ex))]
t2t_sscore_t2t = [t2t_only_test_outputs[ids].sensitivity_t2t[t2t_only_test_outputs[ids].sources.paddings[0][:85] == 0][1:] for ids in range(len(t2t_only_ex))]

In [None]:
correct_ic = []
it2t_only_ic = []
t2t_only_ic = []

for ex in it2t_only_ex:
  it2t_only_ic.append(obtain_diff(ex, it2t_vs_t2t_k[top_k]))

In [None]:
for ex in t2t_only_ex:
  t2t_only_ic.append(obtain_diff(ex, it2t_vs_t2t_k[top_k]))

In [None]:
for ex in correct_ex[:30]:
  correct_ic.append(obtain_diff(ex, it2t_vs_t2t_k[top_k]))

In [None]:
top_k

In [None]:
r2_correct_it2t = [r2(correct_ic[i], correct_sscore_it2t[i]) for i in range(30)]
r2_it2t_it2t = [r2(it2t_only_ic[i], it2t_sscore_it2t[i]) for i in range(len(it2t_only_ex))]
r2_t2t_it2t = [r2(t2t_only_ic[i], t2t_sscore_it2t[i]) for i in range(len(t2t_only_ex))]

In [None]:
r2_correct_t2t = [r2(correct_ic[i], correct_sscore_t2t[i]) for i in range(30)]
r2_it2t_t2t = [r2(it2t_only_ic[i], it2t_sscore_t2t[i]) for i in range(len(it2t_only_ex))]
r2_t2t_t2t = [r2(t2t_only_ic[i], t2t_sscore_t2t[i]) for i in range(len(t2t_only_ex))]

In [None]:
np.mean(r2_correct_it2t), np.mean(r2_correct_t2t)

In [None]:
np.mean(r2_it2t_it2t), np.mean(r2_it2t_t2t)

In [None]:
np.mean(r2_t2t_it2t), np.mean(r2_t2t_t2t)

In [None]:
df = {"correct": r2_correct_it2t, "it2t correct\nt2t wrong": r2_t2t_it2t, "it2t wrong\nt2t correct": r2_it2t_it2t,}

In [None]:
df = dict2pandas(df, "examples", "Pearson r^2")

In [None]:
ax = sns.violinplot(data=df, x='examples', y='Pearson r^2')
for violin, alpha in zip(ax.collections[::2], [0.4,0.4,0.4]):
    violin.set_alpha(alpha)
Means = df.groupby('examples')['Pearson r^2'].median()
plt.scatter(x=range(len(Means)),y=Means,c="white", zorder=10)

In [None]:
Means

In [None]:
def dict2pandas(d, keyname, valname):
    dframes = []
    for k,v in d.items():
        dframes += [pd.DataFrame({keyname : [k] * len(v), valname : v})]
    return pd.concat(dframes)

In [None]:
ids = 6

plt.plot(sscore_it2t[ids])
plt.plot(sscore_t2t[ids])

In [None]:
top_k = 20

for i, ex in enumerate(it2t_only_ex[:1]):
  print_diff(ex, it2t_vs_t2t_k[top_k])
  print_sensitivity(ex, 5* it2t_sscore_it2t[i])

In [None]:
from scipy import stats
def r2(x, y):
    return stats.pearsonr(x, y)[0] ** 2

r2(ex_ic_scores, sscore_it2t[0])

In [None]:
r2(ex_ic_scores, sscore_t2t[0])