## Imports

Import things needed for Tensorflow and CoreML

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __builtin__ import any as b_any

import math
import os
os.environ["CUDA_VISIBLE_DEVICES"]=""
import numpy as np
from PIL import Image

import tensorflow as tf

import configuration
import inference_wrapper
import sys
sys.path.insert(0, 'im2txt/inference_utils')
sys.path.insert(0, 'im2txt/ops')
import caption_generator
import image_processing
import vocabulary

import urllib, os, sys, zipfile
from os.path import dirname
from tensorflow.core.framework import graph_pb2
from tensorflow.python.tools.freeze_graph import freeze_graph
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import gfile
import tfcoreml
import configuration
from coremltools.proto import NeuralNetwork_pb2

## Create the models

Create the Tensorflow model and strip all unused nodes

In [2]:
checkpoint_file = './trainlogIncNEW/model.ckpt'
pre_frozen_model_file = './frozen_model_textgenNEW.pb'
frozen_model_file = './frozen_model_textgenNEW.pb'

# Which nodes we want to input for the network
# Use ['image_feed'] for just Memeception
input_node_names = ['seq_embeddings','lstm/state_feed']

# Which nodes we want to output from the network
# Use ['lstm/initial_state'] for just Memeception
output_node_names = ['softmax','lstm/state']

In [5]:
# Build the inference graph.

g = tf.Graph()
with g.as_default():
    model = inference_wrapper.InferenceWrapper()
    restore_fn = model.build_graph_from_config(configuration.ModelConfig(),
                                               checkpoint_file)
g.finalize()

INFO:tensorflow:Building model.
About to decide if splitting
{'num_or_size_splits': 4, 'value': <tf.Tensor 'lstm/basic_lstm_cell/BiasAdd:0' shape=(1, 2048) dtype=float32>, 'axis': <tf.Tensor 'lstm/basic_lstm_cell/Const:0' shape=() dtype=int32>}


KeyError: 'value'

In [None]:
# Write the graph

tf_model_path = './log/pre_graph_textgenNEW.pb'
tf.train.write_graph(
    g,
    './log',
    'pre_graph_textgenNEW.pb',
    as_text=False,
)

with open(tf_model_path, 'rb') as f:
    serialized = f.read()
tf.reset_default_graph()
original_gdef = tf.GraphDef()
original_gdef.ParseFromString(serialized)

In [None]:
# Strip unused graph elements and serialize the output to file

gdef = strip_unused_lib.strip_unused(
        input_graph_def = original_gdef,
        input_node_names = input_node_names,
        output_node_names = output_node_names,
        placeholder_type_enum = dtypes.float32.as_datatype_enum)
# Save it to an output file
with gfile.GFile(pre_frozen_model_file, 'wb') as f:
    f.write(gdef.SerializeToString())

In [None]:
# Freeze the graph with checkpoint data inside

freeze_graph(input_graph=pre_frozen_model_file,
             input_saver='',
             input_binary=True,
             input_checkpoint=checkpoint_file,
             output_node_names=','.join(output_node_names),
             restore_op_name='save/restore_all',
             filename_tensor_name='save/Const:0',
             output_graph=frozen_model_file,
             clear_devices=True,
             initializer_nodes='')

## Verify the model

Check that it is producing legit captions for *One does not simply*

In [None]:
# Configure the model and load the vocab

config = configuration.ModelConfig()

vocab_file ='vocab4.txt'
vocab = vocabulary.Vocabulary(vocab_file)

In [None]:
# Generate captions on a hard-coded image

with tf.Session(graph=g) as sess:
  restore_fn(sess)
  generator = caption_generator.CaptionGenerator(
      model, vocab, beam_size=config.beam_size)
  for i,filename in enumerate(['memes/one-does-not-simply.jpg']):
    with tf.gfile.GFile(filename, "rb") as f:
      image = Image.open(f)
      image = ((np.array(image.resize((299,299)))/255.0)-0.5)*2.0
    for k in range(5):
      captions = generator.beam_search(sess, image)    
      for i, caption in enumerate(captions):
        sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
        sentence = " ".join(sentence)
      print(sentence)

## Convert the model to CoreML

Specify output variables from the graph to be used

In [None]:
# Set the depth of the beam search
beam_size = 2

# Define basic shapes
# If using Memeception, add 'image_feed:0': [299, 299, 3]
input_tensor_shapes = {
    'seq_embeddings:0': [1, beam_size, 300],
    'lstm/state_feed:0': [1, beam_size, 1024],
}

coreml_model_file = './Textgen_NEW.mlmodel'

In [None]:
output_tensor_names = [node + ':0' for node in output_node_names]

In [None]:
coreml_model = tfcoreml.convert(
        tf_model_path=frozen_model_file, 
        mlmodel_path=coreml_model_file, 
        input_name_shape_dict=input_tensor_shapes,
        output_feature_names=output_tensor_names,
        add_custom_layers=True,
)