Skip to content

Commit

Permalink
Column Based Windows (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeff-hernandez committed Jul 21, 2020
1 parent 26cf588 commit c2f17d1
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 4 deletions.
1 change: 1 addition & 0 deletions composeml/data_slice/extension.py
Expand Up @@ -5,6 +5,7 @@

class DataSliceContext:
"""Tracks contextual attributes about a data slice."""

def __init__(self, slice_number=0, slice_start=None, slice_stop=None, next_start=None):
"""Creates the data slice context.
Expand Down
29 changes: 25 additions & 4 deletions composeml/data_slice/generator.py
@@ -1,8 +1,9 @@
import composeml.data_slice.extension # noqa
from composeml.data_slice.extension import DataSliceContext, DataSliceFrame


class DataSliceGenerator:
"""Generates data slices for the label maker."""
"""Generates data slices for the lable maker."""

def __init__(self, window_size, gap=None, min_data=None, drop_empty=True):
self.window_size = window_size
self.gap = gap
Expand All @@ -11,10 +12,30 @@ def __init__(self, window_size, gap=None, min_data=None, drop_empty=True):

def __call__(self, df):
"""Applies the data slice generator to the data frame."""
return self._slice_by_time(df)
is_column = self.window_size in df
method = 'column' if is_column else 'time'
attr = f'_slice_by_{method}'
return getattr(self, attr)(df)

def _slice_by_column(self, df):
"""Slices the data frame by an existing column."""
slices = df.groupby(self.window_size, sort=False)
slice_number = 1

for group, ds in slices:
ds = DataSliceFrame(ds)
ds.context = DataSliceContext(
slice_number=slice_number,
slice_start=ds.first_valid_index(),
slice_stop=ds.last_valid_index(),
)
setattr(ds.context, self.window_size, group)
del ds.context.next_start
slice_number += 1
yield ds

def _slice_by_time(self, df):
"""Slices data along the time index."""
"""Slices the data frame along the time index."""
data_slices = df.slice(
size=self.window_size,
start=self.min_data,
Expand Down
1 change: 1 addition & 0 deletions composeml/data_slice/offset.py
Expand Up @@ -5,6 +5,7 @@

class DataSliceOffset:
"""Offsets for calculating data slice indices."""

def __init__(self, value):
self.value = value
self._check()
Expand Down
1 change: 1 addition & 0 deletions composeml/label_times/object.py
Expand Up @@ -12,6 +12,7 @@

class LabelTimes(pd.DataFrame):
"""The data frame that contains labels and cutoff times for the target entity."""

def __init__(
self,
data=None,
Expand Down
27 changes: 27 additions & 0 deletions composeml/tests/test_label_maker.py
Expand Up @@ -409,6 +409,33 @@ def test_search_invalid_n_examples(transactions, total_spent_fn):
lm.search(transactions, num_examples_per_instance=2)


def test_column_based_windows(transactions, total_spent_fn):
session_id = [1, 2, 3, 3, 4, 5, 5, 5, 6, 7]
df = transactions.assign(session_id=session_id)

lm = LabelMaker(
target_entity='customer_id',
time_index='time',
window_size='session_id',
labeling_function=total_spent_fn,
)

actual = lm.search(df, -1).pipe(to_csv, index=False)

expected = [
'customer_id,time,total_spent',
'0,2019-01-01 08:00:00,1',
'0,2019-01-01 08:30:00,1',
'1,2019-01-01 09:00:00,2',
'1,2019-01-01 10:00:00,1',
'2,2019-01-01 10:30:00,3',
'2,2019-01-01 12:00:00,1',
'3,2019-01-01 12:30:00,1',
]

assert actual == expected


def test_search_with_invalid_index(transactions, total_spent_fn):
lm = LabelMaker(
target_entity='customer_id',
Expand Down

0 comments on commit c2f17d1

Please sign in to comment.