In [None]:
from src import *

# try:
#     import pandas_bokeh
# except:
#     os.system('pip install --upgrade pandas-bokeh')
#     import pandas_bokeh

@dataclasses.dataclass
class Analysis(Base):
    nodes_tbl : str
    seeds     : typing.Any
        
    def __post_init__(self):
        results_stem = self.nodes_tbl.split('.')[-1][6:]
        self.abbr, self.yr, self.level, self.district_type = results_stem.split('_')
        ds = root_bq + f'.results_{results_stem}'
        bqclient.create_dataset(ds, exists_ok=True)
        self.results_bq = ds + f'.{seed}'
        results_path = root_path / f'results/{results_stem}/{seed}/'

        self.seeds_list = list()
        self.bq_list = list()
        for seed in self.seeds:
            tbl = self.results_bq + f'_{seed}_'
            if check_table(tbl + 'plans'):
                self.seeds_list.append(int(seed))
                self.bq_list.append(tbl)

        a, b = min(self.seeds_list), max(self.seeds_list)
        seeds_range = f'{str(a).rjust(4, "0")}_{str(b).rjust(4, "0")}'
        if all([s in self.seeds_list for s in range(a,b)]):
            seeds_range += '_complete'
        else:
            seeds_range += '_incomplete'
        self.tbl = self.results_bq + f'_{seeds_range}'
        self.pq = root_path / f'results/{results_stem}/{seeds_range}.parquet'
#         print(self.tbl, self.pq, self.bq_list)

    def compute_results(self):
        u = "\nunion all\n"
        stack = {key: u.join([f'select * from {bq}{key}' for bq in self.bq_list]) for key in ['plans', 'stats', 'summary']}

#         cols = [c for c in get_cols(self.nodes) if c not in Levels + District_types + ['geoid', 'county', 'total_pop', 'polygon', 'aland', 'perim', 'polsby_popper', 'density', 'point']]
        
        cols = [c for c in ['total_white'] if c not in Levels + District_types + ['geoid', 'county', 'total_pop', 'polygon', 'aland', 'perim', 'polsby_popper', 'density', 'point']]

        
        query = f"""
select
    B.seed,
    B.plan,
    C.{self.district_type},
    max(B.hash) as hash_plan,
    max(B.pop_imbalance) as pop_imbalance_plan,
    max(B.polsby_popper) as polsby_popper_plan,
    max(C.polsby_popper) as polsby_popper_district,
    max(C.aland) as aland,
    max(C.total_pop) as total_pop,
    max(C.total_pop) / sum(E.aland) as density,
    {join_str(1).join([f'sum(E.{c}) as {c}' for c in cols])}
from (
    select
        *
    from (
        select
            *,
            row_number() over (partition by A.hash order by plan asc, seed asc) as r
        from (
            {subquery(stack['summary'], indents=3)}
            ) as A
        )
    where r = 1
    ) as B
inner join (
    {subquery(stack['stats'], indents=1)}
    ) as C
on
    B.seed = C.seed and B.plan = C.plan
inner join (
    select
        *
    from (
        {subquery(stack['plans'], indents=2)}
        )
    ) as D
on
    C.seed = D.seed and C.plan = D.plan and C.{self.district_type} = D.{self.district_type}
inner join
    {self.nodes} as E
on
    D.geoid = E.geoid
group by
    seed, plan, {self.district_type}
order by
    seed, plan, {self.district_type}
"""
        load_table(tbl=self.tbl, query=query)
        self.fetch_results()
        self.save_results()
        
    def fetch_results(self):
        self.results = read_table(tbl=self.tbl)
        idx = ['seed', 'plan', 'cd']
        for col in idx:
            self.results[col] = rjust(self.results[col])
        self.results.sort_values(idx, inplace=True)
        return self.results
        
    def save_results(self):
        self.results.to_parquet(self.pq)
        to_gcs(self.pq)


    def plot(self, show=True):
        try:
            df = read_table(tbl=self.tbl+'_plans')
            df = df.pivot(index='geoid', columns='plan').astype(int)
            df.columns = df.columns.droplevel().rename(None)
            d = len(str(df.columns.max()))
            plans = ['plan_'+str(c).rjust(d, '0') for c in df.columns]
            df.columns = plans

            shapes = run_query(f'select geoid, county, total_pop, density, aland, perim, polsby_popper, polygon from {self.nodes}')
            df = df.merge(shapes, on='geoid')
            geo = gpd.GeoSeries.from_wkt(df['polygon'], crs='EPSG:4326').simplify(0.001).buffer(0) #<-- little white space @ .001 ~5.7 mb, minimal at .0001 ~10mb, with no white space ~37mb
#             geo = gpd.GeoSeries.from_wkt(df['polygon'], crs='EPSG:4326').buffer(0) # <-------------------- to not simplify at all
            self.gdf = gpd.GeoDataFrame(df.drop(columns='polygon'), geometry=geo)

            if show:
                pandas_bokeh.output_notebook() #<------------- uncommment to view in notebook
            fig = self.gdf.plot_bokeh(
                figsize = (900, 600),
                slider = plans,
                slider_name = "PLAN #",
                show_colorbar = False,
                colorbar_tick_format="0",
                colormap = "Category20",
                hovertool_string = '@geoid, @county<br>pop=@total_pop<br>density=@density{0.0}<br>land=@aland{0.0}<br>pp=@polsby_popper{0.0}',
                tile_provider = "CARTODBPOSITRON",
                return_html = True,
                show_figure = show,
                **{'fill_alpha' :.5,
                  'line_alpha':.05,}
            )
            fn = self.results_path / f'{self.run}_map.html'
            with open(fn, 'w') as file:
                file.write(fig)
#             rpt(f'map creation for {self.seed} - success')
        except Exception as e:
            rpt(f'map creation for {self.seed} - FAIL {e}')
            fig = None
        return fig


In [5]:
from src import *

@dataclasses.dataclass
class Analysis(Base):
    nodes_tbl : str
        
    def __post_init__(self):
        self.results_stem = self.nodes_tbl.split('.')[-1][6:]
        self.abbr, self.yr, self.level, self.district_type = self.results_stem.split('_')
#         bqclient.create_dataset(ds, exists_ok=True)
#         self.results_bq = ds + f'.{results_stem}_{self.seed}'
#         self.results_path = root_path / f'results/{results_stem}/{results_stem}_{self.seed}/'

        
        
        
#         self.seeds_list = list()
#         self.bq_list = list()
#         for seed in self.seeds:
#             tbl = self.results_bq + f'_{seed}_'
#             if check_table(tbl + 'plans'):
#                 self.seeds_list.append(int(seed))
#                 self.bq_list.append(tbl)

#         a, b = min(self.seeds_list), max(self.seeds_list)
#         seeds_range = f'{str(a).rjust(4, "0")}_{str(b).rjust(4, "0")}'
#         if all([s in self.seeds_list for s in range(a,b)]):
#             seeds_range += '_complete'
#         else:
#             seeds_range += '_incomplete'
#         self.tbl = self.results_bq + f'_{seeds_range}'
#         self.pq = root_path / f'results/{results_stem}/{seeds_range}.parquet'
# #         print(self.tbl, self.pq, self.bq_list)

    def compute_results(self):
        ds = f'{root_bq}.{self.results_stem}'
        self.tbls = {'plans':list(), 'stats':list(), 'summaries':list()}
        print(ds)
        for src_tbl in bqclient.list_tables(ds):
            key = src_tbl.table_id.split('_')[-1]
            self.tbls[key].append(src_tbl.full_table_id)
        
        u = "\nunion all\n"
        self.stack = {key: u.join([f'select * from {tbl}' for tbl in tbl_list]) for key, tbl_list in tbls.items()}
        print(stack)
        assert 1==2
        
A = Analysis('cmat-315920.redistricting_data.nodes_TX_2020_cntyvtd_cd')
A.compute_results()

cmat-315920.TX_2020_cntyvtd_cd
{'plans': 'select * from cmat-315920:TX_2020_cntyvtd_cd.1528_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1529_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1530_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1531_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1532_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1533_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1534_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1535_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1536_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1537_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1538_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1539_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1540_plans\nunion all\nselect * from cmat-315920:TX_2020_cntyvtd_cd.1541_plans\nunion all\nselect

AssertionError: 

In [None]:
from . import *

@dataclasses.dataclass
class MCMC(Base):
    max_steps             : int
    gpickle               : str
    seed                  : int = 1
    new_districts         : int = 0
    anneal                : float = 0.0
    pop_diff_exp          : int = 0
    pop_imbalance_target  : float = 1.0
    pop_imbalance_stop    : bool = True
    report_period         : int = 50
    

    def __post_init__(self):
        results_stem = self.gpickle.stem[6:]
        self.abbr, self.yr, self.level, self.district_type = results_stem.split('_')
        ds = root_bq + f'.results_{results_stem}'
        bqclient.create_dataset(ds, exists_ok=True)
        self.results_bq = ds + f'.{seed}'
        results_path = root_path / f'results/{results_stem}/{seed}/'
        self.rng = np.random.default_rng(int(self.seed))

        self.graph = nx.read_gpickle(self.gpickle)
        nx.set_node_attributes(self.graph, self.seed, 'seed')
        
        if self.new_districts > 0:
            M = int(self.nodes_df()[self.district_type].max())
            for n in self.nodes_df().nlargest(self.new_districts, 'total_pop').index:
                M += 1
                self.graph.nodes[n][self.district_type] = str(M)
        self.get_districts()
        self.num_districts = len(self.districts)
        self.pop_total = self.sum_nodes(self.graph, 'total_pop')
        self.pop_ideal = self.pop_total / self.num_districts

    def nodes_df(self, G=None):
        if G is None:
            G = self.graph
        return pd.DataFrame.from_dict(G.nodes, orient='index')
        
    def edges_tuple(self, G=None):
        if G is None:
            G = self.graph
        return tuple(sorted(tuple((min(u,v), max(u,v)) for u, v in G.edges)))
    
    def get_districts(self):
        grp = self.nodes_df().groupby(self.district_type)
        self.districts = {k:tuple(sorted(v)) for k,v in grp.groups.items()}
        self.partition = tuple(sorted(self.districts.values()))
        self.hash = self.partition.__hash__()

    def sum_nodes(self, G, attr='total_pop'):
        return sum(x for n, x in G.nodes(data=attr))
    
    def get_stats(self):
        self.get_districts()
        self.stat = pd.DataFrame()
        for d, N in self.districts.items():
            H = self.graph.subgraph(N)
            s = dict()
            internal_perim = 2*sum(x for a, b, x in H.edges(data='shared_perim') if x is not None)
            external_perim = self.sum_nodes(H, 'perim') - internal_perim
            s['aland'] = self.sum_nodes(H, 'aland')
            s['polsby_popper'] = 4 * np.pi * s['aland'] / (external_perim**2) * 100
            s['total_pop'] = self.sum_nodes(H, 'total_pop')
            for k, v in s.items():
                self.stat.loc[d, k] = v
        self.stat['total_pop'] = self.stat['total_pop'].astype(int)
        self.stat['plan'] = self.step
        self.stat['seed'] = self.seed
        
        
        self.pop_imbalance = (self.stat['total_pop'].max() - self.stat['total_pop'].min()) / self.pop_ideal * 100
        self.summary = pd.DataFrame()
        self.summary['seed'] = [self.seed]
        self.summary['plan'] = [self.step]
        self.summary['hash'] = [self.hash]
        self.summary['pop_imbalance'] = [self.pop_imbalance]
        self.summary['polsby_popper']  = [self.stat['polsby_popper'].mean()]


    def run_chain(self):
        self.step = 0
        self.overite_tbl = True
        nx.set_node_attributes(self.graph, self.step, 'plan')
        self.get_stats()
        self.plans      = [self.nodes_df()[['seed', 'plan', self.district_type]]]
        self.stats      = [self.stat.copy()]
        self.summaries  = [self.summary.copy()]
        self.hashes     = [self.hash]
        for k in range(1, self.max_steps+1):
            self.step = k
            nx.set_node_attributes(self.graph, self.step, 'plan')
            msg = f"seed {self.seed} step {self.step} pop_imbalance={self.pop_imbalance:.1f}"

            if self.recomb():
                self.plans.append(self.nodes_df()[['seed', 'plan', self.district_type]])
                self.stats.append(self.stat.copy())
                self.summaries.append(self.summary.copy())
                self.hashes.append(self.hash)
#                 print('success')
                if self.step % self.report_period == 0:
                    print(msg)
                if self.pop_imbalance_stop:
                    if self.pop_imbalance < self.pop_imbalance_target:
#                         rpt(f'pop_imbalance_target {self.pop_imbalance_target} satisfied - stopping')
                        break
            else:
                rpt(msg)
                break
            if self.step % 500 == 0:
                self.save(gcs=False)
        self.save(gcs=True)
#         print('MCMC done')



    def save(self, gcs=False):
        if self.results_bq is None:
            return
        self.results_path.mkdir(parents=True, exist_ok=True)
        self.file = self.results_path / f'graph.gpickle'
        nx.write_gpickle(self.graph, self.file)
        to_gcs(self.file)
        
        def reorder(df):
            idx = [c for c in ['seed', 'plan'] if c in df.columns]
            return df[idx + [c for c in df.columns if c not in idx]]

        tbls = {nm: self.results_bq+f'_{nm}' for nm in ['plans', 'stats', 'summaries']}
        if len(self.plans) > 0:
            self.plans     = pd.concat(self.plans    , axis=0).rename_axis('geoid').reset_index()
            self.stats     = pd.concat(self.stats    , axis=0).rename_axis(self.district_type).reset_index()
            self.summaries = pd.concat(self.summaries, axis=0)

            for nm, tbl in tbls.items():
                saved = False
                for i in range(1, 60):
                    try:
                        load_table(tbl=tlb, df=reorder(self[nm]), overwrite=self.overite_tbl)
                        self[nm] = list()
                        saved = True
                        break
                    except:
                        time.sleep(1)
                assert saved, f'I tried to write the result of seed {self.seed} {i} times without success - giving up'
            self.overite_tbl = False
        
        if gcs:
            for nm, tbl in tbls.items():
                to_gcs(tbl)


    def recomb(self):
        def gen(pop_diff):
            while len(pop_diff) > 0:
                pop_diff /= pop_diff.sum()
                a = self.rng.choice(pop_diff.index, p=pop_diff)
                pop_diff.pop(a)
                yield a
        L = self.stat['total_pop']
        pop_diff = pd.DataFrame([(x, y, abs(p-q)) for x, p in L.iteritems() for y, q in L.iteritems() if x < y]).set_index([0,1]).squeeze()
        pop_diff = (pop_diff / pop_diff.sum()) ** self.pop_diff_exp
        pairs = gen(pop_diff)
        while True:
            try:
                d0, d1 = next(pairs)
            except StopIteration:
                rpt(f'exhausted all district pairs - I think I am stuck')
                return False
            except Exception as e:
                raise Exception(f'unknown error {e}')
            m = list(self.districts[d0]+self.districts[d1])  # nodes in d0 or d1
            H = self.graph.subgraph(m).copy()  # subgraph on those nodes
            if not nx.is_connected(H):  # if H is not connect, go to next district pair
#                     rpt(f'{d0},{d1} not connected')
                continue
#                 else:
#                     rpt(f'{d0},{d1} connected')
            P = self.stat['total_pop'].copy()
            p0 = P.pop(d0)
            p1 = P.pop(d1)
            q = p0 + p1
            # q is population of d0 & d1
            # P lists all OTHER district populations
            P_min, P_max = P.min(), P.max()

            trees = []  # track which spanning trees we've tried so we don't repeat failures
            for i in range(100):  # max number of spanning trees to try
                for e in self.edges_tuple(H):
                    H.edges[e]['weight'] = self.rng.uniform()
                T = nx.minimum_spanning_tree(H)  # find minimum spanning tree - we assiged random weights so this is really a random spanning tress
                h = self.edges_tuple(T).__hash__()  # hash tree for comparion
                if h not in trees:  # prevents retrying a previously failed treee
                    trees.append(h)
                    # try to make search more efficient by searching for a suitable cut edge among edges with high betweenness-centrality
                    # Since cutting an edge near the perimeter of the tree is veru unlikely to produce population balance,
                    # we focus on edges near the center.  Betweenness-centrality is a good metric for this.
                    B = nx.edge_betweenness_centrality(T)
                    B = sorted(B.items(), key=lambda x:x[1], reverse=True)  # sort edges on betweenness-centrality (largest first)
                    max_tries = int(min(300, 0.2*len(B)))  # number of edge cuts to attempt before giving up on this tree
                    k = 0
                    for e, cent in B[:max_tries]:
                        T.remove_edge(*e)
                        comp = nx.connected_components(T)  # T nows has 2 components
                        next(comp)  # second one tends to be smaller → faster to sum over → skip over the first component
                        s = sum(H.nodes[n]['total_pop'] for n in next(comp))  # sum population in component 2
                        t = q - s  # pop of component 0 (recall q is the combined pop of d0&d1)
                        if s > t:  # ensure s < t
                            s, t = t, s
                        imb = (max(t, P_max) - min(s, P_min)) / self.pop_ideal * 100  # compute new pop imbalance
                        I = self.pop_imbalance - imb
                        if I < 0:
                            if self.anneal < 1e-7:
                                if I < -0.01:
                                    T.add_edge(*e)  #  if pop_balance not achieved, re-insert e
                                    continue
                            elif self.rng.uniform() > np.exp(I / self.anneal):
                                T.add_edge(*e)  #  if pop_balance not achieved, re-insert e
                                continue
                        # We found a good cut edge & made 2 new districts.  They will be label with the values of d0 & d1.
                        # But which one should get d0?  This is surprisingly important so colors "look right" in animations.
                        # Else, colors can get quite "jumpy" and give an impression of chaos and instability
                        # To achieve this, add aland of nodes that have the same od & new district label
                        # and subtract aland of nodes that change district label.  If negative, swap d0 & d1.
                        comp = get_components(T)
                        x = H.nodes(data=True)
                        s = (sum(x[n]['aland'] for n in comp[0] if x[n][self.district_type]==d0) -
                             sum(x[n]['aland'] for n in comp[0] if x[n][self.district_type]!=d0) +
                             sum(x[n]['aland'] for n in comp[1] if x[n][self.district_type]==d1) -
                             sum(x[n]['aland'] for n in comp[1] if x[n][self.district_type]!=d1))
                        if s < 0:
                            d0, d1 = d1, d0
                                
                        # Update district labels
                        for n in comp[0]:
                            self.graph.nodes[n][self.district_type] = d0
                        for n in comp[1]:
                            self.graph.nodes[n][self.district_type] = d1
                            
                        # update stats
                        self.get_stats()
                        assert abs(self.pop_imbalance - imb) < 1e-2, f'disagreement betwen pop_imbalance calculations {self.pop_imbalance} v {imb}'
                        if self.hash in self.hashes: # if we've already seen that plan before, reject and keep trying for a new one
#                             rpt(f'duplicate plan {self.hash}')
                            T.add_edge(*e)
                            # Restore old district labels
                            for n in H.nodes:
                                self.graph.nodes[n][self.district_type] = H.nodes[n][self.district_type]
                            self.get_stats()
                        else:  # if this is a never-before-seen plan, keep it and return happy
#                             rpt(f'recombed {self.district_type} {d0} & {d1} got pop_imbalance={self.pop_imbalance:.2f}%')
                            return True

In [6]:
from src import *
gpickle = pathlib.Path('/home/jupyter/redistricting_data/graph/TX/graph_TX_2020_cntyvtd_cd.gpickle')
gpickle.is_file()
gpickle.stem[6:]

'TX_2020_cntyvtd_cd'