Skip to content

Commit

Permalink
Merge pull request #1547 from RasaHQ/add-analyzer
Browse files Browse the repository at this point in the history
Add analyzer to vectorizer
  • Loading branch information
Ghostvv committed Dec 12, 2018
2 parents e790177 + 9ba15d5 commit f91cd6a
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 39 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Added
- environment variables specified with ``${env_variable}`` in a yaml
configuration file are now replaced with the value of the environment variable
- more documentation on how to run NLU with Docker
- ``analyzer`` parameter to ``intent_featurizer_count_vectors`` featurizer to
configure whether to use word or character n-grams

Changed
-------
Expand Down
6 changes: 6 additions & 0 deletions data/test/config_embedding_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
language: en
pipeline:
- name: "intent_featurizer_count_vectors"
max_ngram: 3
- name: "intent_classifier_tensorflow_embedding"
epochs: 10
25 changes: 22 additions & 3 deletions docs/components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,32 @@ intent_featurizer_count_vectors
`sklearn's CountVectorizer <http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html>`_.
All tokens which consist only of digits (e.g. 123 and 99 but not a123d) will be assigned to the same feature.

.. note:: If the words in the model language cannot be split by whitespace,
.. note::
If the words in the model language cannot be split by whitespace,
a language-specific tokenizer is required in the pipeline before this component
(e.g. using ``tokenizer_jieba`` for Chinese).

:Configuration:
See `sklearn's CountVectorizer docs <http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html>`_
for detailed description of the configuration parameters

This featurizer can be configured to use word or character n-grams, using ``analyzer`` config parameter.
By default ``analyzer`` is set to ``word`` so word token counts are used as features.
If you want to use character n-grams, set ``analyzer`` to ``char`` or ``char_wb``.

.. note::
Option ‘char_wb’ creates character n-grams only from text inside word boundaries;
n-grams at the edges of words are padded with space.
This option can be used to create `Subword Semantic Hashing <https://arxiv.org/abs/1810.07150>`_

.. note::
For character n-grams do not forget to increase ``min_ngram`` and ``max_ngram`` parameters.
Otherwise the vocabulary will contain only single letters

Handling Out-Of-Vacabulary (OOV) words:

.. note:: Enabled only if ``analyzer`` is ``word``.

Since the training is performed on limited vocabulary data, it cannot be guaranteed that during prediction
an algorithm will not encounter an unknown word (a word that were not seen during training).
In order to teach an algorithm how to treat unknown words, some words in training data can be substituted by generic word ``OOV_token``.
Expand All @@ -145,9 +161,8 @@ intent_featurizer_count_vectors
maybe some additional general words. Then an algorithm will likely classify a message with unknown words as this intent ``outofscope``.

.. note::

This featurizer creates a bag-of-words representation by **counting** words,
so the number of ``OOV_token`` s might be important.
so the number of ``OOV_token`` in the sentence might be important.

- ``OOV_token`` set a keyword for unseen words; if training data contains ``OOV_token`` as words in some messages,
during prediction the words that were not seen during training will be substituted with provided ``OOV_token``;
Expand All @@ -163,6 +178,10 @@ intent_featurizer_count_vectors
pipeline:
- name: "intent_featurizer_count_vectors"
# whether to use word or character n-grams
# 'char_wb' creates character n-grams only inside word boundaries
# n-grams at the edges of words are padded with space.
"analyzer": 'word', # use 'char' or 'char_wb' for character
# the parameters are taken from
# sklearn's CountVectorizer
# regular expression for tokens
Expand Down
57 changes: 38 additions & 19 deletions rasa_nlu/classifiers/embedding_intent_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class EmbeddingIntentClassifier(Component):
Based on the starspace idea from: https://arxiv.org/abs/1709.03856.
However, in this implementation the `mu` parameter is treated differently
and additional hidden layers are added together with dropout."""
and additional hidden layers are added together with dropout.
"""

name = "intent_classifier_tensorflow_embedding"

Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(self,
):
# type: (...) -> None
"""Declare instant variables with default values"""

self._check_tensorflow()
super(EmbeddingIntentClassifier, self).__init__(component_config)

Expand Down Expand Up @@ -196,16 +198,14 @@ def _load_visual_params(self, config):

self.evaluate_on_num_examples = config['evaluate_on_num_examples']

def _load_params(self, **kwargs):
# type: (Dict[Text, Any]) -> None
config = copy.deepcopy(self.defaults)
config.update(kwargs)
def _load_params(self):
# type: () -> None

self._load_nn_architecture_params(config)
self._load_embedding_params(config)
self._load_regularization_params(config)
self._load_flag_if_tokenize_intents(config)
self._load_visual_params(config)
self._load_nn_architecture_params(self.component_config)
self._load_embedding_params(self.component_config)
self._load_regularization_params(self.component_config)
self._load_flag_if_tokenize_intents(self.component_config)
self._load_visual_params(self.component_config)

# package safety checks
@classmethod
Expand Down Expand Up @@ -247,7 +247,9 @@ def _create_intent_token_dict(intents, intent_split_symbol):
def _create_encoded_intents(self, intent_dict):
# type: (Dict[Text, int]) -> np.ndarray
"""Create matrix with intents encoded in rows as bag of words.
If intent_tokenization_flag is off, returns identity matrix"""
If intent_tokenization_flag is off, returns identity matrix.
"""

if self.intent_tokenization_flag:
intent_token_dict = self._create_intent_token_dict(
Expand All @@ -267,8 +269,11 @@ def _create_encoded_intents(self, intent_dict):
def _create_all_Y(self, size):
# type: (int) -> np.ndarray
"""Stack encoded_all_intents on top of each other
to create candidates for training examples
to calculate training accuracy"""
to create candidates for training examples and
to calculate training accuracy
"""

return np.stack([self.encoded_all_intents] * size)

# noinspection PyPep8Naming
Expand Down Expand Up @@ -330,9 +335,12 @@ def _create_tf_embed(self,

def _tf_sim(self, a, b):
# type: (tf.Tensor, tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]
"""Define similarity in two cases:
"""Define similarity
in two cases:
sim: between embedded words and embedded intent labels
sim_emb: between individual embedded intent labels only"""
sim_emb: between individual embedded intent labels only
"""

if self.similarity_type == 'cosine':
# normalize embedding vectors for cosine similarity
Expand Down Expand Up @@ -377,8 +385,11 @@ def _tf_loss(self, sim, sim_emb):
# training helpers:
def _create_batch_b(self, batch_pos_b, intent_ids):
# type: (np.ndarray, np.ndarray) -> np.ndarray
"""Create batch of intents, where the first is correct intent
and the rest are wrong intents sampled randomly"""
"""Create batch of intents.
Where the first is correct intent
and the rest are wrong intents sampled randomly
"""

batch_pos_b = batch_pos_b[:, np.newaxis, :]

Expand All @@ -400,7 +411,10 @@ def _create_batch_b(self, batch_pos_b, intent_ids):
def _linearly_increasing_batch_size(self, epoch):
# type: (int) -> int
"""Linearly increase batch size with every epoch.
The idea comes from https://arxiv.org/abs/1711.00489"""
The idea comes from https://arxiv.org/abs/1711.00489
"""

if not isinstance(self.batch_size, list):
return int(self.batch_size)

Expand All @@ -422,6 +436,7 @@ def _train_tf(self,
):
# type: (...) -> None
"""Train tf graph"""

self.session.run(tf.global_variables_initializer())

if self.evaluate_on_num_examples:
Expand Down Expand Up @@ -482,6 +497,7 @@ def _train_tf(self,
def _output_training_stat(self, X, intents_for_X, is_training):
# type: (np.ndarray, np.ndarray, tf.Tensor) -> np.ndarray
"""Output training statistics"""

n = self.evaluate_on_num_examples
ids = np.random.permutation(len(X))[:n]
all_Y = self._create_all_Y(X[ids].shape[0])
Expand Down Expand Up @@ -620,7 +636,10 @@ def process(self, message, **kwargs):
def persist(self, model_dir):
# type: (Text) -> Dict[Text, Any]
"""Persist this model into the passed directory.
Return the metadata necessary to load the model again."""
Return the metadata necessary to load the model again.
"""

if self.session is None:
return {"classifier_file": None}

Expand Down
55 changes: 47 additions & 8 deletions rasa_nlu/featurizers/count_vectors_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ class CountVectorsFeaturizer(Featurizer):
Creates bag-of-words representation of intent features
using sklearn's `CountVectorizer`.
All tokens which consist only of digits (e.g. 123 and 99
but not ab12d) will be represented by a single feature."""
but not ab12d) will be represented by a single feature.
Set `analyzer` to 'char_wb'
to use the idea of Subword Semantic Hashing
from https://arxiv.org/abs/1810.07150.
"""

name = "intent_featurizer_count_vectors"

Expand All @@ -37,7 +42,13 @@ class CountVectorsFeaturizer(Featurizer):
# the parameters are taken from
# sklearn's CountVectorizer

# whether to use word or character n-grams
# 'char_wb' creates character n-grams inside word boundaries
# n-grams at the edges of words are padded with space.
"analyzer": 'word', # use 'char' or 'char_wb' for character

# regular expression for tokens
# only used if analyzer == 'word'
"token_pattern": r'(?u)\b\w\w+\b',

# remove accents during the preprocessing step
Expand Down Expand Up @@ -78,6 +89,9 @@ def required_packages(cls):
return ["sklearn"]

def _load_count_vect_params(self):
# set analyzer
self.analyzer = self.component_config['analyzer']

# regular expression for tokens
self.token_pattern = self.component_config['token_pattern']

Expand Down Expand Up @@ -121,6 +135,20 @@ def _load_OOV_params(self):
if self.OOV_words:
self.OOV_words = [w.lower() for w in self.OOV_words]

def _check_analyzer(self):
if self.analyzer != 'word':
if self.OOV_token is not None:
logger.warning("Analyzer is set to character, "
"provided OOV word token will be ignored.")
if self.stop_words is not None:
logger.warning("Analyzer is set to character, "
"provided stop words will be ignored.")
if self.max_ngram == 1:
logger.warning("Analyzer is set to character, "
"but max n-gram is set to 1. "
"It means that the vocabulary will "
"contain single letters only.")

def __init__(self, component_config=None):
"""Construct a new count vectorizer using the sklearn framework."""

Expand All @@ -132,11 +160,15 @@ def __init__(self, component_config=None):
# handling Out-Of-Vacabulary (OOV) words
self._load_OOV_params()

# warn that some of config parameters might be ignored
self._check_analyzer()

# declare class instance for CountVectorizer
self.vect = None

def _tokenizer(self, text):
"""Override tokenizer in CountVectorizer"""
"""Override tokenizer in CountVectorizer."""

text = re.sub(r'\b[0-9]+\b', '__NUMBER__', text)

token_pattern = re.compile(self.token_pattern)
Expand Down Expand Up @@ -181,9 +213,13 @@ def _check_OOV_present(self, examples):
"".format(self.OOV_token))

def train(self, training_data, cfg=None, **kwargs):
# type: (TrainingData, RasaNLUModelConfig, **Any) -> None
"""Take parameters from config and
construct a new count vectorizer using the sklearn framework."""
# type: (TrainingData, RasaNLUModelConfig, Any) -> None
"""Train the featurizer.
Take parameters from config and
construct a new count vectorizer using the sklearn framework.
"""

from sklearn.feature_extraction.text import CountVectorizer

spacy_nlp = kwargs.get("spacy_nlp")
Expand All @@ -202,7 +238,8 @@ def train(self, training_data, cfg=None, **kwargs):
max_df=self.max_df,
min_df=self.min_df,
max_features=self.max_features,
tokenizer=self._tokenizer)
tokenizer=self._tokenizer,
analyzer=self.analyzer)

lem_exs = [self._get_message_text(example)
for example in training_data.intent_examples]
Expand Down Expand Up @@ -239,7 +276,9 @@ def process(self, message, **kwargs):
def persist(self, model_dir):
# type: (Text) -> Dict[Text, Any]
"""Persist this model into the passed directory.
Returns the metadata necessary to load the model again."""
Returns the metadata necessary to load the model again.
"""

featurizer_file = os.path.join(model_dir, self.name + ".pkl")
utils.pycloud_pickle(featurizer_file, self)
Expand All @@ -250,7 +289,7 @@ def load(cls,
model_dir=None, # type: Text
model_metadata=None, # type: Metadata
cached_component=None, # type: Optional[Component]
**kwargs # type: **Any
**kwargs # type: Any
):
# type: (...) -> CountVectorsFeaturizer

Expand Down
27 changes: 20 additions & 7 deletions tests/base/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
import tempfile

import pytest
from typing import Text

import rasa_nlu
from rasa_nlu import config, utils
from rasa_nlu.config import RasaNLUModelConfig, InvalidConfigError
from rasa_nlu.registry import registered_pipeline_templates
from rasa_nlu.components import ComponentBuilder
from tests.conftest import CONFIG_DEFAULTS_PATH
from tests.utilities import write_file_config

Expand All @@ -31,17 +29,18 @@ def test_blank_config():

def test_invalid_config_json():
file_config = """pipeline: [spacy_sklearn""" # invalid yaml
with tempfile.NamedTemporaryFile("w+", suffix="_tmp_config_file.json") as f:
with tempfile.NamedTemporaryFile("w+",
suffix="_tmp_config_file.json") as f:
f.write(file_config)
f.flush()
with pytest.raises(rasa_nlu.config.InvalidConfigError):
with pytest.raises(config.InvalidConfigError):
config.load(f.name)


def test_invalid_pipeline_template():
args = {"pipeline": "my_made_up_name"}
f = write_file_config(args)
with pytest.raises(InvalidConfigError) as execinfo:
with pytest.raises(config.InvalidConfigError) as execinfo:
config.load(f.name)
assert "unknown pipeline template" in str(execinfo.value)

Expand All @@ -56,7 +55,7 @@ def test_pipeline_looksup_registry():


def test_default_config_file():
final_config = RasaNLUModelConfig()
final_config = config.RasaNLUModelConfig()
assert len(final_config) > 1


Expand All @@ -68,3 +67,17 @@ def test_set_attr_on_component(default_config):

assert cfg.for_component("intent_classifier_sklearn") == expected
assert cfg.for_component("tokenizer_spacy") == {"name": "tokenizer_spacy"}


def test_override_defaults_tensorflow_embedding_pipeline():
cfg = config.load("data/test/config_embedding_test.yml")
builder = ComponentBuilder()

name1 = "intent_featurizer_count_vectors"

component1 = builder.create_component(name1, cfg)
assert component1.max_ngram == 3

name2 = "intent_classifier_tensorflow_embedding"
component2 = builder.create_component(name2, cfg)
assert component2.epochs == 10

0 comments on commit f91cd6a

Please sign in to comment.