In [1]:
%matplotlib inline

In [132]:
import numpy as np
import matplotlib.pyplot as plt
import random

import sys
sys.path.append('./python')
import caffe

sys.path.append('./examples/coco_caption')

In [3]:
!head examples/coco_caption/h5_data/buffer_100/vocabulary.txt

<unk>
a
on
of
the
in
with
and
is
man


In [58]:
vocabulary = ['<EOS>'] + [line.strip() for line in
                          open('examples/coco_caption/h5_data/buffer_100/vocabulary.txt').readlines()]
print len(vocabulary)

8801


In [59]:
iter_num = 110000
net = caffe.Net('./examples/coco_caption/lstm_lm.deploy.prototxt',
                './examples/coco_caption/lstm_lm_iter_%d.caffemodel' % iter_num, caffe.TEST)
print net.blobs['probs'].data.shape

(1, 1, 8801)


In [61]:
def predict_single_word(net, previous_word, output='probs'):
    cont = 0 if previous_word == 0 else 1
    cont_input = np.array([cont])
    word_input = np.array([previous_word])
    net.forward(cont_sentence=cont_input, input_sentence=word_input)
    output_preds = net.blobs[output].data[0, 0, :]
    return output_preds

In [64]:
first_word_dist = predict_single_word(net, 0)

In [65]:
top_preds = np.argsort(-1 * first_word_dist)

In [66]:
print top_preds[:10]
print [vocabulary[index] for index in top_preds[:10]]

[  2  14   5  13  64  77  30  18  93 142]
['a', 'two', 'the', 'an', 'there', 'three', 'some', 'people', 'several', 'this']


In [68]:
second_word_dist = predict_single_word(net, vocabulary.index('two'))
print [vocabulary[index] for index in np.argsort(-1 * second_word_dist)[:10]]

['people', 'men', 'women', 'giraffes', 'zebras', 'young', 'cats', 'elephants', 'horses', 'children']


In [69]:
third_word_dist = predict_single_word(net, vocabulary.index('giraffes'))
print [vocabulary[index] for index in np.argsort(-1 * second_word_dist)[:10]]

['standing', 'are', 'in', 'stand', 'walking', 'and', 'eating', 'that', 'walk', 'with']


In [70]:
third_word_dist = predict_single_word(net, vocabulary.index('eating'))
print [vocabulary[index] for index in np.argsort(-1 * second_word_dist)[:10]]

['leaves', 'from', 'grass', 'hay', 'out', 'some', 'in', 'food', 'off', 'a']


In [136]:
def softmax(softmax_inputs, temp):
    shifted_inputs = softmax_inputs - softmax_inputs.max()
    exp_outputs = np.exp(temp * shifted_inputs)
    exp_outputs_sum = exp_outputs.sum()
    if np.isnan(exp_outputs_sum):
        return exp_outputs * float('nan')
    assert exp_outputs_sum > 0
    if np.isinf(exp_outputs_sum):
        return np.zeros_like(exp_outputs)
    eps_sum = 1e-20
    return exp_outputs / max(exp_outputs_sum, eps_sum)

def random_choice_from_probs(softmax_inputs, temp=1):
    # temperature of infinity == take the max
    if temp == float('inf'):
        return np.argmax(softmax_inputs)
    probs = softmax(softmax_inputs, temp)
    r = random.random()
    cum_sum = 0.
    for i, p in enumerate(probs):
        cum_sum += p
        if cum_sum >= r: return i
    return 1  # return UNK?

In [120]:
def generate_sentence(net, temp=float('inf'), output='predict', max_words=50):
    cont_input = np.array([0])
    word_input = np.array([0])
    sentence = []
    while len(sentence) < max_words and (not sentence or sentence[-1] != 0):
        net.forward(cont_sentence=cont_input, input_sentence=word_input)
        output_preds = net.blobs[output].data[0, 0, :]
        sentence.append(random_choice_from_probs(output_preds, temp=temp))
        cont_input[0] = 1
        word_input[0] = sentence[-1]
    return sentence

In [121]:
sentence = generate_sentence(net)
print sentence
print [vocabulary[index] for index in sentence]

[2, 10, 9, 16, 6, 2, 35, 7, 2, 118, 0]
['a', 'man', 'is', 'standing', 'in', 'a', 'field', 'with', 'a', 'frisbee', '<EOS>']


In [122]:
sentence = generate_sentence(net)
print sentence
print [vocabulary[index] for index in sentence]

[2, 10, 9, 16, 6, 2, 35, 7, 2, 118, 0]
['a', 'man', 'is', 'standing', 'in', 'a', 'field', 'with', 'a', 'frisbee', '<EOS>']


In [137]:
sentence = generate_sentence(net, temp=1.0)
print sentence
print [vocabulary[index] for index in sentence]

[2, 22, 9, 294, 7, 2, 178, 113, 11, 87, 905, 0]
['a', 'woman', 'is', 'posing', 'with', 'a', 'cell', 'phone', 'to', 'her', 'ear', '<EOS>']


In [138]:
sentence = generate_sentence(net, temp=1.0)
print sentence
print [vocabulary[index] for index in sentence]

[2, 28, 26, 2, 38, 209, 3, 2, 38, 152, 0]
['a', 'person', 'holding', 'a', 'tennis', 'racket', 'on', 'a', 'tennis', 'court', '<EOS>']


In [139]:
sentence = generate_sentence(net, temp=1.5)
print sentence
print [vocabulary[index] for index in sentence]

[2, 10, 26, 2, 38, 363, 3, 2, 38, 152, 0]
['a', 'man', 'holding', 'a', 'tennis', 'racquet', 'on', 'a', 'tennis', 'court', '<EOS>']


In [140]:
sentence = generate_sentence(net, temp=1.5)
print sentence
print [vocabulary[index] for index in sentence]

[2, 33, 4, 18, 12, 106, 2, 23, 7, 60, 0]
['a', 'group', 'of', 'people', 'sitting', 'around', 'a', 'table', 'with', 'food', '<EOS>']


In [141]:
sentence = generate_sentence(net, temp=3.0)
print sentence
print [vocabulary[index] for index in sentence]

[2, 10, 6, 2, 261, 8, 217, 16, 6, 2, 43, 0]
['a', 'man', 'in', 'a', 'suit', 'and', 'tie', 'standing', 'in', 'a', 'room', '<EOS>']


In [142]:
sentence = generate_sentence(net, temp=3.0)
print sentence
print [vocabulary[index] for index in sentence]

[2, 10, 26, 2, 38, 363, 3, 2, 38, 152, 0]
['a', 'man', 'holding', 'a', 'tennis', 'racquet', 'on', 'a', 'tennis', 'court', '<EOS>']


In [143]:
sentence = generate_sentence(net, temp=10.0)
print sentence
print [vocabulary[index] for index in sentence]

[2, 10, 9, 16, 6, 2, 35, 7, 2, 118, 0]
['a', 'man', 'is', 'standing', 'in', 'a', 'field', 'with', 'a', 'frisbee', '<EOS>']


In [144]:
sentence = generate_sentence(net, temp=1.0)
print sentence
print [vocabulary[index] for index in sentence]

[1993, 1074, 86, 6, 40, 4, 2, 126, 0]
['staircase', 'laid', 'out', 'in', 'front', 'of', 'a', 'window', '<EOS>']


In [146]:
sentence = generate_sentence(net, temp=0.8)
print sentence
print [vocabulary[index] for index in sentence]

[2, 28, 3, 2, 113, 46, 2, 129, 0]
['a', 'person', 'on', 'a', 'phone', 'riding', 'a', 'car', '<EOS>']


In [147]:
sentence = generate_sentence(net, temp=0.8)
print sentence
print [vocabulary[index] for index in sentence]

[2, 16, 60, 6, 136, 192, 7, 641, 16, 20, 11, 27, 0]
['a', 'standing', 'food', 'in', 'each', 'hand', 'with', 'cattle', 'standing', 'next', 'to', 'it', '<EOS>']


In [148]:
sentence = generate_sentence(net, temp=0.6)
print sentence
print [vocabulary[index] for index in sentence]

[28, 236, 1042, 7, 69, 1257, 487, 1769, 0]
['person', 'taking', 'noodles', 'with', 'other', 'homemade', 'birthday', 'cereal', '<EOS>']


In [145]:
sentence = generate_sentence(net, temp=0.5)
print sentence
print [vocabulary[index] for index in sentence]

[5623, 1087, 15, 6888, 472, 361, 8634, 8, 7241, 3, 77, 299, 935, 1296, 15, 12, 5165, 2867, 3979, 743, 4991, 4470, 640, 9, 259, 2308, 4386, 2552, 3797, 2448, 15, 3617, 5364, 4267, 4549, 8086, 176, 2529, 6434, 5445, 370, 7959, 5672, 1742, 4041, 4258, 1153, 8, 610, 2044]
['chilli', 'frosting', ',', 'medley', 'salad', 'items', 'sideboard', 'and', 'garnishes', 'on', 'three', 'colorful', 'gold', 'desserts', ',', 'sitting', 'knifes', 'need', 'workspace', 'where', 'exchanging', 'hoses', 'left', 'is', 'pink', 'clearing', 'obstacles', 'vandalized', 'idly', 'afternoon', ',', 'halloween', 'rich', 'fixed', 'aid', 'advertise', 'light', 'times', 'delicate', 'dealership', 'like', 'snowsuits', 'florida', 'than', 'ornamental', 'dr', 'curtains', 'and', 'multiple', 'electrical']
