Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Variable Types #571

Merged
merged 9 commits into from
Jun 4, 2019
6 changes: 4 additions & 2 deletions featuretools/entityset/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pandas as pd

from featuretools.entityset.relationship import Relationship
from featuretools.entityset.serialize import FORMATS, VARIABLE_TYPES
from featuretools.entityset.serialize import FORMATS
from featuretools.variable_types.variable import find_variable_types


def description_to_variable(description, entity=None):
Expand All @@ -17,9 +18,10 @@ def description_to_variable(description, entity=None):
Returns:
variable (Variable) : Returns :class:`.Variable`.
'''
variable_types = find_variable_types()
is_type_string = isinstance(description['type'], str)
type = description['type'] if is_type_string else description['type'].pop('value')
variable = VARIABLE_TYPES.get(type, VARIABLE_TYPES.get('None'))
variable = variable_types.get(type, variable_types.get('None')) # 'None' will return the Unknown variable type
if entity is not None:
kwargs = {} if is_type_string else description['type']
variable = variable(description['id'], entity, **kwargs)
Expand Down
6 changes: 0 additions & 6 deletions featuretools/entityset/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
import os
import shutil

from featuretools import variable_types

FORMATS = ['csv', 'pickle', 'parquet']
VARIABLE_TYPES = {
str(getattr(variable_types, type).type_string): getattr(variable_types, type) for type in dir(variable_types)
if hasattr(getattr(variable_types, type), 'type_string')
}


def entity_to_description(entity):
Expand Down
16 changes: 2 additions & 14 deletions featuretools/primitives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
PrimitiveBase,
TransformPrimitive
)
from featuretools.utils import is_python_2
from featuretools.utils.gen_utils import find_descendents, is_python_2

if is_python_2():
import imp
Expand Down Expand Up @@ -153,7 +153,7 @@ class PrimitivesDeserializer(object):

def __init__(self):
self.class_cache = {} # (class_name, module_name) -> class
self.primitive_classes = _descendants(PrimitiveBase)
self.primitive_classes = find_descendents(PrimitiveBase)

def deserialize_primitive(self, primitive_dict):
"""
Expand Down Expand Up @@ -183,15 +183,3 @@ def _find_class_in_descendants(self, search_key):

if cls_key == search_key:
return cls


def _descendants(cls):
"""
A generator which yields all descendant classes of the given class
(including the given class).
"""
yield cls

for sub in cls.__subclasses__():
for c in _descendants(sub):
yield c
30 changes: 28 additions & 2 deletions featuretools/tests/entityset_tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,26 @@
from featuretools.demo import load_mock_customer
from featuretools.entityset import EntitySet, deserialize, serialize
from featuretools.tests import integration_data
from featuretools.variable_types.variable import (
Categorical,
Index,
TimeIndex,
find_variable_types
)

CACHE = os.path.join(os.path.dirname(integration_data.__file__), '.cache')


def test_all_variable_descriptions():
variable_types = find_variable_types()
es = EntitySet()
dataframe = pd.DataFrame(columns=list(serialize.VARIABLE_TYPES))
dataframe = pd.DataFrame(columns=list(variable_types))
es.entity_from_dataframe(
'variable_types',
dataframe,
index='index',
time_index='datetime_time_index',
variable_types=serialize.VARIABLE_TYPES,
variable_types=variable_types,
)
entity = es['variable_types']
for variable in entity.variables:
Expand All @@ -29,6 +36,25 @@ def test_all_variable_descriptions():
assert variable.__eq__(_variable)


def test_custom_variable_descriptions():

class ItemList(Categorical):
type_string = "item_list"
_default_pandas_dtype = list

es = EntitySet()
variables = {'item_list': ItemList, 'time_index': TimeIndex, 'index': Index}
dataframe = pd.DataFrame(columns=list(variables))
es.entity_from_dataframe(
'custom_variable', dataframe, index='index',
time_index='time_index', variable_types=variables)
entity = es['custom_variable']
for variable in entity.variables:
description = variable.to_data_description()
_variable = deserialize.description_to_variable(description, entity=entity)
assert variable.__eq__(_variable)


def test_variable_descriptions(es):
for entity in es.entities:
for variable in entity.variables:
Expand Down
14 changes: 14 additions & 0 deletions featuretools/utils/gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,17 @@ def get_relationship_variable_id(path):

def is_python_2():
return sys.version_info.major < 3


def find_descendents(cls):
"""
A generator which yields all descendent classes of the given class
(including the given class)

Args:
cls (Class): the class to find descendents of
"""
yield cls
for sub in cls.__subclasses__():
for c in find_descendents(sub):
yield c
11 changes: 4 additions & 7 deletions featuretools/variable_types/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd

from featuretools.utils import is_string
from featuretools.utils.gen_utils import find_descendents, is_string


class Variable(object):
Expand Down Expand Up @@ -421,12 +421,9 @@ class FilePath(Variable):
_default_pandas_dtype = str


ALL_VARIABLE_TYPES = [Datetime, Numeric, Timedelta,
Categorical, Text, Ordinal, Discrete,
Boolean, LatLong, ZIPCode, IPAddress,
FullName, EmailAddress, URL, PhoneNumber,
DateOfBirth, CountryCode, SubRegionCode,
FilePath, DatetimeTimeIndex]
def find_variable_types():
return {str(vtype.type_string): vtype for vtype in find_descendents(
Variable) if hasattr(vtype, 'type_string')}


DEFAULT_DTYPE_VALUES = {
Expand Down