# Sample captions & visualize attention

In [None]:
import sys; sys.argv=['']; del sys
import numpy as np
import os
from args import get_parser
from utils.dataloader import DataLoader
from utils.config import get_opt
from utils.lang_proc import idx2word, sample, beamsearch
from model import get_model
import pickle

In [None]:
parser = get_parser()
args_dict = parser.parse_args()
args_dict.mode = 'test'
args_dict.bs = 1

Choose model weights to load:

In [None]:
args_dict.model_file = '/work/asalvador/sat_keras/models/model_weights.07-3.16.h5'

Initialize model and load weights:

In [None]:
model = get_model(args_dict)
opt = get_opt(args_dict)

weights = args_dict.model_file
model.load_weights(weights)

model.compile(optimizer=opt,loss='categorical_crossentropy')

In [None]:
model.summary()

Create the extractor that outputs attention weights and class probabilities

In [None]:
from keras.models import Model
att_layer = 'timedistributed_7'
extractor = Model(input=model.input,output=[model.output,model.get_layer(att_layer).output])

Load the word vocabulary to map idxs to words:

In [None]:
vocab_file = os.path.join(args_dict.data_folder,'data',args_dict.vfile)
vocab = pickle.load(open(vocab_file,'rb'))
inv_vocab = {v:k for k,v in vocab.items()}

Initialize the generator to yield samples:

In [None]:
dataloader = DataLoader(args_dict)
N = args_dict.bs
val_gen = dataloader.generator('val',batch_size=args_dict.bs,train_flag=False) # N samples

In [None]:
from utils.im_proc import process_image, center_crop
import matplotlib.pyplot as plt  
%matplotlib inline
import os
from scipy.ndimage.interpolation import zoom
from scipy.ndimage.filters import gaussian_filter
import copy

IMPATH = os.path.join(args_dict.coco_path,'images','val' + args_dict.year)

count = 0
sig = 5
th = 0.3
for im,cap,imids in val_gen:
    
    atts = np.zeros((args_dict.seqlen, args_dict.imsize,args_dict.imsize))
    
    prevs = np.ones((N,1))
    word_idxs = np.zeros((N,args_dict.seqlen))
    imname = imids[0]['file_name']
    
    img = process_image(os.path.join(IMPATH,imname),args_dict.resize)
    img = center_crop(img,args_dict.imsize)

    for i in range(args_dict.seqlen):
        # get predictions
        preds,att = extractor.predict([im,prevs]) #(N,1,vocab_size)
        
        s_att = np.shape(att)[-1]
        s = int(np.sqrt(s_att))
        att = np.reshape(att,(s,s))
        att = zoom(att,float(img.shape[0])/att.shape[-1],order=1)
        att = gaussian_filter(att,sigma=sig)
        att[att>=th] = 1
        att[att<th] = 0.5
        atts[i] = att

        preds = preds.squeeze()

        word_idxs[:,i] = np.argmax(preds,axis=-1)
        prevs = np.argmax(preds,axis=-1)
        prevs = np.reshape(prevs,(N,1))

    pred_caps = idx2word(word_idxs,inv_vocab)
    true_caps = idx2word(np.argmax(cap,axis=-1),inv_vocab)
    
    n_words = len(pred_caps[0])
    
    f, axarr = plt.subplots(1,n_words,figsize=(30,30))
    
    for i in range(n_words):
        im = copy.deepcopy(img)
        for c in range(3):
            im[:,:,c] = im[:,:,c]*atts[i]
        axarr[i].imshow(im)
        axarr[i].axis('off')
        axarr[i].set_title(pred_caps[0][i])
    
    plt.show()
    
    pred_cap = ' '.join(pred_caps[0])
    true_cap = ' '.join(true_caps[0])
    
    # true captions
    print ("ID:", imids[0]['file_name'])
    print ("True:", true_cap)
    print ("Gen:", pred_cap)
    print ("-"*10)

    model.reset_states()
    count+=1
    if count > 10:
        break