Skip to content

Commit

Permalink
Merge pull request #33 from eschnou/bugfix/issue-32
Browse files Browse the repository at this point in the history
Fixes #32 - Embeddings were not disable when restoring model
  • Loading branch information
Conchylicultor committed Dec 21, 2016
2 parents 88fee0e + 68f1df6 commit ffb039b
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions chatbot/chatbot.py
Expand Up @@ -185,10 +185,8 @@ def main(self, args=None):
if self.args.test != Chatbot.TestMode.ALL:
self.managePreviousModel(self.sess)

# Initialize embeddings with pre-trained word2vec vectors unless we are opening
# a restored model, in which case the embeddings were saved as part of the
# checkpoint.
if self.args.initEmbeddings and self.globStep == 0:
# Initialize embeddings with pre-trained word2vec vectors
if self.args.initEmbeddings:
print("Loading pre-trained embeddings from GoogleNews-vectors-negative300.bin")
self.loadEmbedding(self.sess)

Expand Down Expand Up @@ -377,6 +375,23 @@ def loadEmbedding(self, sess):
Will modify the embedding weights of the current loaded model
Uses the GoogleNews pre-trained values (path hardcoded)
"""

# Fetch embedding variables from model
with tf.variable_scope("embedding_rnn_seq2seq/RNN/EmbeddingWrapper", reuse=True):
em_in = tf.get_variable("embedding")
with tf.variable_scope("embedding_rnn_seq2seq/embedding_rnn_decoder", reuse=True):
em_out = tf.get_variable("embedding")

# Disable training for embeddings
variables = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables.remove(em_in)
variables.remove(em_out)

# If restoring a model, we can leave here
if self.globStep != 0:
return

# New model, we load the pre-trained word2vec data and initialize embeddings
with open(os.path.join(self.args.rootDir, 'data/word2vec/GoogleNews-vectors-negative300.bin'), "rb", 0) as f:
header = f.readline()
vocab_size, vector_size = map(int, header.split())
Expand All @@ -403,20 +418,10 @@ def loadEmbedding(self, sess):
S[:vector_size, :vector_size] = np.diag(s)
initW = np.dot(U[:, :self.args.embeddingSize], S[:self.args.embeddingSize, :self.args.embeddingSize])

# Initialize input embeddings
with tf.variable_scope("embedding_rnn_seq2seq/RNN/EmbeddingWrapper", reuse=True):
em_in = tf.get_variable("embedding")
sess.run(em_in.assign(initW))
# Initialize input and output embeddings
sess.run(em_in.assign(initW))
sess.run(em_out.assign(initW))

# Initialize output embeddings
with tf.variable_scope("embedding_rnn_seq2seq/embedding_rnn_decoder", reuse=True):
em_out = tf.get_variable("embedding")
sess.run(em_out.assign(initW))

# Disable training for embeddings
variables = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables.remove(em_in)
variables.remove(em_out)

def managePreviousModel(self, sess):
""" Restore or reset the model, depending of the parameters
Expand Down

0 comments on commit ffb039b

Please sign in to comment.