diff --git a/featuretools/entityset/entity.py b/featuretools/entityset/entity.py index 2e3f8e4305..7e07477af6 100644 --- a/featuretools/entityset/entity.py +++ b/featuretools/entityset/entity.py @@ -288,7 +288,7 @@ def query_by_values(self, instance_vals, variable_id=None, columns=None, Data older than time_last by more than this will be ignored Returns: - pd.DataFrame : instances that match constraints + pd.DataFrame : instances that match constraints with ids in order of underlying dataframe """ instance_vals = self._vals_to_series(instance_vals, variable_id) @@ -309,8 +309,8 @@ def query_by_values(self, instance_vals, variable_id=None, columns=None, df.dropna(subset=[self.index], inplace=True) else: - df = self.df.merge(instance_vals.to_frame(variable_id), - how="inner", on=variable_id) + df = self.df[self.df[variable_id].isin(instance_vals)] + df = df.set_index(self.index, drop=False) # ensure filtered df has same categories as original diff --git a/featuretools/tests/entityset_tests/test_entity.py b/featuretools/tests/entityset_tests/test_entity.py index 8434b35f62..4d236ea63c 100644 --- a/featuretools/tests/entityset_tests/test_entity.py +++ b/featuretools/tests/entityset_tests/test_entity.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from datetime import datetime +import numpy as np import pandas as pd import pytest @@ -95,3 +96,19 @@ def test_update_data(es): assert es["customers"].df["id"].iloc[0] == 0 es["customers"].update_data(df.copy(deep=True), already_sorted=True) assert es["customers"].df["id"].iloc[0] == 2 + + +def test_query_by_values_returns_rows_in_given_order(): + data = pd.DataFrame({ + "id": [1, 2, 3, 4, 5], + "value": ["a", "c", "b", "a", "a"], + "time": [1000, 2000, 3000, 4000, 5000] + }) + + es = ft.EntitySet() + es = es.entity_from_dataframe(entity_id="test", dataframe=data, index="id", + time_index="time", variable_types={ + "value": ft.variable_types.Categorical + }) + query = es['test'].query_by_values(['b', 'a'], variable_id='value') + assert np.array_equal(query['id'], [1, 3, 4, 5])