Skip to content

Commit

Permalink
Equal and NotEqual fix (#968)
Browse files Browse the repository at this point in the history
* switch from numpy to pandas to avoid error

* dask compatibility

* changelog

* lint

* fix tests
  • Loading branch information
frances-h committed May 21, 2020
1 parent d8843c0 commit 4f270d5
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Changelog
* Add ``get_default_aggregation_primitives`` and ``get_default_transform_primitives`` (:pr:`945`)
* Allow cutoff time dataframe columns to be in any order (:pr:`969`)
* Fixes
* Fix errors with Equals and NotEquals primitives when comparing categoricals or different dtypes (:pr:`968`)
* Normalized type_strings of ``Variable`` classes so that the ``find_variable_types`` function produces a
dictionary with a clear key to name transition (:pr:`982`)
* Changes
Expand All @@ -17,7 +18,7 @@ Changelog
* Update testing dependencies (:pr:`976`)

Thanks to the following people for contributing to this release:
:user:`gsheni`, :user:`rwedge`, :user:`thehomebrewnerd`, :user:`sebrahimi1988`, :user:`tuethan1999`
:user:`gsheni`, :user:`rwedge`, :user:`thehomebrewnerd`, :user:`sebrahimi1988`, :user:`tuethan1999`, :user:`frances-h`

**Breaking Changes**

Expand Down
20 changes: 18 additions & 2 deletions featuretools/primitives/standard/binary_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,15 @@ class Equal(TransformPrimitive):
commutative = True

def get_function(self):
return np.equal
def equal(x_vals, y_vals):
if isinstance(x_vals.dtype, pd.CategoricalDtype) and \
isinstance(y_vals.dtype, pd.CategoricalDtype):
categories = set(x_vals.cat.categories).union(set(y_vals.cat.categories))
x_vals = x_vals.cat.add_categories(categories.difference(set(x_vals.cat.categories)))
y_vals = y_vals.cat.add_categories(categories.difference(set(y_vals.cat.categories)))
return x_vals.eq(y_vals)

return equal

def generate_name(self, base_feature_names):
return "%s = %s" % (base_feature_names[0], base_feature_names[1])
Expand Down Expand Up @@ -301,7 +309,15 @@ class NotEqual(TransformPrimitive):
commutative = True

def get_function(self):
return np.not_equal
def not_equal(x_vals, y_vals):
if isinstance(x_vals.dtype, pd.CategoricalDtype) and \
isinstance(y_vals.dtype, pd.CategoricalDtype):
categories = set(x_vals.cat.categories).union(set(y_vals.cat.categories))
x_vals = x_vals.cat.add_categories(categories.difference(set(x_vals.cat.categories)))
y_vals = y_vals.cat.add_categories(categories.difference(set(y_vals.cat.categories)))
return x_vals.ne(y_vals)

return not_equal

def generate_name(self, base_feature_names):
return "%s != %s" % (base_feature_names[0], base_feature_names[1])
Expand Down
67 changes: 67 additions & 0 deletions featuretools/tests/primitive_tests/test_transform_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,73 @@ def test_make_trans_feat(es):
assert v == 10


@pytest.fixture
def simple_es():
df = pd.DataFrame({
'id': range(4),
'value': pd.Categorical(['a', 'c', 'b', 'd']),
'value2': pd.Categorical(['a', 'b', 'a', 'd']),
'object': ['time1', 'time2', 'time3', 'time4'],
'datetime': pd.Series([pd.Timestamp('2001-01-01'),
pd.Timestamp('2001-01-02'),
pd.Timestamp('2001-01-03'),
pd.Timestamp('2001-01-04')])
})

es = ft.EntitySet('equal_test')
es.entity_from_dataframe('values', df, index='id')

return es


def test_equal_categorical(simple_es):
f1 = ft.Feature([simple_es['values']['value'], simple_es['values']['value2']],
primitive=Equal)

df = ft.calculate_feature_matrix(entityset=simple_es, features=[f1])

assert set(simple_es['values'].df['value'].cat.categories) != \
set(simple_es['values'].df['value2'].cat.categories)
assert df['value = value2'].to_list() == [True, False, False, True]


def test_equal_different_dtypes(simple_es):
f1 = ft.Feature([simple_es['values']['object'], simple_es['values']['datetime']],
primitive=Equal)
f2 = ft.Feature([simple_es['values']['datetime'], simple_es['values']['object']],
primitive=Equal)

# verify that equals works for different dtypes regardless of order
df = ft.calculate_feature_matrix(entityset=simple_es, features=[f1, f2])

assert df['object = datetime'].to_list() == [False, False, False, False]
assert df['datetime = object'].to_list() == [False, False, False, False]


def test_not_equal_categorical(simple_es):
f1 = ft.Feature([simple_es['values']['value'], simple_es['values']['value2']],
primitive=NotEqual)

df = ft.calculate_feature_matrix(entityset=simple_es, features=[f1])

assert set(simple_es['values'].df['value'].cat.categories) != \
set(simple_es['values'].df['value2'].cat.categories)
assert df['value != value2'].to_list() == [False, True, True, False]


def test_not_equal_different_dtypes(simple_es):
f1 = ft.Feature([simple_es['values']['object'], simple_es['values']['datetime']],
primitive=NotEqual)
f2 = ft.Feature([simple_es['values']['datetime'], simple_es['values']['object']],
primitive=NotEqual)

# verify that equals works for different dtypes regardless of order
df = ft.calculate_feature_matrix(entityset=simple_es, features=[f1, f2])
print(df)
assert df['object != datetime'].to_list() == [True, True, True, True]
assert df['datetime != object'].to_list() == [True, True, True, True]


def test_diff(es):
value = ft.Feature(es['log']['value'])
customer_id_feat = ft.Feature(es['sessions']['customer_id'], entity=es['log'])
Expand Down

0 comments on commit 4f270d5

Please sign in to comment.