In [0]:
import os
YB_HOST='[YB HOSTNAME]'                         #Ex: ybhost.elb.us-east-2.amazonaws.com
YB_USER="[YB USERNAME]"                         #Ex: JoeDBA@yellowbrickcloud.com
YB_PASSWORD="[YB PASSWORD]"                     #Ex: SuperS3cret!
YB_DATABASE="[YB DATABASE]"                     #Ex: proddb
YB_SCHEMA="[YB SCHEMA]"                         #Ex: gold

DB_DATABASE = "[DATABRICKS CATALOG]"            #Ex: corp
DB_SCHEMA = "[DATABRICKS SCHEMA]"               #Ex: silver

SPARK_PARITIONS = 20                            #Number of files generated, can increase if very large dataset
MIN_S3_ROWS = 1000000                           #Minimum number of rows for us to use S3, otherwise, use spark write

#AWS Information if using S3 for Staging.  Set useS3 to False if you want to stream directly
BUCKET_NAME = "[BUCKET NAME (NO URL INFO)]"     #Ex: mybucket
AWS_ACCESS_KEY_ID = "[ACCESS KEY]"              #Ex: AVIAGI4VJJXF34LCFSN2
AWS_SECRET_KEY = "[ACCESS SECRET]"              #Ex: 0Ps+NeWy1uf5cXfzg8qZoABdwv9oBbJh2Q0n2pB4
AWS_REGION = "[REGION]"                         #Ex: us-east-1

AWS_ENDPOINT = f"s3.{AWS_REGION}.amazonaws.com"
S3_BUCKET = f"s3a://{BUCKET_NAME}"

In [0]:
import sys
import argparse
import re
import glob
import time
import psycopg2
import html
import boto3

def remove_s3_directory(bucket_name, prefix, aws_access_key_id, aws_secret_access_key, aws_region_name="us-east-1"):
    session = boto3.Session(
        aws_access_key_id=aws_access_key_id,
        aws_secret_access_key=aws_secret_access_key,
    )
    s3 = session.client('s3')
    objects_to_delete = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
    delete_keys = {'Objects': [{'Key': obj['Key']} for obj in objects_to_delete.get('Contents', [])]}
    if delete_keys['Objects']:
        s3.delete_objects(Bucket=bucket_name, Delete=delete_keys)

def configure_test_s3():
    #raises an error if it fails
    spark.conf.set("fs.s3a.access.key", AWS_ACCESS_KEY_ID)
    spark.conf.set("fs.s3a.secret.key", AWS_SECRET_KEY)
    spark.conf.set("fs.s3a.endpoint", AWS_ENDPOINT)

    data = [(3, 9), (4, 16)]
    columns = ["A", "B"]
    df = spark.createDataFrame(data, columns)

    df.write \
        .format("csv") \
        .mode("overwrite") \
        .option("header", "true") \
        .save(S3_BUCKET + "/test.csv")
    df_read = spark.read \
        .format("csv") \
        .option("header", "true") \
        .load(S3_BUCKET + "/test.csv")
    remove_s3_directory(BUCKET_NAME, "/test.csv", AWS_ACCESS_KEY_ID, AWS_SECRET_KEY, AWS_REGION)
    

def configure_test_yb(createS3): 
    with psycopg2.connect(host=YB_HOST, dbname=YB_DATABASE, user=YB_USER, password=YB_PASSWORD, application_name=f'DB2YB') as connYB:
        with connYB.cursor() as cursor:
            cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {YB_SCHEMA};") 
            cursor.execute(f"SET search_path TO {YB_SCHEMA};")
            if createS3:
                cursor.execute(f'drop external location if exists "dbcopy-loc";')
                cursor.execute(f'drop external format if exists "dbcopy-fmt";')
                cursor.execute(f'drop external storage if exists "dbcopy-store";')
                cursor.execute(f"""create external storage "dbcopy-store" type s3 \
                                    endpoint 'https://{AWS_ENDPOINT}' \
                                    region '{AWS_REGION}' \
                                    identity '{AWS_ACCESS_KEY_ID}' \
                                    credential '{AWS_SECRET_KEY}';""")
                cursor.execute(f"""create external format "dbcopy-fmt" type csv with (delimiter '^', escape_char '\\', num_header_lines '1');""")
                cursor.execute(f"""create external location "dbcopy-loc" path '{BUCKET_NAME}' \
                                    external storage "dbcopy-store" \
                                    external format "dbcopy-fmt";""")

def lc(name):
    if re.match(r'^\d', name):
        out_name = f'"{name}"'
    elif name.upper() == name:
        out_name = name.lower()
    elif name.lower() == name:
        out_name = name
    else :
        out_name = f'"{name}"'
    return out_name

def make_fq(table_name, schema, database=None):
    if database is None:
        return f'{lc(schema)}.{lc(table_name)}'
    return f'{lc(database)}.{lc(schema)}.{lc(table_name)}'

def get_matching_tables(regex_patterns):
    df = spark.sql(f"SHOW TABLES IN {DB_DATABASE}.{DB_SCHEMA}")
    tables = [make_fq(row["tableName"], row["database"], DB_DATABASE) for row in df.collect()] 

    matches = []
    for pattern in regex_patterns:
        name=make_fq(pattern,DB_SCHEMA,DB_DATABASE)
        regex = re.compile(name)
        matched = [table for table in tables if regex.fullmatch(table)]
        matches.extend(matched)

    return sorted(set(matches))


def write_ddl_file(ddl, file_name):
    start_time = time.time()
    with open(file_name, 'w') as f:
        f.write(ddl)
    print(f"   * {(time.time()-start_time):7.2f}    - Wrote DDL to file {file_name}")

def read_ddl_file(file_name):
    start_time = time.time()
    with open(file_name, 'r') as f:
        ddl = f.read()
    print(f"   * {(time.time()-start_time):7.2f}    - Read table DDL from file {file_name}")
    return ddl


def getTableDDL(fq_tableName,yb_schema):
    DB_TYPE_MAP = {
            "STRING": "VARCHAR(2000)",
            "INT": "INTEGER",
            "INTEGER": "INTEGER",
            "BIGINT": "BIGINT",
            "DOUBLE": "DOUBLE PRECISION",
            "FLOAT": "REAL",
            "DECIMAL": "NUMERIC",
            "BOOLEAN": "BOOLEAN",
            "TIMESTAMP": "TIMESTAMP",
            "DATE": "DATE",
            "BINARY": "BYTEA",
            "ARRAY": "JSONB",
            "MAP": "JSONB",
            "STRUCT": "JSONB"
        }

    def split_type(type_str):
        pattern = r'^(\w+)(\([^)]+\))?$'
        match = re.match(pattern, type_str.strip())

        if match:
            raw_type = match.group(1)
            precision = match.group(2) if match.group(2) is not None else ""
            return raw_type, precision
        else:
            raise ValueError(f"Invalid type format: {type_str}")

    ddl_query = f"DESCRIBE TABLE {fq_tableName}"
    ddl_df = spark.sql(ddl_query)

    column_lines=[f"{row['col_name']} {row['data_type']}" for row in ddl_df.collect()]
        
    columns_sql = []
    for line in column_lines:
        if line.startswith("#"): 
            break
        parts = re.split(r'\s+', line, 2)
        col_name = parts[0].strip('`"')
        (raw_type,precision) = split_type(parts[1].upper())
        yb_type = next((pg for db, pg in DB_TYPE_MAP.items() if raw_type.startswith(db)), "VARCHAR")
        if precision: 
            columns_sql.append(f'  {lc(col_name)} {yb_type}{precision}')
        elif yb_type == "NUMERIC":
            columns_sql.append(f'  {lc(col_name)} {yb_type}(38,18)')
        else:
            columns_sql.append(f'  {lc(col_name)} {yb_type}')


    yb_table_name = fq_tableName.split(".")[-1]
    create_stmt = f'CREATE TABLE {lc(yb_schema)+"."+lc(yb_table_name)} (\n' + ",\n".join(columns_sql) + "\n) distribute random;"
    return create_stmt

def create_table_in_yellowbrick(ddl, table_name):
    start_time=time.time()
    try:
        with psycopg2.connect(host=YB_HOST, dbname=YB_DATABASE, user=YB_USER, password=YB_PASSWORD, application_name=f'DB2YB') as connYB:
            with connYB.cursor() as cursorYB:
                print(f"   * {(time.time()-start_time):7.2f} YB - Recreating {table_name}")
                cursorYB.execute(f'DROP TABLE IF EXISTS {make_fq(table_name, YB_SCHEMA, YB_DATABASE)};')
                cursorYB.execute(ddl)
            connYB.commit()
        print(f"   * {(time.time()-start_time):7.2f} YB - -- Dropped {table_name}")
    except Exception as e:
        print(f"Error creating table {table_name}: {e}")
    print(f"   * {(time.time()-start_time):7.2f} YB - -- Created Table {table_name}")

def append_prep_destination(append, table):
    start_time = time.time()
    predicate = fix_column_names(html.unescape(append),True).replace('\u00A0', ' ')
    print(f"   * {(time.time()-start_time):7.2f} YB - Append mode.  Deleting rows matching {predicate} from {table}")
    with psycopg2.connect(host=YB_HOST, dbname=YB_DATABASE, user=YB_USER, password=YB_PASSWORD, application_name=f'DB2YB') as connYB:
        with connYB.cursor() as cursorYB:
            delete_statement = f'DELETE FROM {table} WHERE {predicate};'
            cursorYB.execute(delete_statement)
            connYB.commit()
    

def fix_column_names(predicate,YB=True):
    if YB:
        output = re.sub(r'\[', '"', predicate)
        output = re.sub(r'\]', '"', output)
    else:
        output = re.sub(r'\[', '', predicate)
        output = re.sub(r'\]', '', output)
    return output

def get_data(full_table_name, limit, where, append):
    sql_query=True
    if append:
        query = fix_column_names(html.unescape(f"SELECT * FROM {full_table_name} WHERE {append}"),False)
    else:
        if where:
            query = fix_column_names(html.unescape(f"SELECT * FROM {full_table_name} WHERE {where}"),False)
        else:
            sql_query=False
    
    start_time=time.time()
    if not sql_query: 
        if limit:
            lText=f"up to {limit} rows of"
        else:
            lText="entire"
        print(f"             DB - Loading {lText} Spark table") 
        df=spark.table(full_table_name)
        if limit:
            df=df.limit(limit)
    else: 
        if limit:
            query += f" LIMIT {limit}"
        print(f"             DB - Executing {query}") 
        df = spark.sql(query)
    row_count = df.count()
    print(f"   * {(time.time()-start_time):7.2f} DB - Received {row_count} rows from table: {full_table_name}")
    return (row_count, df)

def spark_write(df, yb_table_name):
    start_time = time.time()
    try:
        df.write \
            .format("jdbc") \
            .option("url", f"jdbc:postgresql://{YB_HOST}:5432/{YB_DATABASE}") \
            .option("dbtable", f"{YB_SCHEMA}.{yb_table_name}") \
            .option("currentSchema", YB_SCHEMA) \
            .option("user", YB_USER) \
            .option("password", YB_PASSWORD) \
            .option("driver", "org.postgresql.Driver") \
            .option("numPartitions", SPARK_PARITIONS) \
            .mode("append") \
            .save() 
        count=df.count()
        print(f"   * {(time.time()-start_time):7.2f} YB - Spark Wrote {count} rows to table {yb_table_name} {(count/(time.time()-start_time)):,.0f} RPS")
    except Exception as e:
        print(f"❌ Spark write to {yb_table_name} failed: {e}")
        count=0
    return count

def s3_write(df, yb_table_name):

    def export_table_to_csv(df, output_path, table_name):
        start_time=time.time()
        row_count = df.count()
        df.repartition(SPARK_PARITIONS).write.csv(output_path, mode='overwrite', header=True, sep="^", escape="\\")
        print(f"   * {(time.time()-start_time):7.2f} DB - Exported {row_count} rows to CSV directory: {output_path}.")
        return row_count

    def copy_to_yellowbrick(file_dir, table_name, rows):
        all_start_time = time.time()
        print(f"             YB - Loading S3 files to {table_name}")
        with psycopg2.connect(host=YB_HOST, dbname=YB_DATABASE, user=YB_USER, password=YB_PASSWORD, application_name=f'DB2YB') as connYB:
            with connYB.cursor() as copy_cursor:
                try:
                    start_time = time.time()
                    copy_sql = f"""
                        LOAD TABLE {make_fq(table_name, YB_SCHEMA, YB_DATABASE)} FROM ('{file_dir}/') EXTERNAL LOCATION "dbcopy-loc" EXTERNAL FORMAT "dbcopy-fmt" WITH (num_readers '{SPARK_PARITIONS}', read_sources_concurrently 'ALWAYS')
                    """
                    copy_cursor.execute(copy_sql)   
                    connYB.commit()
                except Exception as e:
                    print(f"❌ COPY failed: {e}")
                print(f"   * {(time.time() - all_start_time):7.2f} YB - Copied all parts to {table_name}. {(rows/(time.time()-all_start_time)):,.0f} RPS")

    csv_dir_name = f"{yb_table_name}.csv"
    rows = export_table_to_csv(df,S3_BUCKET+"/"+csv_dir_name, yb_table_name)
    copy_to_yellowbrick(csv_dir_name, yb_table_name, rows)
    remove_s3_directory(BUCKET_NAME, csv_dir_name+'/', AWS_ACCESS_KEY_ID, AWS_SECRET_KEY, AWS_REGION)
    return rows


def validate_table(table_name, rows, append_mode):
    start_time=time.time()   
    count = 0
    with psycopg2.connect(host=YB_HOST, dbname=YB_DATABASE, user=YB_USER, password=YB_PASSWORD, application_name=f'DB2YB') as connYB:
        with connYB.cursor() as cursorYB:
            query = f'SELECT COUNT(*) FROM {make_fq(table_name, YB_SCHEMA, YB_DATABASE)};'
            cursorYB.execute(query)
            count = cursorYB.fetchone()[0]

    if count == rows:
        print(f"   * {(time.time()-start_time):7.2f} YB - Validation successful: {count} rows in Yellowbrick table {table_name}")
    else:
        if append_mode:
            print(f"   * {(time.time()-start_time):7.2f} YB - Total Rows after append: {count} rows in Yellowbrick table {table_name}")
        else:
            print(f"❌ Validation failed: Expected {rows} rows, but found {count} rows in Yellowbrick table {table_name}")        

    
def DB2YB(table_patterns,
          limit=None,
          append=None,
          write_ddl=None,
          read_ddl=None,
          where=None,
          useS3=True
          ):
    
    if append and where:
        raise Exception("Cannot specify both append and where")

    if write_ddl and read_ddl:
        raise Exception("Cannot specify both write_ddl and read_ddl")

    if useS3: 
        start_time=time.time()
        print("S3 Acceleration Enabled - Testing connection: ", end="")
        configure_test_s3()
        print(f"{time.time()-start_time:.1f} seconds")
    else:
        print("S3 Acceleration not enabled")

    print(f"Connected to Databricks via Spark: Catalog={DB_DATABASE} Schema={DB_SCHEMA}")
    
    print(f"Connecting to Yellowbrick: {YB_DATABASE}, Schema: {YB_SCHEMA}: ", end="")
    start_time=time.time()
    configure_test_yb(useS3)
    print(f"{time.time()-start_time:.1f} seconds")

    table_list = get_matching_tables(table_patterns)
    print(f"\nFound {len(table_list)} tables matching the patterns:")
    for table in table_list:
        print(f" - {table}")
    print("")

    for table in table_list:
        yb_table_name = table.split(".")[-1].strip('"')
        fq_yb_table_name =  make_fq(yb_table_name, YB_SCHEMA, YB_DATABASE) 
        start_time=time.time()
        print(f"Processing table {table} -> {fq_yb_table_name}")
        print(f"        Secs DW - Event")
        print(f"   =====================================================")
        if not append:
            if read_ddl:
                ddl = read_ddl_file(f"{read_ddl}/{yb_table_name}.ddl")
            else:
                ddl = getTableDDL(table, YB_SCHEMA)
                print(f"   * {(time.time()-start_time):7.2f} DB - Getting DDL for {table}")
            if write_ddl:
                write_ddl_file(ddl, f"{write_ddl}/{yb_table_name}.ddl")
                print()
                continue
            create_table_in_yellowbrick(ddl,yb_table_name)
        else:
            append_prep_destination(append, fq_yb_table_name)

        (rows, df) = get_data(table, limit, where, append)
        if rows==0:
            print(f"   * {(time.time()-start_time):7.2f}    - Nothing to process - No rows returned.")
        else:
            if not useS3 or rows<MIN_S3_ROWS:
                if useS3:
                    print(f"   * {(time.time()-start_time):7.2f} YB - Small table, using spark dataframe write")
                rows = spark_write(df, yb_table_name)
            else:
                rows = s3_write(df, yb_table_name)
            validate_table(yb_table_name, rows, append)
            print(f"   * {(time.time()-start_time):7.2f}    - Finished processing table {table} -> {fq_yb_table_name}")
    
        print()

    print("All tables processed successfully.")



In [0]:
DB2YB(table_patterns = ['.*'])


S3 Acceleration Enabled - Testing connection: 2.8 seconds
Connected to Databricks via Spark: Catalog=ybaws Schema=tpcds
Connecting to Yellowbrick: ms_test, Schema: tpcds: 0.2 seconds

Found 25 tables matching the patterns:
 - ybaws.tpcds.call_center
 - ybaws.tpcds.catalog_page
 - ybaws.tpcds.catalog_returns
 - ybaws.tpcds.catalog_sales
 - ybaws.tpcds.customer
 - ybaws.tpcds.customer_address
 - ybaws.tpcds.customer_demographics
 - ybaws.tpcds.date_dim
 - ybaws.tpcds.dbgen_version
 - ybaws.tpcds.household_demographics
 - ybaws.tpcds.income_band
 - ybaws.tpcds.inventory
 - ybaws.tpcds.item
 - ybaws.tpcds.promotion
 - ybaws.tpcds.reason
 - ybaws.tpcds.ship_mode
 - ybaws.tpcds.store
 - ybaws.tpcds.store_returns
 - ybaws.tpcds.store_sales
 - ybaws.tpcds.time_dim
 - ybaws.tpcds.warehouse
 - ybaws.tpcds.web_page
 - ybaws.tpcds.web_returns
 - ybaws.tpcds.web_sales
 - ybaws.tpcds.web_site

Processing table ybaws.tpcds.call_center -> ms_test.tpcds.call_center
        Secs DW - Event
   *    0.32 

In [0]:
DB2YB(table_patterns = ['call_center'], where = '[cc_call_center_sk]<20', limit=10)

S3 Acceleration Enabled - Testing connection: 2.0 seconds
Connected to Databricks via Spark: Catalog=ybaws Schema=tpcds
Connecting to Yellowbrick: ms_test, Schema: tpcds: 0.1 seconds

Found 1 tables matching the patterns:
 - ybaws.tpcds.call_center

Processing table ybaws.tpcds.call_center -> ms_test.tpcds.call_center
        Secs DW - Event
   *    0.20 DB - Getting DDL for ybaws.tpcds.call_center
   *    0.02 YB - Recreating call_center
   *    0.09 YB - -- Dropped call_center
   *    0.09 YB - -- Created Table call_center
             DB - Executing SELECT * FROM ybaws.tpcds.call_center WHERE cc_call_center_sk<20 LIMIT 10
   *    0.35 DB - Received 10 rows from table: ybaws.tpcds.call_center
   *    0.63 YB - Small table, using spark dataframe write
   *    1.02 YB - Spark Wrote 10 rows to table call_center 10 RPS
   *    0.26 YB - Validation successful: 10 rows in Yellowbrick table call_center
   *    1.91    - Finished processing table ybaws.tpcds.call_center -> ms_test.tpcds.call

In [0]:
DB2YB(table_patterns = ['call_center'], append = '[cc_call_center_sk]>=20')

S3 Acceleration Enabled - Testing connection: 2.2 seconds
Connected to Databricks via Spark: Catalog=ybaws Schema=tpcds
Connecting to Yellowbrick: ms_test, Schema: tpcds: 0.2 seconds

Found 1 tables matching the patterns:
 - ybaws.tpcds.call_center

Processing table ybaws.tpcds.call_center -> ms_test.tpcds.call_center
        Secs DW - Event
   *    0.00 YB - Append mode.  Deleting rows matching "cc_call_center_sk">=20 from ms_test.tpcds.call_center
             DB - Executing SELECT * FROM ybaws.tpcds.call_center WHERE cc_call_center_sk>=20
   *    0.36 DB - Received 11 rows from table: ybaws.tpcds.call_center
   *    0.90 YB - Small table, using spark dataframe write
   *    0.72 YB - Spark Wrote 11 rows to table call_center 15 RPS
   *    0.25 YB - Total Rows after append: 21 rows in Yellowbrick table call_center
   *    1.87    - Finished processing table ybaws.tpcds.call_center -> ms_test.tpcds.call_center

All tables processed successfully.


In [0]:
DB2YB(table_patterns = ['store.*', 'catalog.*'],  write_ddl = './ddl')

S3 Acceleration Enabled - Testing connection: 1.9 seconds
Connected to Databricks via Spark: Catalog=ybaws Schema=tpcds
Connecting to Yellowbrick: ms_test, Schema: tpcds: 0.1 seconds

Found 6 tables matching the patterns:
 - ybaws.tpcds.catalog_page
 - ybaws.tpcds.catalog_returns
 - ybaws.tpcds.catalog_sales
 - ybaws.tpcds.store
 - ybaws.tpcds.store_returns
 - ybaws.tpcds.store_sales

Processing table ybaws.tpcds.catalog_page -> ms_test.tpcds.catalog_page
        Secs DW - Event
   *    0.22 DB - Getting DDL for ybaws.tpcds.catalog_page
   *    0.56    - Wrote DDL to file ./ddl/catalog_page.ddl

Processing table ybaws.tpcds.catalog_returns -> ms_test.tpcds.catalog_returns
        Secs DW - Event
   *    0.20 DB - Getting DDL for ybaws.tpcds.catalog_returns
   *    0.47    - Wrote DDL to file ./ddl/catalog_returns.ddl

Processing table ybaws.tpcds.catalog_sales -> ms_test.tpcds.catalog_sales
        Secs DW - Event
   *    0.19 DB - Getting DDL for ybaws.tpcds.catalog_sales
   *    0.46

In [0]:
DB2YB(table_patterns = ['store.*','catalog.*'],  read_ddl = './ddl', limit=2000000)

S3 Acceleration Enabled - Testing connection: 2.2 seconds
Connected to Databricks via Spark: Catalog=ybaws Schema=tpcds
Connecting to Yellowbrick: ms_test, Schema: tpcds: 0.1 seconds

Found 6 tables matching the patterns:
 - ybaws.tpcds.catalog_page
 - ybaws.tpcds.catalog_returns
 - ybaws.tpcds.catalog_sales
 - ybaws.tpcds.store
 - ybaws.tpcds.store_returns
 - ybaws.tpcds.store_sales

Processing table ybaws.tpcds.catalog_page -> ms_test.tpcds.catalog_page
        Secs DW - Event
   *    0.23    - Read table DDL from file ./ddl/catalog_page.ddl
   *    0.02 YB - Recreating catalog_page
   *    0.08 YB - -- Dropped catalog_page
   *    0.08 YB - -- Created Table catalog_page
             DB - Loading up to 2000000 rows of Spark table
   *    0.56 DB - Received 20400 rows from table: ybaws.tpcds.catalog_page
   *    0.87 YB - Small table, using spark dataframe write
   *    2.02 YB - Spark Wrote 20400 rows to table catalog_page 10,101 RPS
   *    0.25 YB - Validation successful: 20400 row