Skip to content

Commit

Permalink
Varying first cutoff time for each target group (#258)
Browse files Browse the repository at this point in the history
* implement and test varying minimum data per group

* update release notes

* pin scikit-learn for doc builds

* update docstring

* add guide for controlling cutoff times

* fix dfs test

* add guide to index

* Revert "fix dfs test"

This reverts commit 584a5cb.

* pin version of featuretools

* update docstring

* update test case

* update docstring

* lint fix

* parametrize test

* lint fix
  • Loading branch information
jeff-hernandez committed Nov 2, 2021
1 parent f396057 commit 723c581
Show file tree
Hide file tree
Showing 8 changed files with 358 additions and 97 deletions.
190 changes: 94 additions & 96 deletions composeml/label_maker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sys import stdout

from pandas import Series
from tqdm import tqdm

from composeml.data_slice import DataSliceGenerator
Expand Down Expand Up @@ -62,13 +63,22 @@ def labeling_function(self, value):
assert isinstance(value, dict), 'value type for labeling function not supported'
self._labeling_function = value

def _check_cutoff_time(self, value):
if isinstance(value, Series):
if value.index.is_unique: return value.to_dict()
else: raise ValueError('more than one cutoff time exists for a target group')
else: return value

def slice(self, df, num_examples_per_instance, minimum_data=None, maximum_data=None, gap=None, drop_empty=True):
"""Generates data slices of target entity.
Args:
df (DataFrame): Data frame to create slices on.
num_examples_per_instance (int): Number of examples per unique instance of target entity.
minimum_data (str): Minimum data before starting the search. Default value is first time of index.
minimum_data (int or str or Series): The amount of data needed before starting the search. Defaults to the first value in the time index.
The value can be a datetime string to directly set the first cutoff time or a timedelta string to denote the amount of data needed before
the first cutoff time. The value can also be an integer to denote the number of rows needed before the first cutoff time.
If a Series, minimum_data should be datetime string, timedelta string, or integer values with a unique set of target groups as the corresponding index.
maximum_data (str): Maximum data before stopping the search. Default value is last time of index.
gap (str or int): Time between examples. Default value is window size.
If an integer, search will start on the first event after the minimum data.
Expand All @@ -79,24 +89,32 @@ def slice(self, df, num_examples_per_instance, minimum_data=None, maximum_data=N
"""
self._check_example_count(num_examples_per_instance, gap)
df = self.set_index(df)
entity_groups = df.groupby(self.target_entity)
target_groups = df.groupby(self.target_entity)
num_examples_per_instance = ExampleSearch._check_number(num_examples_per_instance)

generator = DataSliceGenerator(
window_size=self.window_size,
min_data=minimum_data,
max_data=maximum_data,
drop_empty=drop_empty,
gap=gap,
)
minimum_data = self._check_cutoff_time(minimum_data)
minimum_data_varies = isinstance(minimum_data, dict)

for group_key, df in target_groups:
if minimum_data_varies:
if group_key not in minimum_data: continue
min_data_for_group = minimum_data[group_key]
else:
min_data_for_group = minimum_data

generator = DataSliceGenerator(
window_size=self.window_size,
min_data=min_data_for_group,
max_data=maximum_data,
drop_empty=drop_empty,
gap=gap,
)

for entity_id, df in entity_groups:
for ds in generator(df):
setattr(ds.context, self.target_entity, entity_id)
setattr(ds.context, self.target_entity, group_key)
yield ds

if ds.context.slice_number >= num_examples_per_instance:
break
if ds.context.slice_number >= num_examples_per_instance: break

@property
def _bar_format(self):
Expand All @@ -107,72 +125,6 @@ def _bar_format(self):
value += self.target_entity + ": {n}/{total} "
return value

def _run_search(
self,
df,
generator,
search,
verbose=True,
*args,
**kwargs,
):
"""Search implementation to make label records.
Args:
df (DataFrame): Data frame to search and extract labels.
generator (DataSliceGenerator): The generator for data slices.
search (LabelSearch or ExampleSearch): The type of search to be done.
verbose (bool): Whether to render progress bar. Default value is True.
*args: Positional arguments for labeling function.
**kwargs: Keyword arguments for labeling function.
Returns:
records (list(dict)): Label Records
"""
df = self.set_index(df)
entity_groups = df.groupby(self.target_entity)
multiplier = search.expected_count if search.is_finite else 1
total = entity_groups.ngroups * multiplier

progress_bar, records = tqdm(
total=total,
bar_format=self._bar_format,
disable=not verbose,
file=stdout,
), []

def missing_examples(entity_count):
return entity_count * search.expected_count - progress_bar.n

for entity_count, (entity_id, df) in enumerate(entity_groups):
for ds in generator(df):
items = self.labeling_function.items()
labels = {name: lf(ds, *args, **kwargs) for name, lf in items}
valid_labels = search.is_valid_labels(labels)
if not valid_labels: continue

records.append({
self.target_entity: entity_id,
'time': ds.context.slice_start,
**labels,
})

search.update_count(labels)
# if finite search, progress bar is updated for each example found
if search.is_finite: progress_bar.update(n=1)
if search.is_complete: break

# if finite search, progress bar is updated for examples not found
# otherwise, progress bar is updated for each entity group
n = missing_examples(entity_count + 1) if search.is_finite else 1
progress_bar.update(n=n)
search.reset_count()

total -= progress_bar.n
progress_bar.update(n=total)
progress_bar.close()
return records

def _check_example_count(self, num_examples_per_instance, gap):
"""Checks whether example count corresponds to data slices."""
if self.window_size is None and gap is None:
Expand All @@ -195,8 +147,11 @@ def search(self,
df (DataFrame): Data frame to search and extract labels.
num_examples_per_instance (int or dict): The expected number of examples to return from each entity group.
A dictionary can be used to further specify the expected number of examples to return from each label.
minimum_data (str): Minimum data before starting the search. Default value is first time of index.
maximum_data (str): Maximum data before stopping the search. Default value is last time of index.
minimum_data (int or str or Series): The amount of data needed before starting the search. Defaults to the first value in the time index.
The value can be a datetime string to directly set the first cutoff time or a timedelta string to denote the amount of data needed before
the first cutoff time. The value can also be an integer to denote the number of rows needed before the first cutoff time.
If a Series, minimum_data should be datetime string, timedelta string, or integer values with a unique set of target groups as the corresponding index.
maximum_data (str): Maximum data before stopping the search. Defaults to the last value in the time index.
gap (str or int): Time between examples. Default value is window size.
If an integer, search will start on the first event after the minimum data.
drop_empty (bool): Whether to drop empty slices. Default value is True.
Expand All @@ -212,30 +167,73 @@ def search(self,
is_label_search = isinstance(num_examples_per_instance, dict)
search = (LabelSearch if is_label_search else ExampleSearch)(num_examples_per_instance)

generator = DataSliceGenerator(
window_size=self.window_size,
min_data=minimum_data,
max_data=maximum_data,
drop_empty=drop_empty,
gap=gap,
)
# check minimum data cutoff time
minimum_data = self._check_cutoff_time(minimum_data)
minimum_data_varies = isinstance(minimum_data, dict)

df = self.set_index(df)
total = search.expected_count if search.is_finite else 1
target_groups = df.groupby(self.target_entity)
total *= target_groups.ngroups

records = self._run_search(
df=df,
generator=generator,
search=search,
verbose=verbose,
*args,
**kwargs,
progress_bar = tqdm(
total=total,
file=stdout,
disable=not verbose,
bar_format=self._bar_format,
)

records = []
for group_count, (group_key, df) in enumerate(target_groups, start=1):
if minimum_data_varies:
if group_key not in minimum_data: continue
min_data_for_group = minimum_data[group_key]
else:
min_data_for_group = minimum_data

generator = DataSliceGenerator(
window_size=self.window_size,
min_data=min_data_for_group,
max_data=maximum_data,
drop_empty=drop_empty,
gap=gap,
)

for ds in generator(df):
setattr(ds.context, self.target_entity, group_key)

items = self.labeling_function.items()
labels = {name: lf(ds, *args, **kwargs) for name, lf in items}
valid_labels = search.is_valid_labels(labels)
if not valid_labels: continue

records.append({
self.target_entity: group_key,
'time': ds.context.slice_start,
**labels,
})

search.update_count(labels)
# if finite search, update progress bar for the example found
if search.is_finite: progress_bar.update(n=1)
if search.is_complete: break

# if finite search, update progress bar for missing examples
if search.is_finite: progress_bar.update(n=group_count * search.expected_count - progress_bar.n)
else: progress_bar.update(n=1) # otherwise, update progress bar once for each group
search.reset_count()

total -= progress_bar.n
progress_bar.update(n=total)
progress_bar.close()

lt = LabelTimes(
data=records,
target_columns=list(self.labeling_function),
target_entity=self.target_entity,
search_settings={
'num_examples_per_instance': num_examples_per_instance,
'minimum_data': str(minimum_data),
'minimum_data': minimum_data,
'maximum_data': str(maximum_data),
'window_size': str(self.window_size),
'gap': str(gap),
Expand Down
7 changes: 7 additions & 0 deletions composeml/tests/test_data_slice/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,10 @@ def test_time_index_error(transactions):
match = 'offset by frequency requires a time index'
with raises(AssertionError, match=match):
transactions.slice[::'1h']


def test_minimum_data_per_group(transactions):
lm = LabelMaker('customer_id', labeling_function=len, time_index='time', window_size='1h')
minimum_data = {1: '2019-01-01 09:00:00', 3: '2019-01-01 12:00:00'}
lengths = [len(ds) for ds in lm.slice(transactions, 1, minimum_data=minimum_data)]
assert lengths == [2, 1]
26 changes: 26 additions & 0 deletions composeml/tests/test_label_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,3 +546,29 @@ def test_search_with_maximum_data(transactions):

actual = lt.pipe(to_csv, index=False)
assert actual == expected


@pytest.mark.parametrize("minimum_data", [{1: '2019-01-01 09:30:00', 2: '2019-01-01 11:30:00'}, {1: '30min', 2: '1h'}, {1: 1, 2: 2}])
def test_minimum_data_per_group(transactions, minimum_data):
lm = LabelMaker('customer_id', labeling_function=len, time_index='time', window_size='1h')
for supported_type in [minimum_data, pd.Series(minimum_data)]:
lt = lm.search(transactions, 1, minimum_data=supported_type)
actual = to_csv(lt, index=False)

expected = [
'customer_id,time,len',
'1,2019-01-01 09:30:00,2',
'2,2019-01-01 11:30:00,2'
]

assert actual == expected


def test_minimum_data_per_group_error(transactions):
lm = LabelMaker('customer_id', labeling_function=len, time_index='time', window_size='1h')
data = ['2019-01-01 09:00:00', '2019-01-01 12:00:00']
minimum_data = pd.Series(data=data, index=[1, 1])
match = "more than one cutoff time exists for a target group"

with pytest.raises(ValueError, match=match):
lm.search(transactions, 1, minimum_data=minimum_data)
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ jupyter==1.0.0
nbsphinx==0.8.6
pydata-sphinx-theme==0.6.3
evalml==0.28.0
scikit-learn>=0.24.0,<1.0
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Release Notes
* Enhancements
* Add ``maximum_data`` parameter to control when a search should stop (:pr:`216`)
* Add optional automatic update checker (:pr:`223`, :pr:`229`, :pr:`232`)
* Varying first cutoff time for each target group (:pr:`258`)
* Fixes
* Documentation Changes
* Update doc tutorials to the latest API changes (:pr:`227`)
Expand Down
1 change: 1 addition & 0 deletions docs/source/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Use these guides to learn how to use label transformations and generate better t
:glob:
:maxdepth: 1

user_guide/controlling_cutoff_times
user_guide/using_label_transforms
user_guide/data_slice_generator

Expand Down
Loading

0 comments on commit 723c581

Please sign in to comment.