In [1]:
import argparse
import sys
import random
import time
import imp
import os
import pickle as cPickle
import tensorflow as tf
import numpy as np
import gensim.models as g
#from util import *
# from sonnet_model import SonnetModel
from sklearn.metrics import roc_auc_score
from nltk.corpus import cmudict
from tqdm import tqdm

# constants
pad_symbol = "<pad>"
end_symbol = "<eos>"
unk_symbol = "<unk>"
dummy_symbols = [pad_symbol, end_symbol, unk_symbol]
rhyme_thresholds = [0.9, 0.8, 0.7, 0.6]

In [2]:
%%writefile config.py
###preprocessing options###
word_minfreq=3

###hyper-parameters###
seed=0
batch_size=4
keep_prob=0.7
epoch_size=1
max_grad_norm=5
#language model
word_embedding_dim=100
word_embedding_model="pretrain_word2vec/dim100/word2vec.bin"
lm_enc_dim=200
lm_dec_dim=600
lm_dec_layer_size=1
lm_attend_dim=25
lm_learning_rate=0.2
#pentameter model
char_embedding_dim=150
pm_enc_dim=50
pm_dec_dim=200
pm_attend_dim=50
pm_learning_rate=0.001
repeat_loss_scale=1.0
cov_loss_scale=1.0
cov_loss_threshold=0.7
sigma=1.00
#rhyme model
rm_dim=100
rm_neg=5 #extra randomly sampled negative examples
rm_delta=0.5
rm_learning_rate=0.001

###sonnet hyper-parameters###
bptt_truncate=2 #number of sonnet lines to truncate bptt
doc_lines=14 #total number of lines for a sonnet

###misc###
verbose=False
save_model=True

###input/output###
output_dir="output"
train_data="datasets/gutenberg/debug.txt"#"datasets/gutenberg/sonnet_train.txt"
valid_data="datasets/gutenberg/debug.txt"#"datasets/gutenberg/sonnet_valid.txt"
test_data="datasets/gutenberg/debug.txt"#"datasets/gutenberg/sonnet_test.txt"
output_prefix="wmin%d_sd%d_bat%d_kp%.1f_eph%d_grd%d_wdim%d_lmedim%d_lmddim%d_lmdlayer%d_lmadim%d_lmlr%.1f_cdim%d_pmedim%d_pmddim%d_pmadim%d_pmlr%.1E_loss%.1f-%.1f-%.1f_sm%.2f_rmdim%d_rmn%d_rmd%.1f_rmlr%.1E_son%d-%d" % \
    (word_minfreq, seed, batch_size, keep_prob, epoch_size, max_grad_norm, word_embedding_dim, lm_enc_dim,
    lm_dec_dim, lm_dec_layer_size, lm_attend_dim, lm_learning_rate,
    char_embedding_dim, pm_enc_dim, pm_dec_dim, pm_attend_dim, pm_learning_rate, repeat_loss_scale,
    cov_loss_scale, cov_loss_threshold, sigma, rm_dim, rm_neg, rm_delta, rm_learning_rate, bptt_truncate, doc_lines)


Overwriting config.py


In [3]:
import config as cf

In [4]:
from models import sonnet_nmt_model

In [5]:
from utils.utils_loaders import *

In [6]:
def remove_punct(string):
    return " ".join("".join([ item for item in string if (item.isalpha() or item == " ") ]).split())


In [7]:
import codecs
import operator
import numpy as np
import random
import math
import codecs
import sys
from collections import defaultdict


def load_vocab(corpus, word_minfreq, dummy_symbols):
    idxword, idxchar = [], []
    wordxid, charxid = defaultdict(int), defaultdict(int)
    word_freq, char_freq = defaultdict(int), defaultdict(int)
    wordxchar = defaultdict(list)

    def update_dic(symbol, idxvocab, vocabxid):
        if symbol not in vocabxid:
            idxvocab.append(symbol)
            vocabxid[symbol] = len(idxvocab) - 1 

    for line_id, line in enumerate(codecs.open(corpus, "r", "utf-8")):
        for word in line.strip().split():
            word_freq[word] += 1
        for char in line.strip():
            char_freq[char] += 1

    #add in dummy symbols into dictionaries
    for s in dummy_symbols:
        update_dic(s, idxword, wordxid)
        update_dic(s, idxchar, charxid)

    #remove low fequency words/chars
    def collect_vocab(vocab_freq, idxvocab, vocabxid):
        for w, f in sorted(list(vocab_freq.items()), key=operator.itemgetter(1), reverse=True):
            if f < word_minfreq:
                break
            else:
                update_dic(w, idxvocab, vocabxid)

    collect_vocab(word_freq, idxword, wordxid)
    collect_vocab(char_freq, idxchar, charxid)

    #word id to [char ids]
    dummy_symbols_set = set(dummy_symbols)
    for wi, w in enumerate(idxword):
        if w in dummy_symbols:
            wordxchar[wi] = [wi]
        else:
            for c in w:
                wordxchar[wi].append(charxid[c] if c in charxid else charxid[dummy_symbols[2]])

    return idxword, wordxid, idxchar, charxid, wordxchar



In [8]:
# load vocab
print("\nFirst pass to collect word and character vocabulary...")
idxword, wordxid, idxchar, charxid, wordxchar = load_vocab(cf.train_data, cf.word_minfreq, dummy_symbols)
print("\nWord type size =", len(idxword))
print("\nChar type size =", len(idxchar))



First pass to collect word and character vocabulary...

Word type size = 22

Char type size = 29


In [9]:
global wordxid, idxword, charxid, idxchar, \
wordxchar, train_lm, train_pm, train_rm

random.seed(cf.seed)
np.random.seed(cf.seed)

if cf.word_embedding_model:
    print("\nLoading word embedding model...")
    mword = g.Word2Vec.load(cf.word_embedding_model)
    cf.word_embedding_dim = mword.vector_size

# load vocab
print("\nFirst pass to collect word and character vocabulary...")
idxword, wordxid, idxchar, charxid, wordxchar = load_vocab(cf.train_data, cf.word_minfreq, dummy_symbols)
print("\nWord type size =", len(idxword))
print("\nChar type size =", len(idxchar))


# load train and valid data
print("\nLoading train and valid data...")

train_word_data, train_char_data, train_rhyme_data, train_nwords, train_nchars = \
    load_data(cf.train_data, wordxid, idxword, charxid, idxchar, pad_symbol, end_symbol, unk_symbol)
valid_word_data, valid_char_data, valid_rhyme_data, valid_nwords, valid_nchars = \
    load_data(cf.valid_data, wordxid, idxword, charxid, idxchar, pad_symbol, end_symbol, unk_symbol)
print_stats("\nTrain", train_word_data, train_rhyme_data, train_nwords, train_nchars)
print_stats("\nValid", valid_word_data, valid_rhyme_data, valid_nwords, valid_nchars)

# load test data if it's given
if cf.test_data:
    test_word_data, test_char_data, test_rhyme_data, test_nwords, test_nchars = \
        load_data(cf.test_data, wordxid, idxword, charxid, idxchar, pad_symbol, end_symbol, unk_symbol)
    print_stats("\nTest", test_word_data, test_rhyme_data, test_nwords, test_nchars)    
    


Loading word embedding model...

First pass to collect word and character vocabulary...

Word type size = 22

Char type size = 29

Loading train and valid data...

Train statistics:
  Number of documents         = 2
  Number of rhyme examples    = 24
  Total number of word tokens = 278
  Mean/min/max words per line = 9.93/7/13
  Total number of char tokens = 1193
  Mean/min/max chars per line = 42.61/37/51

Valid statistics:
  Number of documents         = 2
  Number of rhyme examples    = 24
  Total number of word tokens = 278
  Mean/min/max words per line = 9.93/7/13
  Total number of char tokens = 1193
  Mean/min/max chars per line = 42.61/37/51

Test statistics:
  Number of documents         = 2
  Number of rhyme examples    = 24
  Total number of word tokens = 278
  Mean/min/max words per line = 9.93/7/13
  Total number of char tokens = 1193
  Mean/min/max chars per line = 42.61/37/51


In [10]:
from models.sonnet_nmt_model import  SonnetModel

In [11]:
def run_epoch(sess, word_batches, char_batches, rhyme_batches, model, pname,
              is_training):
    start_time = time.time()

    # lm variables
    lm_costs = 0.0
    total_words = 0
    zero_state = sess.run(model.lm_initial_state)
    model_state = None
    prev_doc = -1
    lm_train_op = model.lm_train_op if is_training else tf.no_op()

    # pm variables
    #     pm_costs = 0.0
    #     pm_train_op = model.pm_train_op if is_training else tf.no_op()

    #     # rm variables
    #     rm_costs = 0.0
    #     rm_train_op = model.rm_train_op if is_training else tf.no_op()

    # mix lm and pm batches
    mixed_batch_types = [0] * len(word_batches) + [1] * len(
        char_batches) + [2] * len(rhyme_batches)
    random.shuffle(mixed_batch_types)
    mixed_batches = [word_batches, char_batches, rhyme_batches]

    word_batch_id = 0
    char_batch_id = 0
    rhyme_batch_id = 0

    # cmu pronounciation dictionary for stress and rhyme evaluation
    cmu = cmudict.dict()

    # stress prediction
    stress_acc = [[], [], []]  # buckets for char length: [1-4], [5-8], [9-inf]

    # rhyme predition
    rhyme_pr = {}  # precision/recall for each rhyme threshold
    for rt in rhyme_thresholds:
        rhyme_pr[rt] = [[], []]

    # rhyme pattern
    rhyme_pattern = []  # collection of cosine similarities
    for i in range(12):
        rhyme_pattern.append([])
        for j in range(4):
            if i % 4 == j:
                rhyme_pattern[i].append([-2.0])
            else:
                rhyme_pattern[i].append([])

    print('stress_acc', stress_acc)
    print('rhyme_pr', rhyme_pr)
    print('rhyme_pattern', rhyme_pattern)

    print('mixed_batch_types', mixed_batch_types)
    print('len mixed_batch_types', len(mixed_batch_types))

    for bi, batch_type in tqdm(enumerate(mixed_batch_types)):

        if batch_type == 0 and train_lm:

            b = mixed_batches[batch_type][word_batch_id]

            # reset model state if it's a different set of documents
            if prev_doc != b[2][0]:
                model_state = zero_state
                prev_doc = b[2][0]

            # preprocess character input to [batch_size*doc_len, char_len]
            pm_enc_x = np.array(b[5]).reshape((cf.batch_size * max(b[3]), -1))

            feed_dict = {
                model.lm_x: b[0],
                model.lm_y: b[1],
                model.lm_xlen: b[3],
                model.pm_enc_x: pm_enc_x,
                model.pm_enc_xlen: np.array(b[6]).reshape((-1)),
                model.lm_initial_state: model_state,
                model.lm_hist: b[7],
                model.lm_hlen: b[8]
            }

            cost, model_state, attns, _ = sess.run([
                model.lm_cost, model.lm_final_state, model.lm_attentions,
                lm_train_op
            ], feed_dict)

            lm_costs += cost * cf.batch_size  # keep track of full cost
            total_words += sum(b[3])

            word_batch_id += 1


#         elif batch_type == 1 and train_pm:

#             b = mixed_batches[batch_type][char_batch_id]

#             feed_dict = {model.pm_enc_x: b[0], model.pm_enc_xlen: b[1], model.pm_cov_mask: b[2]}
#             cost, attns, _, = sess.run([model.pm_mean_cost, model.pm_attentions, pm_train_op], feed_dict)
#             pm_costs += cost

#             char_batch_id += 1

#             if not is_training:
#                 eval_stress(stress_acc, cmu, attns, model.pentameter, b[0], idxchar, charxid, pad_symbol, cf)

#         elif batch_type == 2 and train_rm:

#             b = mixed_batches[batch_type][rhyme_batch_id]
#             num_c = 3 + cf.rm_neg

#             feed_dict = {model.pm_enc_x: b[0], model.pm_enc_xlen: b[1], model.rm_num_context: num_c}
#             cost, attns, _ = sess.run([model.rm_cost, model.rm_attentions, rm_train_op], feed_dict)
#             rm_costs += cost

#             rhyme_batch_id += 1

#             # if rhyme_batch_id < 10 and not is_training:
#             #    print_rm_attention(b, cf.batch_size, num_c, attns, charxid[pad_symbol], idxchar)

#             if not is_training:
#                 eval_rhyme(rhyme_pr, rhyme_thresholds, cmu, attns, b, idxchar, charxid, pad_symbol, cf)
#             else:
#                 collect_rhyme_pattern(rhyme_pattern, attns, b, cf.batch_size, num_c, idxchar, charxid[pad_symbol])

        if (((bi % 10) == 0)
                and cf.verbose) or (bi == len(mixed_batch_types) - 1):

            partition = "  " + pname
            sent_end = "\n" if bi == (len(mixed_batch_types) - 1) else "\r"
            speed = (bi + 1) / (time.time() - start_time)

            sys.stdout.write("%s %d/%d: lm ppl = %.1f; pm loss = %.2f; rm loss = %.2f; batch/sec = %.1f%s" % \
                             (partition, bi + 1, len(mixed_batch_types), np.exp(lm_costs / max(total_words, 1)),
                              pm_costs / max(char_batch_id, 1), rm_costs / max(rhyme_batch_id, 1), speed, sent_end))
            sys.stdout.flush()

            if not is_training and (bi == len(mixed_batch_types) - 1):

                if train_pm:
                    all_acc = [
                        item for sublist in stress_acc for item in sublist
                    ]
                    stress_acc.append(all_acc)
                    for acci, acc in enumerate(stress_acc):
                        sys.stdout.write("    Stress acc [%d]   = %.3f (%d)\n"
                                         % (acci, np.mean(acc), len(acc)))

                if train_rm:
                    for t in rhyme_thresholds:
                        p = np.mean(rhyme_pr[t][0])
                        r = np.mean(rhyme_pr[t][1])
                        f = 2 * p * r / (p + r) if (p != 0.0
                                                    and r != 0.0) else 0.0
                        sys.stdout.write(
                            "    Rhyme P/R/F@%.1f  = %.3f / %.3f / %.3f\n" %
                            (t, p, r, f))

                sys.stdout.flush()

    # return avg batch loss for lm, pm and rm
    return lm_costs / max(word_batch_id, 1), pm_costs / max(char_batch_id, 1), rm_costs / max(rhyme_batch_id, 1), \
           rhyme_pattern, (np.mean(stress_acc[-1]) if not is_training else 0.0)

In [12]:
# save model

        #         if cf.save_model:
        #             if not os.path.exists(os.path.join(cf.output_dir, cf.output_prefix)):
        #                 os.makedirs(os.path.join(cf.output_dir, cf.output_prefix))
        #             # create saver object to save model
        #             saver = tf.train.Saver(max_to_keep=0)

        # train model
        
        
# training parameters
train_lm = True
train_pm = True
train_rm = True

pm_costs ,rm_costs= 0.0, 0.0
with tf.Graph().as_default(), tf.Session() as sess:

    tf.set_random_seed(cf.seed)

    with tf.variable_scope("model", reuse=None):
        mtrain = SonnetModel(True, cf.batch_size, len(idxword), len(idxchar),
                             charxid[" "], charxid[pad_symbol], cf)
#     with tf.variable_scope("model", reuse=True):
#         mvalid = SonnetModel(False, cf.batch_size, len(idxword), len(idxchar),
#                              charxid[" "], charxid[pad_symbol], cf)
#     with tf.variable_scope("model", reuse=True):
#         mgen = SonnetModel(False, 1, len(idxword), len(
#             idxchar), charxid[" "], charxid[pad_symbol], cf)

    tf.global_variables_initializer().run()

    # initialise word embedding
    if cf.word_embedding_model:
        word_emb = init_embedding(mword, idxword)
        sess.run(mtrain.word_embedding.assign(word_emb))
    # train model
    prev_lm_loss, prev_pm_loss, prev_rm_loss, rhyme_pattern = None, None, None, None
    for i in range(1):

        print("\nEpoch =", i + 1)

        # create batches for language model
        train_word_batch = create_word_batch(train_word_data, cf.batch_size,
                                             cf.doc_lines, cf.bptt_truncate, wordxid[pad_symbol],
                                             wordxid[end_symbol], wordxid[unk_symbol], True)
#         valid_word_batch = create_word_batch(train_word_data, 1,
#                                              cf.doc_lines, cf.bptt_truncate, wordxid[pad_symbol],
#                                              wordxid[end_symbol], wordxid[unk_symbol], False)

#         # create batches for pentameter model
        train_char_batch = create_char_batch(train_char_data, cf.batch_size,
                                             charxid[pad_symbol], mtrain.pentameter, idxchar, True)
#         valid_char_batch = create_char_batch(valid_char_data, cf.batch_size,
#                                              charxid[pad_symbol], mtrain.pentameter, idxchar, False)

#         # create batches for rhyme model
        train_rhyme_batch = create_rhyme_batch(train_rhyme_data, cf.batch_size, charxid[pad_symbol], wordxchar,
                                               cf.rm_neg, True)
#         valid_rhyme_batch = create_rhyme_batch(valid_rhyme_data, cf.batch_size, charxid[pad_symbol], wordxchar,
#                                                cf.rm_neg, False)

        # train an epoch
        _, _, _, new_rhyme_pattern, _ = run_epoch(sess, train_word_batch, train_char_batch,
                                                  train_rhyme_batch, mtrain, "TRAIN", True)
#         lm_loss, pm_loss, rm_loss, _, sacc = run_epoch(sess, valid_word_batch, valid_char_batch,
#                                                        valid_rhyme_batch, mvalid, "VALID", False)

#         # create batch for test model and run an epoch if it's given
#         if cf.test_data:
#             test_word_batch = create_word_batch(test_word_data, cf.batch_size,
#                                                 cf.doc_lines, cf.bptt_truncate, wordxid[pad_symbol],
#                                                 wordxid[end_symbol], wordxid[unk_symbol], False)
#             test_char_batch = create_char_batch(test_char_data, cf.batch_size,
#                                                 charxid[pad_symbol], mtrain.pentameter, idxchar, False)
#             test_rhyme_batch = create_rhyme_batch(test_rhyme_data, cf.batch_size,
#                                                   charxid[pad_symbol], wordxchar, cf.rm_neg, False)
#             run_epoch(sess, test_word_batch, test_char_batch,
#                       test_rhyme_batch, mvalid, "TEST", False)

#         # if pm performance is really poor, re-initialize network weights
#         if train_pm and sacc < stress_acc_threshold and prev_pm_loss == None:
#             print(
#                 "\n  Valid stress accuracy performance is very poor; re-initializing network with random weights...")
#             tf.global_variables_initializer().run()
#             continue

#         # save model
#         if cf.save_model:
#             if prev_lm_loss == None or prev_pm_loss == None or prev_rm_loss == None or \
#                     ((lm_loss <= prev_lm_loss or not train_lm) and
#                      (pm_loss <= prev_pm_loss * reset_scale or not train_pm) and
#                      (rm_loss <= prev_rm_loss * reset_scale or not train_rm)):
#                 saver.save(sess, os.path.join(
#                     cf.output_dir, cf.output_prefix, "model.ckpt"))
#                 prev_lm_loss, prev_pm_loss, prev_rm_loss = lm_loss, pm_loss, rm_loss
#                 rhyme_pattern = new_rhyme_pattern
#             else:
#                 saver.restore(sess, os.path.join(
#                     cf.output_dir, cf.output_prefix, "model.ckpt"))
#                 print(
#                     "New valid performance is worse; restoring previous parameters...")
#                 print("  lm loss: %.5f --> %.5f" % (prev_lm_loss, lm_loss))
#                 print("  pm loss: %.5f --> %.5f" % (prev_pm_loss, pm_loss))
#                 print("  rm loss: %.5f --> %.5f" % (prev_rm_loss, rm_loss))
#                 sys.stdout.flush()
#         else:
#             rhyme_pattern = new_rhyme_pattern

#     # print global rhyme pattern
#     if train_rm:
#         print("\nAggregated Rhyme Pattern:")
#         for i in range(len(rhyme_pattern)):
#             if i % 4 == 0:
#                 print("\n  Quatrain", i / 4, ":")
#             print("    Line %02d =" % i),
#             for j in range(len(rhyme_pattern[i])):
#                 print("%.2f" % np.mean(rhyme_pattern[i][j])).rjust(7),
#         print

#     # save vocab information and config
#     if cf.save_model:
#         # vocab
#         cPickle.dump((idxword, idxchar, wordxchar),
#                      open(os.path.join(cf.output_dir, cf.output_prefix, "vocabs.pickle"), "w"))

#         # create a dictionary object for config
#         cf_dict = {}
#         for k, v in vars(cf).items():
#             if not k.startswith("__"):
#                 cf_dict[k] = v
#         cPickle.dump(cf_dict, open(os.path.join(
#             cf.output_dir, cf.output_prefix, "config.pickle"), "w"))

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
Instructions for updating:
This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API

Epoch = 1
BATCH_LEN 7


0it [00:00, ?it/s]

stress_acc [[], [], []]
rhyme_pr {0.9: [[], []], 0.8: [[], []], 0.7: [[], []], 0.6: [[], []]}
rhyme_pattern [[[-2.0], [], [], []], [[], [-2.0], [], []], [[], [], [-2.0], []], [[], [], [], [-2.0]], [[-2.0], [], [], []], [[], [-2.0], [], []], [[], [], [-2.0], []], [[], [], [], [-2.0]], [[-2.0], [], [], []], [[], [-2.0], [], []], [[], [], [-2.0], []], [[], [], [], [-2.0]]]
mixed_batch_types [2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2]
len mixed_batch_types 13
  TRAIN 13/13: lm ppl = 1.0; pm loss = 0.00; rm loss = 0.00; batch/sec = 14.4


13it [00:00, 11817.50it/s]


In [13]:


with tf.Graph().as_default(), tf.Session() as sess:
    
    tf.set_random_seed(cf.seed)

    with tf.variable_scope("model", reuse=None):
        mtrain = SonnetModel(True, cf.batch_size, len(idxword), len(idxchar),
                             charxid[" "], charxid[pad_symbol], cf)
#     with tf.variable_scope("model", reuse=True):
#         mvalid = SonnetModel(False, cf.batch_size, len(idxword), len(idxchar),
#                              charxid[" "], charxid[pad_symbol], cf)
#     with tf.variable_scope("model", reuse=True):
#         mgen = SonnetModel(False, 1, len(idxword), len(
#             idxchar), charxid[" "], charxid[pad_symbol], cf)

    tf.global_variables_initializer().run()
    
    # create saver object to save model
    saver = tf.train.Saver(max_to_keep=0)
    pm_costs = 0.0
    rm_costs = 0.0
      # create batches for language model
    train_word_batch = create_word_batch(train_word_data, cf.batch_size,
                                             cf.doc_lines, cf.bptt_truncate, wordxid[pad_symbol],
                                             wordxid[end_symbol], wordxid[unk_symbol], True)
#         valid_word_batch = create_word_batch(train_word_data, 1,
#                                              cf.doc_lines, cf.bptt_truncate, wordxid[pad_symbol],
#                                              wordxid[end_symbol], wordxid[unk_symbol], False)

#         # create batches for pentameter model
    train_char_batch = create_char_batch(train_char_data, cf.batch_size,
                                             charxid[pad_symbol], mtrain.pentameter, idxchar, True)
#         valid_char_batch = create_char_batch(valid_char_data, cf.batch_size,
#                                              charxid[pad_symbol], mtrain.pentameter, idxchar, False)

#         # create batches for rhyme model
    train_rhyme_batch = create_rhyme_batch(train_rhyme_data, cf.batch_size, charxid[pad_symbol], wordxchar,
                                               cf.rm_neg, True)
#         valid_rhyme_batch = create_rhyme_batch(valid_rhyme_data, cf.batch_size, charxid[pad_symbol], wordxchar,
#                                                cf.rm_neg, False)

    _, _, _, new_rhyme_pattern, _ = run_epoch(sess, train_word_batch, train_char_batch,
                                                  train_rhyme_batch, mtrain, "TRAIN", True)
    #saver.save(sess, "/tmp/deepspear_model.ckpt")

BATCH_LEN 7


0it [00:00, ?it/s]

stress_acc [[], [], []]
rhyme_pr {0.9: [[], []], 0.8: [[], []], 0.7: [[], []], 0.6: [[], []]}
rhyme_pattern [[[-2.0], [], [], []], [[], [-2.0], [], []], [[], [], [-2.0], []], [[], [], [], [-2.0]], [[-2.0], [], [], []], [[], [-2.0], [], []], [[], [], [-2.0], []], [[], [], [], [-2.0]], [[-2.0], [], [], []], [[], [-2.0], [], []], [[], [], [-2.0], []], [[], [], [], [-2.0]]]
mixed_batch_types [1, 2, 2, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1]
len mixed_batch_types 13
  TRAIN 13/13: lm ppl = 1.0; pm loss = 0.00; rm loss = 0.00; batch/sec = 12.1


13it [00:00, 13416.82it/s]


In [14]:
with tf.Graph().as_default(), tf.Session() as sess:
    saver.save(sess, "/tmp/deepspear_model.ckpt")

RuntimeError: The Session graph is empty.  Add operations to the graph before calling run().

In [None]:
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp

# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model1.ckpt", tensor_name='', all_tensors=True)

In [None]:


# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model1.ckpt", tensor_name='', all_tensors=False,all_tensor_names=True)

In [None]:
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp

# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/deepspear_model.ckpt", tensor_name='', all_tensors=False,all_tensor_names=True)




In [None]:
tf.all_variables()

In [None]:
saver.restore(sess,"/tmp/model1.ckpt")

In [None]:
with tf.variable_scope("model", reuse=None):
    mtrain = SonnetModel(True, cf.batch_size, len(idxword), len(idxchar),
                             charxid[" "], charxid[pad_symbol], cf)
#     with tf.variable_scope("model", reuse=True):
#         mvalid = SonnetModel(False, cf.batch_size, len(idxword), len(idxchar),
#                              charxid[" "], charxid[pad_symbol], cf)
#     with tf.variable_scope("model", reuse=True):
#         mgen = SonnetModel(False, 1, len(idxword), len(
#             idxchar), charxid[" "], charxid[pad_symbol], cf)
    tf.global_variables_initializer().run(session=sess)


In [None]:
VARS_DICT = {'selector_network/c2w/var1': var1}
with tf.Session() as sess:
    saver = tf.train.Saver(var_list=VARS_DICT)
    saver.restore(sess, '/tmp/model1.ckpt')
# get the graph


In [None]:
g = tf.get_default_graph()
w1 = g.get_tensor_by_name('some_variable_name as per your definition in the model')
