Skip to content

Commit

Permalink
Merge branch 'nlu_lstm' into akela-nlu_lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
Ghostvv committed Mar 5, 2019
2 parents 17938f2 + f0fe837 commit 2dbd8be
Showing 1 changed file with 32 additions and 15 deletions.
47 changes: 32 additions & 15 deletions rasa_nlu/classifiers/embedding_intent_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,21 +1131,6 @@ def train(self,

self.all_intents_embed_values = self._create_all_intents_embed(self.encoded_all_intents)

if self.gpu_lstm:
# rebuild tf graph for prediction
self.word_embed = self._create_tf_gpu_predict_embed(self.a_in,
self.hidden_layer_sizes['a'],
name='a_and_b' if self.share_embedding else 'a')
shape = tf.shape(self.b_in)
self.b_in = tf.reshape(self.b_in, [-1, shape[-2], self.b_in.shape[-1]])
emb_b = self._create_tf_gpu_predict_embed(self.b_in,
self.hidden_layer_sizes['b'],
name='a_and_b' if self.share_embedding else 'b')
# reshape back
self.intent_embed = tf.reshape(emb_b, [shape[0], shape[1], self.embed_dim])

self.sim_op, _ = self._tf_sim(self.word_embed, self.intent_embed)

self.all_intents_embed_in = tf.placeholder(tf.float32, (None, None, self.embed_dim),
name='all_intents_embed')

Expand Down Expand Up @@ -1328,6 +1313,38 @@ def persist(self, model_dir: Text) -> Dict[Text, Any]:
saver = tf.train.Saver()
saver.save(self.session, checkpoint)

if self.gpu_lstm:
# rebuild tf graph for prediction
self.word_embed = self._create_tf_gpu_predict_embed(self.a_in,
self.hidden_layer_sizes['a'],
name='a_and_b' if self.share_embedding else 'a')
shape = tf.shape(self.b_in)
b_in = tf.reshape(self.b_in, [-1, shape[-2], self.b_in.shape[-1]])
emb_b = self._create_tf_gpu_predict_embed(b_in,
self.hidden_layer_sizes['b'],
name='a_and_b' if self.share_embedding else 'b')
# reshape back
self.intent_embed = tf.reshape(emb_b, [shape[0], shape[1], self.embed_dim])

self.sim_op, _ = self._tf_sim(self.word_embed, self.intent_embed)

self.sim_all, _ = self._tf_sim(self.word_embed, self.all_intents_embed_in)

self.graph.clear_collection('similarity_op')
self.graph.add_to_collection('similarity_op',
self.sim_op)

self.graph.clear_collection('sim_all')
self.graph.add_to_collection('sim_all',
self.sim_all)

self.graph.clear_collection('word_embed')
self.graph.add_to_collection('word_embed',
self.word_embed)
self.graph.clear_collection('intent_embed')
self.graph.add_to_collection('intent_embed',
self.intent_embed)

with io.open(os.path.join(
model_dir,
self.name + "_inv_intent_dict.pkl"), 'wb') as f:
Expand Down

0 comments on commit 2dbd8be

Please sign in to comment.