In [None]:
! pip install --upgrade google-cloud-bigquery-storage

In [None]:
import google, numpy as np, pandas as pd, geopandas as gpd, networkx as nx
import matplotlib.pyplot as plt, plotly.express as px 
from shapely.ops import orient
from google.cloud import aiplatform, bigquery
from google.cloud.bigquery_storage import BigQueryReadClient, types
cred, proj = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
bqclient = bigquery.Client(credentials = cred, project = proj)
proj_id = 'cmat-315920'

In [None]:
yr = 2017
state_abbr = 'RI'
min_degree = 2

# input is WKT in NAD83 - https://www2.census.gov/geo/pdfs/maps-data/data/tiger/tgrshp2020/TGRSHP2020_TechDoc_Ch3.pdf
# use ESRI:102003 for area calculations - https://epsg.io/102003
# use ESRI:102005 for length calculations - https://epsg.io/102005
crs_map   = 'NAD83'
crs_area   = 'ESRI:102003'
crs_length = 'ESRI:102005'

In [None]:
query_str = f"""
select
    state_fips_code as fips
    , state_postal_abbreviation as abbr
    , state_name as name
from
    bigquery-public-data.census_utility.fips_codes_states
"""
states = bqclient.query(query_str).result().to_dataframe()
state = states[states['abbr']==state_abbr].iloc[0]
state

def yr_to_congress(yr):
    return int((yr-1786)/2)

def congress_to_yr(congress):
    return 1786 + 2 * congress

congress = yr_to_congress(yr)

In [None]:
# %%time
class myGeoDataFrame(gpd.GeoDataFrame):
    def set_crs(self, crs='NAD83'):
        self['centroid'] = self['centroid'].to_crs(crs)
        self.to_crs(crs, inplace=True)
        return self

    def get_perim(self, col=None):
        X = self.set_crs(crs_length)
        if col:
            X = self.dissolve(by=col)
        return X.length / 1000

    def get_area(self, col=None):
        X = self.set_crs(crs_area)
        if col:
            X = X.dissolve(by=col)
        return X.area / (1000**2)
    
    def copy(self):
        return self.__class__(super().copy())

def get_data(fips):
    query_str = f"""
    select
        --geo_id structure - https://www.census.gov/programs-surveys/geography/guidance/geo-identifiers.html
        geo.geo_id
        , cast(substring(geo.geo_id, 0 , 2) as int) as state_fips
        , cast(substring(geo.geo_id, 3 , 3) as int) as county_fips
        , cast(substring(geo.geo_id, 5 , 6) as int) as tract_ce
        , cast(substring(geo.geo_id, 12, 1) as int) as blockgroup_ce
        , centroids.lon
        , centroids.lat
        , cast(cd.cd as int) as cd
        , cast(acs.total_pop as int) as pop
        , geo.geometry
    from (
        -- get shapes
        select
            geo_id
            --state_fips_code as state_fips
            --, county_fips_code as county_fips
            --, tract_ce
            --, blockgroup_ce
            --, lsad_name
            --, mtfcc_feature_class_code.
            --, functional_status
            --, area_land_meters
            --, area_water_meters
            --, internal_point_lat as lat
            --, internal_point_lon aas loni
            --, internal_point_geom
            , blockgroup_geom as geometry
        from
            bigquery-public-data.geo_census_blockgroups.blockgroups_{fips}
        )  as geo
    inner join (
        -- get shapes demographic data
        select distinct
            geo_id
            , total_pop
        from
            bigquery-public-data.census_bureau_acs.blockgroup_{yr}_5yr
        ) as acs
    on
        geo.geo_id = acs.geo_id
    inner join (
        -- get population weighted centroids
        -- must build geo_id because data source does not include it
        select distinct
            concat( 
                lpad(cast(STATEFP as string), 2, "0"),
                lpad(cast(COUNTYFP as string), 3, "0"),
                lpad(cast(TRACTCE as string), 6, "0"),
                lpad(cast(BLKGRPCE as string), 1, "0")
                ) as geo_id
            --, POPULATION as pop
            , LONGITUDE as lon
            , LATITUDE as lat
        from
            {proj_id}.BLOCK_CENTROIDS.block_centroid_{fips}
        ) as centroids
    on
        geo.geo_id = centroids.geo_id
    inner join (
        -- get congressional district
        -- at block level -> must aggregate to blockgroup
        -- 7141 (3%) of blockgroups span multiple congressional districts
        -- We assign that entire bg to the cd with the most blocks
        select
            *
        from (
            select
                A.*
                , rank() over (partition by A.geo_id order by A.num_blocks_in_cd desc) as r
            from (
                select
                    left(BLOCKID, 12) as geo_id   -- remove last 4 char to get blockgroup geo_id
                    , CD{congress} as cd
                    , count(*) as num_blocks_in_cd
                from 
                    {proj_id}.Block_Equivalency_Files.{congress}th_BEF
                group by
                    1, 2
                ) as A
            ) as B
        where
            r = 1
        ) as cd
    on
        geo.geo_id = cd.geo_id
    """
    df = bqclient.query(query_str).result().to_dataframe().set_index('geo_id')
    df['geometry'] = gpd.GeoSeries.from_wkt(df['geometry']).apply(lambda p: orient(p, -1))
    df = myGeoDataFrame(df, geometry='geometry', crs=crs_map)
    df['centroid'] = gpd.points_from_xy(df['lon'], df['lat'], crs=crs_map)
    df['area'] = df.get_area()
    df['perim'] = df.get_perim()
    return df

df = get_data(state['fips'])
cds = np.unique(df['cd'])

In [None]:
def get_pairs(df, min_degree=4):
    cols = ['geo_id', 'geometry', 'centroid']
    A = df.reset_index().query('pop > 0')[cols]
    pairs = A.merge(A, how='cross').query('geo_id_x < geo_id_y').reset_index(drop=True)
    pairs['distance']     = pairs.set_geometry('centroid_x').distance(    pairs.set_geometry('centroid_y'), align=False)
    pairs['perim_shared'] = pairs.set_geometry('geometry_x').intersection(pairs.set_geometry('geometry_y'), align=False).length
    pairs['touch'] = pairs['perim_shared'] > 1
    pairs['transit_time'] = pairs['distance'] / 1341 * rng.uniform(0.5, 1.5)  # 50 mph → 1341 m/min
    pairs.drop(columns=[c+z for c in cols[1:] for z in ['_x', '_y']], inplace=True)
    pairs = pd.concat([pairs, pairs.rename(columns={'geo_id_x':'geo_id_y', 'geo_id_y':'geo_id_x'})])
    return pairs

rng = np.random.default_rng(42)
pairs = get_pairs(df)

In [None]:
def edges_to_graph(edges):
    edge_attr = ['perim_shared', 'touch', 'distance', 'transit_time']
    return nx.from_pandas_edgelist(edges, source='geo_id_x', target='geo_id_y', edge_attr=edge_attr)

def connect_districts(pairs, nodes, G):
    for cd, X in nodes.groupby('cd'):
        while True:
            H = G.subgraph(X.index)
            components = list(nx.connected_components(H))
            print(f'CD {cd} has {len(components)} connected components')
            if len(components) == 1:
                break
            mask = pairs['geo_id_x'].isin(components[0]) & pairs['geo_id_y'].isin(components[1])
            i = pairs.loc[mask]['distance'].idxmin()
            edges = pairs.loc[i]
            G.update(edges_to_graph(edges))
    return G

def make_min_degree(pairs, G, min_degree):
    edges = list()
    for node, deg in G.degree:
        n = min_degree - deg
        if n > 0:
            print(node,deg)
            mask = (pairs['geo_id_x'] == node) & ~(pairs['geo_id_y'].isin(G.neighbors(node)))
            edges.append(pairs[mask].nsmallest(n, 'distance'))
    if len(edges) > 0:
        G.update(edges_to_graph(pd.concat(edges)))
    return G

def make_graph(pairs, nodes, min_degree=2):
    edges = pairs[pairs['touch']]
    G = edges_to_graph(edges)
    node_attr = ['area', 'perim', 'cd', 'pop']
    nx.set_node_attributes(G, nodes[node_attr].to_dict('index'))

    G = connect_districts(pairs, nodes, G)
    G = make_min_degree(pairs, G, min_degree)
    return G
            
G = make_graph(pairs, nodes=df, min_degree=3)

In [None]:
%%time
df_map = df.copy().set_crs(crs_map).sort_values('cd')
df_map['geometry'] = df_map['geometry'].simplify(0.001)
df_map['cd_pop'] = df_map.groupby('cd')['pop'].transform('sum')
df_map['cd_label'] = df_map['cd'].astype(str) + ': pop=' + df_map['cd_pop'].astype(str)

fig = px.choropleth(df_map,
                    geojson=df_map['geometry'],
                    locations=df_map.index,
                    color="cd_label",
                    hover_data={'area': ':.0f',
                                'pop' : ':',
                               }
                   )
fig.update_geos(fitbounds="locations", visible=True)
fig.update_layout({
    'title' : {'text':f'{state["name"]} {yr}', 'x':0.5, 'y':1.0},
    'margin' : {"r":0,"t":20,"l":0,"b":0},
})
fig.show()