In [None]:
import os, requests, sys, threading, time
from collections import defaultdict
from functools import cache
import pandas as pd
from utils.utils import SimpleThreadPoolExecutor, PrCall, ThCall
import numpy as np
from psql_utils.epsql import get_schema, get_table_name, sanitize_table_name, sanitize_column_names
from psql_utils import epsql
from tqdm.notebook import tqdm
from typing import Dict
import sqlalchemy.exc
import psycopg2.errors


engine_dict = {}
def engine() -> epsql.Engine: 
    id = (os.getpid(), threading.get_ident())
    if id not in engine_dict:
        engine_dict[id] = epsql.Engine(verbose = False)
    return engine_dict[id]

In [None]:
def census_api_get(base_url, payload):
    payload = payload.copy() # Don't modify the original
    payload['key'] = open("secrets/census_api_key.txt").read().strip()
    retries = 5
    for retry in range(retries):
        if retry:
            print(f"Retry {retry+1} of {retries} for GET {base_url} {payload}")
        response = None
        try:
            response = requests.get(base_url, params=payload)
            # If it looks like a server issue, make exception now
            if response.status_code // 100 not in (2, 4):
                response.raise_for_status()
        except Exception as e:
            print(f"During try {retry+1} of {retries} for GET {response and response.url}, received exception {e}")
            if retry == retries - 1:
                print("Aborting since this is the last retry")
                raise
            continue
        if response.status_code == 200:
            if retry:
                print(f"On retry {retry+1}, successful GET {response.url}")
            return pd.DataFrame(response.json()[1:], columns=response.json()[0])
        if response.status_code // 100 == 4:
            # 4xx errors are client errors, so don't retry
            print(f"During try {retry+1} of {retries} for GET {response.url}, aborting due to client error status code {response.status_code} {response.text}")
            response.raise_for_status()
        # Otherwise, retry
        time.sleep(5)
    raise Exception("Should never get here")

class CensusApiDataset():
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self.schema_name = "census"
        unsorted_vars = requests.get(f'https://api.census.gov/data/{dataset_name}/variables.json').json()['variables']
        patches = {
            '2000/dec/sf1': {
                'P001001': {'predicateType': 'int'},
                'P004001': {'predicateType': 'int'}
            },
            '2000/dec/sf2': {
                'HCT004001': {'predicateType': 'int'}
            },
            '2010/dec/sf1': {
                'P001001': {'predicateType': 'int'},
            }
        }

        for col, patch in patches.get(dataset_name, {}).items():
            for key, value in patch.items():
                print(f"Patching {dataset_name}[{repr(col)}][{repr(key)}]={repr(value)}")
                unsorted_vars[col][key] = value
                
        # if dataset_name in ["2000/dec/sf1", "2010/dec/sf1"] and 'predicateType' not in unsorted_vars['P001001']:
        #     print(f"Patching predicateType='int' for {dataset_name}.P001001")
        #     unsorted_vars['P001001']['predicateType'] = "int"
        self.variables = dict(sorted(unsorted_vars.items()))
        tables_build = defaultdict(dict)
        for var_name, var_info in self.variables.items():
            for table_name in set(var_info['group'].split(',')) - {'N/A'}:
                tables_build[table_name][var_name]=var_info

        self.tables: Dict[str, dict] = dict(sorted(tables_build.items()))
        self.table_geom_levels = defaultdict(list)
        for table_name in self.tables.keys():
            self.table_geom_levels[table_name].append('state')
            self.table_geom_levels[table_name].append('county')
            if table_name.startswith('PCO') or table_name.startswith('HCO'):
                continue
            self.table_geom_levels[table_name].append('tract')
            if table_name.startswith('PCT') or table_name.startswith('HCT'):
                continue
            self.table_geom_levels[table_name].append('blockgroup')
            self.table_geom_levels[table_name].append('block')

        self.bad_cols = {}
        print(f"{dataset_name}: found tables ({', '.join(self.tables.keys())})")
    
    def get_data(self, fields: list[str], geometry_level: str, in_: str = ""):
        if geometry_level == "blockgroup":
            geometry_level = "block group"
        assert(fields)
        payload = {
            'get': ','.join(fields),
            'for': f'{geometry_level}:*'
        }
        if in_:
            payload['in'] = in_
        
        data = census_api_get(f'https://api.census.gov/data/{self.dataset_name}', payload)

        conversion_exceptions = []

        for col in data.columns:
            var_info = self.variables.get(col)
            if var_info:
                dtypes = {
                    'int': pd.Int32Dtype(), # NA-able (nullable) integer type
                    'string': object,
                    'float': np.float32,
                }
                if 'predicateType' not in var_info:
                    raise RuntimeError(
                        f"var_info for {self.dataset_name}.{col} is missing predicateType\n"
                        f"var_info = {var_info}"
                    )

                new_dtype = dtypes[var_info['predicateType']]
                if data[col].dtype != new_dtype:
                    required_type = dtypes[var_info['predicateType']]
                    try:
                        data[col] = data[col].astype(required_type)
                    except TypeError as e:
                        self.bad_cols[col] = data[col]
                        conversion_exceptions.append(
                            f"Cannot convert {self.dataset_name}.{col} to type {required_type}\n"
                            f"(Column stored as self.bad_cols[{col}] for developer inspection)\n"
                            f"Source data from census API contains:\n"
                            f"{data[col].map(lambda x: type(x)).value_counts().to_string()}\n"
                            f"(Exception: {e})")
                null_count = data[col].isna().sum()
                if null_count:
                    raise RuntimeError(f"{self.dataset_name}.{col} {geometry_level} contains {null_count} NULLs of {len(data[col])} values")
    
        if conversion_exceptions:
            raise RuntimeError("----------------\n".join(conversion_exceptions))
        
        if "block group" in data.columns:
            data.rename(columns={"block group": "blockgroup"}, inplace=True)

        if "GEO_ID" in data.columns:
            data['geoid'] = data['GEO_ID'].str[9:]
        return data
    
    @cache
    def get_states(self):
        states = self.get_data(["NAME", "GEO_ID"], "state")
        return dict(sorted(zip(states["geoid"], states.to_dict('records'))))

    @cache
    def get_counties(self):
        counties = self.get_data(["NAME", "GEO_ID"], "county")
        return dict(sorted(zip(counties["geoid"], counties.to_dict('records'))))

    def sql_table_name(self, table_name: str, geometry_level: str):
        tokens = self.dataset_name.split("/")
        assert(len(tokens) == 3)
        (year, dataset, subfile) = tokens
        return sanitize_table_name(f"{self.schema_name}.{dataset}{year}{subfile}_{table_name}_{geometry_level}")
    
    # Download all records for table_name
    def download_table(self, table_name: str, geometry_level: str):
        sql_table_name = self.sql_table_name(table_name, geometry_level)
        #print(f"Downloading {sql_table_name}")
        engine().execute(f"CREATE SCHEMA IF NOT EXISTS {get_schema(sql_table_name)}")

        assert table_name in self.tables
        assert geometry_level in self.table_geom_levels[table_name]
        # fields = []
        # for var_name, var_info in self.tables[table_name].items():
        #     fields.append(var_name)
        #     fields += var_info['attributes'].split(',')
        # shards = []
        if geometry_level in ['tract', 'blockgroup', 'block']:
            downloads = [{
                "sql":f"geoid between '{geoid}' and '{geoid}z'",
                "in":f"state:{geoid} county:*"
            } for geoid in self.get_states().keys()]
        else:
            downloads = [{}]
        
        # Create tqdm progress bar that's initially not displayed
        pbar =  None

        for i, download in enumerate(downloads):
            sql = f"SELECT geoid from {sql_table_name}"
            if "sql" in download:
                sql += f" WHERE {download['sql']}"
            sql += " LIMIT 1"
            in_ = download.get("in", "")
            if engine().table_exists(sql_table_name) and len(engine().execute_returning_dicts(sql)) > 0:
                #print(f"{sql_table_name} {in_} already loaded ({count} records), skipping")
                if pbar is not None:
                    pbar.update()
                continue

            if pbar is None:
                pbar = tqdm(total=len(downloads), desc=sql_table_name, initial=i)

            PrCall(self.get_data_and_insert, table_name, geometry_level, sql_table_name, in_).value()
            pbar.update()
    
        self.add_primary_key(sql_table_name)

        if pbar is not None:
            pbar.close()

    def get_data_and_insert(self, table_name, geometry_level, sql_table_name, in_):
        table = self.get_data([f"group({table_name})"], geometry_level, in_)
        sanitize_column_names(table, inplace=True)
        try:
            table.to_sql(get_table_name(sql_table_name), engine().engine, schema=get_schema(sql_table_name), if_exists='append', index=False)
        except (sqlalchemy.exc.IntegrityError, psycopg2.errors.UniqueViolation) as e:
            print(f"While get_data_and_insert for {sql_table_name}, in_ {in_}, got exception {e}", flush=True)
            raise e
        self.add_primary_key(sql_table_name)

    def add_primary_key(self, sql_table_name):
        if not engine().table_has_primary_key(sql_table_name):
            try:
                engine().execute(f'ALTER TABLE {sql_table_name} ADD PRIMARY KEY (geoid)')
            except (sqlalchemy.exc.IntegrityError, psycopg2.errors.UniqueViolation) as e:
                if engine().table_column_exists(sql_table_name, 'popgroup'):
                    engine().execute(f'ALTER TABLE {sql_table_name} ADD PRIMARY KEY (geoid, popgroup)')
                else:
                    raise RuntimeError(f"Table {sql_table_name} has duplicate geoids, but no popgroup column.")

    def download_table_geometries(self, table_name: str):
        for geometry_level in self.table_geom_levels[table_name]:
            self.download_table(table_name, geometry_level)

    def download_tables_geometries(self, nthreads: int = 15):
        print(f"Downloading tables {', '.join(self.tables.keys())}", flush=True)
        if nthreads == 1:
            for table_name in self.tables.keys():
                self.download_table_geometries(table_name)
        else:
            pool = SimpleThreadPoolExecutor(nthreads)
            for table_name in self.tables.keys():
                 pool.submit(self.download_table_geometries, table_name)
            pool.shutdown(tqdm=tqdm(desc=self.dataset_name, colour="red"))

def display_storage():
    bar = None
    while True:
        size = engine().list_schema_sizes().query('schema_name == "census"')['size_mb'].iloc[0]*1e6
        if bar is None:
            bar = tqdm(desc="census schema size", colour="red", unit="B", initial=size, unit_scale=True)
        else:
            bar.update(size - bar.n)
        time.sleep(60)

#engine().execute("drop table census.dec2020pl_p2_tract")

#ds = CensusApiDataset("2010/dec/sf1")

ThCall(display_storage)

for dataset in [
    #"2020/dec/pl",
    #"2010/dec/sf1", 
    #"2010/dec/sf2",
    "2000/dec/sf1", "2000/dec/sf2", "2000/dec/sf3"]:
    ds = CensusApiDataset(dataset)
    ds.download_tables_geometries(nthreads=15)

In [None]:
# for table in engine().list_tables("census"):
#     if table.startswith("dec2020"):
#         engine().execute("drop table census." + table)
#     else:
#         print(table)

In [None]:
engine().execute_returning_df('select "POPGROUP" from census.dec2010sf2_hct1_state limit 10')