# Writing to local/cloud PostGIS db
Execution time ~= 2 minutes

In [1]:
import pandas as pd
import os
import psycopg2 as pg
from psycopg2 import sql
from sqlalchemy import create_engine

In [2]:
valid_answer = True
while (valid_answer):
    answer = input('Use Cloud DB? (y/n):').lower()
    if answer.lower() == 'n' or answer.lower() == 'no':
        db_host = os.environ.get('DB_HOST')
        db_port = os.environ.get('DB_PORT')
        db_user = os.environ.get('DB_USER')
        db_password = os.environ.get('DB_PASSWORD')
        db_name = os.environ.get('DB_NAME')
        valid_answer = False
        print('Using local DB')
    elif answer.lower() == 'y' or answer.lower() =='yes':
        db_host = os.environ.get('DB_HOST')
        db_port = os.environ.get('DB_PORT')
        db_user = os.environ.get('DB_USER')
        db_password = os.environ.get('DB_PASSWORD')
        db_name = os.environ.get('DB_NAME')
        valid_answer = False
        print('Using Cloud DB')
    else:
        print('Invalid input. Please enter y or n.')

db_url = f'postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}'

Using local DB


In [3]:
# read the csv file into a pandas dataframe, replacing NaN with None
df = pd.read_csv('final_output.csv', index_col=['cep_id'], dtype={'pa': 'Int64', 'eco': 'Int64', 'country': 'Int64'}).replace({pd.NA: None})
print(df.head(1))
#read qid df from csv
qid_df = pd.read_csv('qid.csv', index_col=['qid'])
print(qid_df.head(1))

        qid  transition_0  transition_1  transition_2  transition_3  \
cep_id                                                                
1         0    895.792133           0.0           0.0           0.0   

        transition_4  transition_5  transition_6  transition_7  transition_8  \
cep_id                                                                         
1                0.0           0.0           0.0           0.0           0.0   

        ...  transition_10  country  country_name iso3    eco  \
cep_id  ...                                                     
1       ...            0.0      171     Lithuania  LTU  80412   

                              eco_name is_marine  pa  pa_name is_protected  
cep_id                                                                      
1       Central European mixed forests     False   0     None        False  

[1 rows x 21 columns]
    quantile_name
qid              
0         20E_60N


In [None]:
#create table from dataframe (Execution time =~ 1:30 min)
engine = create_engine(db_url)
df.to_sql('cep_water', engine, if_exists='replace')

# add qids to cep_water table
qid_df.to_sql('cep_qid', engine, if_exists='replace')

In [None]:
# connect to the database
conn = pg.connect(
    database=db_name,
    user=db_user,
    password=db_password,
    host=db_host,
    port=db_port
)

#alter table cep_table to add foreign key qid and add relationship to qid table
with conn.cursor() as cursor:
    cursor.execute("""
    ALTER TABLE cep_qid
    ADD CONSTRAINT unique_qid
    UNIQUE(qid);

    ALTER TABLE cep_water
    ADD CONSTRAINT fk_qid
    FOREIGN KEY (qid) 
    REFERENCES cep_qid(qid);
    """)
    conn.commit()

# create view to group/sum transition bands by cep_id
with conn.cursor() as cursor:
    cursor.execute("""
    DROP VIEW IF EXISTS "cep_grouped";
    CREATE VIEW cep_grouped AS
    WITH groupings AS (
        SELECT 
            cep_water.cep_id,
            {bands_SUM}
        FROM 
            cep_water 
        GROUP BY 
            cep_water.cep_id
    )

    -- join temp groupings table with cep_water table, but don't select transition columns
    SELECT DISTINCT
        cep_water.cep_id,
        country,
        country_name,
        iso3,
        eco,
        eco_name,
        is_marine,
        pa,
        pa_name,
        is_protected,
        {groupings_bands}
    FROM
        cep_water
    JOIN
        groupings
    ON
        cep_water.cep_id = groupings.cep_id;
    """
    .format(
        # sum all transition bands
        bands_SUM = ', '.join([f'SUM(transition_{i}) as "transition_{i}"' for i in range(11)]),
        # select all columns except transition bands
        groupings_bands = ', '.join([f'groupings.transition_{i}' for i in range(11)])
    )
)
    conn.commit()

    

# close the connection
cursor.close()
conn.close()

# remove df and qid_df from memory
del df
del qid_df