Skip to content

Commit

Permalink
Add train_subprocess for more stable training
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewScholefield authored and forslund committed Mar 13, 2019
1 parent c865e29 commit e9c1f29
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 8 deletions.
58 changes: 58 additions & 0 deletions padatious/__main__.py
@@ -0,0 +1,58 @@
import inspect
import json
from os.path import basename, splitext

from argparse import ArgumentParser

from padatious import IntentContainer


def train_setup(parser):
parser.add_argument('intent_cache', help='Folder to write trained intents to')
parser.add_argument('input_files', nargs='*', help='Input .intent and .entity files')
parser.add_argument('-d', '--data', help='Serialized training args', type=json.loads)
parser.add_argument('-s', '--single-thread', help='Run training in a single thread')
parser.add_argument('-f', '--force', help='Force retraining if already trained')
parser.add_argument('-a', '--args', help='Extra args (list) for function', type=json.loads)
parser.add_argument('-k', '--kwargs', help='Extra kwargs (json) for function', type=json.loads)


def train(parser, args):
if bool(args.input_files) == bool(args.data):
parser.error('You must specify one of input_files or --data (but not both)')

cont = IntentContainer(args.intent_cache)
if args.data:
cont.apply_training_args(args.data)
else:
for fn in args.input_files:
obj_name, ext = splitext(basename(fn))
if ext == '.intent':
cont.load_intent(obj_name, fn)
elif ext == '.entity':
cont.load_entity(obj_name, fn)
else:
parser.error('Unknown file extension: {}'.format(ext))
kwargs = inspect.signature(cont.train).bind(*(args.args or [])).arguments
kwargs.update(args.kwargs or {})
kwargs.setdefault('debug', True)
kwargs.setdefault('single_thread', args.single_thread)
kwargs.setdefault('force', args.force)
if cont.train(**kwargs):
return 0
return 10 # timeout


def main():
parser = ArgumentParser(description='Tool to interact with padatious via command line')
p = parser.add_subparsers(dest='action')
p.required = True
train_setup(p.add_parser('train'))

args = parser.parse_args()
if args.action == 'train':
exit(train(parser, args))


if __name__ == '__main__':
main()
90 changes: 83 additions & 7 deletions padatious/intent_container.py
Expand Up @@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os

import padaos
import sys
from functools import wraps
from subprocess import call, check_output
from threading import Thread

from padatious.match_data import MatchData
Expand All @@ -23,21 +28,47 @@
from padatious.util import tokenize


def _save_args(func):
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
bound_args = inspect.signature(func).bind(*args, **kwargs)
bound_args.apply_defaults()
kwargs = bound_args.arguments
kwargs['__name__'] = func.__name__
kwargs.pop('self').serialized_args.append(kwargs)

return wrapper


class IntentContainer(object):
"""
Creates an IntentContainer object used to load and match intents
Args:
cache_dir (str): Place to put all saved neural networks
"""

def __init__(self, cache_dir):
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir = cache_dir
self.must_train = False
self.intents = IntentManager(cache_dir)
self.entities = EntityManager(cache_dir)
self.padaos = padaos.IntentContainer()
self.train_thread = None # type: Thread
self.serialized_args = [] # Arguments of all calls to register intents/entities

def clear(self):
os.makedirs(self.cache_dir, exist_ok=True)
self.must_train = False
self.intents = IntentManager(self.cache_dir)
self.entities = EntityManager(self.cache_dir)
self.padaos = padaos.IntentContainer()
self.train_thread = None
self.serialized_args = []

@_save_args
def add_intent(self, name, lines, reload_cache=False):
"""
Creates a new intent, optionally checking the cache first
Expand All @@ -51,6 +82,7 @@ def add_intent(self, name, lines, reload_cache=False):
self.padaos.add_intent(name, lines)
self.must_train = True

@_save_args
def add_entity(self, name, lines, reload_cache=False):
"""
Adds an entity that matches the given lines.
Expand All @@ -69,6 +101,7 @@ def add_entity(self, name, lines, reload_cache=False):
self.padaos.add_entity(name, lines)
self.must_train = True

@_save_args
def load_entity(self, name, file_name, reload_cache=False):
"""
Loads an entity, optionally checking the cache first
Expand All @@ -83,10 +116,12 @@ def load_entity(self, name, file_name, reload_cache=False):
with open(file_name) as f:
self.padaos.add_entity(name, f.read().split('\n'))

@_save_args
def load_file(self, *args, **kwargs):
"""Legacy. Use load_intent instead"""
self.load_intent(*args, **kwargs)

@_save_args
def load_intent(self, name, file_name, reload_cache=False):
"""
Loads an intent, optionally checking the cache first
Expand All @@ -100,52 +135,85 @@ def load_intent(self, name, file_name, reload_cache=False):
with open(file_name) as f:
self.padaos.add_intent(name, f.read().split('\n'))

@_save_args
def remove_intent(self, name):
"""Unload an intent"""
self.intents.remove(name)
self.padaos.remove_intent(name)
self.must_train = True

@_save_args
def remove_entity(self, name):
"""Unload an entity"""
self.entities.remove(name)
self.padaos.remove_entity(name)

def _train(self, *args, **kwargs):
t1 = Thread(target=self.intents.train, args=args, kwargs=kwargs, daemon=True)
t2 = Thread(target=self.entities.train, args=args, kwargs=kwargs, daemon=True)
t2 = Thread(target=self.entities.train, args=args, kwargs=kwargs, daemon=True)
t1.start()
t2.start()
t1.join()
t2.join()
self.entities.calc_ent_dict()

def train(self, *args, force=False, **kwargs):
def train(self, debug=True, force=False, single_thread=False, timeout=20):
"""
Trains all the loaded intents that need to be updated
If a cache file exists with the same hash as the intent file,
the intent will not be trained and just loaded from file
Args:
print_updates (bool): Whether to print a message to stdout
each time a new intent is trained
debug (bool): Whether to print a message to stdout each time a new intent is trained
force (bool): Whether to force training if already finished
single_thread (bool): Whether to force running in a single thread
timeout (float): Seconds before cancelling training
Returns:
bool: True if training succeeded without timeout
"""
if not self.must_train and not force:
return
self.padaos.compile()

timeout = kwargs.setdefault('timeout', 20)
self.train_thread = Thread(target=self._train, args=args, kwargs=kwargs, daemon=True)
self.train_thread = Thread(target=self._train, kwargs=dict(
debug=debug,
single_thread=single_thread,
timeout=timeout
), daemon=True)
self.train_thread.start()
self.train_thread.join(timeout)

self.must_train = False
return not self.train_thread.is_alive()

def train_subprocess(self, *args, **kwargs):
"""
Trains in a subprocess which provides a timeout guarantees everything shuts down properly
Args:
See <train>
Returns:
bool: True for success, False if timed out
"""
ret = call([
sys.executable, '-m', 'padatious', 'train', self.cache_dir,
'-d', json.dumps(self.serialized_args),
'-a', json.dumps(args),
'-k', json.dumps(kwargs),
])
if ret == 2:
raise TypeError('Invalid train arguments: {} {}'.format(args, kwargs))
data = self.serialized_args
self.clear()
self.apply_training_args(data)
self.padaos.compile()
if ret == 0:
self.must_train = False
return True
elif ret == 10: # timeout
return False
else:
raise ValueError('Training failed and returned code: {}'.format(ret))

def calc_intents(self, query):
"""
Tests all the intents against the query and returns
Expand Down Expand Up @@ -184,3 +252,11 @@ def calc_intent(self, query):
best_match = max(matches, key=lambda x: x.conf)
best_matches = (match for match in matches if match.conf == best_match.conf)
return min(best_matches, key=lambda x: sum(map(len, x.matches.values())))

def get_training_args(self):
return self.serialized_args

def apply_training_args(self, data):
for params in data:
func_name = params.pop('__name__')
getattr(self, func_name)(**params)
5 changes: 5 additions & 0 deletions setup.py
Expand Up @@ -35,5 +35,10 @@
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
],
entry_points={
'console_scripts': [
'padatious=padatious.__main__:main'
]
},
keywords='intent-parser parser text text-processing',
)
27 changes: 26 additions & 1 deletion tests/test_container.py
Expand Up @@ -85,6 +85,32 @@ def test_train_timeout(self):
b = monotonic()
assert b - a <= 0.1

def test_train_timeout_subprocess(self):
self.cont.add_intent('a', [
' '.join(random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(5))
for __ in range(300)
])
self.cont.add_intent('b', [
' '.join(random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(5))
for __ in range(300)
])
a = monotonic()
assert not self.cont.train_subprocess(timeout=0.1)
b = monotonic()
assert b - a <= 1

def test_train_subprocess(self):
self.cont.add_intent('timer', [
'set a timer for {time} minutes',
])
self.cont.add_entity('time', [
'#', '##', '#:##', '##:##'
])
assert self.cont.train_subprocess(False, timeout=20)
intent = self.cont.calc_intent('set timer for 3 minutes')
assert intent.name == 'timer'
assert intent.matches == {'time': '3'}

def test_calc_intents(self):
self.test_add_intent()
self.cont.train(False)
Expand Down Expand Up @@ -156,7 +182,6 @@ def test_generalize(self):
])
self.cont.train(False)
intent = self.cont.calc_intent('make a timer for 3 minute')
print(intent)
assert intent.name == 'timer'
assert intent.matches == {'time': '3'}

Expand Down

0 comments on commit e9c1f29

Please sign in to comment.