In [None]:
from pathlib import Path
from IPython.display import Image as IPImage
from PIL import Image
from matplotlib import pylab as plt
from matplotlib import gridspec
import matplotlib.cm as cm

import seaborn as sns
sns.set(font_scale=1.2)

from nmtpytorch.translator import Translator
from nmtpytorch.utils.data import to_var
from nmtpytorch.utils.data import make_dataloader

import numpy as np
import torch

import tqdm

from skimage import transform

In [None]:
def standardize(x, mode=None):
    return (x - x.min()) / (x.max() - x.min())

def torch2ndarray(x):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = np.transpose(x.cpu().numpy(), (1, 2, 0))
    return np.clip(img * std[None, None] + mean[None, None], a_min=0, a_max=1)

def get_sentences():
    with open("/home/rebekka/t2b/Projekte/vision/multi30k-dataset/data/task1/tok/test_2017_flickr.lc.norm.tok.en") as f:
        english_sentences = f.read().split("\n")[:-1]
    with open("/home/rebekka/t2b/Projekte/vision/multi30k-dataset/data/task1/tok/test_2017_flickr.lc.norm.tok.fr") as f:
        french_sentences = f.read().split("\n")[:-1]
    with open("/home/rebekka/t2b/Projekte/vision/multi30k-dataset/data/task1/image_splits/test_2017_flickr.txt") as f:
        image_indices = f.read().split("\n")[:-1]
    indices = {}
    for index, english_sentence in enumerate(english_sentences):
        indices[english_sentence] = (french_sentences[index], image_indices[index])
    return indices

In [None]:
def plot(model, data, idx, model_name=None, english2goldstandard=None, smooth=True, transpose=False, sigma=20, maxnorm=True, only_mean_std=False, image=None):#sample, image, smooth=True):
    # Unpack
    src, hyp, img, ta, ia = data[idx]
    
    if isinstance(english2goldstandard, dict):
        try:
            fr, imagename = english2goldstandard[model.src_vocab.idxs_to_sent(src, debug=True).replace(" <eos>", "")]
            src_sentence = model.src_vocab.idxs_to_sent(src, debug=True)
            print(f"{src_sentence}\n{fr}")
            image = Image.open("/home/rebekka/t2b/Projekte/vision/multi30k-dataset/data/task1/image_splits/mmt_images/task1/" + imagename)
            mplot = plt.imshow(np.asarray(image))
            image.save(f"sentence_pngs/{src_sentence}.png")
            return "/home/rebekka/t2b/Projekte/vision/multi30k-dataset/data/task1/image_splits/mmt_images/task1/" + imagename
        except KeyError:
            print("unk makes finding image impossible")
            return None
    if image:
        old_img = Image.open(image)
        img = Image.open(image)
        quad = np.asarray(old_img)[:]
        quad = quad.shape[0]
        img = img.resize((quad, quad))
        img = np.asarray(img)
    # Textual attention as matrix
    tas = np.stack(ta).T
    src = model.src_vocab.idxs_to_sent(src, debug=True).split(' ')
    hyp = model.trg_vocab.idxs_to_sent(hyp, debug=True).split(' ')
    # 6 for 6x6 or 14 for 14x14, etc.
    img_att_dim = int(np.sqrt(ia[0].shape[0]))
    spat_dim = (img_att_dim, img_att_dim)
    ias = np.stack(ia).reshape((-1, img_att_dim, img_att_dim))
    if only_mean_std:

        return tas, ias
    if transpose:
        ias = np.transpose(ias, (0, 2, 1))
    if smooth:
        upscale = img.shape[0]/img_att_dim
        if upscale <= 1:
            upscale = 1.000000001
        alpha = [transform.pyramid_expand(
            ias[ii], upscale=upscale, sigma=sigma) for ii in range(ias.shape[0])]
    else:
        alpha = [transform.resize(
            ias[ii], [img.shape[0], img.shape[1]], mode='reflect') for ii in range(ias.shape[0])]
        
    kwargs = {}
    if maxnorm:
        kwargs['vmax'] = max([a.max() for a in alpha])

    # Compute wxh for placing image attention
    n_words = len(hyp) + 1
    w = int(np.round(np.sqrt(n_words)))
    h = 5
    w = n_words // h + int(n_words % h > 0)
      
    # Create a grid
    grid = gridspec.GridSpec(1, 2, width_ratios=[1, 2])

    # Get subplot
    fig = plt.figure(figsize=(16, 8))
    ax1 = plt.subplot(grid[0])
    
    # Plot text attention
    sns.heatmap(tas, xticklabels=hyp, yticklabels=src, cmap='Greys',
                square=True, linewidths=0, cbar=False, ax=ax1)
    plt.yticks(rotation=0)

    inner_grid = gridspec.GridSpecFromSubplotSpec(w, h, grid[1], wspace=0.05, hspace=0.1)
    ax = plt.subplot(inner_grid[0])
    print(" ".join(src[:-1]), " ".join(hyp[:-1]))
    # Show image
    ax.imshow(img)
    ax.set_axis_off()
    # 1->black, 0->white
    plt.set_cmap(cm.jet)
    for ii in range(ias.shape[0]):
        ax = plt.subplot(inner_grid[ii + 1])
        ax.text(0, -10, hyp[ii], color='black', backgroundcolor='#ffffff', fontsize=12)
        # Show image
        ax.imshow(img)
        ax.imshow(alpha[ii], alpha=0.6, **kwargs)
        ax.set_axis_off()
    if isinstance(model_name, str):
        if len("_".join(src)) > 200:
            src = "{}_unk".format(len(src))
        filename = "_".join(src)
        plt.tight_layout()
        plt.savefig(f"{model_name}/{idx}_{filename}.png")
    plt.figure()
    plt.imshow(ias.mean(0))
    plt.axis('off')
    plt.colorbar()
    if isinstance(model_name, str):
        plt.close("all")
    return tas, ias

In [None]:
def process(modeltype, split='test_2017_flickr'):
    root = Path('/home/rebekka/t2b/Projekte/vision/mmt_experiments') / modeltype
    model_file = sorted(list(root.glob('*.best.meteor_*ckpt')),
        key=lambda x: float(x.name.split('best.meteor_')[-1].replace('.ckpt', '')))[-1]
    parameter_dict = {"disable_filters": False,
                     "splits": split,
                     "batch_size": 64,
                     "beam_size": 12,
                     "max_len" : 200,
                     "lp_alpha": 0,
                     "device_id": "gpu",
                     "models": [model_file],
                     "override": [],
                     "stochastic": False,
                     "beam_func": "beam_search",
                     "splits": split,
                     "source": None,
                     "task_id": None}
    translator = Translator(**parameter_dict)
    model = translator.instances[0]
    model = model.cuda()
    dataset = model.load_data(translator.splits[0], parameter_dict["batch_size"], mode="test")
    # Load and decode
    loader = make_dataloader(dataset)
    data = []
    for batch in tqdm.tqdm(loader, unit='batch'):
        #batch = to_var(batch_)
        batch.device("cuda")
        img_att = [[] for i in range(batch.size)]
        txt_att = [[] for i in range(batch.size)]
        hyps = [[] for i in range(batch.size)]
        srcs = batch[model.sl].data.cpu().t().tolist()
        imgs = batch['image'].data.cpu()
        fini = torch.zeros(batch.size).long().cuda()
        ctx_dict = model.encode(batch)
        # Get initial hidden state
        h_t = model.dec.f_init(ctx_dict)
        tile = range(batch.size)
        y_t = model.get_bos(batch.size).cuda()
        # Iterate for 100 timesteps
        for t in range(100):
            logp, h_t = model.dec.f_next(ctx_dict, model.dec.emb(y_t), h_t)
            tatt = model.dec.txt_alpha_t.cpu().data.clone().numpy()
            iatt = model.dec.img_alpha_t.cpu().data.clone().numpy()
            top_scores, y_t = logp.data.topk(1, largest=True)
            hyp = y_t.cpu().numpy().tolist()
            y_t = y_t.squeeze(1)
            for idx, w in enumerate(hyp):
                if 2 not in hyps[idx]:
                    hyps[idx].append(w[0])
                    txt_att[idx].append(tatt[:, idx])
                    img_att[idx].append(iatt[:, idx])

            fini = fini | y_t.eq(2).squeeze().long()
            if fini.sum() == batch.size:
                break

        for s, h, ta, ia, img in zip(srcs, hyps, txt_att, img_att, imgs):
            data.append((s, h, torch2ndarray(img.unsqueeze(0)), ta, ia))

    # Put into correct order
    data = [data[i] for i, j in sorted(
        enumerate(loader.batch_sampler.orig_idxs), key=lambda k: k[1])]
    # Sort by increasing length
    data = sorted(data, key=lambda x: len(x[0]))
    return data, model

In [None]:

res4f, m1 = process('mmt_complete_dataset_paper/mmt-task-en-fr-multimodalatt/best/')
res4f_random, m2 = process('mmt_random_replacement/mmt-task-en-fr-multimodalatt_random_replacement/best/')
res4f_0_remaining, m3 = process('mmt_0_remaining/mmt-task-en-fr-multimodalatt_0_remaining/best/')
res4f_1_remaining, m4 = process('mmt_1_remaining/mmt-task-en-fr-multimodalatt_1_remaining/best/')
res4f_2_remaining, m5 = process('mmt_2_remaining/mmt-task-en-fr-multimodalatt_2_remaining/best/')
res4f_3_remaining, m6 = process('mmt_3_remaining/mmt-task-en-fr-multimodalatt_3_remaining/best/')
res4f_4_remaining, m7 = process('mmt_4_remaining/mmt-task-en-fr-multimodalatt_4_remaining/best/')
res4f_6_remaining, m8 = process('mmt_6_remaining/mmt-task-en-fr-multimodalatt_6_remaining/best/')
res4f_12_remaining, m9 = process('mmt_12_remaining/mmt-task-en-fr-multimodalatt_12_remaining/best/')
res4f_20_remaining, m10 = process('mmt_20_remaining/mmt-task-en-fr-multimodalatt_20_remaining/best/')
#res4f_random_pos, m11 = process('mmt_random_pos_replacement/mmt-task-en-fr-multimodalatt_random_pos_replacement/best/')
#res4f_random_pos_dt, m12 = process('mmt_random_pos_replacement_of_dt/mmt-task-en-fr-multimodalatt_random_pos_replacement_dt/best/')
res4f_random_pos_nn, m13 = process('mmt_random_pos_replacement_of_nn/best/')
res4f_random_pos_vb, m14 = process('mmt_random_pos_replacement_of_vb/best/')
res4f_random_pos_jj, m15 = process('mmt_random_pos_replacement_of_jj/best/')
#res4f_random_pos_prp, m16 = process('mmt_random_pos_replacement_of_prp/mmt-task-en-fr-multimodalatt_random_pos_replacement_prp/best/')
#res4f_random_pos_in, m17 = process('mmt_random_pos_replacement_of_in/mmt-task-en-fr-multimodalatt_random_pos_replacement_in/best/')
resf_random_pos_nn4, m18 = process('mmt_random_pos_replacement_of_nn4/mmt-task-en-fr-multimodalatt_random_pos_replacement_nn_only4/best/')
#resf_random_pos_nn3, m19 = process('mmt_random_pos_replacement_of_nn3/mmt-task-en-fr-multimodalatt_random_pos_replacement_nn_only3/best/')
res4f_random_4, m20 = process('mmt_random_replacement4/')

In [None]:
english2gold = get_sentences()
print(len(res4f))
fig, axes= plt.subplots(2, 2, figsize=(14.8, 9.6),sharex=True)
ax = axes[0, 0]
ax2 = axes[1, 0]
ax3 = axes[0, 1]
ax4 = axes[1, 1]
ax.set_xticks([i for i in range(0, len(res4f), 100)])
ax2.set_xticks([i for i in range(0, len(res4f), 100)])
ax3.set_xticks([i for i in range(0, len(res4f), 100)])
ax4.set_xticks([i for i in range(0, len(res4f), 100)])
labels = []
arrays_t = []
arrays_ts = []
ta_means, ta_stds, ia_means_c, ia_stds_c, ia_maxs_c, ia_mins_c, ta_maxs_c, ta_stds_c = [], [], [], [], [], [], [], []
for idx in range(len(res4f)):
    #imagename = plot(m1, res4f, idx, "complete", english2gold, smooth=True, maxnorm=True, only_mean_std=True)
    ta, ia = plot(m1, res4f, idx, "complete", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ia_mean = ia.mean()
    ta_std = ta.std()
    ia_std = ia.std()
    ta_maxs_c.append(ta.max())
    ia_maxs_c.append(ia.max())
    ia_mins_c.append(ia.min())
    ta_means.append(ta_mean)
    ta_stds.append(ta_std)
    ia_means_c.append(ia_mean)
    ia_stds_c.append(ia_std)
    ta_stds_c.append(ta_std)
    #plt.close("all")
l1, = ax.plot(ta_means, color="gold")
ax3.plot(ta_means, color="gold")
ax2.plot(ta_stds, color="gold")
ax4.plot(ta_stds, color="gold")
labels.append("complete")
arrays_t.append(np.array(ta_means).mean())
arrays_ts.append(np.array(ta_means).std())
ta_means, ta_stds, ia_means_r4, ia_stds_r4, ia_maxs_r4, ia_mins_r4, ta_maxs_r4, ta_stds_r4 = [], [], [], [], [], [], [], []
for idx in range(len(res4f_random_4)):
    ta, ia = plot(m20, res4f_random_4, idx, "random", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ia_mean = ia.mean()
    ta_std = ta.std()
    ia_std = ia.std()
    ta_maxs_r4.append(ta.max())
    ia_maxs_r4.append(ia.max())
    ia_mins_r4.append(ia.min())
    ta_stds_r4.append(ta_std)
    ia_means_r4.append(ia_mean)
    ia_stds_r4.append(ia_std)    

ia_means_nn4, ia_stds_nn4, ia_maxs_nn4, ia_mins_nn4, ta_maxs_nn4, ta_stds_nn4 = [], [], [], [], [], []
for idx in range(len(resf_random_pos_nn4)):
    ta, ia = plot(m18, resf_random_pos_nn4, idx, "random_nn4", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ia_mean = ia.mean()
    ta_std = ta.std()
    ia_std = ia.std()
    ta_maxs_nn4.append(ta.max())
    ia_maxs_nn4.append(ia.max())
    ia_mins_nn4.append(ia.min())
    ta_stds_nn4.append(ta_std)
    ia_means_nn4.append(ia_mean)
    ia_stds_nn4.append(ia_std)
    
ta_means, ta_stds, ia_means_r, ia_stds_r, ia_maxs_r, ia_mins_r, ta_maxs_r = [], [], [], [], [], [], []
for idx in range(len(res4f_random)):
    ta, ia = plot(m2, res4f_random, idx, "random", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ia_mean = ia.mean()
    ta_std = ta.std()
    ia_std = ia.std()
    ta_maxs_r.append(ta.max())
    ia_maxs_r.append(ia.max())
    ia_mins_r.append(ia.min())
    ta_means.append(ta_mean)
    ta_stds.append(ta_std)
    ia_means_r.append(ia_mean)
    ia_stds_r.append(ia_std)
    ##plt.close("all")
    #print(idx)
l2, = ax.plot(ta_means, color="navy")
ax2.plot(ta_stds, color="navy")
labels.append("random")
arrays_t.append(np.array(ta_means).mean())
arrays_ts.append(np.array(ta_means).std())
ta_means, ta_stds, ia_means_0, ia_stds_0, ia_maxs_0, ia_mins_0, ta_maxs_0 = [], [], [], [], [], [], []
for idx in range(len(res4f_0_remaining)):
    ta, ia = plot(m3, res4f_0_remaining, idx, "0_rem", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ia_mean = ia.mean()
    ta_std = ta.std()
    ia_std = ia.std()
    ia_maxs_0.append(ia.max())
    ia_mins_0.append(ia.min())
    ta_means.append(ta_mean)
    ta_maxs_0.append(ta.max())
    ia_means_0.append(ia_mean)
    ta_stds.append(ta_std)
    ia_stds_0.append(ia_std)
    #plt.close("all")
    #print(idx)
l3, = ax.plot(ta_means, color="aqua")
ax2.plot(ta_stds, color="aqua")
labels.append("k=0")
arrays_t.append(np.array(ta_means).mean())
arrays_ts.append(np.array(ta_means).std())
ta_means, ta_stds, ia_means_1, ia_stds_1, ia_maxs_1, ia_mins_1, ta_maxs_1 = [], [], [], [], [], [], []
for idx in range(len(res4f_1_remaining)):
    ta, ia  = plot(m4, res4f_1_remaining, idx, "1_rem", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ia_mean = ia.mean()
    ia_std = ia.std()
    ta_std = ta.std()
    ia_maxs_1.append(ia.max())
    ia_mins_1.append(ia.min())
    ta_means.append(ta_mean)
    ta_maxs_1.append(ta.max())
    ia_means_1.append(ia_mean)
    ta_stds.append(ta_std)
    ia_stds_1.append(ia_std)
    #plt.close("all")
    #print(idx)
l4, = ax.plot(ta_means, color="coral")
ax2.plot(ta_stds, color="coral")
labels.append("k=1")
arrays_t.append(np.array(ta_means).mean())
arrays_ts.append(np.array(ta_means).std())
ta_means, ta_stds, ia_means_2, ia_stds_2, ia_maxs_2, ia_mins_2, ta_maxs_2 = [], [], [], [], [], [], []
for idx in range(len(res4f_2_remaining)):
    ta, ia = plot(m5, res4f_2_remaining, idx, "2_rem", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ia_mean = ia.mean()
    ta_maxs_2.append(ta.max())
    ia_std = ia.std()
    ta_std = ta.std()
    ia_maxs_2.append(ia.max())
    ia_mins_2.append(ia.min())
    ta_means.append(ta_mean)
    ia_means_2.append(ia_mean)
    ta_stds.append(ta_std)
    ia_stds_2.append(ia_std)
    #plt.close("all")
    #print(idx)
l5, = ax.plot(ta_means, color="red")
ax2.plot(ta_stds, color="red")
labels.append("k=2")
arrays_t.append(np.array(ta_means).mean())
arrays_ts.append(np.array(ta_means).std())
ta_means, ta_stds, ia_means_3, ia_stds_3, ia_maxs_3, ia_mins_3, ta_maxs_3 = [], [], [], [], [], [], []
for idx in range(len(res4f_3_remaining)):
    ta, ia = plot(m6, res4f_3_remaining, idx, "3_rem", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ta_std = ta.std()
    ia_maxs_3.append(ia.max())
    ia_mins_3.append(ia.min())
    ta_maxs_3.append(ta.max())
    ta_means.append(ta_mean)
    ia_means_3.append(ia.mean())
    ta_stds.append(ta_std)
    ia_stds_3.append(ia.std())
    #plt.close("all")
    #print(idx)
l6, = ax.plot(ta_means, color="indigo")
ax2.plot(ta_stds, color="indigo")
labels.append("k=3")
arrays_t.append(np.array(ta_means).mean())
arrays_ts.append(np.array(ta_means).std())
ta_means, ta_stds, ia_means_4, ia_stds_4, ia_maxs_4, ia_mins_4, ta_maxs_4 = [], [], [], [], [], [], []
for idx in range(len(res4f_4_remaining)):
    ta, ia =plot(m7, res4f_4_remaining, idx, "4_rem", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ta_std = ta.std()
    ta_maxs_4.append(ta.max())
    ia_maxs_4.append(ia.max())
    ia_mins_4.append(ia.min())
    ta_means.append(ta_mean)
    ia_means_4.append(ia.mean())
    ta_stds.append(ta_std)
    ia_stds_4.append(ia.std())
    #plt.close("all")
    #print(idx)
l7, = ax3.plot(ta_means, color="deepskyblue")
ax4.plot(ta_stds, color="deepskyblue")
labels.append("k=4")
arrays_t.append(np.array(ta_means).mean())
arrays_ts.append(np.array(ta_means).std())
ta_means, ta_stds, ia_means_6, ia_stds_6, ia_maxs_6, ia_mins_6, ta_maxs_6 = [], [], [], [], [], [], []
for idx in range(len(res4f_6_remaining)):
    ta, ia = plot(m8, res4f_6_remaining, idx, "6_rem", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ta_std = ta.std()
    ia_maxs_6.append(ia.max())
    ta_maxs_6.append(ta.max())
    ia_mins_6.append(ia.min())
    ta_means.append(ta_mean)
    ia_means_6.append(ia.mean())
    ta_stds.append(ta_std)
    ia_stds_6.append(ia.std())
    #plt.close("all")
    #print(idx)
l8, = ax3.plot(ta_means, color="tab:purple")
ax4.plot(ta_stds, color="tab:purple")
labels.append("k=6")
arrays_t.append(np.array(ta_means).mean())
arrays_ts.append(np.array(ta_means).std())
ta_means, ta_stds, ia_means_12, ia_stds_12, ia_maxs_12, ia_mins_12, ta_maxs_12 = [], [], [], [], [], [], []
for idx in range(len(res4f_12_remaining)):
    ta, ia = plot(m9, res4f_12_remaining, idx, "12_rem", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ta_std = ta.std()
    ia_maxs_12.append(ia.max())
    ta_maxs_12.append(ta.max())
    ia_mins_12.append(ia.min())
    ta_means.append(ta_mean)
    ia_means_12.append(ia.mean())
    ta_stds.append(ta_std)
    ia_stds_12.append(ia.std())
    #plt.close("all")
    #print(idx)
l9, = ax3.plot(ta_means, color="springgreen")
ax4.plot(ta_stds, color="springgreen")
labels.append("k=12")
arrays_t.append(np.array(ta_means).mean())
arrays_ts.append(np.array(ta_means).std())
ta_means, ta_stds, ia_means_20, ia_stds_20, ia_maxs_20, ia_mins_20, ta_maxs_20 = [], [], [], [], [], [], []
for idx in range(len(res4f_20_remaining)):
    ta, ia = plot(m10, res4f_20_remaining, idx, "20_rem", smooth=True, maxnorm=True, only_mean_std=True)
    ta_mean = ta.mean()
    ta_std = ta.std()
    ta_maxs_20.append(ta.max())
    ia_maxs_20.append(ia.max())
    ia_mins_20.append(ia.min())
    ta_means.append(ta_mean)
    ia_means_20.append(ia.mean())
    ta_stds.append(ta_std)
    ia_stds_20.append(ia.std())
    #plt.close("all")
    #print(idx)
l10, = ax3.plot(ta_means, color="fuchsia")
ax4.plot(ta_stds, color="fuchsia")
labels.append("k=20")
plt.xlabel("index")
plt.ylabel("mean textual attention")
fig.legend(handles=[l1, l2, l3, l4, l5, l6, l7, l8, l9, l10], labels=labels, loc='center')
fig.text(0.5, 0.04, 'example index', ha='center', va='center')
fig.text(0.06, 0.5, 'std textual attention                mean textual attention', ha='center', va='center', rotation='vertical')
plt.show()
# Good samples: 464, 615, 7, 103
idx = np.random.choice(len(res4f), size=1)[0]
idx = 7
english2gold = get_sentences()
print(idx)
fig, axes= plt.subplots(2, 2, figsize=(14.8, 9.6),sharex=True, sharey=True)
ax = axes[0, 0]
ax2 = axes[1, 0]
ax3 = axes[0, 1]
ax4 = axes[1, 1]
l1, = ax.plot(ia_means_c, color="gold")
ax2.plot(ia_stds_c, color="gold")
l2, = ax.plot(ia_means_r, color="navy")
ax2.plot(ia_stds_r, color="navy")
l3, = ax.plot(ia_means_0, color="aqua")
ax2.plot(ia_stds_0, color="aqua")
l4, = ax.plot(ia_means_1, color="coral")
ax2.plot(ia_stds_1, color="coral")
l5, = ax.plot(ia_means_2, color="red")
ax2.plot(ia_stds_2, color="red")
l6, = ax.plot(ia_means_3, color="indigo")
ax2.plot(ia_stds_3, color="indigo")
l7, = ax3.plot(ia_means_4, color="deepskyblue")
ax4.plot(ia_stds_4, color="deepskyblue")
l8, = ax3.plot(ia_means_6, color="tab:purple")
ax4.plot(ia_stds_6, color="tab:purple")
l9, = ax3.plot(ia_means_12, color="springgreen")
ax4.plot(ia_stds_12, color="springgreen")
l10, = ax3.plot(ia_means_20, color="fuchsia")
ax4.plot(ia_stds_20, color="fuchsia")
fig.legend(handles=[l1, l2, l3, l4, l5, l6, l7, l8, l9, l10], labels=labels, loc="center")
fig.text(0.5, 0.04, 'example index', ha='center', va='center')
fig.text(0.06, 0.5, 'std of image attention                mean image attention', ha='center', va='center', rotation='vertical')
fig, axes= plt.subplots(2, 2, figsize=(14.8, 9.6),sharex=True)
ax = axes[0, 0]
ax2 = axes[1, 0]
ax3 = axes[0, 1]
ax4 = axes[1, 1]
l1, = ax.plot(ia_maxs_c, color="gold")
ax2.plot(ia_mins_c, color="gold")
l2, = ax.plot(ia_maxs_r, color="navy")
ax2.plot(ia_mins_r, color="navy")
l3, = ax.plot(ia_maxs_0, color="aqua")
ax2.plot(ia_mins_0, color="aqua")
l4, = ax.plot(ia_maxs_1, color="coral")
ax2.plot(ia_mins_1, color="coral")
l5, = ax.plot(ia_maxs_2, color="red")
ax2.plot(ia_mins_2, color="red")
l6, = ax.plot(ia_maxs_3, color="indigo")
ax2.plot(ia_mins_3, color="indigo")
l7, = ax.plot(ia_maxs_4, color="deepskyblue")
ax2.plot(ia_mins_4, color="deepskyblue")
l8, = ax3.plot(ia_maxs_6, color="tab:purple")
ax4.plot(ia_mins_6, color="tab:purple")
l9, = ax3.plot(ia_maxs_12, color="springgreen")
ax4.plot(ia_mins_12, color="springgreen")
l10, = ax3.plot(ia_maxs_20, color="fuchsia")
ax4.plot(ia_mins_20, color="fuchsia")
fig.legend(handles=[l1, l2, l3, l4, l5, l6, l7, l8, l9, l10], labels=labels, loc="center")
fig.text(0.5, 0.04, 'example index', ha='center', va='center')
fig.text(0.06, 0.5, 'minimum image attention            maximum image attention', ha='center', va='center', rotation='vertical')
fig, axes= plt.subplots(2, 2, figsize=(14.8, 9.6),sharex=True)
ax = axes[0, 0]
ax2 = axes[1, 0]
ax3 = axes[0, 1]
ax4 = axes[1, 1]
l1, = ax.plot(ta_maxs_c, color="gold")
l2, = ax.plot(ta_maxs_r, color="navy")
l3, = ax.plot(ta_maxs_0, color="aqua")
l4, = ax3.plot(ta_maxs_1, color="coral")
l5, = ax3.plot(ta_maxs_2, color="red")
l6, = ax3.plot(ta_maxs_3, color="indigo")
l7, = ax2.plot(ta_maxs_4, color="deepskyblue")
l8, = ax2.plot(ta_maxs_6, color="tab:purple")
l9, = ax4.plot(ta_maxs_12, color="springgreen")
l10, = ax4.plot(ta_maxs_20, color="fuchsia")
fig.legend(handles=[l1, l2, l3, l4, l5, l6, l7, l8, l9, l10], labels=labels, loc="center")
fig.text(0.5, 0.04, 'example index', ha='center', va='center')
fig.text(0.06, 0.5, 'maximum textual attention', ha='center', va='center', rotation='vertical')

arrays = []
arrays_i = []
arrays_s = []
arrays_si = []
for array in [ta_maxs_0, ta_maxs_1, ta_maxs_2, ta_maxs_3, ta_maxs_4, ta_maxs_6, ta_maxs_12, ta_maxs_20, ta_maxs_c]:
    arrays.append(np.asarray(array).mean())
    arrays_s.append(np.asarray(array).std())
highest_image_attention_indices = {}
k2abl = {1:1, 2:2, 3:3, 4:4, 0:0, 5:6, 6:12, 7:20, 8:"complete", "random":"random"}
for k, array in enumerate([ia_maxs_0, ia_maxs_1, ia_maxs_2, ia_maxs_3, ia_maxs_4, ia_maxs_6, ia_maxs_12, ia_maxs_20, ia_maxs_c]):
    arrays_i.append(np.asarray(array).mean())
    arrays_si.append(np.asarray(array).std())
    a = [(index, array[index]) for index in range(len(array))]
    
    highest_image_attention_indices[k] = sorted(a, reverse=True, key=lambda x: x[1])[0:15]
for key, top in highest_image_attention_indices.items():
    print(key, top)
fig, ax = plt.subplots()
ax.set_xticks([0, 1, 2, 3, 4, 5, 6, 7, 8])
ax.set_xticklabels([0, 1, 2, 3, 4, 6, 12, 20, "all"])
plt.xlabel("k-remaining words")
plt.ylabel("mean maximum attention")
ax.errorbar([i for i in range(9)], arrays, arrays_si, label="textual attention")
ax.errorbar([i for i in range(9)], arrays_i, arrays_si, label="image attention")
plt.legend()
array = [(index,  ia_maxs_r[index]) for index in range(len(ia_maxs_r))]
highest_image_attention_indices["random"] = sorted(a, reverse=True, key=lambda x: x[1])[0:15]
plt.tight_layout()
plt.savefig("tex_im_att.png")
print(np.mean(ia_stds_c), np.mean(ia_stds_r4), np.mean(ia_stds_nn4))
print(np.mean(ta_stds_c), np.mean(ta_stds_r4), np.mean(ta_stds_nn4))

In [None]:
"""
for k, top_a in highest_image_attention_indices.items():
    print(k)
    
    if isinstance(k, int):
        if k < 8:
            continue
    
    k = k2abl[k]
    for sample in top_a:
        imagename = plot(m1, res4f, sample[0], f"complete_model_{k}", english2gold, smooth=True, maxnorm=True)
        tas, ias = plot(m1, res4f, sample[0], f"complete_model_{k}", image=imagename, smooth=True, maxnorm=True)
        tas, ias = plot(m2, res4f_random, sample[0], f"random_model_{k}", image=imagename, smooth=True, maxnorm=True)
        tas, ias = plot(m3, res4f_0_remaining, sample[0], f"0_rem_model_{k}", image=imagename, smooth=True, maxnorm=True)
        tas, ias = plot(m4, res4f_1_remaining, sample[0], f"1_rem_model_{k}", image=imagename, smooth=True, maxnorm=True)
        tas, ias = plot(m5, res4f_2_remaining, sample[0], f"2_rem_model_{k}", image=imagename, smooth=True, maxnorm=True)
        tas, ias = plot(m6, res4f_3_remaining, sample[0], f"3_rem_model_{k}", image=imagename, smooth=True, maxnorm=True)
        tas, ias = plot(m7, res4f_4_remaining, sample[0], f"4_rem_model_{k}", image=imagename, smooth=True, maxnorm=True)
        tas, ias = plot(m8, res4f_6_remaining, sample[0], f"6_rem_model_{k}", image=imagename, smooth=True, maxnorm=True)
        tas, ias = plot(m9, res4f_12_remaining, sample[0], f"12_rem_model_{k}", image=imagename, smooth=True, maxnorm=True)
        tas, ias = plot(m10, res4f_20_remaining, sample[0], f"20_rem_model_{k}", image=imagename, smooth=True, maxnorm=True)
        tas, ias = plot(m20, res4f_random_4, sample[0], f"random4_model_{k}", image=imagename, smooth=True, maxnorm=True)

for sample in highest_image_attention_indices["random"]:
    imagename = plot(m1, res4f, sample[0], f"complete_model_complete", english2gold, smooth=True, maxnorm=True)
    #tas, ias = plot(m11, res4f_random_pos, sample[0], f"random_pos_model", image=imagename, smooth=True, maxnorm=True)
    #tas, ias = plot(m12, res4f_random_pos_dt, sample[0], f"random_pos_model_dt", image=imagename, smooth=True, maxnorm=True)
    #tas, ias = plot(m14, res4f_random_pos_vb, sample[0], f"random_pos_model_vb", image=imagename, smooth=True, maxnorm=True)
    tas, ias = plot(m13, res4f_random_pos_nn, sample[0], f"random_pos_model_nn", image=imagename, smooth=True, maxnorm=True)
    #tas, ias = plot(m15, res4f_random_pos_jj, sample[0], f"random_pos_model_jj", image=imagename, smooth=True, maxnorm=True)
    #tas, ias = plot(m16, res4f_random_pos_prp, sample[0], f"random_pos_model_prp", image=imagename, smooth=True, maxnorm=True)
    #tas, ias = plot(m17, res4f_random_pos_in, sample[0], f"random_pos_model_in", image=imagename, smooth=True, maxnorm=True)

for index in range(len(res4f_random_pos_vb)):
    print(index)
    plot(m14, res4f_random_pos_vb, index)
    plt.close("all")
"""
vb_corresponding_indices = [527, 93, 508, False, 166, 10, 24, False, 171, 489, False, 395, 19, 633, 668]

for index, sample in enumerate(highest_image_attention_indices["random"]):
    if vb_corresponding_indices[index]:
        print(sample)
        imagename = plot(m1, res4f, sample[0], f"complete_model_complete", english2gold, smooth=True, maxnorm=True)
        tas, ias = plot(m14, res4f_random_pos_vb, vb_corresponding_indices[index], f"random_pos_model_vb", image=imagename, smooth=True, maxnorm=True)
"""
for index in range(len(res4f_random_pos_jj)):
    print(index)
    plot(m15, res4f_random_pos_jj, index)
    plt.close("all")
"""
jj_corresponding_indices = [False, False, False, False, False, 107, False, False, 111, 356, False, 283, False, False, 511]
for index, sample in enumerate(highest_image_attention_indices["random"]):
    if jj_corresponding_indices[index]:
        print(sample)
        imagename = plot(m1, res4f, sample[0], f"complete_model_complete", english2gold, smooth=True, maxnorm=True)
        tas, ias = plot(m15, res4f_random_pos_jj, jj_corresponding_indices[index], f"random_pos_model_jj", image=imagename, smooth=True, maxnorm=True)

for sample in range(len(resf_random_pos_nn4)):
    plot(m18, resf_random_pos_nn4, sample, smooth=True, maxnorm=True)
    plt.close("all")

nn_corresponding_indices = [248, False, 225, False, False, False, False, False, False, 210, False, False, False, 348, 387]
for index, sample in enumerate(highest_image_attention_indices["random"]):
    if nn_corresponding_indices[index]:
        print(sample)
        imagename = plot(m1, res4f, sample[0], f"complete_model_random", english2gold, smooth=True, maxnorm=True)
        tas, ias = plot(m18, resf_random_pos_nn4, nn_corresponding_indices[index], f"random_pos_model_nn4", image=imagename, smooth=True, maxnorm=True)

In [None]:
idx = 7
tas, ias = plot(m2, res4f_random, idx, smooth=True, maxnorm=True)

In [None]:
tas, ias = plot(m3, res4f_0_remaining, idx, smooth=True, maxnorm=True)

In [None]:
tas, ias = plot(m4, res4f_1_remaining, idx, smooth=True, maxnorm=True)

In [None]:
tas, ias = plot(m5, res4f_2_remaining, idx, smooth=True, maxnorm=True)

In [None]:
tas, ias = plot(m6, res4f_3_remaining, idx, smooth=True, maxnorm=True)

In [None]:
tas, ias = plot(m7, res4f_4_remaining, idx, smooth=True, maxnorm=True)

In [None]:
tas, ias = plot(m8, res4f_6_remaining, idx, smooth=True, maxnorm=True)

In [None]:
tas, ias = plot(m9, res4f_12_remaining, idx, smooth=True, maxnorm=True)

In [None]:
tas, ias = plot(m10, res4f_20_remaining, idx, smooth=True, maxnorm=True)

In [None]:
imagename = plot(m1, res4f, idx, english2goldstandard=english2gold, smooth=True, maxnorm=True)
tas, ias = plot(m1, res4f, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m9, res4f_12_remaining, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m3, res4f_0_remaining, idx, image=imagename, smooth=True, maxnorm=True)

In [None]:
idx = 678
nn_idx = 225
imagename = plot(m1, res4f, idx, english2goldstandard=english2gold, smooth=True, maxnorm=True)
tas, ias = plot(m2, res4f_random, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m11, res4f_random_pos, idx, image=imagename, smooth=True, maxnorm=True)
#tas, ias = plot(m13, res4f_random_pos_nn, idx, image=imagename, smooth=True, maxnorm=True)
#tas, ias = plot(m14, res4f_random_pos_vb, idx, image=imagename, smooth=True, maxnorm=True)
#tas, ias = plot(m15, res4f_random_pos_jj, idx, image=imagename, smooth=True, maxnorm=True)
#tas, ias = plot(m16, res4f_random_pos_prp, idx, image=imagename, smooth=True, maxnorm=True)
#tas, ias = plot(m17, res4f_random_pos_in, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m18, resf_random_pos_nn4, nn_idx, image=imagename, smooth=True, maxnorm=True)

plot(m1, res4f, idx, smooth=True, maxnorm=True)

In [None]:
idx = 138
imagename = plot(m1, res4f, idx, english2goldstandard=english2gold, smooth=True, maxnorm=True)
tas, ias = plot(m3, res4f_0_remaining, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m4, res4f_1_remaining, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m5, res4f_2_remaining, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m6, res4f_3_remaining, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m7, res4f_4_remaining, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m8, res4f_6_remaining, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m9, res4f_12_remaining, idx, image=imagename, smooth=True, maxnorm=True)
tas, ias = plot(m10, res4f_20_remaining, idx, image=imagename, smooth=True, maxnorm=True)
plot(m1, res4f, idx, smooth=True, maxnorm=True)