Skip to content

Commit

Permalink
Custom Variable Types (#571)
Browse files Browse the repository at this point in the history
Added to entityset/deserialize.py so that description_to_variable() looks at all variables that are (recursively) a subclass of Variable, so that it will recognize custom variable types.
  • Loading branch information
allisonportis committed Jun 4, 2019
1 parent 8f75909 commit 1c00b98
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 31 deletions.
6 changes: 4 additions & 2 deletions featuretools/entityset/deserialize.py
Expand Up @@ -4,7 +4,8 @@
import pandas as pd

from featuretools.entityset.relationship import Relationship
from featuretools.entityset.serialize import FORMATS, VARIABLE_TYPES
from featuretools.entityset.serialize import FORMATS
from featuretools.variable_types.variable import find_variable_types


def description_to_variable(description, entity=None):
Expand All @@ -17,9 +18,10 @@ def description_to_variable(description, entity=None):
Returns:
variable (Variable) : Returns :class:`.Variable`.
'''
variable_types = find_variable_types()
is_type_string = isinstance(description['type'], str)
type = description['type'] if is_type_string else description['type'].pop('value')
variable = VARIABLE_TYPES.get(type, VARIABLE_TYPES.get('None'))
variable = variable_types.get(type, variable_types.get('None')) # 'None' will return the Unknown variable type
if entity is not None:
kwargs = {} if is_type_string else description['type']
variable = variable(description['id'], entity, **kwargs)
Expand Down
6 changes: 0 additions & 6 deletions featuretools/entityset/serialize.py
Expand Up @@ -2,13 +2,7 @@
import os
import shutil

from featuretools import variable_types

FORMATS = ['csv', 'pickle', 'parquet']
VARIABLE_TYPES = {
str(getattr(variable_types, type).type_string): getattr(variable_types, type) for type in dir(variable_types)
if hasattr(getattr(variable_types, type), 'type_string')
}


def entity_to_description(entity):
Expand Down
16 changes: 2 additions & 14 deletions featuretools/primitives/utils.py
Expand Up @@ -9,7 +9,7 @@
PrimitiveBase,
TransformPrimitive
)
from featuretools.utils import is_python_2
from featuretools.utils.gen_utils import find_descendents, is_python_2

if is_python_2():
import imp
Expand Down Expand Up @@ -153,7 +153,7 @@ class PrimitivesDeserializer(object):

def __init__(self):
self.class_cache = {} # (class_name, module_name) -> class
self.primitive_classes = _descendants(PrimitiveBase)
self.primitive_classes = find_descendents(PrimitiveBase)

def deserialize_primitive(self, primitive_dict):
"""
Expand Down Expand Up @@ -183,15 +183,3 @@ def _find_class_in_descendants(self, search_key):

if cls_key == search_key:
return cls


def _descendants(cls):
"""
A generator which yields all descendant classes of the given class
(including the given class).
"""
yield cls

for sub in cls.__subclasses__():
for c in _descendants(sub):
yield c
30 changes: 28 additions & 2 deletions featuretools/tests/entityset_tests/test_serialization.py
Expand Up @@ -8,19 +8,26 @@
from featuretools.demo import load_mock_customer
from featuretools.entityset import EntitySet, deserialize, serialize
from featuretools.tests import integration_data
from featuretools.variable_types.variable import (
Categorical,
Index,
TimeIndex,
find_variable_types
)

CACHE = os.path.join(os.path.dirname(integration_data.__file__), '.cache')


def test_all_variable_descriptions():
variable_types = find_variable_types()
es = EntitySet()
dataframe = pd.DataFrame(columns=list(serialize.VARIABLE_TYPES))
dataframe = pd.DataFrame(columns=list(variable_types))
es.entity_from_dataframe(
'variable_types',
dataframe,
index='index',
time_index='datetime_time_index',
variable_types=serialize.VARIABLE_TYPES,
variable_types=variable_types,
)
entity = es['variable_types']
for variable in entity.variables:
Expand All @@ -29,6 +36,25 @@ def test_all_variable_descriptions():
assert variable.__eq__(_variable)


def test_custom_variable_descriptions():

class ItemList(Categorical):
type_string = "item_list"
_default_pandas_dtype = list

es = EntitySet()
variables = {'item_list': ItemList, 'time_index': TimeIndex, 'index': Index}
dataframe = pd.DataFrame(columns=list(variables))
es.entity_from_dataframe(
'custom_variable', dataframe, index='index',
time_index='time_index', variable_types=variables)
entity = es['custom_variable']
for variable in entity.variables:
description = variable.to_data_description()
_variable = deserialize.description_to_variable(description, entity=entity)
assert variable.__eq__(_variable)


def test_variable_descriptions(es):
for entity in es.entities:
for variable in entity.variables:
Expand Down
14 changes: 14 additions & 0 deletions featuretools/utils/gen_utils.py
Expand Up @@ -82,3 +82,17 @@ def get_relationship_variable_id(path):

def is_python_2():
return sys.version_info.major < 3


def find_descendents(cls):
"""
A generator which yields all descendent classes of the given class
(including the given class)
Args:
cls (Class): the class to find descendents of
"""
yield cls
for sub in cls.__subclasses__():
for c in find_descendents(sub):
yield c
11 changes: 4 additions & 7 deletions featuretools/variable_types/variable.py
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd

from featuretools.utils import is_string
from featuretools.utils.gen_utils import find_descendents, is_string


class Variable(object):
Expand Down Expand Up @@ -421,12 +421,9 @@ class FilePath(Variable):
_default_pandas_dtype = str


ALL_VARIABLE_TYPES = [Datetime, Numeric, Timedelta,
Categorical, Text, Ordinal, Discrete,
Boolean, LatLong, ZIPCode, IPAddress,
FullName, EmailAddress, URL, PhoneNumber,
DateOfBirth, CountryCode, SubRegionCode,
FilePath, DatetimeTimeIndex]
def find_variable_types():
return {str(vtype.type_string): vtype for vtype in find_descendents(
Variable) if hasattr(vtype, 'type_string')}


DEFAULT_DTYPE_VALUES = {
Expand Down

0 comments on commit 1c00b98

Please sign in to comment.