diff --git a/urbansim/sim/simulation.py b/urbansim/sim/simulation.py index 7a7cf1ab..e3991df7 100644 --- a/urbansim/sim/simulation.py +++ b/urbansim/sim/simulation.py @@ -32,7 +32,7 @@ class SimulationError(Exception): pass -class _DataFrameWrapper(object): +class DataFrameWrapper(object): """ Wraps a DataFrame so it can provide certain columns and handle computed columns. @@ -54,7 +54,15 @@ def columns(self): Columns in this table. """ - return list(self._frame.columns) + _list_columns_for_table(self.name) + return self.local_columns + _list_columns_for_table(self.name) + + @property + def local_columns(self): + """ + Columns that are part of the wrapped DataFrame. + + """ + return list(self._frame.columns) @property def index(self): @@ -136,7 +144,7 @@ def __len__(self): return len(self._frame) -class _TableFuncWrapper(object): +class TableFuncWrapper(object): """ Wrap a function that provides a DataFrame. @@ -159,11 +167,25 @@ def __init__(self, name, func): @property def columns(self): """ - Columns in this table. (May often be out of date.) + Columns in this table. (May contain only computed columns + if the wrapped function has not been called yet.) """ return self._columns + _list_columns_for_table(self.name) + @property + def local_columns(self): + """ + Only the columns contained in the DataFrame returned by the + wrapped function. (No registered columns included.) + + """ + if self._columns: + return self._columns + else: + self._call_func() + return self._columns + @property def index(self): """ @@ -173,6 +195,19 @@ def index(self): """ return self._index + def _call_func(self): + """ + Call the wrapped function and return the result. Also updates + attributes like columns, index, and length. + + """ + kwargs = _collect_injectables(self._arg_list) + frame = self._func(**kwargs) + self._columns = list(frame.columns) + self._index = frame.index + self._len = len(frame) + return frame + def to_frame(self, columns=None): """ Make a DataFrame with the given columns. @@ -188,12 +223,8 @@ def to_frame(self, columns=None): frame : pandas.DataFrame """ - kwargs = _collect_injectables(self._arg_list) - frame = self._func(**kwargs) - self._columns = list(frame.columns) - self._index = frame.index - self._len = len(frame) - return _DataFrameWrapper(self.name, frame).to_frame(columns) + frame = self._call_func() + return DataFrameWrapper(self.name, frame).to_frame(columns) def get_column(self, column_name): """ @@ -220,7 +251,7 @@ def __len__(self): return self._len -class _TableSourceWrapper(_TableFuncWrapper): +class TableSourceWrapper(TableFuncWrapper): """ Wraps a function that returns a DataFrame. After the function is evaluated the returned DataFrame replaces the function in the @@ -232,6 +263,15 @@ class _TableSourceWrapper(_TableFuncWrapper): func : callable """ + def convert(self): + """ + Evaluate the wrapped function, store the returned DataFrame as a + table, and return the new DataFrameWrapper instance created. + + """ + frame = self._call_func() + return add_table(self.name, frame) + def to_frame(self, columns=None): """ Make a DataFrame with the given columns. The first time this @@ -249,10 +289,7 @@ def to_frame(self, columns=None): frame : pandas.DataFrame """ - kwargs = _collect_injectables(self._arg_list) - frame = self._func(**kwargs) - add_table(self.name, frame) - return _DataFrameWrapper(self.name, frame).to_frame(columns) + return self.convert().to_frame(columns) class _ColumnFuncWrapper(object): @@ -391,16 +428,22 @@ def add_table(table_name, table): names will be matched to known tables, which will be injected when this function is called. + Returns + ------- + wrapped : `DataFrameWrapper` or `TableFuncWrapper` + """ if isinstance(table, pd.DataFrame): - table = _DataFrameWrapper(table_name, table) + table = DataFrameWrapper(table_name, table) elif isinstance(table, Callable): - table = _TableFuncWrapper(table_name, table) + table = TableFuncWrapper(table_name, table) else: raise TypeError('table must be DataFrame or function.') _TABLES[table_name] = table + return table + def table(table_name): """ @@ -430,8 +473,14 @@ def add_table_source(table_name, func): Function argument names will be matched to known injectables, which will be injected when this function is called. + Returns + ------- + wrapped : `TableSourceWrapper` + """ - _TABLES[table_name] = _TableSourceWrapper(table_name, func) + wrapped = TableSourceWrapper(table_name, func) + _TABLES[table_name] = wrapped + return wrapped def table_source(table_name): @@ -457,7 +506,7 @@ def get_table(table_name): Returns ------- - table : _DataFrameWrapper or _TableFuncWrapper + table : `DataFrameWrapper`, `TableFuncWrapper`, or `TableSourceWrapper` """ if table_name in _TABLES: @@ -755,7 +804,7 @@ def merge_tables(target, tables, columns=None): ---------- target : str Name of the table onto which tables will be merged. - tables : list of _DataFrameWrapper or _TableFuncWrapper + tables : list of `DataFrameWrapper` or `TableFuncWrapper` All of the tables to merge. Should include the target table. columns : list of str, optional If given, columns will be mapped to `tables` and only those columns diff --git a/urbansim/sim/tests/test_mergetables.py b/urbansim/sim/tests/test_mergetables.py index ffa05aad..846a2f8d 100644 --- a/urbansim/sim/tests/test_mergetables.py +++ b/urbansim/sim/tests/test_mergetables.py @@ -9,7 +9,7 @@ @pytest.fixture def dfa(): - return sim._DataFrameWrapper('a', pd.DataFrame( + return sim.DataFrameWrapper('a', pd.DataFrame( {'a1': [1, 2, 3], 'a2': [4, 5, 6], 'a3': [7, 8, 9]}, @@ -18,7 +18,7 @@ def dfa(): @pytest.fixture def dfz(): - return sim._DataFrameWrapper('z', pd.DataFrame( + return sim.DataFrameWrapper('z', pd.DataFrame( {'z1': [90, 91], 'z2': [92, 93], 'z3': [94, 95], @@ -29,7 +29,7 @@ def dfz(): @pytest.fixture def dfb(): - return sim._DataFrameWrapper('b', pd.DataFrame( + return sim.DataFrameWrapper('b', pd.DataFrame( {'b1': range(10, 15), 'b2': range(15, 20), 'a_id': ['ac', 'ac', 'ab', 'aa', 'ab'], @@ -39,7 +39,7 @@ def dfb(): @pytest.fixture def dfc(): - return sim._DataFrameWrapper('c', pd.DataFrame( + return sim.DataFrameWrapper('c', pd.DataFrame( {'c1': range(20, 30), 'c2': range(30, 40), 'b_id': ['ba', 'bd', 'bb', 'bc', 'bb', 'ba', 'bb', 'bc', 'bd', 'bb']}, @@ -48,14 +48,14 @@ def dfc(): @pytest.fixture def dfg(): - return sim._DataFrameWrapper('g', pd.DataFrame( + return sim.DataFrameWrapper('g', pd.DataFrame( {'g1': [1, 2, 3]}, index=['ga', 'gb', 'gc'])) @pytest.fixture def dfh(): - return sim._DataFrameWrapper('h', pd.DataFrame( + return sim.DataFrameWrapper('h', pd.DataFrame( {'h1': range(10, 15), 'g_id': ['ga', 'gb', 'gc', 'ga', 'gb']}, index=['ha', 'hb', 'hc', 'hd', 'he'])) diff --git a/urbansim/sim/tests/test_simulation.py b/urbansim/sim/tests/test_simulation.py index 388183a5..015bfc58 100644 --- a/urbansim/sim/tests/test_simulation.py +++ b/urbansim/sim/tests/test_simulation.py @@ -24,7 +24,7 @@ def df(): def test_tables(df, clear_sim): - sim.add_table('test_frame', df) + wrapped_df = sim.add_table('test_frame', df) @sim.table('test_func') def test_func(test_frame): @@ -33,7 +33,9 @@ def test_func(test_frame): assert set(sim.list_tables()) == {'test_frame', 'test_func'} table = sim.get_table('test_frame') + assert table is wrapped_df assert table.columns == ['a', 'b'] + assert table.local_columns == ['a', 'b'] assert len(table) == 3 pdt.assert_index_equal(table.index, df.index) pdt.assert_series_equal(table.get_column('a'), df.a) @@ -285,13 +287,50 @@ def source(): return df table = sim.get_table('source') - assert isinstance(table, sim._TableSourceWrapper) + assert isinstance(table, sim.TableSourceWrapper) test_df = table.to_frame() pdt.assert_frame_equal(test_df, df) + assert table.columns == list(df.columns) + assert len(table) == len(df) + pdt.assert_index_equal(table.index, df.index) table = sim.get_table('source') - assert isinstance(table, sim._DataFrameWrapper) + assert isinstance(table, sim.DataFrameWrapper) test_df = table.to_frame() pdt.assert_frame_equal(test_df, df) + + +def test_table_source_convert(clear_sim, df): + @sim.table_source('source') + def source(): + return df + + table = sim.get_table('source') + assert isinstance(table, sim.TableSourceWrapper) + + table = table.convert() + assert isinstance(table, sim.DataFrameWrapper) + pdt.assert_frame_equal(table.to_frame(), df) + + table2 = sim.get_table('source') + assert table2 is table + + +def test_table_func_local_cols(clear_sim, df): + @sim.table('table') + def table(): + return df + sim.add_column('table', 'new', pd.Series(['a', 'b', 'c'], index=df.index)) + + assert sim.get_table('table').local_columns == ['a', 'b'] + + +def test_table_source_local_cols(clear_sim, df): + @sim.table_source('source') + def source(): + return df + sim.add_column('source', 'new', pd.Series(['a', 'b', 'c'], index=df.index)) + + assert sim.get_table('source').local_columns == ['a', 'b']