diff --git a/synthpop/ipf/ipf.py b/synthpop/ipf/ipf.py index 5b396d3..b3beddb 100644 --- a/synthpop/ipf/ipf.py +++ b/synthpop/ipf/ipf.py @@ -52,20 +52,16 @@ def calc_diff(x, y): iterations = 0 + list_of_loc = [ + ((flat_joint_dist[idx[0]] == idx[1]).values, marginals[idx]) + for idx in marginals.index + ] + while calc_diff(constraints, prev_constraints) > tolerance: prev_constraints[:] = constraints - for idx in marginals.index: - # get the locations of the things we're updating - cat = idx[0] # top level category (col name in flat_joint_dist) - subcat = idx[1] # subcategory (col values in flat_joint_dist) - loc = (flat_joint_dist[cat] == subcat).values - - # figure out the proportions for this update - proportions = constraints[loc] / constraints[loc].sum() - - # distribute new total for these classes - constraints[loc] = proportions * marginals[idx] + for loc, target in list_of_loc: + constraints[loc] *= target / constraints[loc].sum() iterations += 1