Skip to content

Commit

Permalink
Merge branch 'master' into new-simulation-testing
Browse files Browse the repository at this point in the history
Conflicts:
	urbansim/sim/simulation.py
  • Loading branch information
fscottfoti committed Jul 28, 2014
2 parents d7dbb95 + 2f11343 commit ebb7b7c
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 1 deletion.
149 changes: 148 additions & 1 deletion urbansim/sim/simulation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import print_function

import inspect
from collections import Callable
from collections import Callable, namedtuple

import pandas as pd
import toolz
Expand All @@ -10,9 +10,12 @@

logger = logging.getLogger(__name__)

from ..utils.misc import column_map

_TABLES = {}
_COLUMNS = {}
_MODELS = {}
_BROADCASTS = {}


def clear_sim():
Expand All @@ -23,6 +26,7 @@ def clear_sim():
_TABLES.clear()
_COLUMNS.clear()
_MODELS.clear()
_BROADCASTS.clear()


class _DataFrameWrapper(object):
Expand Down Expand Up @@ -504,3 +508,146 @@ def run(models, years=None):
t1 = time.time()
model(year=year)
logger.debug("Time to execute model = %.3fs" % (time.time()-t1))


_Broadcast = namedtuple(
'_Broadcast',
['cast', 'onto', 'cast_on', 'onto_on', 'cast_index', 'onto_index'])


def broadcast(cast, onto, cast_on=None, onto_on=None,
cast_index=False, onto_index=False):
"""
Register a rule for merging two tables by broadcasting one onto
the other.
Parameters
----------
cast, onto : str
Names of registered tables.
cast_on, onto_on : str, optional
Column names used for merge, equivalent of ``left_on``/``right_on``
parameters of pandas.merge.
cast_index, onto_index : bool, optional
Whether to use table indexes for merge. Equivalent of
``left_index``/``right_index`` parameters of pandas.merge.
"""
_BROADCASTS[(cast, onto)] = \
_Broadcast(cast, onto, cast_on, onto_on, cast_index, onto_index)


def _get_broadcasts(tables):
"""
Get the broadcasts associated with a set of tables.
Parameters
----------
tables : sequence of str
Table names for which broadcasts have been registered.
Returns
-------
casts : dict of `_Broadcast`
Keys are tuples of strings like (cast_name, onto_name).
"""
tables = set(tables)
casts = toolz.keyfilter(
lambda x: x[0] in tables and x[1] in tables, _BROADCASTS)
if tables - set(toolz.concat(casts.keys())):
raise ValueError('Not enough links to merge all tables.')
return casts


# utilities for merge_tables
def _all_reachable_tables(t):
for k, v in t.items():
for tname in _all_reachable_tables(v):
yield tname
yield k


def _is_leaf_node(merge_node):
return not any(merge for merge in merge_node.values())


def _next_merge(merge_node):
if all(_is_leaf_node(merge) for merge in merge_node.values()):
return merge_node
else:
for merge in merge_node.values():
if merge:
return _next_merge(merge)


def merge_tables(target, tables, columns=None):
"""
Merge a number of tables onto a target table. Tables must have
registered merge rules via the `broadcast` function.
Parameters
----------
target : str
Name of the table onto which tables will be merged.
tables : list of _DataFrameWrapper or _TableFuncWrapper
All of the tables to merge. Should include the target table.
columns : list of str, optional
If given, columns will be mapped to `tables` and only those columns
will be requested from each table. The final merged table will have
only these columns. By default all columns are used from every
table.
Returns
-------
merged : pandas.DataFrame
"""
merges = {t.name: {} for t in tables}
tables = {t.name: t for t in tables}
casts = _get_broadcasts(tables.keys())

# relate all the tables by registered broadcasts
for table, onto in casts:
merges[onto][table] = merges[table]
merges = {target: merges[target]}

# verify that all the tables can be merged to the target
all_tables = set(_all_reachable_tables(merges))

if all_tables != set(tables.keys()):
raise RuntimeError(
('Not all tables can be merged to target "{}". Unlinked tables: {}'
).format(target, list(set(tables.keys()) - all_tables)))

# add any columns necessary for indexing into columns
if columns:
columns = list(columns)
for c in casts.values():
if c.onto_on:
columns.append(c.onto_on)
if c.cast_on:
columns.append(c.cast_on)

# get column map for which columns go with which table
colmap = column_map(tables.values(), columns)

# get frames
frames = {name: t.to_frame(columns=colmap[name])
for name, t in tables.items()}

while merges[target]:
nm = _next_merge(merges)
onto = nm.keys()[0]
onto_table = frames[onto]
for cast in nm[onto].keys():
cast_table = frames[cast]
bc = casts[(cast, onto)]
onto_table = pd.merge(
onto_table, cast_table,
left_on=bc.onto_on, right_on=bc.cast_on,
left_index=bc.onto_index, right_index=bc.cast_index)
frames[onto] = onto_table
nm[onto] = {}

return frames[target]
119 changes: 119 additions & 0 deletions urbansim/sim/tests/test_mergetables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pandas as pd
import pytest
from pandas.util import testing as pdt

from .. import simulation as sim
from .test_simulation import clear_sim
from ...utils.testing import assert_frames_equal


@pytest.fixture
def dfa():
return sim._DataFrameWrapper('a', pd.DataFrame(
{'a1': [1, 2, 3],
'a2': [4, 5, 6],
'a3': [7, 8, 9]},
index=['aa', 'ab', 'ac']))


@pytest.fixture
def dfz():
return sim._DataFrameWrapper('z', pd.DataFrame(
{'z1': [90, 91],
'z2': [92, 93],
'z3': [94, 95],
'z4': [96, 97],
'z5': [98, 99]},
index=['za', 'zb']))


@pytest.fixture
def dfb():
return sim._DataFrameWrapper('b', pd.DataFrame(
{'b1': range(10, 15),
'b2': range(15, 20),
'a_id': ['ac', 'ac', 'ab', 'aa', 'ab'],
'z_id': ['zb', 'zb', 'za', 'za', 'zb']},
index=['ba', 'bb', 'bc', 'bd', 'be']))


@pytest.fixture
def dfc():
return sim._DataFrameWrapper('c', pd.DataFrame(
{'c1': range(20, 30),
'c2': range(30, 40),
'b_id': ['ba', 'bd', 'bb', 'bc', 'bb', 'ba', 'bb', 'bc', 'bd', 'bb']},
index=['ca', 'cb', 'cc', 'cd', 'ce', 'cf', 'cg', 'ch', 'ci', 'cj']))


@pytest.fixture
def dfg():
return sim._DataFrameWrapper('g', pd.DataFrame(
{'g1': [1, 2, 3]},
index=['ga', 'gb', 'gc']))


@pytest.fixture
def dfh():
return sim._DataFrameWrapper('h', pd.DataFrame(
{'h1': range(10, 15),
'g_id': ['ga', 'gb', 'gc', 'ga', 'gb']},
index=['ha', 'hb', 'hc', 'hd', 'he']))


def all_broadcasts():
sim.broadcast('a', 'b', cast_index=True, onto_on='a_id')
sim.broadcast('z', 'b', cast_index=True, onto_on='z_id')
sim.broadcast('b', 'c', cast_index=True, onto_on='b_id')
sim.broadcast('g', 'h', cast_index=True, onto_on='g_id')


def test_merge_tables_raises(clear_sim, dfa, dfz, dfb, dfg, dfh):
all_broadcasts()

with pytest.raises(RuntimeError):
sim.merge_tables('b', [dfa, dfb, dfz, dfg, dfh])


def test_merge_tables1(clear_sim, dfa, dfz, dfb):
all_broadcasts()

merged = sim.merge_tables('b', [dfa, dfz, dfb])

expected = pd.merge(
dfa.to_frame(), dfb.to_frame(), left_index=True, right_on='a_id')
expected = pd.merge(
expected, dfz.to_frame(), left_on='z_id', right_index=True)

assert_frames_equal(merged, expected)


def test_merge_tables2(clear_sim, dfa, dfz, dfb, dfc):
all_broadcasts()

merged = sim.merge_tables('c', [dfa, dfz, dfb, dfc])

expected = pd.merge(
dfa.to_frame(), dfb.to_frame(), left_index=True, right_on='a_id')
expected = pd.merge(
expected, dfz.to_frame(), left_on='z_id', right_index=True)
expected = pd.merge(
expected, dfc.to_frame(), left_index=True, right_on='b_id')

assert_frames_equal(merged, expected)


def test_merge_tables_cols(clear_sim, dfa, dfz, dfb, dfc):
all_broadcasts()

merged = sim.merge_tables(
'c', [dfa, dfz, dfb, dfc], columns=['a1', 'b1', 'z1', 'c1'])

expected = pd.DataFrame(
{'c1': range(20, 30),
'b1': [10, 13, 11, 12, 11, 10, 11, 12, 13, 11],
'a1': [3, 1, 3, 2, 3, 3, 3, 2, 1, 3],
'z1': [91, 90, 91, 90, 91, 91, 91, 90, 90, 91]},
index=['ca', 'cb', 'cc', 'cd', 'ce', 'cf', 'cg', 'ch', 'ci', 'cj'])

assert_frames_equal(merged, expected)
17 changes: 17 additions & 0 deletions urbansim/sim/tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,20 @@ def test_model2(test_table):
2000: [2012, 2015, 2018],
3000: [3012, 3017, 3024]},
index=['x', 'y', 'z']))


def test_get_broadcasts(clear_sim):
sim.broadcast('a', 'b')
sim.broadcast('b', 'c')
sim.broadcast('z', 'b')
sim.broadcast('f', 'g')

with pytest.raises(ValueError):
sim._get_broadcasts(['a', 'b', 'g'])

assert set(sim._get_broadcasts(['a', 'b', 'c', 'z']).keys()) == \
{('a', 'b'), ('b', 'c'), ('z', 'b')}
assert set(sim._get_broadcasts(['a', 'b', 'z']).keys()) == \
{('a', 'b'), ('z', 'b')}
assert set(sim._get_broadcasts(['a', 'b', 'c']).keys()) == \
{('a', 'b'), ('b', 'c')}
41 changes: 41 additions & 0 deletions urbansim/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from __future__ import print_function

import os

import numpy as np
import pandas as pd
import toolz


def _mkifnotexists(folder):
Expand Down Expand Up @@ -37,6 +39,14 @@ def runs_dir():
return _mkifnotexists("runs")


def models_dir():
"""
Return the directory for the model configuration files (used by the
website).
"""
return _mkifnotexists("configs")


def charts_dir():
"""
Return the directory for the chart configuration files (used by the
Expand Down Expand Up @@ -285,3 +295,34 @@ def pandasdfsummarytojson(df, ndigits=3):
"""
df = df.transpose()
return {k: _pandassummarytojson(v, ndigits) for k, v in df.iterrows()}


def column_map(tables, columns):
"""
Take a list of tables and a list of column names and resolve which
columns come from which table.
Parameters
----------
tables : sequence of _DataFrameWrapper or _TableFuncWrapper
Could also be sequence of modified pandas.DataFrames, the important
thing is that they have ``.name`` and ``.columns`` attributes.
columns : sequence of str
The column names of interest.
Returns
-------
col_map : dict
Maps table names to lists of column names.
"""
if not columns:
return {t.name: None for t in tables}

columns = set(columns)
colmap = {t.name: list(set(t.columns).intersection(columns)) for t in tables}
foundcols = toolz.reduce(lambda x, y: x.union(y), (set(v) for v in colmap.values()))
if foundcols != columns:
raise RuntimeError('Not all required columns were found. '
'Missing: {}'.format(list(columns - foundcols)))
return colmap
Loading

0 comments on commit ebb7b7c

Please sign in to comment.