Since we originally started with district 4841 some of our code is labeled with this label. District 4841 and 4840 should really be merged together to form all of district 4840.
The purpose of this script is to repurpose any 4841 label to 4840

In [None]:
from dotenv import load_dotenv
import sqlalchemy as sq
import geopandas as gpd  # type: ignore
import pandas as pd
import os, sys

sys.path.append("../")
from Shared.DataService import DataService
from Shared.GenericQueryBuilder import GenericQueryBuilder

In [None]:
AG_REGIONS_TABLENAME = "census_ag_regions"

load_dotenv()
PG_DB = os.getenv("POSTGRES_DB")
PG_ADDR = os.getenv("POSTGRES_ADDR")
PG_PORT = os.getenv("POSTGRES_PORT")
PG_USER = os.getenv("POSTGRES_USER")
PG_PW = os.getenv("POSTGRES_PW")

In [None]:
def getTables(conn: sq.engine.Connection) -> pd.DataFrame:
    tableQuery = sq.text(
        """
        SELECT table_name FROM information_schema.tables
        WHERE table_schema='public';
        """
    )

    return pd.read_sql_query(tableQuery, conn)

In [None]:
def getColumns(table, conn: sq.engine.Connection) -> pd.DataFrame:
    columnQuery = sq.text(
        f"""
        SELECT column_name FROM information_schema.columns
        WHERE table_schema = 'public'
        AND table_name = '{table}';
        """
    )

    return pd.read_sql_query(columnQuery, conn)

In [None]:
def updateColumnName(table: str, oldName: str, newName: str, db: DataService):
    updateColNameQuery = sq.text(
        f"""
        ALTER TABLE public.{table}
        RENAME COLUMN {oldName} to {newName};
        COMMIT;
        """
    )

    db.execute(updateColNameQuery)

In [None]:
def updateAllValues(
    table: str, colName: str, oldVal: int, newVal: int, db: DataService
):
    updateAllValuesQuery = sq.text(
        f"""
        UPDATE public.{table}
        SET {colName} = {newVal}
        WHERE {colName} = {oldVal};
        COMMIT;
        """
    )

    db.execute(updateAllValuesQuery)

In [None]:
def correctRegions(conn: sq.engine.Connection, db: DataService):
    regionQuery = sq.text("select * FROM public.census_ag_regions")
    agRegions = gpd.GeoDataFrame.from_postgis(
        regionQuery, conn, crs="EPSG:3347", geom_col="geometry"
    )

    # get both rows corresponding to district 4840 and 4841 respectively
    districts4841And4840 = agRegions.loc[
        (agRegions["car_uid"] == 4840) | (agRegions["car_uid"] == 4841)
    ]

    # merge their geometries into the row with district 4840
    agRegions.loc[
        agRegions["car_uid"] == 4840, "geometry"
    ] = districts4841And4840.unary_union

    # update the name to reflect the entirety of the district
    agRegions.loc[
        agRegions["car_uid"] == 4840, "car_name"
    ] = "Census Agricultural Region 4"

    # delete district 4841 which is now apart of 4840
    agRegions.drop(agRegions.loc[agRegions["car_uid"] == 4841].index)

    # rename the car_uid column to district
    agRegions.rename(columns={"car_uid": "district"}, inplace=True)

    dropRegionsQuery = sq.text("DROP TABLE public.census_ag_regions;")
    db.execute(dropRegionsQuery)
    agRegions.to_postgis(AG_REGIONS_TABLENAME, conn, index=False, if_exists="replace")

In [None]:
def main():
    if (
        PG_DB is None
        or PG_ADDR is None
        or PG_PORT is None
        or PG_USER is None
        or PG_PW is None
    ):
        raise ValueError("Environment variables not set")

    db = DataService(PG_DB, PG_ADDR, int(PG_PORT), PG_USER, PG_PW)
    conn = db.connect()

    tables = getTables(conn)
    for index, row in tables.iterrows():  # for each table, check their columns
        columns = getColumns(row["table_name"], conn)

        # creates a list out of the columns pulled from the table
        columnList = columns["column_name"].tolist()

        try:
            # if the current table has a column named car_uid or district, check it out
            if "car_uid" in columnList:
                if row["table_name"] == AG_REGIONS_TABLENAME:
                    correctRegions(conn, db)
                else:
                    updateColumnName(row["table_name"], "car_uid", "district", db)
                    updateAllValues(row["table_name"], "district", 4841, 4840, db)

            elif "district" in columnList:
                updateAllValues(row["table_name"], "district", 4841, 4840, db)

            print(f'[{index + 1}/{len(tables)}] Finished updating {row["table_name"]}')
        except Exception as e:
            print(
                f'[{index + 1}/{len(tables)}] ERROR - could not update {row["table_name"]}'
            )
            print(e)

    db.cleanup()

In [None]:
if __name__ == "__main__":
    main()