---
AUTOMATE TEST SCRIPT

---

In [1]:
from snowflake.snowpark import Session
import os
from datetime import datetime

In [2]:
user = 'DOLPHIN'
password = 'Maapaa@1603'  # Avoid hardcoding sensitive information
account = 'URB63596'
database = 'mimic_iv_medi_assist'
schema = 'raw'
warehouse = 'my_warehouse'

In [3]:
def snowpark_basic_auth() -> Session:
    connection_parameters = {
        "ACCOUNT":"URB63596",
        "USER":"DOLPHIN",
        "PASSWORD":"Maapaa@1603"
    }
    return Session.builder.configs(connection_parameters).create()


In [4]:
def generate_ddl_statement(column_names, data_types, table_name):
    ddl_template = "CREATE TABLE IF NOT EXISTS {} (\n{})"
    columns = []
    for name, data_type in zip(column_names, data_types):
        column_definition = f"   {name} {data_type}"
        columns.append(column_definition)

    ddl_statement = ddl_template.format(table_name, ",\n".join(columns))
    return ddl_statement


In [5]:
def generate_copy_statement(table_name,stage_name,csv_file_path,file_format):
    copy_command = f"""
    COPY INTO {table_name}
    FROM @{stage_name}/{csv_file_path}
    FILE_FORMAT = (FORMAT_NAME = '{file_format}')
    ;
    """

    return copy_command


In [6]:
def create_file_format(session):
    print("Creating file format...")
    session.sql("""
        CREATE OR REPLACE FILE FORMAT file_format_csv
        TYPE = 'CSV'
        COMPRESSION = 'GZIP'                   -- Specify GZIP compression for .gz files
        FIELD_DELIMITER = ','                   -- Specify the field delimiter
        PARSE_HEADER = TRUE                      -- Parse the header row for column names
        FIELD_OPTIONALLY_ENCLOSED_BY = '"';     -- Optional field enclosure
    """).collect()

    print("File format 'file_format_csv' created successfully.")
    print("===========")

In [None]:
def create_file_format(session):
    print("Creating file format for DDL...")
    session.sql("""
        CREATE OR REPLACE FILE FORMAT file_format_ddl
        TYPE = 'CSV'
        COMPRESSION = 'auto'                   -- Specify GZIP compression for .gz files
        FIELD_DELIMITER = ','                   -- Specify the field delimiter
        PARSE_HEADER = TRUE                    -- Parse the header row for column names
        FIELD_OPTIONALLY_ENCLOSED_BY = '\042'
        ESCAPE_UNENCLOSED_FIELD = NONE 
        TRIM_SPACE = TRUE 
        ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE;
    """).collect()

    print("File format 'file_format_ddl' created successfully.")
    print("===========")

    session.sql("""
        CREATE OR REPLACE FILE FORMAT file_format_load
        TYPE = 'CSV'
        COMPRESSION = 'auto'                   
        FIELD_DELIMITER = ','                   
        RECORD_DELIMITER = '\n'
        SKIP_HEADER = 1                      
        FIELD_OPTIONALLY_ENCLOSED_BY = '"'
    """).collect()

    print("File format 'file_format_load' created successfully.")
    print("===========")

    session.sql("""
        CREATE OR REPLACE FILE FORMAT file_format_generic
        TYPE = 'CSV'
        COMPRESSION = 'GZIP'
        FIELD_DELIMITER = ','
        PARSE_HEADER = TRUE
        FIELD_OPTIONALLY_ENCLOSED_BY = '"'
        ESCAPE_UNENCLOSED_FIELD = None
    """).collect()

    print("File format 'file_format_generic' created successfully.")
    print("===========")


In [19]:
utc_start_time = datetime.utcnow()
session_wih_pwd = snowpark_basic_auth()


session_wih_pwd.sql("USE DATABASE mimic_iv_medi_assist").collect()
session_wih_pwd.sql("USE SCHEMA raw").collect()
session_wih_pwd.sql("USE WAREHOUSE my_warehouse").collect()

create_file_format(session_wih_pwd)

Creating file format for DDL...
File format 'file_format_ddl' created successfully.
File format 'file_format_load' created successfully.
File format 'file_format_generic' created successfully.


In [20]:
stg_files = session_wih_pwd.sql("list @my_internal_stage").collect()
print(stg_files)

[Row(name='my_internal_stage/admissions.csv.gz', size=19652448, md5='e030d760b3e15ce1c6aa9e8e8637ebcd', last_modified='Wed, 9 Oct 2024 06:45:27 GMT'), Row(name='my_internal_stage/d_icd_diagnoses.csv.gz', size=849392, md5='d77fe5f8249b16ed3692e52bd781dc06', last_modified='Wed, 9 Oct 2024 06:45:17 GMT'), Row(name='my_internal_stage/d_icd_procedures.csv.gz', size=549936, md5='d4a9b85d7ee6a75d32f6e1d61ac173cf', last_modified='Wed, 9 Oct 2024 06:44:37 GMT'), Row(name='my_internal_stage/discharge_two.csv.gz', size=1138715888, md5='b85ea55aef83c331f7c821e2e425f0de-136', last_modified='Sat, 12 Oct 2024 04:05:15 GMT'), Row(name='my_internal_stage/drgcodes.csv.gz', size=9509520, md5='89a92ba0a394557ceaa16cfcf7c93bce', last_modified='Wed, 9 Oct 2024 06:45:16 GMT'), Row(name='my_internal_stage/pharmacy.csv.gz', size=28197696, md5='880636d6f06cccdcb30a3b541320f35b', last_modified='Tue, 5 Nov 2024 20:32:14 GMT')]


In [None]:
import os
from datetime import datetime

utc_start_time = datetime.utcnow()
print("Process started at:", utc_start_time)

for row in stg_files:
    print("======================================================\n")
    print("Processing row:", row)
    
    # Convert row to dictionary
    row_value = row.as_dict()
    print("Row as dictionary:", row_value)
    
    # Extract the staged file path value
    stg_file_path_value = row_value.get('name')
    print("Staged file path value:", stg_file_path_value)

    # Split file path and name
    file_path, file_name = os.path.split(stg_file_path_value)
    print("File path:", file_path)
    print("File name:", file_name)

    # Create staged location variable
    stg_location = "@" + file_path
    print("Staged location:", stg_location)

    # Filter for specific file
    if file_name not in ('pharmacy.csv.gz'):
        #print(f"Skipping file {file_name} as it doesn't match the target file.")
        continue
    
    print(f"Processing target file: {file_name}")
    
    # Generate SQL for inferring schema
    infer_schema_sql = """\
        SELECT * 
        FROM TABLE(
            INFER_SCHEMA(
            LOCATION=>'{}/',
            files => '{}',
            FILE_FORMAT => 'file_format_ddl'
        )    
    )
    """.format(stg_location, file_name)
    
    print("\n=========== INFER SCHEMA SQL =============")
    print(f"File: {file_name}")
    print(infer_schema_sql)

    # Execute schema inference
    inferred_schema_rows = session_wih_pwd.sql(infer_schema_sql).collect()
    print("\nSchema inference completed. Inferred schema rows:")
    print(inferred_schema_rows)

    # Prepare lists for column names and types
    col_name_lst = []
    col_data_type_lst = []

    # Process each row in inferred schema
    for row in inferred_schema_rows:
        row_value = row.as_dict()
        print("Inferred schema row:", row_value)
        
        column_name = row_value.get('COLUMN_NAME')
        column_type = row_value.get('TYPE')

        col_name_lst.append(column_name)
        col_data_type_lst.append(column_type)

    print("Column names list:", col_name_lst)
    print("Column data types list:", col_data_type_lst)

    # Generate table name and DDL statement
    table_name = file_name.split('.')[0] + "_raw"
    create_ddl_stmt = generate_ddl_statement(col_name_lst, col_data_type_lst, table_name.upper())
    print("=================== DDL STATEMENT =====================")
    print(create_ddl_stmt)

    # Generate copy statement for loading data
    copy_stmt = generate_copy_statement(table_name, 'my_internal_stage', file_name, 'file_format_load')
    print("=================== COPY STATEMENT =====================")
    print(copy_stmt)

    # Define SQL file path and save DDL and copy statements to file
    sql_file_path = table_name + ".sql"
    print("=================== SQL FILE PATH =====================")
    print("File path for saving SQL:", sql_file_path)
    with open(sql_file_path, "w") as sql_file:
        sql_file.write("---- Following statement is creating table\n\n")
        sql_file.write(create_ddl_stmt)
        sql_file.write("\n-- Following statement is executing copy command\n")
        sql_file.write(copy_stmt)
    print("SQL statements written to file:", sql_file_path)

    # Execute DDL to create the table
    session_wih_pwd.sql(create_ddl_stmt).collect()
    print("Table created successfully with DDL statement.")

    # Execute copy command to load data into the table
    session_wih_pwd.sql(copy_stmt).collect()
    print("Data loaded into the table with COPY statement.")

# End of processing and time calculation
utc_end_time = datetime.utcnow()
print("Process completed at:", utc_end_time)
print("Total processing time:", utc_end_time - utc_start_time)


Process started at: 2024-11-05 21:46:58.122099

Processing row: Row(name='my_internal_stage/admissions.csv.gz', size=19652448, md5='e030d760b3e15ce1c6aa9e8e8637ebcd', last_modified='Wed, 9 Oct 2024 06:45:27 GMT')
Row as dictionary: {'name': 'my_internal_stage/admissions.csv.gz', 'size': 19652448, 'md5': 'e030d760b3e15ce1c6aa9e8e8637ebcd', 'last_modified': 'Wed, 9 Oct 2024 06:45:27 GMT'}
Staged file path value: my_internal_stage/admissions.csv.gz
File path: my_internal_stage
File name: admissions.csv.gz
Staged location: @my_internal_stage
Skipping file admissions.csv.gz as it doesn't match the target file.

Processing row: Row(name='my_internal_stage/d_icd_diagnoses.csv.gz', size=849392, md5='d77fe5f8249b16ed3692e52bd781dc06', last_modified='Wed, 9 Oct 2024 06:45:17 GMT')
Row as dictionary: {'name': 'my_internal_stage/d_icd_diagnoses.csv.gz', 'size': 849392, 'md5': 'd77fe5f8249b16ed3692e52bd781dc06', 'last_modified': 'Wed, 9 Oct 2024 06:45:17 GMT'}
Staged file path value: my_internal_st

SnowparkSQLException: (1304): 01b82d1a-0004-20d3-0000-13e7001a5132: 100332 (22000): Error with CSV header: header defined 27 columns while data contains 10 columns. 

  File 'pharmacy.csv.gz'
  Row 0 starts at line 0, column 