In [1]:
import argparse
import sys
import random
import time
import imp
import os
import pickle as cPickle
import tensorflow as tf
import numpy as np
import gensim.models as g
#from util import *
# from sonnet_model import SonnetModel
from sklearn.metrics import roc_auc_score
from nltk.corpus import cmudict
from tqdm import tqdm

# constants
pad_symbol = "<pad>"
end_symbol = "<eos>"
unk_symbol = "<unk>"
dummy_symbols = [pad_symbol, end_symbol, unk_symbol]


In [2]:
%%writefile config.py
###preprocessing options###
word_minfreq=3

###hyper-parameters###
seed=0
batch_size=32
keep_prob=0.7
epoch_size=1
max_grad_norm=5
#language model
word_embedding_dim=100
word_embedding_model="pretrain_word2vec/dim100/word2vec.bin"
lm_enc_dim=200
lm_dec_dim=600
lm_dec_layer_size=1
lm_attend_dim=25
lm_learning_rate=0.2
#pentameter model
char_embedding_dim=150
pm_enc_dim=50
pm_dec_dim=200
pm_attend_dim=50
pm_learning_rate=0.001
repeat_loss_scale=1.0
cov_loss_scale=1.0
cov_loss_threshold=0.7
sigma=1.00
#rhyme model
rm_dim=100
rm_neg=5 #extra randomly sampled negative examples
rm_delta=0.5
rm_learning_rate=0.001

###sonnet hyper-parameters###
bptt_truncate=2 #number of sonnet lines to truncate bptt
doc_lines=14 #total number of lines for a sonnet

###misc###
verbose=False
save_model=True

###input/output###
output_dir="output"
train_data="datasets/gutenberg/sonnet_train.txt"
valid_data="datasets/gutenberg/sonnet_valid.txt"
test_data="datasets/gutenberg/sonnet_test.txt"
output_prefix="wmin%d_sd%d_bat%d_kp%.1f_eph%d_grd%d_wdim%d_lmedim%d_lmddim%d_lmdlayer%d_lmadim%d_lmlr%.1f_cdim%d_pmedim%d_pmddim%d_pmadim%d_pmlr%.1E_loss%.1f-%.1f-%.1f_sm%.2f_rmdim%d_rmn%d_rmd%.1f_rmlr%.1E_son%d-%d" % \
    (word_minfreq, seed, batch_size, keep_prob, epoch_size, max_grad_norm, word_embedding_dim, lm_enc_dim,
    lm_dec_dim, lm_dec_layer_size, lm_attend_dim, lm_learning_rate,
    char_embedding_dim, pm_enc_dim, pm_dec_dim, pm_attend_dim, pm_learning_rate, repeat_loss_scale,
    cov_loss_scale, cov_loss_threshold, sigma, rm_dim, rm_neg, rm_delta, rm_learning_rate, bptt_truncate, doc_lines)


Overwriting config.py


In [3]:
import config as cf

In [4]:
import codecs
import operator
import numpy as np
import random
import math
import codecs
import sys
from collections import defaultdict


def load_vocab(corpus, word_minfreq, dummy_symbols):
    idxword, idxchar = [], []
    wordxid, charxid = defaultdict(int), defaultdict(int)
    word_freq, char_freq = defaultdict(int), defaultdict(int)
    wordxchar = defaultdict(list)

    def update_dic(symbol, idxvocab, vocabxid):
        if symbol not in vocabxid:
            idxvocab.append(symbol)
            vocabxid[symbol] = len(idxvocab) - 1 

    for line_id, line in enumerate(codecs.open(corpus, "r", "utf-8")):
        for word in line.strip().split():
            word_freq[word] += 1
        for char in line.strip():
            char_freq[char] += 1

    #add in dummy symbols into dictionaries
    for s in dummy_symbols:
        update_dic(s, idxword, wordxid)
        update_dic(s, idxchar, charxid)

    #remove low fequency words/chars
    def collect_vocab(vocab_freq, idxvocab, vocabxid):
        for w, f in sorted(list(vocab_freq.items()), key=operator.itemgetter(1), reverse=True):
            if f < word_minfreq:
                break
            else:
                update_dic(w, idxvocab, vocabxid)

    collect_vocab(word_freq, idxword, wordxid)
    collect_vocab(char_freq, idxchar, charxid)

    #word id to [char ids]
    dummy_symbols_set = set(dummy_symbols)
    for wi, w in enumerate(idxword):
        if w in dummy_symbols:
            wordxchar[wi] = [wi]
        else:
            for c in w:
                wordxchar[wi].append(charxid[c] if c in charxid else charxid[dummy_symbols[2]])

    return idxword, wordxid, idxchar, charxid, wordxchar


def only_symbol(word):
    for c in word:
        if c.isalpha():
            return False

    return True


def remove_punct(string):
    return " ".join("".join([ item for item in string if (item.isalpha() or item == " ") ]).split())


def load_data(corpus, wordxid, idxword, charxid, idxchar, pad_symbol, end_symbol, unk_symbol):
    """
    Translate text data. Note that it works with 14 lines text
    :param corpus:
    :param wordxid:
    :param idxword:
    :param charxid:
    :param idxchar:
    :param pad_symbol:
    :param end_symbol:
    :param unk_symbol:
    :return: word_data, char_data, rhyme_data, nwords, nchars
    word_data: list of word indicies, order of sentences is reversed, order of words is reversed also
    char_data: list of char indicies, order of sentences is reversed, order of words is same (normal)
    rhyme_data: list of special words in chars presentation.
    For every sonnet, since third sentence, divides on gtoups of 4 sentences. For every group gets
    last words of sentence and stores.
    [word on current sentences, other words in group, idx of sentence]
    """
    nwords     = [] #number of words for each line
    nchars     = [] #number of chars for each line
    word_data  = [] #data[doc_id][0][line_id] = list of word ids; data[doc_id][1][line_id] = list of [char_ids]
    char_data  = [] #data[line_id] = list of char ids
    rhyme_data = [] #list of ( target_word, [candidate_words], target_word_line_id ); word is a list of characters

    def word_to_char(word):
        if word in set([pad_symbol, end_symbol, unk_symbol]):
            return [ wordxid[word] ]
        else:
            return [ charxid[item] if item in charxid else charxid[unk_symbol] for item in word ]


    for doc in codecs.open(corpus, "r", "utf-8"):

        word_lines, char_lines = [[], []], []
        last_words = []

         #reverse the order of lines and words as we are generating from end to start
        for line in reversed(doc.strip().split(end_symbol)):

            if len(line.strip()) > 0:

                word_seq = [ wordxid[item] if item in wordxid else wordxid[unk_symbol] \
                    for item in reversed(line.strip().split()) ] + [wordxid[end_symbol]]

                char_seq = [ word_to_char(item) for item in reversed(line.strip().split()) ] + [word_to_char(end_symbol)]

                word_lines[0].append(word_seq)
                word_lines[1].append(char_seq)
                char_lines.append([ charxid[item] if item in charxid else charxid[unk_symbol] \
                    for item in remove_punct(line.strip())])
                nwords.append(len(word_lines[0][-1]))
                nchars.append(len(char_lines[-1]))

                last_words.append(line.strip().split()[-1])

        if len(word_lines[0]) == 14: #14 lines for sonnets

            word_data.append(word_lines)
            char_data.extend(char_lines)

            last_words = last_words[2:] #remove couplets (since they don't always rhyme)

            for wi, w in enumerate(last_words):
                rhyme_data.append((word_to_char(w),
                                   [word_to_char(item)
                                    for item_id, item in enumerate(last_words[(wi // 4) * 4:(wi // 4 + 1) * 4])
                                    if item_id != (wi % 4)], (11 - wi)))

    return word_data, char_data, rhyme_data, nwords, nchars
            

def print_stats(partition, word_data, rhyme_data, nwords, nchars):
    print(partition, "statistics:")
    print("  Number of documents         =", len(word_data))
    print("  Number of rhyme examples    =", len(rhyme_data))
    print("  Total number of word tokens =", sum(nwords))
    print("  Mean/min/max words per line = %.2f/%d/%d" % (np.mean(nwords), min(nwords), max(nwords)))
    print("  Total number of char tokens =", sum(nchars))
    print("  Mean/min/max chars per line = %.2f/%d/%d" % (np.mean(nchars), min(nchars), max(nchars)))


def init_embedding(model, idxword):
    word_emb = []
    for vi, v in enumerate(idxword):
        if v in model:
            word_emb.append(model[v])
        else:
            word_emb.append(np.random.uniform(-0.5/model.vector_size, 0.5/model.vector_size, [model.vector_size,]))
    return np.array(word_emb)


def pad(lst, max_len, pad_symbol):
    if len(lst) > max_len:
        print("\nERROR: padding")
        print("length of list greater than maxlen; list =", lst, "; maxlen =", max_len)
        raise SystemExit
    return lst + [pad_symbol] * (max_len - len(lst))


def get_vowels():
    return set(["a", "e", "i", "o", "u"])


def coverage_mask(char_ids, idxchar):
    vowels = get_vowels()
    return [ float(idxchar[c] in vowels) for c in char_ids ]


def flatten_list(l):
    return [item for sublist in l for item in sublist]


def create_word_batch(data, batch_size, lines_per_doc, nlines_per_batch, pad_symbol, end_symbol, unk_symbol, shuffle_data):
    """
    Creates word batch
    :param data:
    :param batch_size:
    :param lines_per_doc:
    :param nlines_per_batch:
    :param pad_symbol:
    :param end_symbol:
    :param unk_symbol:
    :param shuffle_data:
    :return:
    Generates sentence from couple of sonnet strings with features:
    list (x, y, docs, doc_lens, doc_lines, xchar, xchar_lens, hist, hist_lens)
    x: tokenized sentence with spec symbol at the end (1)
    y: tokenized sentence with spec symbol at the begining (1)
    docs: idx for every sonnet (document)
    doc_lens: length for generated sentence (len of couple sonnet strings)
    xchar: list of translated word -> chars sentence, len of every word`s chars list equals len of word. If word ends,
    list appended by zeroes
    xchar_lens: list of every word lens
    hist: previous couple of strings
    hist_lens: lens of prevous string couple
    """
    docs_per_batch = len(data) // batch_size
    batches = []
    doc_ids = list(range(len(data)))
    if shuffle_data:
        random.shuffle(doc_ids)

    if lines_per_doc % nlines_per_batch != 0:
        print("\nERROR:")
        print(("lines_per_doc (%d) %% nlines_per_batch (%d) must equal 0" % (lines_per_doc, nlines_per_batch)))
        raise SystemExit

    for i in range(docs_per_batch):

        for j in range(lines_per_doc // nlines_per_batch):

            docs       = []
            doc_lens   = []
            doc_lines  = []
            x          = []
            y          = []
            xchar      = []
            xchar_lens = []
            hist       = []
            hist_lens  = []

            for k in range(batch_size):

                d       = doc_ids[i*batch_size+k]
                wordseq = flatten_list(data[d][0][j*nlines_per_batch:(j+1)*nlines_per_batch])
                charseq = flatten_list(data[d][1][j*nlines_per_batch:(j+1)*nlines_per_batch])
                histseq = flatten_list(data[d][0][:j*nlines_per_batch])
    
                x.append([end_symbol] + wordseq[:-1])
                y.append(wordseq)

                docs.append(d)
                doc_lens.append(len(wordseq))
                doc_lines.append(list(range(j*nlines_per_batch, (j+1)*nlines_per_batch)))

                xchar.append([[end_symbol]] + charseq[:-1])
                xchar_lens.append( [1] + [ len(item) for item in charseq[:-1] ])

                hist.append(histseq if len(histseq) > 0 else [unk_symbol])
                hist_lens.append(len(histseq) if len(histseq) > 0 else 1)

            #pad the data
            word_pad_len = max(doc_lens)
            char_pad_len = max(flatten_list(xchar_lens))
            hist_pad_len = max(hist_lens)
            for k in range(batch_size):

                x[k] = pad(x[k], word_pad_len, pad_symbol)
                y[k] = pad(y[k], word_pad_len, pad_symbol)

                xchar_lens[k].extend( [1]*(word_pad_len-len(xchar[k])) ) #add len for pad symbols

                xchar[k] = pad(xchar[k], word_pad_len, [pad_symbol]) #pad the word lengths
                xchar[k] = [pad(item, char_pad_len, pad_symbol) for item in xchar[k]] #pad the characters

                hist[k] = pad(hist[k], hist_pad_len, pad_symbol)

            batches.append((x, y, docs, doc_lens, doc_lines, xchar, xchar_lens, hist, hist_lens))

    return batches


def create_char_batch(data, batch_size, pad_symbol, pentameter, idxchar, shuffle_data):
    """
    Generates char batch
    :param data:
    :param batch_size:
    :param pad_symbol:
    :param pentameter:
    :param idxchar:
    :param shuffle_data:
    :return:
    for every line returns [list of encoded chars, len of line, coverage_mask]
    coverage mask - shows vowels in sentences
    """
    batches = []
    batch_len = len(data) // batch_size
    print("BATCH_LEN", batch_len)
    if shuffle_data:
        random.shuffle(data)

    for i in range(batch_len):
        enc_x    = []
        enc_xlen = []

        for j in range(batch_size):
            enc_x.append(data[i*batch_size+j])
            enc_xlen.append(len(data[i*batch_size+j]))

        xlen_max = max(enc_xlen)
        cov_mask = np.zeros((batch_size, xlen_max)) #coverage mask

        for j in range(batch_size):
            enc_x[j]    = pad(enc_x[j], xlen_max, pad_symbol)
            cov_mask[j] = coverage_mask(enc_x[j], idxchar)
            
        batches.append((enc_x, enc_xlen, cov_mask))

    return batches


def create_rhyme_batch(data, batch_size, pad_symbol, wordxchar, num_neg, shuffle_data):
    """

    :param data:
    :param batch_size:
    :param pad_symbol:
    :param wordxchar:
    :param num_neg:
    :param shuffle_data:
    :return: list of [xc, xclen, xid]
    xc: like rhyme data - but special words extends by random word

    """
    if shuffle_data:
        random.shuffle(data)

    batches = []

    for i in range(len(data) // batch_size):
        x, xid, c  = [], [], []
        xlen, clen = [], []

        for j in range(batch_size):
            x.append(data[i*batch_size+j][0])
            xid.append(data[i*batch_size+j][2])
            xlen.append(len(data[i*batch_size+j][0]))
            for context in data[i*batch_size+j][1]:
                c.append(context)
                clen.append(len(context))

            # From preliminary
            # experiments we found that we can improve
            # the model by introducing additional non-rhyming
            # or negative reference words. Negative reference
            # words are sampled uniform randomly from the vocabulary,
            # and the number of additional negative
            # words is a hyper-parameter.
            for _ in range(num_neg):
                c.append(wordxchar[random.randrange(3, len(wordxchar))])
                clen.append(len(c[-1]))

        #merging target and context words
        # (first batch_size = target words; following batch_size*(3+num_neg) = context words)
        xc        = x + c
        xclen     = xlen + clen
        xclen_max = max(xclen)

        #pad the target words and context words
        for xci, xcv in enumerate(xc):
            xc[xci] = pad(xcv, xclen_max, pad_symbol)

        batches.append((xc, xclen, xid))

    return batches
        

def print_lm_attention(bi, b, attentions, idxword, cf):

    print(("\n", "="*100))
    for ex in range(cf.batch_size)[-1:]:
        xword = [ idxword[item] for item in b[1][ex] ]
        hword = [ idxword[item] for item in b[7][ex] ]
        print(("\nBatch ID =", bi))
        print(("Example =", ex))
        print(("x_word =", " ".join(xword)))
        print(("hist_word=", " ".join(hword)))
        for xi, x in enumerate(xword):
            print(("\nWord =", x))
            print(("\tSum dist =", sum(attentions[ex][xi])))
            attn_dist_sort = np.argsort(-attentions[ex][xi])
            print("\t")
            for hi in attn_dist_sort[:5]:
                print(("[%d]%s:%.3f  " % (hi, hword[hi], attentions[ex][xi][hi])))



def print_pm_attention(b, batch_size, costs, logits, attentions, mius, idxchar):

    print(("\n", "="*100))
    for ex in range(batch_size)[-10:]:
        print(("\nSentence =", ex))
        print(("x =", b[0][ex]))
        print(("x len =", b[1][ex]))
        print(("x char=", "".join(idxchar[item] for item in b[0][ex])))
        print(("losses =", costs[ex]))
        print(("pentameter output =", logits[ex]))
        print(("coverage mask =", b[2][ex]))
        for attni, attn in enumerate(attentions):
            print(("attention at time step", attni, ":"))
            print(("\tmiu_p =", mius[attni][ex] * (b[1][ex] - 1.0)))
            for xid in reversed(np.argsort(attn[ex])):
                if attn[ex][xid] > 0.05:
                    print(("\t%.3f %d %s" % (attn[ex][xid], xid, (idxchar[b[0][ex][xid]]))))


def print_rm_attention(b, batch_size, num_context, attentions, pad_id, idxchar):

    print(("\n", "="*100))
    for exid in range(batch_size)[-10:]:
        print(("\nTarget word =", "".join([idxchar[item] for item in b[0][exid] if item != pad_id])))
        for ci, c in enumerate(b[0][(exid*num_context+batch_size):((exid+1)*num_context+batch_size)]):
            print(("\t", ("%.2f" % attentions[exid][ci]), "=", "".join([idxchar[item] for item in c if item != pad_id])))


def get_word_stress(cmu, word):

    stresses = set([])

    def valid(stress):
        for sti in range(len(stress)-1):
            if abs(int(stress[sti]) - int(stress[sti+1])) != 1:
                return False
        return True

    if word in cmu:
        for res in cmu[word]:
            stress = ""
            for syl in res:
                if syl[-1] == "0":
                    stress += "0"
                elif syl[-1] == "1" or syl[-1] == "2":
                    stress += "1"

            if valid(stress):
                stresses.add(stress)

    return stresses


def update_stress_accs(accs, word_len, score):

    word_buckets   = [4,8,float("inf")]

    for wi, wb in enumerate(word_buckets):
        if word_len <= wb:
            accs[wi].append(score)
            break


def eval_stress(accs, cmu, attns, pentameter, batch, idxchar, charxid, pad_symbol, cf):

    attn_threshold = 0.2

    for ex in range(cf.batch_size):

        chars     = [idxchar[item] for item in batch[ex]]
        space_ids = [-1] + [i for i, ch in enumerate(chars) if ch == " "] + \
            [batch[ex].index(charxid[pad_symbol]) if charxid[pad_symbol] in batch[ex] else len(batch[ex])]

        for spi, sp in enumerate(space_ids[:-1]):

            start  = sp+1
            end    = space_ids[spi+1]
            word   = "".join(chars[start:end])

            gold_stress = get_word_stress(cmu, word)
            sys_stress  = ""

            if len(gold_stress) == 0:
                continue

            for attni, attn in enumerate(attns):

                for ch in range(start, end):

                    if attn[ex][ch] >= attn_threshold: 
                        sys_stress += str(pentameter[attni])
                        break

            update_stress_accs(accs, (end-start), float(sys_stress in gold_stress))


def eval_rhyme(pr, thresholds, cmu, attns, b, idxchar, charxid, pad_symbol, cf, cmu_rhyme=None, cmu_norhyme=None,
    em_vocab=None, em_theta=None):

    def syllable_to_rhyme(syllable):

        stresses = set(["0", "1", "2"])
        r = []
        for s in reversed(syllable):
            if s[-1] in stresses:
                r.append(s[:-1])
                break
            else:
                r.append(s)

        return "-".join(list(reversed(r)))

    def get_rhyme(words):

        rhymes = []
        
        for word in words:
            word_rhymes = set([])
            if word in cmu:
                for res in cmu[word]:
                    word_rhymes.add(syllable_to_rhyme(res))
            rhymes.append(word_rhymes)

        return rhymes

    def rhyme_score(target_rhymes, context_rhymes):
        
        if len(target_rhymes) == 0 or len(context_rhymes) == 0:
            return None
        else:
            for tr in target_rhymes:
                if tr in context_rhymes: 
                    return 1.0
            return 0.0

    def last_syllable(word):

        i = len(word)
        for c in reversed(word):
            i -= 1
            if c in get_vowels():
                break

        return word[i:]

    def em_rhyme_score(x, y):

        if x in em_vocab and y in em_vocab:
            xi = em_vocab.index(x)
            yi = em_vocab.index(y)

            return max(em_theta[xi][yi], em_theta[yi][xi])
    
        return 0.0


    num_c = 3 + cf.rm_neg
    for ex in range(cf.batch_size):
        target  = "".join([idxchar[item] for item in b[0][ex][:b[1][ex]]])
        context = []
        for ci, c in enumerate(b[0][(ex*num_c+cf.batch_size):(ex*num_c+cf.batch_size+3)]):
            context.append("".join([idxchar[item] for item in c[:b[1][ex*num_c+cf.batch_size+ci]]]))

        target_rhyme  = get_rhyme([target])[0]
        context_rhyme = get_rhyme(context)

        for t in thresholds:
            for ci, c in enumerate(context):
                score = rhyme_score(target_rhyme, context_rhyme[ci])
                system_score = attns[ex][ci] if em_vocab == None else em_rhyme_score(target, context[ci])

                #precision
                if system_score >= t:
                    if score != None:
                        pr[t][0].append(score)
                        
                #recall
                if score == 1.0:
                    pr[t][1].append(float(system_score >= t))
            
                if cmu_rhyme != None and cmu_norhyme != None:
                    if score == 1.0:
                        if (target, context[ci]) not in cmu_rhyme and (context[ci], target) not in cmu_rhyme:
                            cmu_rhyme[(target, context[ci])] = system_score
                    elif score == 0.0:
                        if (target, context[ci]) not in cmu_norhyme and (context[ci], target) not in cmu_norhyme:
                            cmu_norhyme[(target, context[ci])] = system_score


def collect_rhyme_pattern(rhyme_pattern, attentions, b, batch_size, num_context, idxchar, pad_id):

    def get_context_line_id(target_line_id, context_id):
        p = target_line_id % 4
        q = 3-context_id
        if q <= p:
            q -= 1
        return q

    #print "\n", "="*100
    for exid in range(batch_size):
        target_line_id = b[2][exid]
        for ci, c in enumerate(b[0][(exid*num_context+batch_size):((exid+1)*num_context+batch_size)][:3]):
            context_line_id = get_context_line_id(target_line_id, ci)
            rhyme_pattern[target_line_id][context_line_id].append(attentions[exid][ci])


def postprocess_sentence(line):
    cleaned_sent = ""
    for w in line.strip().split():
        spacing = " "
        if w.startswith("'") or only_symbol(w):
            spacing = ""
        cleaned_sent += spacing + w
    return cleaned_sent.strip()


In [5]:
import tensorflow as tf
import numpy as np
import math
import time
# from util import *
from collections import Counter
#from rnn_cell import ExtendedMultiRNNCell

class SonnetModel(object):

    def get_last_hidden(self, h, xlen):

        ids = tf.range(tf.shape(xlen)[0])
        gather_ids = tf.concat(axis=1, values=[tf.expand_dims(ids, 1), tf.expand_dims(xlen-1, 1)])
        return tf.gather_nd(h, gather_ids)


    def gated_layer(self, s, h, sdim, hdim):

        update_w = tf.get_variable("update_w", [sdim+hdim, hdim])
        update_b = tf.get_variable("update_b", [hdim], initializer=tf.constant_initializer(1.0))
        reset_w  = tf.get_variable("reset_w", [sdim+hdim, hdim])
        reset_b  = tf.get_variable("reset_b", [hdim], initializer=tf.constant_initializer(1.0))
        c_w      = tf.get_variable("c_w", [sdim+hdim, hdim])
        c_b      = tf.get_variable("c_b", [hdim], initializer=tf.constant_initializer())

        z = tf.sigmoid(tf.matmul(tf.concat(axis=1, values=[s, h]), update_w) + update_b)
        r = tf.sigmoid(tf.matmul(tf.concat(axis=1, values=[s, h]), reset_w) + reset_b)
        c = tf.tanh(tf.matmul(tf.concat(axis=1, values=[s, r*h]), c_w) + c_b)
        
        return (1-z)*h + z*c


    def selective_encoding(self, h, s, hdim):

        h1 = tf.shape(h)[0]
        h2 = tf.shape(h)[1]
        h_ = tf.reshape(h, [-1, hdim])
        s_ = tf.reshape(tf.tile(s, [1, h2]), [-1, hdim])

        attend_w = tf.get_variable("attend_w", [hdim*2, hdim])
        attend_b = tf.get_variable("attend_b", [hdim], initializer=tf.constant_initializer())

        g = tf.sigmoid(tf.matmul(tf.concat(axis=1, values=[h_, s_]), attend_w) + attend_b)

        return tf.reshape(h_* g, [h1, h2, -1])


    def __init__(self, is_training, batch_size, word_type_size, char_type_size, space_id, pad_id, cf):

        self.config = cf

        ###########
        #constants#
        ###########
        self.pentameter = [0,1]*5
        self.pentameter_len = len(self.pentameter)

        ##############
        #placeholders#
        ##############

        #language model placeholders
        self.lm_x    = tf.placeholder(tf.int32, [None, None])
        self.lm_xlen = tf.placeholder(tf.int32, [None])
        self.lm_y    = tf.placeholder(tf.int32, [None, None])
        self.lm_hist = tf.placeholder(tf.int32, [None, None])
        self.lm_hlen = tf.placeholder(tf.int32, [None])

        #pentameter model placeholders
        self.pm_enc_x    = tf.placeholder(tf.int32, [None, None])
        self.pm_enc_xlen = tf.placeholder(tf.int32, [None])
        self.pm_cov_mask = tf.placeholder(tf.float32, [None, None])

        #rhyme model placeholders
        self.rm_num_context = tf.placeholder(tf.int32)

        ##################
        #pentameter model#
        ##################
#         with tf.variable_scope("pentameter_model"):
#             self.init_pm(is_training, batch_size, char_type_size, space_id, pad_id)

        ################
        #language model#
        ################
        with tf.variable_scope("language_model"):
            self.init_lm(is_training, batch_size, word_type_size)

        #############
        #rhyme model#
        #############
#         with tf.variable_scope("rhyme_model"):
#             self.init_rm(is_training, batch_size, char_type_size)

       
    # -- language model network
    def init_lm(self, is_training, batch_size, word_type_size):

        cf = self.config

        #shared word embeddings (used by encoder and decoder)
        self.word_embedding = tf.get_variable("word_embedding", [word_type_size, cf.word_embedding_dim],
            initializer=tf.random_uniform_initializer(-0.05/cf.word_embedding_dim, 0.05/cf.word_embedding_dim))
    
        #########
        #decoder#
        #########

        #define lstm cells
        lm_dec_cell = tf.nn.rnn_cell.LSTMCell(cf.lm_dec_dim, use_peepholes=True, forget_bias=1.0)
        if is_training and cf.keep_prob < 1.0:
            lm_dec_cell = tf.nn.rnn_cell.DropoutWrapper(lm_dec_cell, output_keep_prob=cf.keep_prob)
        self.lm_dec_cell = tf.nn.rnn_cell.MultiRNNCell([lm_dec_cell] * cf.lm_dec_layer_size)

        #initial states
        self.lm_initial_state = self.lm_dec_cell.zero_state(batch_size, tf.float32)
        state = self.lm_initial_state

        #pad symbol vocab ID = 0; create mask = 1.0 where vocab ID > 0 else 0.0
        lm_mask = tf.cast(tf.greater(self.lm_x, tf.zeros(tf.shape(self.lm_x), dtype=tf.int32)), dtype=tf.float32)

        #embedding lookup
        word_inputs = tf.nn.embedding_lookup(self.word_embedding, self.lm_x)
        if is_training and cf.keep_prob < 1.0:
            word_inputs = tf.nn.dropout(word_inputs, cf.keep_prob)

        #process character encodings
        #concat last hidden state of fw RNN with first hidden state of bw RNN
        fw_hidden = self.get_last_hidden(self.char_encodings[0], self.pm_enc_xlen)
        char_inputs = tf.concat(axis=1, values=[fw_hidden, self.char_encodings[1][:,0,:]])
        char_inputs = tf.reshape(char_inputs, [batch_size, -1, cf.pm_enc_dim*2]) #reshape into same dimension as inputs
        
        #concat word and char encodings
        inputs = tf.concat(axis=2, values=[word_inputs, char_inputs])
        #inputs = word_inputs

        #apply mask to zero out pad embeddings
        inputs = inputs * tf.expand_dims(lm_mask, -1)

        #dynamic rnn
        dec_outputs, final_state = tf.nn.dynamic_rnn(self.lm_dec_cell, inputs, sequence_length=self.lm_xlen, \
            dtype=tf.float32, initial_state=self.lm_initial_state)
        self.lm_final_state = final_state

        #########################
        #encoder (history words)#
        #########################

        #embedding lookup
        hist_inputs = tf.nn.embedding_lookup(self.word_embedding, self.lm_hist)
        if is_training and cf.keep_prob < 1.0:
            hist_inputs = tf.nn.dropout(hist_inputs, cf.keep_prob)

        #encoder lstm cell
        lm_enc_cell = tf.nn.rnn_cell.LSTMCell(cf.lm_enc_dim, forget_bias=1.0)
        if is_training and cf.keep_prob < 1.0:
            lm_enc_cell = tf.nn.rnn_cell.DropoutWrapper(lm_enc_cell, output_keep_prob=cf.keep_prob)

        #history word encodings
        hist_outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw=lm_enc_cell, cell_bw=lm_enc_cell,
            inputs=hist_inputs, sequence_length=self.lm_hlen, dtype=tf.float32)

        #full history encoding
        full_encoding = tf.concat(axis=1, values=[hist_outputs[0][:,-1,:], hist_outputs[1][:,0,:]])

        #concat fw and bw hidden states
        hist_outputs = tf.concat(axis=2, values=hist_outputs)

        #selective encoding
        with tf.variable_scope("selective_encoding"):
            hist_outputs = self.selective_encoding(hist_outputs, full_encoding, cf.lm_enc_dim*2)

        #attention (concat)
        with tf.variable_scope("lm_attention"):
            attend_w = tf.get_variable("attend_w", [cf.lm_enc_dim*2+cf.lm_dec_dim, cf.lm_attend_dim])
            attend_b = tf.get_variable("attend_b", [cf.lm_attend_dim], initializer=tf.constant_initializer())
            attend_v = tf.get_variable("attend_v", [cf.lm_attend_dim, 1])

        enc_steps = tf.shape(hist_outputs)[1]
        dec_steps = tf.shape(dec_outputs)[1]

        #prepare encoder and decoder
        hist_outputs_t = tf.tile(hist_outputs, [1, dec_steps, 1])
        dec_outputs_t  = tf.reshape(tf.tile(dec_outputs, [1, 1, enc_steps]),
            [batch_size, -1, cf.lm_dec_dim])

        #compute e
        hist_dec_concat = tf.concat(axis=1, values=[tf.reshape(hist_outputs_t, [-1, cf.lm_enc_dim*2]),
            tf.reshape(dec_outputs_t, [-1, cf.lm_dec_dim])])
        e = tf.matmul(tf.tanh(tf.matmul(hist_dec_concat, attend_w) + attend_b), attend_v)
        e = tf.reshape(e, [-1, enc_steps])

        #mask out pad symbols to compute alpha and weighted sum of history words
        alpha    = tf.reshape(tf.nn.softmax(e), [batch_size, -1, 1])
        context  = tf.reduce_sum(tf.reshape(alpha * hist_outputs_t,
            [batch_size,dec_steps,enc_steps,-1]), 2)

        #save attention weights
        self.lm_attentions = tf.reshape(alpha, [batch_size, dec_steps, enc_steps])

        ##############
        #output layer#
        ##############

        #reshape both into [batch_size*len, hidden_dim]
        dec_outputs = tf.reshape(dec_outputs, [-1, cf.lm_dec_dim])
        context     = tf.reshape(context, [-1, cf.lm_enc_dim*2])
        
        #combine context and decoder hidden state with a gated unit
        with tf.variable_scope("gated_unit"):
            hidden = self.gated_layer(context, dec_outputs, cf.lm_enc_dim*2, cf.lm_dec_dim)
            #hidden = dec_outputs

        #output embeddings
        lm_output_proj = tf.get_variable("lm_output_proj", [cf.word_embedding_dim, cf.lm_dec_dim])
        lm_softmax_b   = tf.get_variable("lm_softmax_b", [word_type_size], initializer=tf.constant_initializer())
        lm_softmax_w   = tf.transpose(tf.tanh(tf.matmul(self.word_embedding, lm_output_proj)))

        #compute logits and cost
        lm_logits     = tf.matmul(hidden, lm_softmax_w) + lm_softmax_b
        lm_crossent   = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=lm_logits, labels=tf.reshape(self.lm_y, [-1]))
        lm_crossent_m = lm_crossent * tf.reshape(lm_mask, [-1])
        self.lm_cost  = tf.reduce_sum(lm_crossent_m) / batch_size

        if not is_training:
            self.lm_probs = tf.nn.softmax(lm_logits)
            return

        #run optimiser and backpropagate (clipped) gradients for lm loss
        lm_tvars = tf.trainable_variables()
        lm_grads, _ = tf.clip_by_global_norm(tf.gradients(self.lm_cost, lm_tvars), cf.max_grad_norm)
        self.lm_train_op = tf.train.AdagradOptimizer(cf.lm_learning_rate).apply_gradients(zip(lm_grads, lm_tvars))


#     # -- pentameter model network
#     def init_pm(self, is_training, batch_size, char_type_size, space_id, pad_id):

#         cf = self.config

#         #character embeddings
#         self.char_embedding = tf.get_variable("char_embedding", [char_type_size, cf.char_embedding_dim],
#             initializer=tf.random_uniform_initializer(-0.05/cf.char_embedding_dim, 0.05/cf.char_embedding_dim))

#         #get bidirectional rnn states of the encoder
#         enc_cell = tf.nn.rnn_cell.LSTMCell(cf.pm_enc_dim)#, forget_bias=1.0)
#         if is_training and cf.keep_prob < 1.0:
#             enc_cell = tf.nn.rnn_cell.DropoutWrapper(enc_cell, output_keep_prob=cf.keep_prob)

#         char_inputs    = tf.nn.embedding_lookup(self.char_embedding, self.pm_enc_x)
#         enc_hiddens, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw=enc_cell, cell_bw=enc_cell, inputs=char_inputs,
#             sequence_length=self.pm_enc_xlen, dtype=tf.float32)

#         #save enc_hiddens
#         self.char_encodings = enc_hiddens

#         #reshape enc_hiddens
#         enc_hiddens  = tf.reshape(tf.concat(axis=2, values=enc_hiddens), [-1, cf.pm_enc_dim*2]) #[batch_size*num_steps, hidden]

#         #get decoder hidden states
#         dec_cell = tf.nn.rnn_cell.LSTMCell(cf.pm_dec_dim)#, forget_bias=1.0)

#         if is_training and cf.keep_prob < 1.0:
#             dec_cell = tf.nn.rnn_cell.DropoutWrapper(dec_cell, output_keep_prob=cf.keep_prob)

#         #compute loss for pentameter
#         self.pm_costs     = self.compute_pm_loss(is_training, batch_size, enc_hiddens, dec_cell, space_id, pad_id)
#         self.pm_mean_cost = tf.reduce_sum(self.pm_costs) / batch_size

#         if not is_training:
#             return

#         #run optimiser and backpropagate (clipped) gradients for pm loss
#         pm_tvars         = tf.trainable_variables()
#         pm_grads, _      = tf.clip_by_global_norm(tf.gradients(self.pm_mean_cost, pm_tvars), cf.max_grad_norm)
#         self.pm_train_op = tf.train.AdamOptimizer(cf.pm_learning_rate).apply_gradients(zip(pm_grads, pm_tvars))


#     # -- rhyme model network
#     def init_rm(self, is_training, batch_size, char_type_size):

#         cf = self.config

#         #get char encodings
#         rnn_cell = tf.nn.rnn_cell.LSTMCell(cf.rm_dim)#, forget_bias=1.0)
#         if is_training and cf.keep_prob < 1.0:
#             rnn_cell = tf.nn.rnn_cell.DropoutWrapper(rnn_cell, output_keep_prob=cf.keep_prob)

#         initial_state = rnn_cell.zero_state(tf.shape(self.pm_enc_x)[0], tf.float32)
#         char_enc, _   = tf.nn.dynamic_rnn(rnn_cell, tf.nn.embedding_lookup(self.char_embedding, self.pm_enc_x),
#             sequence_length=self.pm_enc_xlen, dtype=tf.float32, initial_state=initial_state)

#         #get last hidden states
#         char_enc = self.get_last_hidden(char_enc, self.pm_enc_xlen)
        
#         #slice it into target_words and context words
#         target  = char_enc[:batch_size,:]
#         context = char_enc[batch_size:,:]

#         target_tiled   = tf.reshape(tf.tile(target, [1,self.rm_num_context]), [-1, cf.rm_dim])
#         target_context = tf.concat(axis=1, values=[target_tiled, context])

#         #cosine similarity
#         e = tf.reduce_sum(tf.nn.l2_normalize(target_tiled, 1) * tf.nn.l2_normalize(context, 1), 1)
#         e = tf.reshape(e, [batch_size, -1])

#         #save the attentions
#         self.rm_attentions = e

#         #max margin loss
#         min_cos = tf.nn.top_k(e, 2)[0][:, -1] #second highest cos similarity
#         max_cos = tf.reduce_max(e, 1)
#         self.rm_cost = tf.reduce_mean(tf.maximum(0.0, cf.rm_delta - max_cos + min_cos))

#         if not is_training:
#             return
        
#         self.rm_train_op = tf.train.AdamOptimizer(cf.rm_learning_rate).minimize(self.rm_cost)


#     # -- compute pentameter model loss, given a pentameter input
#     def compute_pm_loss(self, is_training, batch_size, enc_hiddens, dec_cell, space_id, pad_id):

#         cf             = self.config
#         xlen_max       = tf.reduce_max(self.pm_enc_xlen)

#         #use decoder hidden states to select encoder hidden states to predict stress for next time step
#         repeat_loss    = tf.zeros([batch_size])
#         attentions     = tf.zeros([batch_size, xlen_max]) #historical attention weights
#         prev_miu       = tf.zeros([batch_size,1])
#         outputs        = []
#         attention_list = []
#         miu_list       = []

#         #initial inputs (learnable) and state
#         initial_inputs = tf.get_variable("dec_init_input", [cf.pm_enc_dim*2])
#         inputs         = tf.reshape(tf.tile(initial_inputs, [batch_size]), [batch_size, -1])
#         state          = dec_cell.zero_state(batch_size, tf.float32)

#         #manual unroll of time steps because attention depends on previous attention weights
#         with tf.variable_scope("RNN"):
#             for time_step in range(self.pentameter_len):

#                 if time_step > 0:
#                     tf.get_variable_scope().reuse_variables()

#                 def attend(enc_hiddens, dec_hidden, attn_hist, prev_miu):
#                     with tf.variable_scope("pm_attention"):
#                         attend_w = tf.get_variable("attend_w", [cf.pm_enc_dim*2+cf.pm_dec_dim, cf.pm_attend_dim])
#                         attend_b = tf.get_variable("attend_b", [cf.pm_attend_dim], initializer=tf.constant_initializer())
#                         attend_v = tf.get_variable("attend_v", [cf.pm_attend_dim, 1])
#                         miu_w    = tf.get_variable("miu_w", [cf.pm_dec_dim+1, cf.pm_attend_dim])
#                         miu_b    = tf.get_variable("miu_b", [cf.pm_attend_dim], initializer=tf.constant_initializer())
#                         miu_v    = tf.get_variable("miu_v", [cf.pm_attend_dim, 1])
                    
#                     #position attention
#                     miu     = tf.minimum(tf.sigmoid(tf.matmul(tf.tanh(tf.matmul(tf.concat(axis=1,
#                         values=[dec_hidden, prev_miu]), miu_w) + miu_b), miu_v)) + prev_miu, tf.ones([batch_size, 1]))
#                     miu_p   = miu * tf.reshape(tf.cast(self.pm_enc_xlen-1, tf.float32), [-1, 1])
#                     pos     = tf.cast(tf.reshape(tf.tile(tf.range(xlen_max), [batch_size]), [batch_size, -1]),
#                         dtype=tf.float32)
#                     pos_lp  = -(pos - miu_p)**2 / (2 * tf.reshape(tf.tile([tf.square(cf.sigma)], [batch_size]),
#                         [batch_size,-1]))
            
#                     #char encoding attention
#                     pos_weight = tf.reshape(tf.exp(pos_lp), [-1, 1])
#                     inp_concat = tf.concat(axis=1, values=[enc_hiddens * pos_weight,
#                         tf.reshape(tf.tile(dec_hidden, [1,xlen_max]), [-1,cf.pm_dec_dim])])
#                     x       = self.pm_enc_x
#                     e       = tf.matmul(tf.tanh(tf.matmul(inp_concat, attend_w) + attend_b), attend_v)
#                     e       = tf.reshape(e, [batch_size, xlen_max])
#                     mask1   = tf.cast(tf.equal(x, tf.fill(tf.shape(x), pad_id)), dtype=tf.float32)
#                     mask2   = tf.cast(tf.equal(x, tf.fill(tf.shape(x), space_id)), dtype=tf.float32)
#                     e_mask  = tf.maximum(mask1, mask2)
#                     e_mask *= tf.constant(-1e20)

#                     #combine alpha with position probability
#                     alpha   = tf.nn.softmax(e + e_mask + pos_lp)
#                     #alpha   = tf.nn.softmax(e + e_mask)

#                     #weighted sum
#                     c       = tf.reduce_sum(tf.expand_dims(alpha, 2)*tf.reshape(enc_hiddens,
#                         [batch_size, xlen_max, cf.pm_enc_dim*2]), 1)

#                     return c, alpha, miu

#                 dec_hidden, state               = dec_cell(inputs, state)
#                 enc_hiddens_sum, attn, prev_miu = attend(enc_hiddens, dec_hidden, attentions, prev_miu)

#                 repeat_loss += tf.reduce_sum(tf.minimum(attentions, attn), 1)
#                 attentions  += attn
#                 inputs       = enc_hiddens_sum

#                 attention_list.append(attn)
#                 miu_list.append(prev_miu)
#                 outputs.append(enc_hiddens_sum)

#         #reshape output into [batch_size*num_steps,hidden_size]
#         outputs = tf.reshape(tf.concat(axis=1, values=outputs), [-1, cf.pm_enc_dim*2])

#         #compute loss
#         pm_softmax_w = tf.get_variable("pm_softmax_w", [cf.pm_enc_dim*2, 1])
#         pm_softmax_b = tf.get_variable("pm_softmax_b", [1], initializer=tf.constant_initializer())
#         pm_logit     = tf.squeeze(tf.matmul(outputs, pm_softmax_w) + pm_softmax_b)
#         pm_crossent  = tf.nn.sigmoid_cross_entropy_with_logits(logits=pm_logit,
#             labels=tf.tile(tf.cast(self.pentameter, tf.float32), [batch_size]))
#         cov_loss     = tf.reduce_sum(tf.nn.relu(self.pm_cov_mask*cf.cov_loss_threshold - attentions), 1)
#         pm_cost      = tf.reduce_sum(tf.reshape(pm_crossent, [batch_size, -1]), 1) + \
#             cf.repeat_loss_scale*repeat_loss + cf.cov_loss_scale*cov_loss

#         #save some variables
#         self.pm_logits     = tf.sigmoid(tf.reshape(pm_logit, [batch_size, -1]))
#         self.pm_attentions = attention_list
#         self.mius          = miu_list

#         return pm_cost


    # -- sample a word given probability distribution (with option to normalise the distribution with temperature)
    # -- temperature = 0 means argmax
    def sample_word(self, sess, probs, temperature, unk_symbol_id, pad_symbol_id, wordxchar, idxword,
        rm_target_pos, rm_target_neg, rm_threshold_pos, avoid_words):

        def rhyme_pair_to_char(rhyme_pair):
            char_ids, char_id_len = [], []

            for word in rhyme_pair:
                char_ids.append(wordxchar[word])
                char_id_len.append(len(char_ids[-1]))

            #pad char_ids
            for ci, c in enumerate(char_ids):
                char_ids[ci] = pad(c, max(char_id_len), pad_symbol_id)

            return char_ids, char_id_len

        def rhyme_cos(x, y):
            rm_char, rm_char_len = rhyme_pair_to_char([x, y])

            feed_dict = {self.pm_enc_x: rm_char, self.pm_enc_xlen: rm_char_len, self.rm_num_context: 1}
            rm_attns  = sess.run(self.rm_attentions, feed_dict)

            return rm_attns[0][0]

        rm_threshold_neg = 0.7 #non-rhyming words A and B shouldn't have similarity larger than this threshold

        if temperature == 0:
            return np.argmax(probs)

        probs = probs.astype(np.float64) #convert to float64 for higher precision
        probs = np.log(probs) / temperature
        probs = np.exp(probs) / math.fsum(np.exp(probs))

        #avoid unk_symbol_id if possible
        sampled = None
        pw      = idxword[rm_target_pos] if rm_target_pos != None else "None"
        nw      = idxword[rm_target_neg] if rm_target_neg != None else "None"
        for i in range(1000):
            sampled = np.argmax(np.random.multinomial(1, probs, 1))

            #resample if it's a word to avoid
            if sampled in avoid_words:
                continue

            #if it needs to rhyme, resample until we find a rhyming word
            if rm_threshold_pos != 0.0 and rm_target_pos != None and rhyme_cos(sampled, rm_target_pos) < rm_threshold_pos:
                continue

            if rm_target_neg != None and rhyme_cos(sampled, rm_target_neg) > rm_threshold_neg:
                continue

            return sampled

        return None


    # -- generate a sentence by sampling one word at a time
    def sample_sent(self, sess, state, x, hist, hlen, xchar, xchar_len, avoid_symbols, stopwords,temp_min, temp_max,
        unk_symbol_id, pad_symbol_id, end_symbol_id, space_id, idxchar, charxid, idxword, wordxchar,
        rm_target_pos, rm_target_neg, rm_threshold, last_words, max_words):

        def filter_stop_symbol(word_ids):
            cleaned = set([])
            for w in word_ids:
                if w not in (stopwords | set([pad_symbol_id, end_symbol_id])) and not only_symbol(idxword[w]):
                    cleaned.add(w)
            return cleaned

        def get_freq_words(word_ids, freq_threshold):
            words     = []
            word_freq = Counter(word_ids)
            for k, v in word_freq.items():
                #if v >= freq_threshold and not only_symbol(idxword[k]) and k != end_symbol_id:
                if v >= freq_threshold and k != end_symbol_id:
                    words.append(k)
            return set(words)

        sent   = []

        while True:
            probs, state = sess.run([self.lm_probs, self.lm_final_state],
                {self.lm_x: x, self.lm_initial_state: state, self.lm_xlen: [1],
                self.lm_hist: hist, self.lm_hlen: hlen,
                self.pm_enc_x: xchar, self.pm_enc_xlen: xchar_len})

            #avoid words previously generated            
            avoid_words = filter_stop_symbol(sent + hist[0])
            freq_words  = get_freq_words(sent + hist[0], 2) #avoid any words that occur >= N times
            avoid_words = avoid_words | freq_words | set(sent[-3:] + last_words + avoid_symbols + [unk_symbol_id])
            #avoid_words = set(sent[-3:] + last_words + avoid_symbols + [unk_symbol_id])

            word = self.sample_word(sess, probs[0], np.random.uniform(temp_min, temp_max), unk_symbol_id,
                pad_symbol_id, wordxchar, idxword, rm_target_pos, rm_target_neg, rm_threshold, avoid_words)

            if word != None:
                sent.append(word)
                x             = [[ sent[-1] ]]
                xchar         = [wordxchar[sent[-1]]]
                xchar_len     = [len(xchar[0])]
                rm_target_pos = None
                rm_target_neg = None
            else:
                return None, None, None

            if sent[-1] == end_symbol_id or len(sent) >= max_words:

                if len(sent) > 1:
                    pm_loss  = self.eval_pm_loss(sess, sent, end_symbol_id, space_id, idxchar, charxid, idxword, wordxchar)
                    return sent, state, pm_loss
                else:
                    return None, None, None
    

    # -- compute pentameter loss
    def eval_pm_loss(self, sess, sent, end_symbol, space_id, idxchar, charxid, idxword, wordxchar):

        def sent_to_char(words):
            char_ids = []

            for word in reversed(words):
                if word != end_symbol:
                    char_ids.extend(wordxchar[word])
                    char_ids.append(space_id)

            #remove punctuation
            chars = "".join([ idxchar[item] for item in char_ids])
            char_ids = [charxid[item] for item in remove_punct(chars)]

            return char_ids

        #pentameter check
        pm_char  = sent_to_char(sent)
        cov_mask = coverage_mask(pm_char, idxchar)

        #create pseudo batch
        b = ([pm_char], [len(pm_char)], [cov_mask])

        feed_dict = {self.pm_enc_x: b[0], self.pm_enc_xlen: b[1], self.pm_cov_mask: b[2]}
        pm_attns, pm_costs, logits, mius = sess.run([self.pm_attentions, self.pm_costs, self.pm_logits,
            self.mius], feed_dict)

        return pm_costs[0]


    # -- quatrain generation
    def generate(self, sess, idxword, idxchar, charxid, wordxchar, pad_symbol_id, end_symbol_id, unk_symbol_id, space_id,
        avoid_symbols, stopwords, temp_min, temp_max, max_lines, max_words, sent_sample, rm_threshold, verbose=False):

        def reset():
            rhyme_aabb  = ([None, 0, None, 2], [None, None, 1, None])
            rhyme_abab  = ([None, None, 0, 1], [None, 0, None, None])
            rhyme_abba  = ([None, None, 1, 0], [None, 0, None, None])
            rhyme_pttn  = random.choice([rhyme_aabb, rhyme_abab, rhyme_abba])
            state       = sess.run(self.lm_dec_cell.zero_state(1, tf.float32))
            prev_state  = state
            x           = [[end_symbol_id]]
            xchar       = [wordxchar[end_symbol_id]]
            xchar_len   = [1]
            sonnet      = []
            sent_probs  = []
            last_words  = []
            total_words = 0
            total_lines = 0
            
            return state, prev_state, x, xchar, xchar_len, sonnet, sent_probs, last_words, total_words, total_lines, \
                rhyme_pttn[0], rhyme_pttn[1]

        end_symbol = idxword[end_symbol_id]
        sent_temp  = 0.1 #sentence sampling temperature

        state, prev_state, x, xchar, xchar_len, sonnet, sent_probs, last_words, total_words, total_lines, \
            rhyme_pttn_pos, rhyme_pttn_neg = reset()

        #verbose prints during generation
        if verbose:
            sys.stdout.write("  Number of generated lines = 0/4\r")
            sys.stdout.flush()

        while total_words < max_words and total_lines < max_lines:

            #add history context
            if len(sonnet) == 0 or sonnet.count(end_symbol_id) < 1:
                hist = [[unk_symbol_id] + [pad_symbol_id]*5]
            else:
                hist = [sonnet + [pad_symbol_id]*5]
            hlen = [len(hist[0])]

            """
            print "\n\n", "="*80
            print "state =", state[0][1][0][:10]
            print "prev state =", prev_state[0][1][0][:10]
            print "x =", x, idxword[x[0][0]]
            print "hist =", hist, " ".join(idxword[item] for item in hist[0])
            print "sonnet =", sonnet, " ".join(idxword[item] for item in sonnet)
            """

            #get rhyme targets for the 'first' word
            rm_target_pos, rm_target_neg = None, None
            if rhyme_pttn_pos[total_lines] != None:
                rm_target_pos = last_words[rhyme_pttn_pos[total_lines]]
            if rhyme_pttn_neg[total_lines] != None:
                rm_target_neg = last_words[rhyme_pttn_neg[total_lines]]

            #genereate N sentences and sample from them (using softmax(-one_pl) as probability)
            all_sent, all_state, all_pl = [], [], []
            for _ in range(sent_sample):
                one_sent, one_state, one_pl = self.sample_sent(sess, state, x, hist, hlen, xchar, xchar_len,
                    avoid_symbols, stopwords, temp_min, temp_max, unk_symbol_id, pad_symbol_id, end_symbol_id, space_id,
                    idxchar, charxid, idxword, wordxchar, rm_target_pos, rm_target_neg, rm_threshold, last_words, max_words)
                if one_sent != None:
                    all_sent.append(one_sent)
                    all_state.append(one_state)
                    all_pl.append(-one_pl)
                else:
                    all_sent = []
                    break

            #unable to generate sentences; reset whole quatrain
            if len(all_sent) == 0:

                state, prev_state, x, xchar, xchar_len, sonnet, sent_probs, last_words, total_words, total_lines, \
                    rhyme_pttn_pos, rhyme_pttn_neg = reset()

            else:

                #convert pm_loss to probability using softmax
                probs = np.exp(np.array(all_pl)/sent_temp)
                probs = probs.astype(np.float64) #convert to float64 for higher precision
                probs = probs / math.fsum(probs)

                #sample a sentence
                sent_id = np.argmax(np.random.multinomial(1, probs, 1))
                sent    = all_sent[sent_id]
                state   = all_state[sent_id]
                pl      = all_pl[sent_id]

                total_words += len(sent)
                total_lines += 1
                prev_state   = state

                sonnet.extend(sent)
                sent_probs.append(-pl)
                last_words.append(sent[0])

            if verbose:
                sys.stdout.write("  Number of generated lines = %d/4\r" % (total_lines))
                sys.stdout.flush()

        #postprocessing
        sonnet = sonnet[:-1] if sonnet[-1] == end_symbol_id else sonnet
        sonnet = [ postprocess_sentence(item) for item in \
            " ".join(list(reversed([ idxword[item] for item in sonnet ]))).strip().split(end_symbol) ]

        return sonnet, list(reversed(sent_probs))


In [6]:
def remove_punct(string):
    return " ".join("".join([ item for item in string if (item.isalpha() or item == " ") ]).split())


In [7]:
import codecs
import operator
import numpy as np
import random
import math
import codecs
import sys
from collections import defaultdict


def load_vocab(corpus, word_minfreq, dummy_symbols):
    idxword, idxchar = [], []
    wordxid, charxid = defaultdict(int), defaultdict(int)
    word_freq, char_freq = defaultdict(int), defaultdict(int)
    wordxchar = defaultdict(list)

    def update_dic(symbol, idxvocab, vocabxid):
        if symbol not in vocabxid:
            idxvocab.append(symbol)
            vocabxid[symbol] = len(idxvocab) - 1 

    for line_id, line in enumerate(codecs.open(corpus, "r", "utf-8")):
        for word in line.strip().split():
            word_freq[word] += 1
        for char in line.strip():
            char_freq[char] += 1

    #add in dummy symbols into dictionaries
    for s in dummy_symbols:
        update_dic(s, idxword, wordxid)
        update_dic(s, idxchar, charxid)

    #remove low fequency words/chars
    def collect_vocab(vocab_freq, idxvocab, vocabxid):
        for w, f in sorted(list(vocab_freq.items()), key=operator.itemgetter(1), reverse=True):
            if f < word_minfreq:
                break
            else:
                update_dic(w, idxvocab, vocabxid)

    collect_vocab(word_freq, idxword, wordxid)
    collect_vocab(char_freq, idxchar, charxid)

    #word id to [char ids]
    dummy_symbols_set = set(dummy_symbols)
    for wi, w in enumerate(idxword):
        if w in dummy_symbols:
            wordxchar[wi] = [wi]
        else:
            for c in w:
                wordxchar[wi].append(charxid[c] if c in charxid else charxid[dummy_symbols[2]])

    return idxword, wordxid, idxchar, charxid, wordxchar



In [8]:
# load vocab
print("\nFirst pass to collect word and character vocabulary...")
idxword, wordxid, idxchar, charxid, wordxchar = load_vocab(cf.train_data, cf.word_minfreq, dummy_symbols)
print("\nWord type size =", len(idxword))
print("\nChar type size =", len(idxchar))



First pass to collect word and character vocabulary...

Word type size = 8398

Char type size = 77


In [9]:
# del load_data
def load_data(corpus, wordxid, idxword, charxid, idxchar, pad_symbol, end_symbol, unk_symbol):
    nwords     = [] #number of words for each line
    nchars     = [] #number of chars for each line
    word_data  = [] #data[doc_id][0][line_id] = list of word ids; data[doc_id][1][line_id] = list of [char_ids]
    char_data  = [] #data[line_id] = list of char ids
    rhyme_data = [] #list of ( target_word, [candidate_words], target_word_line_id ); word is a list of characters

    def word_to_char(word):
        if word in set([pad_symbol, end_symbol, unk_symbol]):
            return [ wordxid[word] ]
        else:
            return [ charxid[item] if item in charxid else charxid[unk_symbol] for item in word ]


    for doc in codecs.open(corpus, "r", "utf-8"):

        word_lines, char_lines = [[], []], []
        last_words = []

         #reverse the order of lines and words as we are generating from end to start
        for line in reversed(doc.strip().split(end_symbol)):
            print(line)

            if len(line.strip()) > 0:

                word_seq = [ wordxid[item] if item in wordxid else wordxid[unk_symbol] \
                    for item in reversed(line.strip().split()) ] + [wordxid[end_symbol]]

                char_seq = [ word_to_char(item) for item in reversed(line.strip().split()) ] + [word_to_char(end_symbol)]

                word_lines[0].append(word_seq)
                word_lines[1].append(char_seq)
                char_lines.append([ charxid[item] if item in charxid else charxid[unk_symbol] \
                    for item in remove_punct(line.strip())])
                nwords.append(len(word_lines[0][-1]))
                nchars.append(len(char_lines[-1]))

                last_words.append(line.strip().split()[-1])
                
                print('1', last_words)

        if len(word_lines[0]) == 14: #14 lines for sonnets

            word_data.append(word_lines)
            char_data.extend(char_lines)

            last_words = last_words[2:] #remove couplets (since they don't always rhyme)

            for wi, w in enumerate(last_words):
                print('+++', wi, w)
                rhyme_data.append((word_to_char(w),
                                   [word_to_char(item)
                                    for item_id, item in enumerate(last_words[(wi // 4) * 4:(wi // 4 + 1) * 4])
                                    if item_id != (wi % 4)], (11 - wi)))

    return word_data, char_data, rhyme_data, nwords, nchars


In [10]:
def print_stats(partition, word_data, rhyme_data, nwords, nchars):
    print(partition, "statistics:")
    print("  Number of documents         =", len(word_data))
    print("  Number of rhyme examples    =", len(rhyme_data))
    print("  Total number of word tokens =", sum(nwords))
    print("  Mean/min/max words per line = %.2f/%d/%d" % (np.mean(nwords), min(nwords), max(nwords)))
    print("  Total number of char tokens =", sum(nchars))
    print("  Mean/min/max chars per line = %.2f/%d/%d" % (np.mean(nchars), min(nchars), max(nchars)))



In [11]:
global wordxid, idxword, charxid, idxchar, \
wordxchar, train_lm, train_pm, train_rm

random.seed(cf.seed)
np.random.seed(cf.seed)

if cf.word_embedding_model:
    print("\nLoading word embedding model...")
    mword = g.Word2Vec.load(cf.word_embedding_model)
    cf.word_embedding_dim = mword.vector_size

# load vocab
print("\nFirst pass to collect word and character vocabulary...")
idxword, wordxid, idxchar, charxid, wordxchar = load_vocab(cf.train_data, cf.word_minfreq, dummy_symbols)
print("\nWord type size =", len(idxword))
print("\nChar type size =", len(idxchar))


# load train and valid data
print("\nLoading train and valid data...")

train_word_data, train_char_data, train_rhyme_data, train_nwords, train_nchars = \
    load_data(cf.train_data, wordxid, idxword, charxid, idxchar, pad_symbol, end_symbol, unk_symbol)
valid_word_data, valid_char_data, valid_rhyme_data, valid_nwords, valid_nchars = \
    load_data(cf.valid_data, wordxid, idxword, charxid, idxchar, pad_symbol, end_symbol, unk_symbol)
print_stats("\nTrain", train_word_data, train_rhyme_data, train_nwords, train_nchars)
print_stats("\nValid", valid_word_data, valid_rhyme_data, valid_nwords, valid_nchars)

# load test data if it's given
if cf.test_data:
    test_word_data, test_char_data, test_rhyme_data, test_nwords, test_nchars = \
        load_data(cf.test_data, wordxid, idxword, charxid, idxchar, pad_symbol, end_symbol, unk_symbol)
    print_stats("\nTest", test_word_data, test_rhyme_data, test_nwords, test_nchars)    
    


Loading word embedding model...

First pass to collect word and character vocabulary...

Word type size = 8398

Char type size = 77

Loading train and valid data...

 a curse now dumb upon the lips of night 
1 ['night']
 uttered by death in hate of heaven and light 
1 ['night', 'light']
 hell yawns on him whose life was as a word 
1 ['night', 'light', 'word']
 take him , for he is yours , o night and death 
1 ['night', 'light', 'word', 'death']
 that seemed so long but weak and wasted breath 
1 ['night', 'light', 'word', 'death', 'breath']
 the bond is cancelled , and the prayer is heard 
1 ['night', 'light', 'word', 'death', 'breath', 'heard']
 in all the days whereof his eye takes ken 
1 ['night', 'light', 'word', 'death', 'breath', 'heard', 'ken']
 that shall behold not so abhorred an one 
1 ['night', 'light', 'word', 'death', 'breath', 'heard', 'ken', 'one']
 nor this face look upon the living sun 
1 ['night', 'light', 'word', 'death', 'breath', 'heard', 'ken', 'one', 'sun']
 shal

In [12]:
def run_epoch(sess, word_batches, char_batches, rhyme_batches, model, pname, is_training):
    start_time = time.time()

    # lm variables
    lm_costs = 0.0
    total_words = 0
    zero_state = sess.run(model.lm_initial_state)
    model_state = None
    prev_doc = -1
    lm_train_op = model.lm_train_op if is_training else tf.no_op()

    # pm variables
    pm_costs = 0.0
    pm_train_op = model.pm_train_op if is_training else tf.no_op()

    # rm variables
    rm_costs = 0.0
    rm_train_op = model.rm_train_op if is_training else tf.no_op()

    # mix lm and pm batches
    mixed_batch_types = [0] * len(word_batches) + [1] * len(char_batches) + [2] * len(rhyme_batches)
    random.shuffle(mixed_batch_types)
    mixed_batches = [word_batches, char_batches, rhyme_batches]

    word_batch_id = 0
    char_batch_id = 0
    rhyme_batch_id = 0

    # cmu pronounciation dictionary for stress and rhyme evaluation
    cmu = cmudict.dict()

    # stress prediction
    stress_acc = [[], [], []]  # buckets for char length: [1-4], [5-8], [9-inf]

    # rhyme predition
    rhyme_pr = {}  # precision/recall for each rhyme threshold
    for rt in rhyme_thresholds:
        rhyme_pr[rt] = [[], []]

    # rhyme pattern
    rhyme_pattern = []  # collection of cosine similarities
    for i in range(12):
        rhyme_pattern.append([])
        for j in range(4):
            if i % 4 == j:
                rhyme_pattern[i].append([-2.0])
            else:
                rhyme_pattern[i].append([])
                
    print('stress_acc', stress_acc)
    print('rhyme_pr', rhyme_pr)
    print('rhyme_pattern', rhyme_pattern)
    
    print('mixed_batch_types', mixed_batch_types)
    print('len mixed_batch_types', len(mixed_batch_types))

    for bi, batch_type in tqdm(enumerate(mixed_batch_types)):

        if batch_type == 0 and train_lm:

            b = mixed_batches[batch_type][word_batch_id]

            # reset model state if it's a different set of documents
            if prev_doc != b[2][0]:
                model_state = zero_state
                prev_doc = b[2][0]

            # preprocess character input to [batch_size*doc_len, char_len]
            pm_enc_x = np.array(b[5]).reshape((cf.batch_size * max(b[3]), -1))

            feed_dict = {model.lm_x: b[0], 
                         model.lm_y: b[1], 
                         model.lm_xlen: b[3], 
                         model.pm_enc_x: pm_enc_x,
                         model.pm_enc_xlen: np.array(b[6]).reshape((-1)), 
                         model.lm_initial_state: model_state,
                         model.lm_hist: b[7], 
                         model.lm_hlen: b[8]}

            cost, model_state, attns, _ = sess.run([model.lm_cost, 
                                                    model.lm_final_state, 
                                                    model.lm_attentions, 
                                                    lm_train_op],
                                                   feed_dict)

            lm_costs += cost * cf.batch_size  # keep track of full cost
            total_words += sum(b[3])

            word_batch_id += 1

#         elif batch_type == 1 and train_pm:

#             b = mixed_batches[batch_type][char_batch_id]

#             feed_dict = {model.pm_enc_x: b[0], model.pm_enc_xlen: b[1], model.pm_cov_mask: b[2]}
#             cost, attns, _, = sess.run([model.pm_mean_cost, model.pm_attentions, pm_train_op], feed_dict)
#             pm_costs += cost

#             char_batch_id += 1

#             if not is_training:
#                 eval_stress(stress_acc, cmu, attns, model.pentameter, b[0], idxchar, charxid, pad_symbol, cf)

#         elif batch_type == 2 and train_rm:

#             b = mixed_batches[batch_type][rhyme_batch_id]
#             num_c = 3 + cf.rm_neg

#             feed_dict = {model.pm_enc_x: b[0], model.pm_enc_xlen: b[1], model.rm_num_context: num_c}
#             cost, attns, _ = sess.run([model.rm_cost, model.rm_attentions, rm_train_op], feed_dict)
#             rm_costs += cost

#             rhyme_batch_id += 1

#             # if rhyme_batch_id < 10 and not is_training:
#             #    print_rm_attention(b, cf.batch_size, num_c, attns, charxid[pad_symbol], idxchar)

#             if not is_training:
#                 eval_rhyme(rhyme_pr, rhyme_thresholds, cmu, attns, b, idxchar, charxid, pad_symbol, cf)
#             else:
#                 collect_rhyme_pattern(rhyme_pattern, attns, b, cf.batch_size, num_c, idxchar, charxid[pad_symbol])

        if (((bi % 10) == 0) and cf.verbose) or (bi == len(mixed_batch_types) - 1):

            partition = "  " + pname
            sent_end = "\n" if bi == (len(mixed_batch_types) - 1) else "\r"
            speed = (bi + 1) / (time.time() - start_time)

            sys.stdout.write("%s %d/%d: lm ppl = %.1f; pm loss = %.2f; rm loss = %.2f; batch/sec = %.1f%s" % \
                             (partition, bi + 1, len(mixed_batch_types), np.exp(lm_costs / max(total_words, 1)),
                              pm_costs / max(char_batch_id, 1), rm_costs / max(rhyme_batch_id, 1), speed, sent_end))
            sys.stdout.flush()

            if not is_training and (bi == len(mixed_batch_types) - 1):

                if train_pm:
                    all_acc = [item for sublist in stress_acc for item in sublist]
                    stress_acc.append(all_acc)
                    for acci, acc in enumerate(stress_acc):
                        sys.stdout.write("    Stress acc [%d]   = %.3f (%d)\n" % (acci, np.mean(acc), len(acc)))

                if train_rm:
                    for t in rhyme_thresholds:
                        p = np.mean(rhyme_pr[t][0])
                        r = np.mean(rhyme_pr[t][1])
                        f = 2 * p * r / (p + r) if (p != 0.0 and r != 0.0) else 0.0
                        sys.stdout.write("    Rhyme P/R/F@%.1f  = %.3f / %.3f / %.3f\n" % (t, p, r, f))

                sys.stdout.flush()

    # return avg batch loss for lm, pm and rm
    return lm_costs / max(word_batch_id, 1), pm_costs / max(char_batch_id, 1), rm_costs / max(rhyme_batch_id, 1), \
           rhyme_pattern, (np.mean(stress_acc[-1]) if not is_training else 0.0)


In [13]:






        
    

def flatten_list(l):
    return [item for sublist in l for item in sublist]

def pad(lst, max_len, pad_symbol):
    if len(lst) > max_len:
        print("\nERROR: padding")
        print("length of list greater than maxlen; list =", lst, "; maxlen =", max_len)
        raise SystemExit
    return lst + [pad_symbol] * (max_len - len(lst))

def coverage_mask(char_ids, idxchar):
    vowels = get_vowels()
    return [ float(idxchar[c] in vowels) for c in char_ids ]

def get_vowels():
    return set(["a", "e", "i", "o", "u"])



def collect_rhyme_pattern(rhyme_pattern, attentions, b, batch_size, num_context, idxchar, pad_id):

    def get_context_line_id(target_line_id, context_id):
        p = target_line_id % 4
        q = 3-context_id
        if q <= p:
            q -= 1
        return q

    #print "\n", "="*100
    for exid in range(batch_size):
        target_line_id = b[2][exid]
        for ci, c in enumerate(b[0][(exid*num_context+batch_size):((exid+1)*num_context+batch_size)][:3]):
            context_line_id = get_context_line_id(target_line_id, ci)
            rhyme_pattern[target_line_id][context_line_id].append(attentions[exid][ci])




In [14]:
# save model

        #         if cf.save_model:
        #             if not os.path.exists(os.path.join(cf.output_dir, cf.output_prefix)):
        #                 os.makedirs(os.path.join(cf.output_dir, cf.output_prefix))
        #             # create saver object to save model
        #             saver = tf.train.Saver(max_to_keep=0)

        # train model
        
        
# training parameters
train_lm = True
train_pm = True
train_rm = True

with tf.Graph().as_default(), tf.Session() as sess:

    tf.set_random_seed(cf.seed)

    with tf.variable_scope("model", reuse=None):
        mtrain = SonnetModel(True, cf.batch_size, len(idxword), len(idxchar),
                             charxid[" "], charxid[pad_symbol], cf)
#     with tf.variable_scope("model", reuse=True):
#         mvalid = SonnetModel(False, cf.batch_size, len(idxword), len(idxchar),
#                              charxid[" "], charxid[pad_symbol], cf)
#     with tf.variable_scope("model", reuse=True):
#         mgen = SonnetModel(False, 1, len(idxword), len(
#             idxchar), charxid[" "], charxid[pad_symbol], cf)

    tf.global_variables_initializer().run()

    # initialise word embedding
    if cf.word_embedding_model:
        word_emb = init_embedding(mword, idxword)
        sess.run(mtrain.word_embedding.assign(word_emb))
    # train model
    prev_lm_loss, prev_pm_loss, prev_rm_loss, rhyme_pattern = None, None, None, None
    for i in range(1):

        print("\nEpoch =", i + 1)

        # create batches for language model
        train_word_batch = create_word_batch(train_word_data, cf.batch_size,
                                             cf.doc_lines, cf.bptt_truncate, wordxid[pad_symbol],
                                             wordxid[end_symbol], wordxid[unk_symbol], True)
#         valid_word_batch = create_word_batch(train_word_data, 1,
#                                              cf.doc_lines, cf.bptt_truncate, wordxid[pad_symbol],
#                                              wordxid[end_symbol], wordxid[unk_symbol], False)

#         # create batches for pentameter model
        train_char_batch = create_char_batch(train_char_data, cf.batch_size,
                                             charxid[pad_symbol], mtrain.pentameter, idxchar, True)
#         valid_char_batch = create_char_batch(valid_char_data, cf.batch_size,
#                                              charxid[pad_symbol], mtrain.pentameter, idxchar, False)

#         # create batches for rhyme model
        train_rhyme_batch = create_rhyme_batch(train_rhyme_data, cf.batch_size, charxid[pad_symbol], wordxchar,
                                               cf.rm_neg, True)
#         valid_rhyme_batch = create_rhyme_batch(valid_rhyme_data, cf.batch_size, charxid[pad_symbol], wordxchar,
#                                                cf.rm_neg, False)

        # train an epoch
        _, _, _, new_rhyme_pattern, _ = run_epoch(sess, train_word_batch, train_char_batch,
                                                  train_rhyme_batch, mtrain, "TRAIN", True)
#         lm_loss, pm_loss, rm_loss, _, sacc = run_epoch(sess, valid_word_batch, valid_char_batch,
#                                                        valid_rhyme_batch, mvalid, "VALID", False)

#         # create batch for test model and run an epoch if it's given
#         if cf.test_data:
#             test_word_batch = create_word_batch(test_word_data, cf.batch_size,
#                                                 cf.doc_lines, cf.bptt_truncate, wordxid[pad_symbol],
#                                                 wordxid[end_symbol], wordxid[unk_symbol], False)
#             test_char_batch = create_char_batch(test_char_data, cf.batch_size,
#                                                 charxid[pad_symbol], mtrain.pentameter, idxchar, False)
#             test_rhyme_batch = create_rhyme_batch(test_rhyme_data, cf.batch_size,
#                                                   charxid[pad_symbol], wordxchar, cf.rm_neg, False)
#             run_epoch(sess, test_word_batch, test_char_batch,
#                       test_rhyme_batch, mvalid, "TEST", False)

#         # if pm performance is really poor, re-initialize network weights
#         if train_pm and sacc < stress_acc_threshold and prev_pm_loss == None:
#             print(
#                 "\n  Valid stress accuracy performance is very poor; re-initializing network with random weights...")
#             tf.global_variables_initializer().run()
#             continue

#         # save model
#         if cf.save_model:
#             if prev_lm_loss == None or prev_pm_loss == None or prev_rm_loss == None or \
#                     ((lm_loss <= prev_lm_loss or not train_lm) and
#                      (pm_loss <= prev_pm_loss * reset_scale or not train_pm) and
#                      (rm_loss <= prev_rm_loss * reset_scale or not train_rm)):
#                 saver.save(sess, os.path.join(
#                     cf.output_dir, cf.output_prefix, "model.ckpt"))
#                 prev_lm_loss, prev_pm_loss, prev_rm_loss = lm_loss, pm_loss, rm_loss
#                 rhyme_pattern = new_rhyme_pattern
#             else:
#                 saver.restore(sess, os.path.join(
#                     cf.output_dir, cf.output_prefix, "model.ckpt"))
#                 print(
#                     "New valid performance is worse; restoring previous parameters...")
#                 print("  lm loss: %.5f --> %.5f" % (prev_lm_loss, lm_loss))
#                 print("  pm loss: %.5f --> %.5f" % (prev_pm_loss, pm_loss))
#                 print("  rm loss: %.5f --> %.5f" % (prev_rm_loss, rm_loss))
#                 sys.stdout.flush()
#         else:
#             rhyme_pattern = new_rhyme_pattern

#     # print global rhyme pattern
#     if train_rm:
#         print("\nAggregated Rhyme Pattern:")
#         for i in range(len(rhyme_pattern)):
#             if i % 4 == 0:
#                 print("\n  Quatrain", i / 4, ":")
#             print("    Line %02d =" % i),
#             for j in range(len(rhyme_pattern[i])):
#                 print("%.2f" % np.mean(rhyme_pattern[i][j])).rjust(7),
#         print

#     # save vocab information and config
#     if cf.save_model:
#         # vocab
#         cPickle.dump((idxword, idxchar, wordxchar),
#                      open(os.path.join(cf.output_dir, cf.output_prefix, "vocabs.pickle"), "w"))

#         # create a dictionary object for config
#         cf_dict = {}
#         for k, v in vars(cf).items():
#             if not k.startswith("__"):
#                 cf_dict[k] = v
#         cPickle.dump(cf_dict, open(os.path.join(
#             cf.output_dir, cf.output_prefix, "config.pickle"), "w"))

AttributeError: 'SonnetModel' object has no attribute 'init_pm'