Skip to content

Commit

Permalink
train crf on all data
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Feb 7, 2018
1 parent ebe244f commit abae509
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
5 changes: 4 additions & 1 deletion rasa_nlu/extractors/crf_entity_extractor.py
Expand Up @@ -95,9 +95,12 @@ def train(self, training_data, config, **kwargs):
self.L1_C = config.get("L1_c", 1)
self.L2_C = config.get("L2_c", 1e-3)

# checks whether there is at least one example with an entity annotation
if training_data.entity_examples:
# convert the dataset into features
dataset = self._create_dataset(training_data.entity_examples)
# this will train on ALL examples, even the ones
# without annotations
dataset = self._create_dataset(training_data.training_examples)
# train the model
self._train_model(dataset)

Expand Down
18 changes: 12 additions & 6 deletions rasa_nlu/training_data/formats/rasa.py
Expand Up @@ -4,11 +4,12 @@
from __future__ import unicode_literals

import logging

from collections import defaultdict

from rasa_nlu.training_data import Message, TrainingData
from rasa_nlu.training_data.formats.readerwriter import JsonTrainingDataReader, TrainingDataWriter
from rasa_nlu.training_data.formats.readerwriter import (
JsonTrainingDataReader,
TrainingDataWriter)
from rasa_nlu.training_data.util import transform_entity_synonyms
from rasa_nlu.utils import json_to_string

Expand All @@ -30,19 +31,23 @@ def read_from_json(self, js, **kwargs):
entity_synonyms = transform_entity_synonyms(entity_synonyms)

if intent_examples or entity_examples:
logger.warn("DEPRECATION warning: your rasa data contains 'intent_examples' "
logger.warn("DEPRECATION warning: your rasa data "
"contains 'intent_examples' "
"or 'entity_examples' which will be "
"removed in the future. Consider putting all your examples "
"removed in the future. Consider "
"putting all your examples "
"into the 'common_examples' section.")

all_examples = common_examples + intent_examples + entity_examples
training_examples = []
for ex in all_examples:
msg = Message.build(ex['text'], ex.get("intent"), ex.get("entities"))
msg = Message.build(ex['text'], ex.get("intent"),
ex.get("entities"))
training_examples.append(msg)

return TrainingData(training_examples, entity_synonyms, regex_features)


class RasaWriter(TrainingDataWriter):
def dumps(self, training_data, **kwargs):
"""Writes Training Data to a string in json format."""
Expand Down Expand Up @@ -82,6 +87,7 @@ def validate_rasa_nlu_data(data):
"https://rasahq.github.io/rasa_nlu/dataformat.html")
raise e


def _rasa_nlu_data_schema():
training_example_schema = {
"type": "object",
Expand Down Expand Up @@ -139,4 +145,4 @@ def _rasa_nlu_data_schema():
}
},
"additionalProperties": False
}
}
7 changes: 5 additions & 2 deletions rasa_nlu/training_data/message.py
Expand Up @@ -13,7 +13,11 @@ def __init__(self, text, data=None, output_properties=None, time=None):
self.text = text
self.time = time
self.data = data if data else {}
self.output_properties = output_properties if output_properties else set()

if output_properties:
self.output_properties = output_properties
else:
self.output_properties = set()

def set(self, prop, info, add_to_output=False):
self.data[prop] = info
Expand Down Expand Up @@ -42,7 +46,6 @@ def __eq__(self, other):
def __hash__(self):
return hash((self.text, str(ordered(self.data))))


@classmethod
def build(cls, text, intent=None, entities=None):
data = {}
Expand Down

0 comments on commit abae509

Please sign in to comment.