In [None]:
import os
import tensorflow as tf
import numpy as np
import cPickle
import math
from configuration import *
from caption_gen.CaptionWraper import *
from skipthought import encoder_manager
from storyteller.inference_utils import vocabulary
from storyteller.inference_utils import story_generator
from storyteller import inference_wrapper

In [None]:
os.environ["THEANO_FLAGS"] = "device=cuda0"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
config_hardware = tf.ConfigProto()
config_hardware.gpu_options.per_process_gpu_memory_fraction = 0.15
tf.set_random_seed(9)

In [None]:
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("stv_vocab", "", "path to vocab used by stv (expanded)")
tf.flags.DEFINE_string("stv_embedding", "", "path stv embeddings of vocabulary")
tf.flags.DEFINE_string("stv_model", "", "path of pretrained skipthought model checkpoint")
tf.flags.DEFINE_string("checkpoint_path", "","checkpoint of story teller decoder.")
tf.flags.DEFINE_string("vocab_file", "", " vocab file generated by build data in story decoder.")
tf.flags.DEFINE_string("book_data_dir", "", "directory of book data")
tf.flags.DEFINE_string("book_category", "", "cateogry of book")
tf.flags.DEFINE_string("image_path", "", "path of image used to generate story")
tf.flags.DEFINE_integer("num_captoins", 100, "number of captoins used to generate story")

tf.logging.set_verbosity(tf.logging.INFO)

In [None]:
vse_config = vse_config()
wrapper = CapgenWrapper(vse_config)
captions = wrapper.get_caption(image_loc=FLAGS.image_path, k=FLAGS.num_captions)

In [None]:
encoder = encoder_manager.EncoderManager(config=config_hardware)

In [None]:
encoder.load_model(model_config=stv_config(),
                   vocabulary_file=FLAGS.stv_vocab,
                   embedding_matrix_file=FLAGS.stv_embedding",
                   checkpoint_path=FLAGS.stv_model)

In [None]:
vocab = vocabulary.Vocabulary(vocab_file=FLAGS.vocab_file, start_word='<sos>', end_word='<eos>',unk_word='<unk>')

In [None]:
storyteller_config = storyteller_config()
storyteller_config.vocab_size = len(vocab.vocab)

In [None]:
g = tf.Graph()
with g.as_default():
    model = inference_wrapper.InferenceWrapper()
    restore_fn = model.build_graph_from_config(storyteller_config,
                                               FLAGS.checkpoint_path)
g.finalize()

In [None]:
description = np.mean(encoder.encode(captions),axis=0)

In [None]:
encoder.close()

In [None]:
with open("./style_bias/bias_cap.pkl", 'r') as handle:
    bias_source = cPickle.load(handle)
with open("./style_bias/bias_advent_long_100.pkl", 'r') as handle:
    bias_target_long = cPickle.load(handle)
with open("./style_bias/bias_advent_short_100.pkl", 'r') as handle:
    bias_target_short = cPickle.load(handle)

In [None]:
description_style = description - bias_source + bias_target_long

In [None]:
sess = tf.Session(graph=g,config=config_hardware)

In [None]:
restore_fn(sess)
generator = story_generator.StoryGenerator(model, vocab, max_caption_length=100, beam_size=10)



stories = generator.beam_search(sess, np.expand_dims(description_style,0))
#captions = generator.beam_search(sess, stv_embed)
for i, story in enumerate(stories):
    # Ignore begin and end words.
    sentence = [vocab.id_to_word(w) for w in story.sentence[1:-1]]
    sentence = " ".join(sentence)
    print("  %d) %s (p=%f)" % (i, sentence, math.exp(story.logprob)))