In [1]:
from sqlalchemy import create_engine, MetaData
import json
from sqlalchemy.engine import URL
from sqlalchemy.types import VARCHAR
from pyodbc import connect, SQL_WVARCHAR
import logging


In [2]:
def get_logger(name):
    """Generate a logger with a given name."""
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    
    file_formatter = logging.Formatter(
        '%(levelname)s - %(asctime)s - %(message)s - %(module)s',
        "%Y-%m-%d %H:%M:%S")

    file_handler = logging.FileHandler("logfile.log")
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(file_formatter)

    logger.addHandler(file_handler)

    # console_formatter = logging.Formatter(
    #     '%(levelname)s - %(asctime)s - %(message)s - %(module)s',
    #     "%Y-%m-%d %H:%M:%S")
    
    # console_handler = logging.StreamHandler()
    # console_handler.setLevel(logging.INFO)
    # console_handler.setFormatter(console_formatter)

    # logger.addHandler(console_handler)

    return logger

LOGGER = get_logger(__name__)

In [3]:
# {
#     'destination': {
#         'drivername': 'postgresql+psycopg2',
#         'username': '',
#         'password': '',
#         'host': '',
#         'port': '',
#         'database': ''
#     },
#     'origin': {
#         'drivername': 'postgresql+psycopg2',
#         'username': '',
#         'password': '',
#         'host': '',
#         'port': '',
#         'database': ''
#     }
# }



with open('credentials.json', 'r') as file:
    CREDS = json.load(file)

In [4]:
URL_ORIGIN = URL.create(**CREDS['origin'])
URL_DESTINATION = URL.create(**CREDS['destination'])

In [5]:
ORIGIN_ENGINE = create_engine(URL_ORIGIN, connect_args={'sslmode': 'prefer'})
DESTINATION_ENGINE = create_engine(URL_DESTINATION, connect_args={'sslmode': 'prefer'})
ODBC_CREDS = {
    'Driver': '/opt/amazon/redshift/lib/libamazonredshiftodbc.dylib',
    'Server': CREDS['destination']['host'],
    'Database': CREDS['destination']['database'],
    'UID': CREDS['destination']['username'],
    'PWD': CREDS['destination']['password'],
    'Port': CREDS['destination']['port'],
    'BoolsAsChar': 0
    }

In [6]:
class DuplicateSchema():
    def __init__(self, origin_engine, destination_engine):
        # SQLalchemy engines for origin and destination DB.
        self.origin_engine = origin_engine
        self.destination_engine = destination_engine
        # Meta data objects.
        self.origin_engine_meta = MetaData(bind=origin_engine)
        self.destination_engine_meta = MetaData(bind=destination_engine)
        # Schema tracker variable for _meta_refresch method.
        self.previous_schema = ''

    def _meta_refresh(self, schema_name):
        """Refresh origin db metadata, when necessary.
        
        The relfect() method takes a while to run. This function makes sure
        that it only runs when users switch to new Redshift schema.
        """
        if self.previous_schema != schema_name:
            LOGGER.info("Set sqlalchemy meta schema to %s.", schema_name)
            self.origin_engine_meta.clear()
            self.origin_engine_meta.reflect(schema=schema_name)
            self.previous_schema = schema_name
        return None
    
    def setup_schema(self, schema_name):
        """Create schema only in destination db."""
        LOGGER.info("Creating the %s schema.", schema_name)
        with self.destination_engine.connect() as con:
            con.execute('CREATE SCHEMA IF NOT EXISTS {};'.format(schema_name))

        return None
    
    def _remove_identity_clause(self, table_name):
        """Removes indenity and default column values.
        
        Trying to insert values into an identity/serial column causes errors.
        As a work around, it is possible to remove these restrictions from
        the metadata file, so that they do not get implemented in destination
        database.
        """
        table = self.origin_engine_meta.tables[table_name]

        for column in table.columns:
            column.server_default = None

        return None

    def create_1_table(self, schema_name, table_name):
        """Create one empty table in destination db."""
        LOGGER.info("Creating the %s table.", table_name)
        self._meta_refresh(schema_name)
        
        self._remove_identity_clause(table_name)

        self.origin_engine_meta.tables[table_name].create(bind=self.destination_engine)

        return None

    def create_all_tables(self, schema_name):
        """Creates all tables in a given schema in desitnation db."""
        LOGGER.info("Creating all tables in %s schema.", schema_name)

        self._meta_refresh(schema_name)

        tables = self.origin_engine_meta.tables

        for table in tables:
            self._remove_identity_clause(table)

        self.origin_engine_meta.create_all(bind=self.destination_engine)

    def create_full_schema(self, schema_name):
        """Creates both schema and corresponding table in destination db."""
        self._meta_refresh(schema_name)

        self.setup_schema(schema_name)

        self.create_all_tables(schema_name)

        return None


In [7]:
class SampleData(DuplicateSchema):
    """Populate newly duplicated tables in new db with sample data from origin db."""
    def __init__(self, origin_engine, destination_engine, odbc_creds):
        self.origin_engine = origin_engine
        self.destination_engine = destination_engine
        self.origin_engine_meta = MetaData(bind=origin_engine)
        self.destination_engine_meta = MetaData(bind=destination_engine)
        self.odbc_creds = odbc_creds
        self.odbc_connection = connect(**self.odbc_creds)
        self.previous_schema = ''

    def _get_data_sample(self, table_name, sample_size):
        """Fetch a sample of rows from a single table in origin db."""
        LOGGER.info("Fetch data sample from %s table.", table_name)

        table =  self.origin_engine_meta.tables[table_name]

        cursor = table.select().limit(sample_size).execute()

        rows = cursor.fetchall()

        return rows

    def _odbc_executemany_args(self, table_name, rows):
        """Generate inputs for the ODBC insert query.
        The query needs to have as many question marks as variables in the data sample.
        Each row of variables must be a tuple in a list.
        """
        LOGGER.info("Generate odbc arguments for the %s table.", table_name)

        query_template = "INSERT INTO {table_name} VALUES ({question_marks})"

        # Create as many quetion marks as variables in source data.
        question_mark_list = ['?'] * len(rows[0])

        # Generate the query with table name and question marks.
        insert_query = query_template.format(
            table_name = table_name,
            question_marks = ", ".join(question_mark_list)
        )

        list_of_tuples = [tuple(x) for x in rows]

        return insert_query, list_of_tuples
        
    def _odbc_data_types(self, table_name):
        """Overwrite pyodbc character limit on VARCHAR.
        
        There is a well known issue with pyodbc module, which happens
        when trying to insert long strinds into large TEXT or VARCHAR
        columns. Even if the destination column is big enough, the module
        will throw out an overflow error. Manually overriding the data types
        used in the cursor is one workaround. This function generates a list of 
        tuples, which modify the data types of large VARCHAR columns.
        """
        LOGGER.info("Generate data types for the %s table.", table_name)

        columns = self.origin_engine_meta.tables[table_name].columns

        # Generate a list of None values. If there are no large VARCHAR columns
        # the cursor will use its defaults. 
        result = [None] * len(columns)

        # If a given column is of type VARCHAR and has more than 100 char length,
        # None value in the list is replaced with a data type tuple.
        for i, column in enumerate(columns):            
            if type(column.type) == VARCHAR:
                if column.type.length >= 100:
                    result[i] = (SQL_WVARCHAR, 100000, 0)
        
        return result



    def populate_1_table(self, schema_name, table_name, sample_size):
        """Fetch a data sample and insert it into a single destiontion table."""
        LOGGER.info("Being populating the %s table.", table_name)

        self._meta_refresh(schema_name)
        
        rows = self._get_data_sample(table_name, sample_size)

        if len(rows) == 0:
            pass

        else:

            insert_query, list_of_tuples = self._odbc_executemany_args(table_name, rows)

            data_types = self._odbc_data_types(table_name)
            
            cursor = self.odbc_connection.cursor()
            # Here, the list of typles is used to overwrite certain column data types.
            cursor.setinputsizes(data_types)
            cursor.fast_executemany = True

            cursor.executemany(insert_query, list_of_tuples)

            cursor.commit()
            cursor.close()
            
            return None
    
    def populate_all_tables(self, schema_name, sample_size):
        """Populate all tables in a given destination schema with sample data."""
        LOGGER.info("Begin populating tables in the %s schema.", schema_name)

        self._meta_refresh(schema_name)

        table_names = list(self.origin_engine_meta.tables.keys())

        for table in table_names:
            self.populate_1_table(schema_name, table, sample_size)
        
        return None
        
        