Skip to content

Commit

Permalink
adapted code to comply with PEP8 via `autopep8 --in-place --aggressiv…
Browse files Browse the repository at this point in the history
…e --aggressive <file.py>`
  • Loading branch information
repodiac authored and krisgesling committed Oct 14, 2019
1 parent 275cc10 commit b195105
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 28 deletions.
68 changes: 52 additions & 16 deletions padatious/intent_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,26 @@ def clear(self):

def instantiate_from_disk(self):
"""
Instantiates the necessary (internal) data structures when loading persisted model from disk.
This is done via injecting entities and intents back from cached file versions.
"""
Instantiates the necessary (internal) data structures when loading persisted model from disk.
This is done via injecting entities and intents back from cached file versions.
"""

# ToDo: still padaos.compile (regex compilation) is redone when loading
for f in os.listdir(self.cache_dir):
if f.startswith('{') and f.endswith('}.hash'):
entity_name = f[1:f.find('}.hash')]
self.add_entity(name=entity_name, lines=[], reload_cache=False, must_train=False)
self.add_entity(
name=entity_name,
lines=[],
reload_cache=False,
must_train=False)
elif not f.startswith('{') and f.endswith('.hash'):
intent_name = f[0:f.find('.hash')]
self.add_intent(name=intent_name, lines=[], reload_cache=False, must_train=False)
self.add_intent(
name=intent_name,
lines=[],
reload_cache=False,
must_train=False)

@_save_args
def add_intent(self, name, lines, reload_cache=False, must_train=True):
Expand Down Expand Up @@ -110,15 +118,24 @@ def add_entity(self, name, lines, reload_cache=False, must_train=True):
name (str): The name of the entity
lines (list<str>): Lines of example extracted entities
reload_cache (bool): Whether to refresh all of cache
must_train (bool): Whether to dismiss model if present and train from scratch again
must_train (bool): Whether to dismiss model if present and train from scratch again
"""
Entity.verify_name(name)
self.entities.add(Entity.wrap_name(name), lines, reload_cache, must_train)
self.entities.add(
Entity.wrap_name(name),
lines,
reload_cache,
must_train)
self.padaos.add_entity(name, lines)
self.must_train = must_train

@_save_args
def load_entity(self, name, file_name, reload_cache=False, must_train=True):
def load_entity(
self,
name,
file_name,
reload_cache=False,
must_train=True):
"""
Loads an entity, optionally checking the cache first
Expand All @@ -140,7 +157,12 @@ def load_file(self, *args, **kwargs):
self.load_intent(*args, **kwargs)

@_save_args
def load_intent(self, name, file_name, reload_cache=False, must_train=True):
def load_intent(
self,
name,
file_name,
reload_cache=False,
must_train=True):
"""
Loads an intent, optionally checking the cache first
Expand Down Expand Up @@ -169,8 +191,16 @@ def remove_entity(self, name):
self.padaos.remove_entity(name)

def _train(self, *args, **kwargs):
t1 = Thread(target=self.intents.train, args=args, kwargs=kwargs, daemon=True)
t2 = Thread(target=self.entities.train, args=args, kwargs=kwargs, daemon=True)
t1 = Thread(
target=self.intents.train,
args=args,
kwargs=kwargs,
daemon=True)
t2 = Thread(
target=self.entities.train,
args=args,
kwargs=kwargs,
daemon=True)
t1.start()
t2.start()
t1.join()
Expand Down Expand Up @@ -221,7 +251,9 @@ def train_subprocess(self, *args, **kwargs):
'-k', json.dumps(kwargs),
])
if ret == 2:
raise TypeError('Invalid train arguments: {} {}'.format(args, kwargs))
raise TypeError(
'Invalid train arguments: {} {}'.format(
args, kwargs))
data = self.serialized_args
self.clear()
self.apply_training_args(data)
Expand All @@ -232,7 +264,8 @@ def train_subprocess(self, *args, **kwargs):
elif ret == 10: # timeout
return False
else:
raise ValueError('Training failed and returned code: {}'.format(ret))
raise ValueError(
'Training failed and returned code: {}'.format(ret))

def calc_intents(self, query):
"""
Expand All @@ -253,7 +286,8 @@ def calc_intents(self, query):
sent = tokenize(query)
for perfect_match in self.padaos.calc_intents(query):
name = perfect_match['name']
intents[name] = MatchData(name, sent, matches=perfect_match['entities'], conf=1.0)
intents[name] = MatchData(
name, sent, matches=perfect_match['entities'], conf=1.0)
return list(intents.values())

def calc_intent(self, query):
Expand All @@ -270,8 +304,10 @@ def calc_intent(self, query):
if len(matches) == 0:
return MatchData('', '')
best_match = max(matches, key=lambda x: x.conf)
best_matches = (match for match in matches if match.conf == best_match.conf)
return min(best_matches, key=lambda x: sum(map(len, x.matches.values())))
best_matches = (
match for match in matches if match.conf == best_match.conf)
return min(best_matches, key=lambda x: sum(
map(len, x.matches.values())))

def get_training_args(self):
return self.serialized_args
Expand Down
30 changes: 22 additions & 8 deletions padatious/training_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class TrainingManager(object):
cls (Type[Trainable]): Class to wrap
cache_dir (str): Place to store cache files
"""

def __init__(self, cls, cache_dir):
self.cls = cls
self.cache = cache_dir
Expand All @@ -47,10 +48,15 @@ def __init__(self, cls, cache_dir):

def add(self, name, lines, reload_cache=False, must_train=True):

# special case: load persisted (aka. cached) resource (i.e. entity or intent) from file into memory data structures
# special case: load persisted (aka. cached) resource (i.e.
# entity or intent) from file into memory data structures
if not must_train:
self.objects.append(self.cls.from_file(name=name, folder=self.cache))
# general case: load resource (entity or intent) to training queue or if no change occurred to memory data structures
self.objects.append(
self.cls.from_file(
name=name,
folder=self.cache))
# general case: load resource (entity or intent) to training queue
# or if no change occurred to memory data structures
else:
hash_fn = join(self.cache, name + '.hash')
old_hsh = None
Expand All @@ -62,7 +68,9 @@ def add(self, name, lines, reload_cache=False, must_train=True):
if reload_cache or old_hsh != new_hsh:
self.objects_to_train.append(self.cls(name=name, hsh=new_hsh))
else:
self.objects.append(self.cls.from_file(name=name, folder=self.cache))
self.objects.append(
self.cls.from_file(
name=name, folder=self.cache))
self.train_data.add_lines(name, lines)

def load(self, name, file_name, reload_cache=False):
Expand All @@ -71,13 +79,16 @@ def load(self, name, file_name, reload_cache=False):

def remove(self, name):
self.objects = [i for i in self.objects if i.name != name]
self.objects_to_train = [i for i in self.objects_to_train if i.name != name]
self.objects_to_train = [
i for i in self.objects_to_train if i.name != name]
self.train_data.remove_lines(name)

def train(self, debug=True, single_thread=False, timeout=20):
train = partial(
_train_and_save, cache=self.cache, data=self.train_data, print_updates=debug
)
_train_and_save,
cache=self.cache,
data=self.train_data,
print_updates=debug)

if single_thread:
for i in self.objects_to_train:
Expand All @@ -97,7 +108,10 @@ def train(self, debug=True, single_thread=False, timeout=20):
# Load saved objects from disk
for obj in self.objects_to_train:
try:
self.objects.append(self.cls.from_file(name=obj.name, folder=self.cache))
self.objects.append(
self.cls.from_file(
name=obj.name,
folder=self.cache))
except IOError:
if debug:
print('Took too long to train', obj.name)
Expand Down
12 changes: 8 additions & 4 deletions tests/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def test_instantiate_from_disk(self):
self.setup()
self.test_add_intent()
self.cont.train()

# instantiate from disk (load cached files)
self.setup()
self.cont.instantiate_from_disk()

assert len(self.cont.intents.train_data.sent_lists) == 0
assert len(self.cont.intents.objects_to_train) == 0
assert len(self.cont.intents.objects) == 2
Expand All @@ -78,7 +78,9 @@ def _create_large_intent(self, depth):
return '(a|b|)'
return '{0} {0}'.format(self._create_large_intent(depth - 1))

@pytest.mark.skipif(not os.environ.get('RUN_LONG'), reason="Takes a long time")
@pytest.mark.skipif(
not os.environ.get('RUN_LONG'),
reason="Takes a long time")
def test_train_timeout(self):
self.cont.add_intent('a', [
' '.join(random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(5))
Expand Down Expand Up @@ -130,7 +132,9 @@ def test_calc_intents(self):
self.cont.train(False)

intents = self.cont.calc_intents('this is another test')
assert (intents[0].conf > intents[1].conf) == (intents[0].name == 'test')
assert (
intents[0].conf > intents[1].conf) == (
intents[0].name == 'test')
assert self.cont.calc_intent('this is another test').name == 'test'

def test_empty(self):
Expand Down

0 comments on commit b195105

Please sign in to comment.