From 874be4cd43131eb663e27cfd73c6421760ea0496 Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Tue, 16 Jun 2020 15:33:57 +0200 Subject: [PATCH 1/9] add confidence values for entities --- rasa/nlu/classifiers/diet_classifier.py | 22 ++- rasa/utils/tensorflow/crf.py | 181 ++++++++++++++++++++++++ rasa/utils/tensorflow/layers.py | 13 +- 3 files changed, 207 insertions(+), 9 deletions(-) create mode 100644 rasa/utils/tensorflow/crf.py diff --git a/rasa/nlu/classifiers/diet_classifier.py b/rasa/nlu/classifiers/diet_classifier.py index 20484bf4ea9b..4ef4df266436 100644 --- a/rasa/nlu/classifiers/diet_classifier.py +++ b/rasa/nlu/classifiers/diet_classifier.py @@ -798,10 +798,13 @@ def _predict_entities( if predict_out is None: return [] - predicted_tags = self._entity_label_to_tags(predict_out) + predicted_tags, confidence_values = self._entity_label_to_tags(predict_out) entities = self.convert_predictions_into_entities( - message.text, message.get(TOKENS_NAMES[TEXT], []), predicted_tags + message.text, + message.get(TOKENS_NAMES[TEXT], []), + predicted_tags, + confidence_values, ) entities = self.add_extractor_name(entities) @@ -811,11 +814,14 @@ def _predict_entities( def _entity_label_to_tags( self, predict_out: Dict[Text, Any] - ) -> Dict[Text, List[Text]]: + ) -> Tuple[Dict[Text, List[Text]], Dict[Text, List[float]]]: predicted_tags = {} + confidence_values = {} for tag_spec in self._entity_tag_specs: predictions = predict_out[f"e_{tag_spec.tag_name}_ids"].numpy() + confidences = predict_out[f"e_{tag_spec.tag_name}_scores"].numpy() + confidences = [float(c) for c in confidences[0]] tags = [tag_spec.ids_to_tags[p] for p in predictions[0]] if self.component_config[BILOU_FLAG]: @@ -823,8 +829,9 @@ def _entity_label_to_tags( tags = bilou_utils.remove_bilou_prefixes(tags) predicted_tags[tag_spec.tag_name] = tags + confidence_values[tag_spec.tag_name] = confidences - return predicted_tags + return predicted_tags, confidence_values def process(self, message: Message, **kwargs: Any) -> None: """Return the most likely label and its similarity to the input.""" @@ -1479,7 +1486,7 @@ def _calculate_entity_loss( logits = self._tf_layers[f"embed.{tag_name}.logits"](inputs) # should call first to build weights - pred_ids = self._tf_layers[f"crf.{tag_name}"](logits, sequence_lengths) + pred_ids, _ = self._tf_layers[f"crf.{tag_name}"](logits, sequence_lengths) # pytype cannot infer that 'self._tf_layers["crf"]' has the method '.loss' # pytype: disable=attribute-error loss = self._tf_layers[f"crf.{tag_name}"].loss( @@ -1671,9 +1678,12 @@ def _batch_predict_entities( _input = tf.concat([_input, _tags], axis=-1) _logits = self._tf_layers[f"embed.{name}.logits"](_input) - pred_ids = self._tf_layers[f"crf.{name}"](_logits, sequence_lengths - 1) + pred_ids, confidences = self._tf_layers[f"crf.{name}"]( + _logits, sequence_lengths - 1 + ) predictions[f"e_{name}_ids"] = pred_ids + predictions[f"e_{name}_scores"] = confidences if name == ENTITY_ATTRIBUTE_TYPE: # use the entity tags as additional input for the role diff --git a/rasa/utils/tensorflow/crf.py b/rasa/utils/tensorflow/crf.py new file mode 100644 index 000000000000..7ddbb68e0aaf --- /dev/null +++ b/rasa/utils/tensorflow/crf.py @@ -0,0 +1,181 @@ +import numpy as np +import tensorflow as tf + +from tensorflow_addons.utils.types import TensorLike +from typeguard import typechecked +from typing import Optional + + +class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell): + """Computes the forward decoding in a linear-chain CRF.""" + + @typechecked + def __init__(self, transition_params: TensorLike, **kwargs): + """Initialize the CrfDecodeForwardRnnCell. + + Args: + transition_params: A [num_tags, num_tags] matrix of binary + potentials. This matrix is expanded into a + [1, num_tags, num_tags] in preparation for the broadcast + summation occurring within the cell. + """ + super().__init__(**kwargs) + self._transition_params = tf.expand_dims(transition_params, 0) + self._num_tags = transition_params.shape[0] + + @property + def state_size(self): + return self._num_tags + + @property + def output_size(self): + return self._num_tags + + def build(self, input_shape): + super().build(input_shape) + + def call(self, inputs, state): + """Build the CrfDecodeForwardRnnCell. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous step's + score values. + + Returns: + backpointers: A [batch_size, num_tags] matrix of backpointers. + new_state: A [batch_size, num_tags] matrix of new score values. + """ + state = tf.expand_dims(state[0], 2) + transition_scores = state + self._transition_params + new_state = inputs + tf.reduce_max(transition_scores, [1]) + backpointers = tf.argmax(transition_scores, 1) + backpointers = tf.cast(backpointers, tf.float32) + scores = tf.reduce_max(transition_scores, [1]) + return tf.concat([backpointers, scores], axis=1), new_state + + +def crf_decode_forward( + inputs: TensorLike, + state: TensorLike, + transition_params: TensorLike, + sequence_lengths: TensorLike, +) -> tf.Tensor: + """Computes forward decoding in a linear-chain CRF. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous step's + score values. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + sequence_lengths: A [batch_size] vector of true sequence lengths. + + Returns: + backpointers: A [batch_size, num_tags] matrix of backpointers. + new_state: A [batch_size, num_tags] matrix of new score values. + """ + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + crf_fwd_layer = tf.keras.layers.RNN( + crf_fwd_cell, return_sequences=True, return_state=True + ) + return crf_fwd_layer(inputs, state, mask=mask) + + +def crf_decode_backward(inputs: TensorLike, state: TensorLike) -> tf.Tensor: + """Computes backward decoding in a linear-chain CRF. + + Args: + inputs: A [batch_size, num_tags] matrix of + backpointer of next step (in time order). + state: A [batch_size, 1] matrix of tag index of next step. + + Returns: + new_tags: A [batch_size, num_tags] + tensor containing the new tag indices. + """ + inputs = tf.transpose(inputs, [1, 0, 2]) + + def _scan_fn(state, inputs): + state = tf.squeeze(state, axis=[1]) + idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1) + new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1) + return new_tags + + return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) + + +def crf_decode( + potentials: TensorLike, transition_params: TensorLike, sequence_length: TensorLike +) -> tf.Tensor: + """Decode the highest scoring sequence of tags. + + Args: + potentials: A [batch_size, max_seq_len, num_tags] tensor of + unary potentials. + transition_params: A [num_tags, num_tags] matrix of + binary potentials. + sequence_length: A [batch_size] vector of true sequence lengths. + + Returns: + decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. + Contains the highest scoring tag indices. + scores: A [batch_size, max_seq_len] vector, containing the score of `decode_tags`. + """ + sequence_length = tf.cast(sequence_length, dtype=tf.int32) + + # If max_seq_len is 1, we skip the algorithm and simply return the + # argmax tag and the max activation. + def _single_seq_fn(): + decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32) + best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1]) + return decode_tags, best_score + + def _multi_seq_fn(): + # Computes forward decoding. Get last score and backpointers. + initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1]) + initial_state = tf.squeeze(initial_state, axis=[1]) + inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) + + sequence_length_less_one = tf.maximum( + tf.constant(0, dtype=tf.int32), sequence_length - 1 + ) + + backpointers, last_score = crf_decode_forward( + inputs, initial_state, transition_params, sequence_length_less_one + ) + + backpointers, scores = tf.split(backpointers, 2, axis=2) + + scores = tf.reduce_max(scores, axis=[2]) + initial_score = tf.reduce_max(last_score, axis=[1]) + initial_score = tf.expand_dims(initial_score, axis=1) + scores = tf.concat([scores, initial_score], axis=1) + + backpointers = tf.cast(backpointers, tf.int32) + backpointers = tf.reverse_sequence( + backpointers, sequence_length_less_one, seq_axis=1 + ) + + initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) + initial_state = tf.expand_dims(initial_state, axis=-1) + + decode_tags = crf_decode_backward(backpointers, initial_state) + decode_tags = tf.squeeze(decode_tags, axis=[2]) + decode_tags = tf.concat([initial_state, decode_tags], axis=1) + decode_tags = tf.reverse_sequence(decode_tags, sequence_length, seq_axis=1) + + return decode_tags, scores + + if potentials.shape[1] is not None: + # shape is statically know, so we just execute + # the appropriate code path + if potentials.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() + else: + return tf.cond( + tf.equal(tf.shape(potentials)[1], 1), _single_seq_fn, _multi_seq_fn + ) diff --git a/rasa/utils/tensorflow/layers.py b/rasa/utils/tensorflow/layers.py index e8c1a171bab6..74665b7e116e 100644 --- a/rasa/utils/tensorflow/layers.py +++ b/rasa/utils/tensorflow/layers.py @@ -2,6 +2,7 @@ from typing import List, Optional, Text, Tuple, Callable, Union, Any import tensorflow as tf import tensorflow_addons as tfa +import rasa.utils.tensorflow.crf from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras import backend as K from rasa.utils.tensorflow.constants import SOFTMAX, MARGIN, COSINE, INNER @@ -460,7 +461,9 @@ def build(self, input_shape: tf.TensorShape) -> None: self.built = True # noinspection PyMethodOverriding - def call(self, logits: tf.Tensor, sequence_lengths: tf.Tensor) -> tf.Tensor: + def call( + self, logits: tf.Tensor, sequence_lengths: tf.Tensor + ) -> Tuple[tf.Tensor, tf.Tensor]: """Decodes the highest scoring sequence of tags. Arguments: @@ -471,8 +474,10 @@ def call(self, logits: tf.Tensor, sequence_lengths: tf.Tensor) -> tf.Tensor: Returns: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. Contains the highest scoring tag indices. + A [batch_size, max_seq_len] matrix, with dtype `tf.float32`. + Contains the confidence values of the highest scoring tag indices. """ - pred_ids, _ = tfa.text.crf.crf_decode( + pred_ids, scores = rasa.utils.tensorflow.crf.crf_decode( logits, self.transition_params, sequence_lengths ) # set prediction index for padding to `0` @@ -480,7 +485,9 @@ def call(self, logits: tf.Tensor, sequence_lengths: tf.Tensor) -> tf.Tensor: sequence_lengths, maxlen=tf.shape(pred_ids)[1], dtype=pred_ids.dtype ) - return pred_ids * mask + confidence_values = tf.nn.softmax(scores * tf.cast(mask, tf.float32)) + + return pred_ids * mask, confidence_values def loss( self, logits: tf.Tensor, tag_indices: tf.Tensor, sequence_lengths: tf.Tensor From 198c8c670efa30b296b031efb8fbce13d8ebf7d4 Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Tue, 16 Jun 2020 15:44:49 +0200 Subject: [PATCH 2/9] update EXTRACTORS_WITH_CONFIDENCES --- rasa/nlu/test.py | 2 +- rasa/utils/tensorflow/crf.py | 3 +-- rasa/utils/tensorflow/layers.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/rasa/nlu/test.py b/rasa/nlu/test.py index a0effebb026a..2edf7607859d 100644 --- a/rasa/nlu/test.py +++ b/rasa/nlu/test.py @@ -53,7 +53,7 @@ # performs entity extraction but those two classifiers don't ENTITY_PROCESSORS = {"EntitySynonymMapper", "ResponseSelector"} -EXTRACTORS_WITH_CONFIDENCES = {"CRFEntityExtractor"} +EXTRACTORS_WITH_CONFIDENCES = {"CRFEntityExtractor", "DIETClassifier"} CVEvaluationResult = namedtuple("Results", "train test") diff --git a/rasa/utils/tensorflow/crf.py b/rasa/utils/tensorflow/crf.py index 7ddbb68e0aaf..59c20ea5ff55 100644 --- a/rasa/utils/tensorflow/crf.py +++ b/rasa/utils/tensorflow/crf.py @@ -51,8 +51,7 @@ def call(self, inputs, state): new_state = inputs + tf.reduce_max(transition_scores, [1]) backpointers = tf.argmax(transition_scores, 1) backpointers = tf.cast(backpointers, tf.float32) - scores = tf.reduce_max(transition_scores, [1]) - return tf.concat([backpointers, scores], axis=1), new_state + return tf.concat([backpointers, new_state], axis=1), new_state def crf_decode_forward( diff --git a/rasa/utils/tensorflow/layers.py b/rasa/utils/tensorflow/layers.py index 74665b7e116e..2dfaf2a2730e 100644 --- a/rasa/utils/tensorflow/layers.py +++ b/rasa/utils/tensorflow/layers.py @@ -485,7 +485,7 @@ def call( sequence_lengths, maxlen=tf.shape(pred_ids)[1], dtype=pred_ids.dtype ) - confidence_values = tf.nn.softmax(scores * tf.cast(mask, tf.float32)) + confidence_values = tf.nn.sigmoid(scores * tf.cast(mask, tf.float32)) return pred_ids * mask, confidence_values From a585702d629f6cd2be5e7ea6d5b762390630f910 Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Wed, 17 Jun 2020 09:59:56 +0200 Subject: [PATCH 3/9] update crf_decode_backward --- rasa/utils/tensorflow/crf.py | 77 ++++++++++++++++++++------------- rasa/utils/tensorflow/layers.py | 11 +++-- 2 files changed, 55 insertions(+), 33 deletions(-) diff --git a/rasa/utils/tensorflow/crf.py b/rasa/utils/tensorflow/crf.py index 59c20ea5ff55..7ebb7b426005 100644 --- a/rasa/utils/tensorflow/crf.py +++ b/rasa/utils/tensorflow/crf.py @@ -1,16 +1,15 @@ -import numpy as np import tensorflow as tf from tensorflow_addons.utils.types import TensorLike from typeguard import typechecked -from typing import Optional +from typing import Tuple class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell): """Computes the forward decoding in a linear-chain CRF.""" @typechecked - def __init__(self, transition_params: TensorLike, **kwargs): + def __init__(self, transition_params: TensorLike, **kwargs) -> None: """Initialize the CrfDecodeForwardRnnCell. Args: @@ -24,17 +23,19 @@ def __init__(self, transition_params: TensorLike, **kwargs): self._num_tags = transition_params.shape[0] @property - def state_size(self): + def state_size(self) -> int: return self._num_tags @property - def output_size(self): + def output_size(self) -> int: return self._num_tags def build(self, input_shape): super().build(input_shape) - def call(self, inputs, state): + def call( + self, inputs: TensorLike, state: TensorLike + ) -> Tuple[tf.Tensor, tf.Tensor]: """Build the CrfDecodeForwardRnnCell. Args: @@ -43,7 +44,7 @@ def call(self, inputs, state): score values. Returns: - backpointers: A [batch_size, num_tags] matrix of backpointers. + output: A [batch_size, num_tags * 2] matrix of backpointers and scores. new_state: A [batch_size, num_tags] matrix of new score values. """ state = tf.expand_dims(state[0], 2) @@ -51,7 +52,8 @@ def call(self, inputs, state): new_state = inputs + tf.reduce_max(transition_scores, [1]) backpointers = tf.argmax(transition_scores, 1) backpointers = tf.cast(backpointers, tf.float32) - return tf.concat([backpointers, new_state], axis=1), new_state + scores = tf.reduce_max(tf.nn.softmax(transition_scores, axis=1), [1]) + return tf.concat([backpointers, scores], axis=1), new_state def crf_decode_forward( @@ -59,7 +61,7 @@ def crf_decode_forward( state: TensorLike, transition_params: TensorLike, sequence_lengths: TensorLike, -) -> tf.Tensor: +) -> Tuple[tf.Tensor, tf.Tensor]: """Computes forward decoding in a linear-chain CRF. Args: @@ -70,7 +72,7 @@ def crf_decode_forward( sequence_lengths: A [batch_size] vector of true sequence lengths. Returns: - backpointers: A [batch_size, num_tags] matrix of backpointers. + output: A [batch_size, num_tags * 2] matrix of backpointers and scores. new_state: A [batch_size, num_tags] matrix of new score values. """ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) @@ -82,32 +84,40 @@ def crf_decode_forward( return crf_fwd_layer(inputs, state, mask=mask) -def crf_decode_backward(inputs: TensorLike, state: TensorLike) -> tf.Tensor: +def crf_decode_backward( + inputs: TensorLike, scores: TensorLike, state: TensorLike +) -> Tuple[tf.Tensor, tf.Tensor]: """Computes backward decoding in a linear-chain CRF. Args: - inputs: A [batch_size, num_tags] matrix of - backpointer of next step (in time order). + inputs: A [batch_size, num_tags] matrix of backpointer of next step + (in time order). + scores: A [batch_size, num_tags] matrix of scores of next step (in time order). state: A [batch_size, 1] matrix of tag index of next step. Returns: - new_tags: A [batch_size, num_tags] - tensor containing the new tag indices. + new_tags: A [batch_size, num_tags] tensor containing the new tag indices. + new_scores: A [batch_size, num_tags] tensor containing the new score values. """ inputs = tf.transpose(inputs, [1, 0, 2]) + scores = tf.transpose(scores, [1, 0, 2]) def _scan_fn(state, inputs): - state = tf.squeeze(state, axis=[1]) + state = tf.cast(tf.squeeze(state, axis=[1]), dtype=tf.int32) idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1) new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1) return new_tags - return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) + output_tags = tf.scan(_scan_fn, inputs, state) + state = tf.cast(state, dtype=tf.float32) + output_scores = tf.scan(_scan_fn, scores, state) + + return tf.transpose(output_tags, [1, 0, 2]), tf.transpose(output_scores, [1, 0, 2]) def crf_decode( potentials: TensorLike, transition_params: TensorLike, sequence_length: TensorLike -) -> tf.Tensor: +) -> Tuple[tf.Tensor, tf.Tensor]: """Decode the highest scoring sequence of tags. Args: @@ -128,7 +138,9 @@ def crf_decode( # argmax tag and the max activation. def _single_seq_fn(): decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32) - best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1]) + best_score = tf.reshape( + tf.reduce_max(tf.nn.softmax(potentials, axis=2), axis=2), shape=[-1] + ) return decode_tags, best_score def _multi_seq_fn(): @@ -141,31 +153,38 @@ def _multi_seq_fn(): tf.constant(0, dtype=tf.int32), sequence_length - 1 ) - backpointers, last_score = crf_decode_forward( + output, last_score = crf_decode_forward( inputs, initial_state, transition_params, sequence_length_less_one ) - backpointers, scores = tf.split(backpointers, 2, axis=2) + backpointers, scores = tf.split(output, 2, axis=2) - scores = tf.reduce_max(scores, axis=[2]) - initial_score = tf.reduce_max(last_score, axis=[1]) - initial_score = tf.expand_dims(initial_score, axis=1) - scores = tf.concat([scores, initial_score], axis=1) - - backpointers = tf.cast(backpointers, tf.int32) + backpointers = tf.cast(backpointers, dtype=tf.int32) backpointers = tf.reverse_sequence( backpointers, sequence_length_less_one, seq_axis=1 ) + scores = tf.reverse_sequence(scores, sequence_length_less_one, seq_axis=1) + initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) initial_state = tf.expand_dims(initial_state, axis=-1) - decode_tags = crf_decode_backward(backpointers, initial_state) + initial_score = tf.reduce_max(tf.nn.softmax(last_score, axis=1), axis=[1]) + initial_score = tf.expand_dims(initial_score, axis=-1) + + decode_tags, decode_scores = crf_decode_backward( + backpointers, scores, initial_state + ) + decode_tags = tf.squeeze(decode_tags, axis=[2]) decode_tags = tf.concat([initial_state, decode_tags], axis=1) decode_tags = tf.reverse_sequence(decode_tags, sequence_length, seq_axis=1) - return decode_tags, scores + decode_scores = tf.squeeze(decode_scores, axis=[2]) + decode_scores = tf.concat([initial_score, decode_scores], axis=1) + decode_scores = tf.reverse_sequence(decode_scores, sequence_length, seq_axis=1) + + return decode_tags, decode_scores if potentials.shape[1] is not None: # shape is statically know, so we just execute diff --git a/rasa/utils/tensorflow/layers.py b/rasa/utils/tensorflow/layers.py index 2dfaf2a2730e..582e2c4e7d0b 100644 --- a/rasa/utils/tensorflow/layers.py +++ b/rasa/utils/tensorflow/layers.py @@ -477,17 +477,20 @@ def call( A [batch_size, max_seq_len] matrix, with dtype `tf.float32`. Contains the confidence values of the highest scoring tag indices. """ - pred_ids, scores = rasa.utils.tensorflow.crf.crf_decode( + predicted_ids, scores = rasa.utils.tensorflow.crf.crf_decode( logits, self.transition_params, sequence_lengths ) # set prediction index for padding to `0` mask = tf.sequence_mask( - sequence_lengths, maxlen=tf.shape(pred_ids)[1], dtype=pred_ids.dtype + sequence_lengths, + maxlen=tf.shape(predicted_ids)[1], + dtype=predicted_ids.dtype, ) - confidence_values = tf.nn.sigmoid(scores * tf.cast(mask, tf.float32)) + confidence_values = scores * tf.cast(mask, tf.float32) + predicted_ids = predicted_ids * mask - return pred_ids * mask, confidence_values + return predicted_ids, confidence_values def loss( self, logits: tf.Tensor, tag_indices: tf.Tensor, sequence_lengths: tf.Tensor From 0f87326d7285b4c6c1de5f45cc6c3cec0750ff3f Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Wed, 17 Jun 2020 10:08:06 +0200 Subject: [PATCH 4/9] add changelog entry --- changelog/5481.improvement.rst | 1 + docs/nlu/entity-extraction.rst | 4 ++-- rasa/utils/tensorflow/crf.py | 5 +++++ 3 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 changelog/5481.improvement.rst diff --git a/changelog/5481.improvement.rst b/changelog/5481.improvement.rst new file mode 100644 index 000000000000..438148dbd085 --- /dev/null +++ b/changelog/5481.improvement.rst @@ -0,0 +1 @@ +``DIETClassifier`` now also assigns a confidence value to entity predictions. diff --git a/docs/nlu/entity-extraction.rst b/docs/nlu/entity-extraction.rst index b35246c83de3..1605aff60c0f 100644 --- a/docs/nlu/entity-extraction.rst +++ b/docs/nlu/entity-extraction.rst @@ -60,9 +60,9 @@ exactly. Instead it will return the trained synonym. .. note:: - The ``confidence`` will be set by the ``CRFEntityExtractor`` component. The + The ``confidence`` will be set by the ``CRFEntityExtractor`` and the ``DIETClassifier`` component. The ``DucklingHTTPExtractor`` will always return ``1``. The ``SpacyEntityExtractor`` extractor - and ``DIETClassifier`` do not provide this information and return ``null``. + does not provide this information and returns ``null``. Some extractors, like ``duckling``, may include additional information. For example: diff --git a/rasa/utils/tensorflow/crf.py b/rasa/utils/tensorflow/crf.py index 7ebb7b426005..fc73fce8fdda 100644 --- a/rasa/utils/tensorflow/crf.py +++ b/rasa/utils/tensorflow/crf.py @@ -5,6 +5,11 @@ from typing import Tuple +# original code taken from +# https://github.com/tensorflow/addons/blob/master/tensorflow_addons/text/crf.py +# (modified to our neeeds) + + class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell): """Computes the forward decoding in a linear-chain CRF.""" From 512bd4d9bc7f40dcb44b77ed3329490451b9b668 Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Wed, 17 Jun 2020 10:15:38 +0200 Subject: [PATCH 5/9] update tests --- tests/nlu/test_evaluation.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/nlu/test_evaluation.py b/tests/nlu/test_evaluation.py index e6dff4ebfe9f..1a0db9638c39 100644 --- a/tests/nlu/test_evaluation.py +++ b/tests/nlu/test_evaluation.py @@ -253,6 +253,21 @@ def test_determine_token_labels_with_extractors(): ["CRFEntityExtractor"], 0.87, ), + ( + Token("pizza", 4), + [ + { + "start": 4, + "end": 9, + "value": "pizza", + "entity": "food", + "confidence_entity": 0.87, + "extractor": "DIETClassfifier", + } + ], + ["DIETClassfifier"], + 0.87, + ), ], ) def test_get_entity_confidences( From def314f5bce6c3b90a7382d9081c5050d40f6fe9 Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Wed, 17 Jun 2020 11:21:28 +0200 Subject: [PATCH 6/9] fix typo --- tests/nlu/test_evaluation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/nlu/test_evaluation.py b/tests/nlu/test_evaluation.py index 1a0db9638c39..a4c3a9f4a5a6 100644 --- a/tests/nlu/test_evaluation.py +++ b/tests/nlu/test_evaluation.py @@ -262,10 +262,10 @@ def test_determine_token_labels_with_extractors(): "value": "pizza", "entity": "food", "confidence_entity": 0.87, - "extractor": "DIETClassfifier", + "extractor": "DIETClassifier", } ], - ["DIETClassfifier"], + ["DIETClassifier"], 0.87, ), ], From 8b6f43960895d696fb8bd7181149f2e922397e8b Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Wed, 17 Jun 2020 13:38:12 +0200 Subject: [PATCH 7/9] review comments --- rasa/utils/tensorflow/crf.py | 45 +++++++++++++++++++++++---------- rasa/utils/tensorflow/layers.py | 2 +- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/rasa/utils/tensorflow/crf.py b/rasa/utils/tensorflow/crf.py index fc73fce8fdda..c3b679eb81d1 100644 --- a/rasa/utils/tensorflow/crf.py +++ b/rasa/utils/tensorflow/crf.py @@ -55,9 +55,17 @@ def call( state = tf.expand_dims(state[0], 2) transition_scores = state + self._transition_params new_state = inputs + tf.reduce_max(transition_scores, [1]) + backpointers = tf.argmax(transition_scores, 1) backpointers = tf.cast(backpointers, tf.float32) + + # apply softmax to transition_scores to get scores in range from 0 to 1 scores = tf.reduce_max(tf.nn.softmax(transition_scores, axis=1), [1]) + + # In the RNN implementation only the first value that is returned from a cell + # is kept throughout the RNN, so that you will have the values from each time + # step in the final output. As we need the backpointers as well as the scores + # for each time step, we concatenate them. return tf.concat([backpointers, scores], axis=1), new_state @@ -90,12 +98,12 @@ def crf_decode_forward( def crf_decode_backward( - inputs: TensorLike, scores: TensorLike, state: TensorLike + backpointers: TensorLike, scores: TensorLike, state: TensorLike ) -> Tuple[tf.Tensor, tf.Tensor]: """Computes backward decoding in a linear-chain CRF. Args: - inputs: A [batch_size, num_tags] matrix of backpointer of next step + backpointers: A [batch_size, num_tags] matrix of backpointer of next step (in time order). scores: A [batch_size, num_tags] matrix of scores of next step (in time order). state: A [batch_size, 1] matrix of tag index of next step. @@ -104,16 +112,17 @@ def crf_decode_backward( new_tags: A [batch_size, num_tags] tensor containing the new tag indices. new_scores: A [batch_size, num_tags] tensor containing the new score values. """ - inputs = tf.transpose(inputs, [1, 0, 2]) + backpointers = tf.transpose(backpointers, [1, 0, 2]) scores = tf.transpose(scores, [1, 0, 2]) - def _scan_fn(state, inputs): - state = tf.cast(tf.squeeze(state, axis=[1]), dtype=tf.int32) - idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1) - new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1) - return new_tags + def _scan_fn(_state: TensorLike, _inputs: TensorLike) -> tf.Tensor: + _state = tf.cast(tf.squeeze(_state, axis=[1]), dtype=tf.int32) + idxs = tf.stack([tf.range(tf.shape(_inputs)[0]), _state], axis=1) + return tf.expand_dims(tf.gather_nd(_inputs, idxs), axis=-1) - output_tags = tf.scan(_scan_fn, inputs, state) + output_tags = tf.scan(_scan_fn, backpointers, state) + # the dtype of the input parameters of tf.scan need to match + # convert state to float32 to match the type of scores state = tf.cast(state, dtype=tf.float32) output_scores = tf.scan(_scan_fn, scores, state) @@ -122,7 +131,7 @@ def _scan_fn(state, inputs): def crf_decode( potentials: TensorLike, transition_params: TensorLike, sequence_length: TensorLike -) -> Tuple[tf.Tensor, tf.Tensor]: +) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: """Decode the highest scoring sequence of tags. Args: @@ -135,7 +144,9 @@ def crf_decode( Returns: decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. Contains the highest scoring tag indices. - scores: A [batch_size, max_seq_len] vector, containing the score of `decode_tags`. + decode_scores: A [batch_size, max_seq_len] matrix, containing the score of + `decode_tags`. + best_score: A [batch_size] vector, containing the best score of `decode_tags`. """ sequence_length = tf.cast(sequence_length, dtype=tf.int32) @@ -143,10 +154,11 @@ def crf_decode( # argmax tag and the max activation. def _single_seq_fn(): decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32) - best_score = tf.reshape( + decode_scores = tf.reshape( tf.reduce_max(tf.nn.softmax(potentials, axis=2), axis=2), shape=[-1] ) - return decode_tags, best_score + best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1]) + return decode_tags, decode_scores, best_score def _multi_seq_fn(): # Computes forward decoding. Get last score and backpointers. @@ -162,6 +174,9 @@ def _multi_seq_fn(): inputs, initial_state, transition_params, sequence_length_less_one ) + # output is a matrix of size [batch-size, max-seq-length, num-tags * 2] + # split the matrix on axis 2 to get the backpointers and scores, which are + # both of size [batch-size, max-seq-length, num-tags] backpointers, scores = tf.split(output, 2, axis=2) backpointers = tf.cast(backpointers, dtype=tf.int32) @@ -189,7 +204,9 @@ def _multi_seq_fn(): decode_scores = tf.concat([initial_score, decode_scores], axis=1) decode_scores = tf.reverse_sequence(decode_scores, sequence_length, seq_axis=1) - return decode_tags, decode_scores + best_score = tf.reduce_max(last_score, axis=1) + + return decode_tags, decode_scores, best_score if potentials.shape[1] is not None: # shape is statically know, so we just execute diff --git a/rasa/utils/tensorflow/layers.py b/rasa/utils/tensorflow/layers.py index 582e2c4e7d0b..c6d051191e1e 100644 --- a/rasa/utils/tensorflow/layers.py +++ b/rasa/utils/tensorflow/layers.py @@ -477,7 +477,7 @@ def call( A [batch_size, max_seq_len] matrix, with dtype `tf.float32`. Contains the confidence values of the highest scoring tag indices. """ - predicted_ids, scores = rasa.utils.tensorflow.crf.crf_decode( + predicted_ids, scores, _ = rasa.utils.tensorflow.crf.crf_decode( logits, self.transition_params, sequence_lengths ) # set prediction index for padding to `0` From ea17a0c0eb2b286c5e6e44cbf86e6e32c6af63b0 Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Wed, 17 Jun 2020 13:49:50 +0200 Subject: [PATCH 8/9] update single_seq_fn --- rasa/utils/tensorflow/crf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rasa/utils/tensorflow/crf.py b/rasa/utils/tensorflow/crf.py index c3b679eb81d1..4a1b5945ebbc 100644 --- a/rasa/utils/tensorflow/crf.py +++ b/rasa/utils/tensorflow/crf.py @@ -154,9 +154,7 @@ def crf_decode( # argmax tag and the max activation. def _single_seq_fn(): decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32) - decode_scores = tf.reshape( - tf.reduce_max(tf.nn.softmax(potentials, axis=2), axis=2), shape=[-1] - ) + decode_scores = tf.reduce_max(tf.nn.softmax(potentials, axis=2), axis=2) best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1]) return decode_tags, decode_scores, best_score From a9b0c76cb816f71d74a0a86d1dee723b3feab9df Mon Sep 17 00:00:00 2001 From: Tanja Bergmann Date: Wed, 17 Jun 2020 14:43:32 +0200 Subject: [PATCH 9/9] fix deepsource issue --- rasa/utils/tensorflow/crf.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/rasa/utils/tensorflow/crf.py b/rasa/utils/tensorflow/crf.py index 4a1b5945ebbc..e4e33511a851 100644 --- a/rasa/utils/tensorflow/crf.py +++ b/rasa/utils/tensorflow/crf.py @@ -211,9 +211,7 @@ def _multi_seq_fn(): # the appropriate code path if potentials.shape[1] == 1: return _single_seq_fn() - else: - return _multi_seq_fn() - else: - return tf.cond( - tf.equal(tf.shape(potentials)[1], 1), _single_seq_fn, _multi_seq_fn - ) + + return _multi_seq_fn() + + return tf.cond(tf.equal(tf.shape(potentials)[1], 1), _single_seq_fn, _multi_seq_fn)