Skip to content

Commit

Permalink
Update DFS primitive matching to use ColumnSchema (#1523)
Browse files Browse the repository at this point in the history
* update dfs primitive matching

* work on test_deep_feature_synthesis tests

* work on dfs tests

* fix more tests

* exclude foriegn key cols from transform feats

* fix Trend

* more test updates

* more test work

* remove files

* fix dfs to match old features

* lint fix

* remove old print statement

* more naming updates

* update handling of foreign key columns

* lots of naming updates

* fix test names

* even more naming updates

* more cleanup and test fixes

* rename return_variable_types to return_types

* fix broken entityset tests

* pr naming updates

* remove unnecssary primitive

* add new _schemas_equal conditions

* lint fix
  • Loading branch information
thehomebrewnerd committed Jul 14, 2021
1 parent 864a9aa commit af195aa
Show file tree
Hide file tree
Showing 31 changed files with 909 additions and 781 deletions.
6 changes: 3 additions & 3 deletions docs/source/api_reference.rst
Expand Up @@ -282,7 +282,7 @@ EntitySet load and prepare data

EntitySet.add_dataframe
EntitySet.add_relationship
EntitySet.normalize_entity
EntitySet.normalize_dataframe
EntitySet.add_interesting_values
EntitySet.set_secondary_time_index

Expand Down Expand Up @@ -310,8 +310,8 @@ EntitySet query methods
EntitySet.__getitem__
EntitySet.find_backward_paths
EntitySet.find_forward_paths
EntitySet.get_forward_entities
EntitySet.get_backward_entities
EntitySet.get_forward_dataframes
EntitySet.get_backward_dataframes

EntitySet visualization
-----------------------
Expand Down
8 changes: 4 additions & 4 deletions featuretools/computational_backends/utils.py
Expand Up @@ -260,10 +260,10 @@ def _check_cutoff_time_type(cutoff_time, es_time_type):
is_datetime = time_type == 'datetime_time_index'
else:
raise NotImplementedError()
cutoff_time_dtype = cutoff_time['time'].dtype.name
# TODO: refactor for woodwork columns, maybe use ww is_datetime and is_numeric?
is_numeric = cutoff_time_dtype in PandasTypes._pandas_numerics
is_datetime = cutoff_time_dtype in PandasTypes._pandas_datetimes
# cutoff_time_dtype = cutoff_time['time'].dtype.name
# # TODO: refactor for woodwork columns, maybe use ww is_datetime and is_numeric?
# is_numeric = cutoff_time_dtype in PandasTypes._pandas_numerics
# is_datetime = cutoff_time_dtype in PandasTypes._pandas_datetimes

if es_time_type == "numeric_time_index" and not is_numeric:
raise TypeError("cutoff_time times must be numeric: try casting "
Expand Down
4 changes: 2 additions & 2 deletions featuretools/entityset/relationship.py
Expand Up @@ -125,7 +125,7 @@ def name(self):

return '.'.join(relationship_names)

def entities(self):
def dataframes(self):
if self:
# Yield first dataframe.
is_forward, relationship = self[0]
Expand Down Expand Up @@ -164,7 +164,7 @@ def __ne__(self, other):

def __repr__(self):
if self._relationships_with_direction:
path = '%s.%s' % (next(self.entities()), self.name)
path = '%s.%s' % (next(self.dataframes()), self.name)
else:
path = '[]'
return '<RelationshipPath %s>' % path
Expand Down
6 changes: 6 additions & 0 deletions featuretools/feature_base/feature_base.py
Expand Up @@ -184,6 +184,12 @@ def column_schema(self): # typing_info ? typing_schema? schema?
elif 'index' in column_schema.semantic_tags:
column_schema = ColumnSchema(logical_type=column_schema.logical_type,
semantic_tags=column_schema.semantic_tags - {"index"})
# Need to add back in the numeric standard tag so the schema can get recognized
# as a valid return type
if column_schema.is_numeric:
column_schema.semantic_tags.add('numeric')
if column_schema.is_categorical:
column_schema.semantic_tags.add('category')

# direct features should keep the Id return type, but all other features should get
# converted to Categorical
Expand Down
2 changes: 1 addition & 1 deletion featuretools/primitives/base/transform_primitive_base.py
Expand Up @@ -11,7 +11,7 @@ class TransformPrimitive(PrimitiveBase):
in that entity."""
# (bool) If True, feature function depends on all values of entity
# (and will receive these values as input, regardless of specified instance ids)
uses_full_entity = False
uses_full_dataframe = False

def generate_name(self, base_feature_names):
return u"%s(%s%s)" % (
Expand Down
159 changes: 80 additions & 79 deletions featuretools/primitives/options_utils.py
Expand Up @@ -10,50 +10,51 @@

def _get_primitive_options():
# all possible option keys: function that verifies value type
return {'ignore_entities': list_entity_check,
'include_entities': list_entity_check,
'ignore_variables': dict_to_list_variable_check,
'include_variables': dict_to_list_variable_check,
'ignore_groupby_entities': list_entity_check,
'include_groupby_entities': list_entity_check,
'ignore_groupby_variables': dict_to_list_variable_check,
'include_groupby_variables': dict_to_list_variable_check}
return {'ignore_dataframes': list_dataframe_check,
'include_dataframes': list_dataframe_check,
'ignore_columns': dict_to_list_column_check,
'include_columns': dict_to_list_column_check,
'ignore_groupby_dataframes': list_dataframe_check,
'include_groupby_dataframes': list_dataframe_check,
'ignore_groupby_columns': dict_to_list_column_check,
'include_groupby_columns': dict_to_list_column_check}


def dict_to_list_variable_check(option, es):
def dict_to_list_column_check(option, es):
if not (isinstance(option, dict) and
all([isinstance(option_val, list) for option_val in option.values()])):
return False
else:
for entity, variables in option.items():
if entity not in es:
warnings.warn("Entity '%s' not in entityset" % (entity))
for dataframe, columns in option.items():
if dataframe not in es:
warnings.warn("Dataframe '%s' not in entityset" % (dataframe))
else:
for invalid_var in [variable for variable in variables
if variable not in es[entity]]:
warnings.warn("Variable '%s' not in entity '%s'" % (invalid_var, entity))
for invalid_col in [column for column in columns
if column not in es[dataframe]]:
warnings.warn("Column '%s' not in dataframe '%s'" % (invalid_col, dataframe))
return True


def list_entity_check(option, es):
def list_dataframe_check(option, es):
if not isinstance(option, list):
return False
else:
for invalid_entity in [entity for entity in option if entity not in es]:
warnings.warn("Entity '%s' not in entityset" % (invalid_entity))
for invalid_dataframe in [dataframe for dataframe in option if dataframe not in es]:
warnings.warn("Dataframe '%s' not in entityset" % (invalid_dataframe))
return True


def generate_all_primitive_options(all_primitives,
primitive_options,
ignore_entities,
ignore_variables,
ignore_dataframes,
ignore_columns,
es):
entityset_dict = {entity.id: [variable.id for variable in entity.variables]
for entity in es.entities}
primitive_options = _init_primitive_options(primitive_options, entityset_dict)
global_ignore_entities = ignore_entities
global_ignore_variables = ignore_variables.copy()
dataframe_dict = {dataframe.ww.name: [col for col in dataframe.columns]
for dataframe in es.dataframes}

primitive_options = _init_primitive_options(primitive_options, dataframe_dict)
global_ignore_dataframes = ignore_dataframes
global_ignore_columns = ignore_columns.copy()
# for now, only use primitive names as option keys
for primitive in all_primitives:
if primitive in primitive_options and primitive.name in primitive_options:
Expand All @@ -64,36 +65,36 @@ def generate_all_primitive_options(all_primitives,
if primitive in primitive_options or primitive.name in primitive_options:
options = primitive_options.get(primitive, primitive_options.get(primitive.name))
# Reconcile global options with individually-specified options
included_entities = set().union(*[
option.get('include_entities', set()).union(
option.get('include_variables', {}).keys())
included_dataframes = set().union(*[
option.get('include_dataframes', set()).union(
option.get('include_columns', {}).keys())
for option in options])
global_ignore_entities = global_ignore_entities.difference(included_entities)
global_ignore_dataframes = global_ignore_dataframes.difference(included_dataframes)
for option in options:
# don't globally ignore a variable if it's included for a primitive
if 'include_variables' in option:
for entity, include_vars in option['include_variables'].items():
global_ignore_variables[entity] = \
global_ignore_variables[entity].difference(include_vars)
option['ignore_entities'] = option['ignore_entities'].union(
ignore_entities.difference(included_entities)
# don't globally ignore a column if it's included for a primitive
if 'include_columns' in option:
for dataframe, include_cols in option['include_columns'].items():
global_ignore_columns[dataframe] = \
global_ignore_columns[dataframe].difference(include_cols)
option['ignore_dataframes'] = option['ignore_dataframes'].union(
ignore_dataframes.difference(included_dataframes)
)
for entity, ignore_vars in ignore_variables.items():
# if already ignoring variables for this entity, add globals
for dataframe, ignore_cols in ignore_columns.items():
# if already ignoring columns for this dataframe, add globals
for option in options:
if entity in option['ignore_variables']:
option['ignore_variables'][entity] = option['ignore_variables'][entity].union(ignore_vars)
# if no ignore_variables and entity is explicitly included, don't ignore the variable
elif entity in included_entities:
if dataframe in option['ignore_columns']:
option['ignore_columns'][dataframe] = option['ignore_columns'][dataframe].union(ignore_cols)
# if no ignore_columns and dataframe is explicitly included, don't ignore the column
elif dataframe in included_dataframes:
continue
# Otherwise, keep the global option
else:
option['ignore_variables'][entity] = ignore_vars
option['ignore_columns'][dataframe] = ignore_cols
else:
# no user specified options, just use global defaults
primitive_options[primitive] = [{'ignore_entities': ignore_entities,
'ignore_variables': ignore_variables}]
return primitive_options, global_ignore_entities, global_ignore_variables
primitive_options[primitive] = [{'ignore_dataframes': ignore_dataframes,
'ignore_columns': ignore_columns}]
return primitive_options, global_ignore_dataframes, global_ignore_columns


def _init_primitive_options(primitive_options, es):
Expand Down Expand Up @@ -150,60 +151,60 @@ def _init_option_dict(key, option_dict, es):
initialized_option_dict[option_key] = set(option)
elif isinstance(option, dict):
initialized_option_dict[option_key] = {key: set(option[key]) for key in option}
# initialize ignore_entities and ignore_variables to empty sets if not present
if 'ignore_variables' not in initialized_option_dict:
initialized_option_dict['ignore_variables'] = dict()
if 'ignore_entities' not in initialized_option_dict:
initialized_option_dict['ignore_entities'] = set()
# initialize ignore_dataframes and ignore_columns to empty sets if not present
if 'ignore_columns' not in initialized_option_dict:
initialized_option_dict['ignore_columns'] = dict()
if 'ignore_dataframes' not in initialized_option_dict:
initialized_option_dict['ignore_dataframes'] = set()
return initialized_option_dict


def variable_filter(f, options, groupby=False):
if groupby and 'category' not in f.column_schema.semantic_tags:
def column_filter(f, options, groupby=False):
if groupby and not f.column_schema.semantic_tags.intersection({'category', 'foreign_key'}):
return False
include_vars = 'include_groupby_variables' if groupby else 'include_variables'
ignore_vars = 'ignore_groupby_variables' if groupby else 'ignore_variables'
include_entities = 'include_groupby_entities' if groupby else 'include_entities'
ignore_entities = 'ignore_groupby_entities' if groupby else 'ignore_entities'
include_cols = 'include_groupby_columns' if groupby else 'include_columns'
ignore_cols = 'ignore_groupby_columns' if groupby else 'ignore_columns'
include_dataframes = 'include_groupby_dataframes' if groupby else 'include_dataframes'
ignore_dataframes = 'ignore_groupby_dataframes' if groupby else 'ignore_dataframes'

dependencies = f.get_dependencies(deep=True) + [f]
for base_f in dependencies:
if isinstance(base_f, IdentityFeature):
if include_vars in options and base_f.entity.id in options[include_vars]:
if base_f.get_name() in options[include_vars][base_f.entity.id]:
if include_cols in options and base_f.dataframe_name in options[include_cols]:
if base_f.get_name() in options[include_cols][base_f.dataframe_name]:
continue # this is a valid feature, go to next
else:
return False # this is not an included feature
if ignore_vars in options and base_f.entity.id in options[ignore_vars]:
if base_f.get_name() in options[ignore_vars][base_f.entity.id]:
if ignore_cols in options and base_f.dataframe_name in options[ignore_cols]:
if base_f.get_name() in options[ignore_cols][base_f.dataframe_name]:
return False # ignore this feature
if include_entities in options and \
base_f.entity.id not in options[include_entities]:
return False # not an included entity
elif ignore_entities in options and \
base_f.entity.id in options[ignore_entities]:
return False # ignore the entity
if include_dataframes in options and \
base_f.dataframe_name not in options[include_dataframes]:
return False # not an included dataframe
elif ignore_dataframes in options and \
base_f.dataframe_name in options[ignore_dataframes]:
return False # ignore the dataframe
return True


def ignore_entity_for_primitive(options, entity, groupby=False):
# This logic handles whether given options ignore an entity or not
def should_ignore_entity(option):
def ignore_dataframe_for_primitive(options, dataframe, groupby=False):
# This logic handles whether given options ignore an dataframe or not
def should_ignore_dataframe(option):
if groupby:
if 'include_groupby_variables' not in option or entity.id not in option['include_groupby_variables']:
if 'include_groupby_entities' in option and entity.id not in option['include_groupby_entities']:
if 'include_groupby_columns' not in option or dataframe.ww.name not in option['include_groupby_columns']:
if 'include_groupby_dataframes' in option and dataframe.ww.name not in option['include_groupby_dataframes']:
return True
elif 'ignore_groupby_entities' in option and entity.id in option['ignore_groupby_entities']:
elif 'ignore_groupby_dataframes' in option and dataframe.ww.name in option['ignore_groupby_dataframes']:
return True
if 'include_variables' in option and entity.id in option['include_variables']:
if 'include_columns' in option and dataframe.ww.name in option['include_columns']:
return False
elif 'include_entities' in option and entity.id not in option['include_entities']:
elif 'include_dataframes' in option and dataframe.ww.name not in option['include_dataframes']:
return True
elif entity.id in option['ignore_entities']:
elif dataframe.ww.name in option['ignore_dataframes']:
return True
else:
return False
return any([should_ignore_entity(option) for option in options])
return any([should_ignore_dataframe(option) for option in options])


def filter_groupby_matches_by_options(groupby_matches, options):
Expand All @@ -216,13 +217,13 @@ def filter_matches_by_options(matches, options, groupby=False, commutative=False
# If more than one option, than need to handle each for each input
if len(options) > 1:
def is_valid_match(match):
if all([variable_filter(m, option, groupby) for m, option in zip(match, options)]):
if all([column_filter(m, option, groupby) for m, option in zip(match, options)]):
return True
else:
return False
else:
def is_valid_match(match):
if all([variable_filter(f, options[0], groupby) for f in match]):
if all([column_filter(f, options[0], groupby) for f in match]):
return True
else:
return False
Expand Down
2 changes: 1 addition & 1 deletion featuretools/primitives/standard/aggregation_primitives.py
Expand Up @@ -273,7 +273,7 @@ class PercentTrue(AggregationPrimitive):
0.6
"""
name = "percent_true"
input_types = [ColumnSchema(logical_type=Boolean), ColumnSchema(logical_type=BooleanNullable)]
input_types = [[ColumnSchema(logical_type=Boolean)], [ColumnSchema(logical_type=BooleanNullable)]]
return_type = ColumnSchema(semantic_tags={'numeric'})
stack_on = []
stack_on_exclude = []
Expand Down
15 changes: 7 additions & 8 deletions featuretools/primitives/standard/cum_transform_feature.py
Expand Up @@ -23,7 +23,7 @@ class CumSum(TransformPrimitive):
name = "cum_sum"
input_types = [ColumnSchema(semantic_tags={'numeric'})]
return_type = ColumnSchema(semantic_tags={'numeric'})
uses_full_entity = True
uses_full_dataframe = True
description_template = "the cumulative sum of {}"

def get_function(self):
Expand All @@ -48,10 +48,9 @@ class CumCount(TransformPrimitive):
[1, 2, 3, 4, 5, 6]
"""
name = "cum_count"
input_types = [[ColumnSchema(semantic_tags={'foreign_key'})],
[ColumnSchema(semantic_tags={'category'})]]
return_type = ColumnSchema(logical_type=Integer)
uses_full_entity = True
input_types = [ColumnSchema(semantic_tags={'category'})]
return_type = ColumnSchema(logical_type=Integer, semantic_tags={'numeric'})
uses_full_dataframe = True
description_template = "the cumulative count of {}"

def get_function(self):
Expand Down Expand Up @@ -79,7 +78,7 @@ class CumMean(TransformPrimitive):
name = "cum_mean"
input_types = [ColumnSchema(semantic_tags={'numeric'})]
return_type = ColumnSchema(semantic_tags={'numeric'})
uses_full_entity = True
uses_full_dataframe = True
description_template = "the cumulative mean of {}"

def get_function(self):
Expand Down Expand Up @@ -107,7 +106,7 @@ class CumMin(TransformPrimitive):
name = "cum_min"
input_types = [ColumnSchema(semantic_tags={'numeric'})]
return_type = ColumnSchema(semantic_tags={'numeric'})
uses_full_entity = True
uses_full_dataframe = True
description_template = "the cumulative minimum of {}"

def get_function(self):
Expand Down Expand Up @@ -135,7 +134,7 @@ class CumMax(TransformPrimitive):
name = "cum_max"
input_types = [ColumnSchema(semantic_tags={'numeric'})]
return_type = ColumnSchema(semantic_tags={'numeric'})
uses_full_entity = True
uses_full_dataframe = True
description_template = "the cumulative maximum of {}"

def get_function(self):
Expand Down

0 comments on commit af195aa

Please sign in to comment.