# Deep Meme Caption Generator
<img src="meme_characters/futurama-fry/futurama-fry.jpg" align="left">
<img src="meme_characters/philosoraptor/philosoraptor.jpg" align="right">
<img src="meme_characters/y-u-no/y-u-no.jpg" align="center">

In [1]:
%matplotlib inline
import os
import math
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import configuration as config
from utils.vocabulary import Vocabulary
from utils.caption_generator import CaptionGenerator
from utils.inception_v3 import preprocess_input
from model import MemeModel
from keras.models import load_model
from keras.preprocessing import image
from utils.inception_v3 import preprocess_input

tf.logging.set_verbosity(tf.logging.INFO)

Using TensorFlow backend.


In [2]:
import keras
keras.__version__

'2.0.6'

In [3]:
# mod.layers.pop()
tf.__version__

'1.2.1'

## Initial arguments

In [4]:
checkpoint_path = 'model/train/'
vocab_file = '10k/word_count.txt'
dataset_dir = 'meme_characters/'

## Wrapper functions to generate captions

In [5]:
def build_model(dataset_dir, image_format='jpeg'):
    model = MemeModel('inference', vocab_file, dataset_dir=dataset_dir)
    model.build(image_format)
    return model

def feed_image(sess, encoded_image):
    initial_state = sess.run(fetches="lstm/initial_state:0",
                             feed_dict={"image_feed:0": encoded_image})
    return initial_state

def inference_step(sess, input_feed, state_feed):
    softmax_output, state_output = sess.run(
        fetches=["softmax:0", "lstm/state:0"],
        feed_dict={
            "input_feed:0": input_feed,
            "lstm/state_feed:0": state_feed,
        })
    return softmax_output, state_output, None

## Build and restore model functions

In [None]:
# Creates a function that restores a model from checkpoint
def create_restore_fn(checkpoint_path, saver):
    if tf.gfile.IsDirectory(checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
        if not checkpoint_path:
            raise ValueError("No checkpoint file found in: %s" % checkpoint_path)

    def _restore_fn(sess):
        tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
        saver.restore(sess, checkpoint_path)
        tf.logging.info("Successfully loaded checkpoint: %s",
                        os.path.basename(checkpoint_path))
        
    return _restore_fn

# Builds the inference graph from a configuration object.
def build_graph_from_config(data_dir, checkpoint_path, image_format='jpeg'):
    tf.logging.info("Building model.")
    model = build_model(data_dir, image_format)
    saver = tf.train.Saver(tf.global_variables())
    return create_restore_fn(checkpoint_path, saver), model

## Build model and inference graph

In [None]:
tf.reset_default_graph()
#g = tf.Graph()
#with g.as_default():
restore_fn, mememodel = build_graph_from_config(dataset_dir,
                                                checkpoint_path,
                                                image_format='jpeg')
#g.finalize()

INFO:tensorflow:Building model.


In [None]:
# Create the vocabulary.
vocab = Vocabulary(vocab_file)

## Run caption generation over `input_files`

In [None]:
#mememodel.model.outputs = [mememodel.model.layers[-1].output]
#mememodel.model.layers[-1].outbound_nodes = [
mememodel.model.summary()

In [None]:
from keras.preprocessing import image
from keras.applications.inception_v3 import preprocess_input

img = image.load_img('10k/part-0-to-1000/y-u-no/y-u-no.jpg', target_size=(299, 299))
x = image.img_to_array(img)
print(x.shape)
x = np.expand_dims(x, axis=0)
print(x.shape)
x = preprocess_input(x)
print(x.shape)
preds = mememodel.model.predict(x)
print(preds.shape)
#print(mememodel.model.summary())

In [None]:
sess = tf.InteractiveSession()
# Load the model from checkpoint.
restore_fn(sess)
# end = mememodel.model.layers[-1]
# end.output_dim

Prepare the caption generator. Here we are implicitly using the default beam search parameters.
See [`caption_generator.py`](utils/caption_generator.py) for a description of the
available beam search parameters.

In [None]:
generator = CaptionGenerator(feed_image, 
                             inference_step, 
                             vocab,
                             max_caption_length=10)

## CAPTION IMAGES!!

In [None]:
# Meme paths to be captioned
input_files = ['meme_characters/american-pride-eagle/american-pride-eagle.jpg']

input_files += list(map(lambda f: os.path.join('test', f), os.listdir('test')))

In [None]:
for filename in input_files:
    if os.path.exists(filename):
        # Display image
        im = mpimg.imread(filename)
        plt.imshow(im)
        plt.figure()
        
        # Caption image
        img = image.load_img(filename, target_size=(299, 299))
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        #x = preprocess_input(x)
        preds = mememodel.model.predict(x)
        captions = generator.beam_search(sess, preds)
        print("Captions for image %s:" % os.path.basename(filename))
        for i, caption in enumerate(captions):
            # Ignore begin and end words.
            print('raw sentence:', caption.sentence[1:-1])
            sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
            sentence = " ".join(sentence)
            print("  %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob)))


In [None]:
original_image = 'meme_characters/one-does-not-simply/one-does-not-simply.jpg'
img = image.load_img(original_image, target_size=(299, 299))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
#x = preprocess_input(x)
preds = mememodel.model.predict(x)
#preds = np.zeros((1, 1000))
captions = generator.beam_search(sess, preds)
# Display image
im = mpimg.imread(original_image)
plt.imshow(im)
plt.figure()
print("Captions for image %s:" % os.path.basename(original_image))
for i, caption in enumerate(captions):
    # Ignore begin and end words.
    print('raw sentence:', caption.sentence[1:-1])
    sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
    sentence = " ".join(sentence)
    print("  %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob)))

In [None]:
np.count_nonzero(preds)

In [None]:
np.shape([x for x in preds[0] if x >= 0.15])