In [None]:
print('postgres-utils.ipynb is deprecated;  consider switching to psql.py')

def drop_table(con, table_name):
    cur = con.cursor()
    cur.execute(f"DROP TABLE {table_name};")
    cur.close()
    con.commit()

def psql_sanitize_column_name(colname):
    colname = re.sub('\W+$', '', colname)
    colname = re.sub('^\W+', '', colname)
    colname = re.sub('\W+', '_', colname)
    return colname

sanitize_column_name = psql_sanitize_column_name

def psql_create_empty_table_from_df(con, table_name, df, override_types={}, primary_key=None, dry_run=False):
    type_map = {
        np.dtype('O'): 'text',
        np.dtype('float64'): 'float8',
        np.dtype('int64'): 'int8',
        np.dtype('bool'): 'bool',
        np.dtype('datetime64[ns]'): 'date'   # date without time or timezone.  use override_types for other choices
    }

    def col_constraint(col):
        if col == primary_key:
            return " PRIMARY KEY";
        else:
            return ""

    def col_type(col):
        if col in override_types:
            return override_types[col]
        try:
            # If geopandas is loaded, look for GeometryDtype
            if isinstance(df[col].dtype, gpd.array.GeometryDtype):
                return 'geometry'
        except:
            pass
        return type_map[df[col].dtype]

    sql_cols = [f"    {sanitize_column_name(col):63s} {col_type(col)}{col_constraint(col)}" for col in df.columns]
    sql_cols = ',\n'.join(sql_cols)

    cmd = f"""CREATE TABLE {table_name} (
{sql_cols}
);"""
    if dry_run:
        print(cmd)
    else:
        psql_execute(con, cmd)


create_empty_table_from_df_colums = psql_create_empty_table_from_df

def psql_append_df_to_table(con, df, table_name):
    cur = con.cursor()
    col_names = [sanitize_column_name(c) for c in df.columns]
    with Stopwatch(f'Creating csv of {len(df)} records for {table_name}'):
        csv = io.StringIO(df.to_csv(index=False))
    with Stopwatch(f'Appending csv of {len(df)} records to {table_name}'):
        # postgres ignores CSV header!  so be sure we specify the column names correctly
        cur.copy_expert(sql=f"COPY {table_name} ({','.join(col_names)}) FROM stdin DELIMITER ',' CSV header;",
                        file=csv)
        cur.close()
        con.commit()
    sys.stdout.write(f'Wrote {len(df)} records to {table_name}\n')
    sys.stdout.flush()

append_df_to_table=psql_append_df_to_table

def psql_read_table_as_df(con, table_name):
    cur = con.cursor()
    cur.execute(f'SELECT * FROM {table_name};')
    col_names = [col.name for col in cur.description]
    data = cur.fetchall()
    print(f'Read {len(data)} records from {table_name}')
    return pd.DataFrame(data, columns=col_names)

read_table_as_df = psql_read_table_as_df

def read_table_as_gdf(con, table_name, cmd=None, cmd_args=None):
    if cmd is None:
        cmd = f'SELECT * FROM {table_name};'
    if cmd_args is None:
        cmd_args = ()
    rec_arr = psql_select_records(con, cmd, cmd_args)
    
    if len(rec_arr) == 0:
        return None

    gdf = gpd.GeoDataFrame(rec_arr,crs={'init': 'epsg:4326'})

    return gdf

def psql_select_records(con, cmd, args=()):
    cur = con.cursor()

    # find type_code for postgis geometry type
    cur.execute("SELECT oid FROM pg_catalog.pg_type WHERE typname='geometry';")
    records = cur.fetchall()
    if len(records) == 1:
        geometry_type_code = records[0][0]
    else:
        geometry_type_code = None

    cur.execute(cmd, args)
    col_names = [col.name for col in cur.description]
    records = cur.fetchall()

    # For each geometry column...
    if len(records):
        for i in range(len(cur.description)):
            if cur.description[i].type_code == geometry_type_code:
                # If records are tuples, convert to lists to be mutable
                if isinstance(records[0], tuple):
                    records = [list(record) for record in records]
                print(f'Converting {cur.description[i].name} (column {i}) to geometry')
                # Convert from hex to object
                for record in records:
                    record[i] = shapely.wkb.loads(record[i], hex=True) if not pd.isna(record[i]) else np.nan

    ret = []
    for record in records:
        ret.append(dict(zip(col_names, record)))
    return ret

def psql_select_record(con, cmd, args=()):
    records = psql_select_records(con, cmd, args)
    if len(records) != 1:
        raise Exception(f'Expected 1 record but received {len(record)}')
    return records[0]

def psql_select_record_or_none(con, cmd, args=()):
    records = psql_select_records(con, cmd, args)
    if not records:
        return None
    if len(records) != 1:
        raise Exception(f'Expected 0 or 1 record but received {len(record)}')
    return records[0]

def psql_insert_record(con, table, dic, verbose=False):
    keys = ','.join(dic.keys())
    values = ','.join(['%s'] * len(dic))
    cmd = f"INSERT INTO {table} ({keys}) VALUES ({values})"
    psql_execute(con, cmd, tuple(dic.values()), verbose=verbose)

def psql_execute(con, cmd, args=(), verbose=True):
    cur = con.cursor()
    if (verbose):
        print(cmd)
    cur.execute(cmd, args)
    cur.close()
    con.commit()

# Currently, only ingests xlsx.  Should be easy to add other formats
# dry_run:  display CREATE_TABLE command and data as read by pandas.  Do not create table or load data
# delete_first:  delete table before creating and loading
def psql_create_table_and_populate(con, path, dry_run=False, delete_first=False):
    df = pd.read_excel(f'copaftp.state.pa.us/{path}')
    display(df)
    table_name = os.path.splitext(path)[0].replace('/','_')
    if delete_first and not dry_run:
        psql_execute(con, f'DROP TABLE {table_name};', verbose=True)
    create_empty_table_from_df_columns(con, table_name, df, dry_run=dry_run)
    if dry_run:
        return
    append_df_to_table(con, df, table_name)
    display(read_table_as_df(con, table_name))