In [None]:
import psycopg2
from urllib.parse import urlparse
import json

In [None]:
def read_json_config(file_path):
    with open(file_path, 'r') as file:
        config = json.load(file)
    return config

In [None]:
# Configuration for source database
source_db_config = read_json_config('./secrets/db_config_source.json')
print(source_db_config)

# Configuration for destination database
dest_db_config = read_json_config('./secrets/db_config_dest.json')
print(dest_db_config)

In [None]:
class DBConnection:
    def __init__(self, _config: dict):
        self.config = _config
        self.connection = psycopg2.connect(**_config)
        self.cursor = self.connection.cursor()
        self.connection.autocommit = False

    def fetch_query(self, _sql_query_string: str, _default=None):
        self.cursor.execute(_sql_query_string)
        try:
            return self.cursor.fetchall()
        except Exception as e:
            print(f"an exception occurred when fetching the following query \n'{_sql_query_string}'\n-----\n ERROR : {str(e)}")
        return _default
    
    def execute(self, _sql_command_string: str):
        self.cursor.execute(_sql_command_string)
        
    def commit_transaction(self):
        self.connection.commit()

    def rollback_transaction(self):
        self.connection.rollback()
    
    def assert_environment(self, _host_name: str, _message: str):
        if self.config['host'] != _host_name:
            raise Exception(_message)
    
    def assert_test_environment(self, _message: str):
        self.assert_environment('localhost', _message)

    def close(self):
        try:
            self.rollback_transaction()
            self.cursor.close()
            self.connection.close()
            print(f"closing DB connection : {self.config['host']}")
        except Exception as e:
            message = str(e)
            if message == 'connection already closed':
                print(f"{message} : {self.config['host']}")
            else:
                raise e

    def __del__(self):
        self.close()

In [None]:
# Clear all tables in the destination database
def wipe_schema_tables(_schema_name: str, _dest_db_config: dict):
    db_dest = DBConnection(_dest_db_config)
    db_dest.assert_test_environment("DO NOT WIPE SOURCE DATABASE!!!!!")
    try:
        tables = db_dest.fetch_query(f"""
            SELECT table_name
            FROM information_schema.tables
            WHERE table_schema='{_schema_name}' AND table_type='BASE TABLE';
        """)
        print("tables: ", tables)
        for table in tables:
            print(f"dropping table '{table[0]}' from schema {_schema_name}")
            db_dest.execute(f'DROP TABLE {_schema_name}.{table[0]} CASCADE;')
        db_dest.commit_transaction()
        print("schema tables wiped successfully")
        return True
    except Exception as e:
        print(f"Failed to wipe schema tables '{_schema_name}' : \nError Message = '{str(e).strip()}'")
    db_dest.rollback_transaction()
    return False

In [None]:
def check_if_schema_exists(_db: DBConnection, _schema_name: str):
    _db.cursor.execute(f"""
        SELECT EXISTS(
            SELECT 1 
            FROM information_schema.schemata 
            WHERE schema_name = '{_schema_name}'
        )
    """)
    exists = _db.cursor.fetchone()[0]
    return exists


def check_if_table_exists(_db: DBConnection, _schema_name: str, _table_name: str):
    _db.cursor.execute(f"""
        SELECT EXISTS (
            SELECT 1 
            FROM information_schema.tables 
            WHERE 
                table_schema = '{_schema_name}' 
                AND table_name = '{_table_name}'
        )
    """)
    exists = _db.cursor.fetchone()[0]
    return exists

In [None]:
def create_schema(_db_dest: DBConnection, _schema_name: str):
    _db_dest.assert_test_environment("DO NOT WRITE TO SOURCE DATABASE!!!")
    try:
        if not check_if_schema_exists(_db_dest, _schema_name):
            _db_dest.execute(f'CREATE SCHEMA {_schema_name};')
            _db_dest.commit_transaction()
            print(f"Schema {_schema_name} created successfully.")
        return True
    except Exception as e:
        print(f"Failed to create schema '{_schema_name}' : \nError Message = '{str(e).strip()}'")
    _db_dest.rollback_transaction()
    return False

In [None]:
def create_table(_db_dest: DBConnection, _db_source: DBConnection, _schema_name: str, _table_name: str):
    _db_dest.assert_test_environment("DO NOT WRITE TO SOURCE DATABASE!!!")
    try:
        if not check_if_table_exists(_db_dest, _schema_name, _table_name):
            column_query = f"""
                SELECT column_name, data_type 
                FROM information_schema.columns 
                WHERE table_schema='{_schema_name}'
                AND table_name='{_table_name}'
            ;"""
            column_names = _db_source.fetch_query(column_query)
            if column_names == None:
                raise Exception(f"query failed : {column_query}")
            
            column_definitions = ', '.join([f'"{col[0]}" {col[1]}' for col in column_names])
            _db_dest.execute(f'CREATE TABLE {_schema_name}.{_table_name} ({column_definitions});')
            _db_dest.commit_transaction()
            print(f"Table {_schema_name}.{_table_name} created successfully.")
        return True
    except Exception as e:
        print(f"Failed to create table '{_schema_name}.{_table_name}' : \n    !! Error Message = '{str(e).strip()}'")
    _db_dest.rollback_transaction()
    return False

In [None]:
def insert_row(_db: DBConnection, _schema_name: str, _table_name: str, _column_names: list[str], _row_values: list[any]):
    _db.assert_test_environment("DO NOT WRITE TO SOURCE DATABASE!!!")
    row_values_string_template = ','.join(['%s'] * len(_row_values))
    sql_command_insert_into = f"""
        INSERT INTO {_schema_name}.{_table_name} 
        ({",".join(_column_names)}) 
        VALUES ({row_values_string_template})
    ;"""
    # pass row as variable into cursor function, not the DBConnection function
    _db.cursor.execute(sql_command_insert_into, _row_values)

    # do NOT commit transaction here.

In [None]:
def create_and_populate_table(_db_dest: DBConnection, _db_source: DBConnection, _schema_name: str, _table_name: str):
    _db_dest.assert_test_environment("DO NOT POPULATE SOURCE DB!!!")
    try:    
        if not create_table(_db_dest, _db_source, _schema_name, _table_name):
            raise Exception("failed to create table")
        
        rows = _db_source.fetch_query(f'SELECT * FROM {_schema_name}.{_table_name};')
        if rows == None:
            raise Exception(
                f"  ! Error requesting rows from {_schema_name}.{_table_name} :\n" +
                f"    !! table population failed for table = {_table_name}"
            )
        elif not bool(rows):
            # TODO : check table record count
            print(f"  ? no records to populate {_schema_name}.{_table_name}")
            return False

        column_names = []
        for desc in _db_source.cursor.description:
            column_names.append(f'"{desc[0]}"')
        
        print(f"  > table '{_schema_name}.{_table_name}'")
        print(f"  > col ({len(column_names)}) : {column_names}")
        print(f"  > row ({len(rows[0]) }) : {rows[0]}")
        try:
            for row_values in rows:
                insert_row(_db_dest, _schema_name, _table_name, column_names, row_values)
            _db_dest.commit_transaction()
            print(f"table population completed successfully : table = {_table_name}")
            return True
        except Exception as e:
            _db_dest.rollback_transaction()
            print(
                f"  ! An error occurred while populating table rows for : '{_schema_name}.{_table_name}' :\n" +
                f"    !! {str(e)}"
            )
    except Exception as e:
        _db_dest.rollback_transaction()
        print(
            f"  ! Failed to Populate table '{_schema_name}.{_table_name}' :\n" +
            f"    !! '{str(e).strip()}'"
        )
        raise e  # unknown failure = catastrophic failure
    return False

In [None]:
# Copy tables and records from source to destination
def copy_all_tables_in_schema(_dest_db_config: dict, _source_db_config: dict, _schema_name: str):
    db_dest = DBConnection(_dest_db_config)
    db_dest.assert_test_environment("DO NOT WRITE TO SOURCE DATABASE!!!")
    db_source = DBConnection(_source_db_config)
    try:
        if not create_schema(db_dest, _schema_name):
            return False
        
        tables = db_source.fetch_query(f"""
            SELECT table_name
            FROM information_schema.tables
            WHERE table_schema='{_schema_name}' AND table_type='BASE TABLE';
        """)
        for table in tables:
            print("")
            table_name = table[0]
            create_and_populate_table(db_dest, db_source, _schema_name, table_name)
        print(f"Database population completed successfully. : schema = {_schema_name}")
        return True
    except Exception as e:
        print(f'an error was received while populating tables for schema = {_schema_name} :\n    !! {str(e)}')
    db_dest.rollback_transaction()
    return False

In [None]:
# Clear and populate destination database
def clear_and_populate(_schema_name, _source_db_config, _dest_db_config):
    if (
        wipe_schema_tables(_schema_name, _dest_db_config)
        and copy_all_tables_in_schema(_dest_db_config, _source_db_config, _schema_name)
    ):
        print("Database population completed successfully.")
    else:
        print("Database population failed")

In [None]:
# Call the clear_and_populate function to perform the operation
clear_and_populate('bishnet', source_db_config, dest_db_config)

# TODO:

## sanitize data
replace sensitive information with randomly generated pseudo info
- need to create a dictionary identifying all `schema.table.column`(s) that need to be sanitized, and what type of data to replace those entries with

## nullify empty values
some values have empty strings. change these to null?