Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test calculating features with instances that aren't in data #559

Merged
merged 3 commits into from May 22, 2019
Merged
Changes from 1 commit
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

@@ -1170,3 +1170,66 @@ def test_no_data_for_cutoff_time():
# due to default values for each primitive
# count will be 0, but max will nan
np.testing.assert_array_equal(fm.values, [[0, np.nan]])


def test_instances_not_in_data(es):
instances = list(range(20, 30))

This comment has been minimized.

Copy link
@CJStadler

CJStadler May 22, 2019

Contributor

To ensure that these are not in the data (and make it explicit to the reader) you could do something like take the max of the actual ids and then pick 10 higher numbers.

identity_feature = IdentityFeature(es['log']['value'])
property_feature = identity_feature > 10
agg_feat = AggregationFeature(es['log']['value'],
parent_entity=es["sessions"],
primitive=Max)
direct_feature = DirectFeature(agg_feat, es["log"])
features = [identity_feature, property_feature, direct_feature]
fm = calculate_feature_matrix(features, entityset=es, instance_ids=instances)
assert all(fm.index.values == instances)
for column in fm.columns:
assert fm[column].isnull().all()

fm = calculate_feature_matrix(features,
entityset=es,
instance_ids=instances,
approximate="2 years")
assert all(fm.index.values == instances)
for column in fm.columns:
assert fm[column].isnull().all()


def test_some_instances_not_in_data(es):
a_time = datetime(2011, 4, 10, 10, 41, 9) # only valid data
b_time = datetime(2011, 4, 10, 11, 10, 5) # some missing data
c_time = datetime(2011, 4, 10, 12, 0, 0) # all missing data
times = [a_time, b_time, a_time, a_time, b_time, b_time] + [c_time] * 4
cutoff_time = pd.DataFrame({"instance_id": list(range(12, 22)),
"time": times})
identity_feature = IdentityFeature(es['log']['value'])
property_feature = identity_feature > 10
agg_feat = AggregationFeature(es['log']['value'],
parent_entity=es["sessions"],
primitive=Max)
direct_feature = DirectFeature(agg_feat, es["log"])
features = [identity_feature, property_feature, direct_feature]
fm = calculate_feature_matrix(features,
entityset=es,
cutoff_time=cutoff_time)

index_answer = [12, 14, 15, 13, 16, 17, 18, 19, 20, 21]
ifeat_answer = [0, 14, np.nan, 7] + [np.nan] * 6
prop_answer = [0, 1, np.nan, 0, 0] + [np.nan] * 5
dfeat_answer = [14, 14, np.nan, 14] + [np.nan] * 6

assert all(fm.index.values == index_answer)
for x, y in zip(fm.columns, [ifeat_answer, prop_answer, dfeat_answer]):
np.testing.assert_array_equal(fm[x], y)

fm = calculate_feature_matrix(features,
entityset=es,
cutoff_time=cutoff_time,
approximate="5 seconds")

dfeat_answer[0:2] = [7, 7] # approximate calculated before 14 appears
prop_answer[2] = 0 # no_unapproximated_aggs code ignores cutoff time

assert all(fm.index.values == index_answer)
for x, y in zip(fm.columns, [ifeat_answer, prop_answer, dfeat_answer]):
np.testing.assert_array_equal(fm[x], y)
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.