In [None]:
import requests
from collections import defaultdict
from functools import cache
import pandas as pd
from utils.utils import SimpleThreadPoolExecutor

In [None]:
from psql_utils.epsql import get_schema, get_table_name, sanitize_table_name
from psql_utils import epsql

@cache
def engine():
    return epsql.Engine()

def census_api_get(base_url, payload):
    payload = payload.copy() # Don't modify the original
    payload['key'] = open("secrets/census_api_key.txt").read().strip()
    response = requests.get(base_url, params=payload)
    if response.status_code != 200:
        print(response.status_code, response.text)
        response.raise_for_status()
    return pd.DataFrame(response.json()[1:], columns=response.json()[0])

class CensusApiDataset():
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self.schema_name = "census"
        variables = requests.get(f'https://api.census.gov/data/{dataset_name}/variables.json').json()['variables']
        tables_build = defaultdict(dict)
        for var_name in sorted(variables.keys()):
            var_info = variables[var_name]
            for table_name in set(var_info['group'].split(',')) - {'N/A'}:
                tables_build[table_name][var_name]=var_info

        self.tables = dict(sorted(tables_build.items()))
        print(f"{dataset_name}: found tables ({', '.join(self.tables.keys())})")
    
    def get_data(self, fields: list[str], geometry_level: str, in_: str = ""):
        if geometry_level == "blockgroup":
            geometry_level = "block group"
        assert(fields)
        data: pd.DataFrame|None = None
        shards: list[pd.DataFrame] = []
        # Observe the census api limit of 50 fields per request, and make sure to add GEO_ID
        # to each subrequest so we can join the results
        for subfields in [fields[i:i + 49] for i in range(0, len(fields), 49)]:
            if "GEO_ID" not in subfields:
                subfields.append("GEO_ID")                
            payload = {
                'get': ','.join(subfields),
                'for': f'{geometry_level}:*'
            }
            if in_:
                payload['in'] = in_
        
            subdata = census_api_get(f'https://api.census.gov/data/{self.dataset_name}', payload)
            if data is None:
                data = subdata
            else:
                # Join on all columns that are in both dataframes, as merge will not duplicate and add suffixes
                # to any columns in the "on" argument.
                # We assume that any duplicate columns (e.g. state, county) have the same data, and that GEO_ID
                # in particular is specific enough to guarantee like records are combined.
                join_columns = list(set(subdata.columns).intersection(set(data.columns)))
                data = data.merge(subdata, on=join_columns)
        
        assert(data is not None)
        if 'GEO_ID' not in fields:
            data.drop(columns=['GEO_ID'], inplace=True)

        if "block group" in data.columns:
            data.rename(columns={"block group": "blockgroup"}, inplace=True)

        if "GEO_ID" in data.columns:
            data['geoid'] = data['GEO_ID'].str[9:]
        return data
    
    @cache
    def get_states(self):
        states = self.get_data(["NAME", "GEO_ID"], "state")
        return dict(sorted(zip(states["geoid"], states.to_dict('records'))))

    @cache
    def get_counties(self):
        counties = self.get_data(["NAME", "GEO_ID"], "county")
        return dict(sorted(zip(counties["geoid"], counties.to_dict('records'))))

    def sql_table_name(self, table_name: str, geometry_level: str):
        tokens = self.dataset_name.split("/")
        assert(len(tokens) == 3)
        (year, dataset, subfile) = tokens
        return sanitize_table_name(f"{self.schema_name}.{dataset}{year}{subfile}_{table_name}_{geometry_level}")
    
    # Download all records for table_name
    def download_table(self, table_name: str, geometry_level: str):
        print(f"Downloading {table_name} ({geometry_level})")
        sql_table_name = self.sql_table_name(table_name, geometry_level)
        engine().execute(f"CREATE SCHEMA IF NOT EXISTS {get_schema(sql_table_name)}")

        assert table_name in self.tables
        assert geometry_level in ['block', 'blockgroup', 'tract', 'county', 'place', 'state']
        fields = []
        for var_name, var_info in self.tables[table_name].items():
            fields.append(var_name)
            fields += var_info['attributes'].split(',')
        shards = []
        if geometry_level == 'tract':
            downloads = [{
                "sql":f"geoid between '{geoid}' and '{geoid}z'",
                "in":f"state:{geoid}"
            } for geoid in self.get_states().keys()]
        elif geometry_level in ['blockgroup', 'block']:
            downloads = [{
                "sql":f"geoid between '{geoid}' and '{geoid}z'",
                "in":f"state:{geoid[0:2]} county:{geoid[2:5]}"
            } for geoid in self.get_counties().keys()]
        else:
            downloads = [{}]
        n_done = 0
        for download in downloads:
            sql = f"SELECT count(*) from {sql_table_name}"
            if "sql" in download:
                sql += f" WHERE {download['sql']}"
            in_ = download.get("in", "")
            if engine().table_exists(sql_table_name) and (count := engine().execute_returning_value(sql)) > 0:
                n_done += 1
                #print(f"{sql_table_name} {in_} already loaded ({count} records), skipping")
                continue
            table = self.get_data(fields, geometry_level, in_)
            table.to_sql(get_table_name(sql_table_name), engine().engine, schema=get_schema(sql_table_name), if_exists='append', index=False)
            if not engine().table_has_primary_key(sql_table_name):
                engine().execute(f'ALTER TABLE {sql_table_name} ADD PRIMARY KEY (geoid)')
            print(f"{sql_table_name} {in_} loaded {len(table)} records of {len(table.columns)} fields")
        if n_done:
            print(f"{sql_table_name} {n_done} of {len(downloads)} downloads previously completed")
        print(f"{table_name} ({geometry_level}) is complete")

    def download_table_geometries(self, table_name: str):
        for geometry_level in ['state', 'county', 'tract', 'blockgroup', 'block']:
            self.download_table(table_name, geometry_level)

    def download_tables_geometries(self, nthreads: int = 10):
        pool = SimpleThreadPoolExecutor(nthreads)
        print(f"Downloading tables {', '.join(self.tables.keys())}")
        for table_name in self.tables.keys():
            pool.submit(self.download_table_geometries, table_name)
        pool.shutdown()

    # def geocode_in_place(self, table_name: str, idx_name: str = 'idx', chunk_size: int = 500, nthreads:int = 15):
    #     # Performance is around 400 geocodes per second, on hal21 with 15 threads
    #     min_idx = self.execute_returning_dicts(f'select min(idx) from {table_name}')[0]['min']
    #     max_idx = self.execute_returning_dicts(f'select max(idx) from {table_name}')[0]['max']
    #     print(f'geocode_in_place: {idx_name} ranges from {min_idx} to {max_idx}')
    #     print(max_idx)
    #     pool = SimpleThreadPoolExecutor(nthreads)
    #     chunks = list(range(min_idx, max_idx + 1, chunk_size))
    #     print(f'Geocoding in {len(chunks)} chunks of size {chunk_size}')
    #     for chunk in chunks:
    #         pool.submit(self.geocode_chunk_in_place, table_name, chunk, min(chunk + chunk_size - 1, max_idx), idx_name)
    #     pool.shutdown()

    #     for table_name in list(self.tables.keys()):
    #         for geometry_level in ['state', 'county', 'tract', 'blockgroup', 'block']:
    #             self.get_table(table_name, geometry_level)


#engine().execute("drop table census.dec2020pl_p2_tract")
ds = CensusApiDataset("2020/dec/pl")
#ds.get_table("H1", "state")
#ds.get_table("H1", "county")["state"].unique()
#ds.get_table("H1", "tract")
#ds.get_table("H1", "blockgroup")
#ds.get_data(["NAME", "GEO_ID"], "state")
ds.download_tables_geometries()
#ds.download_table_geometries("P1")

