<!--
#  Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License").
#    You may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
-->

# Prepare the Data Lake from Raw CSV Files
***Determine Data Schema and Create Parquet Output with Glue Metadata***

---
## Contents
1. [Introduction](#Introduction)
2. [Set Up](#Set-Up)
3. [Get Schema for Files](#Get-Schema-for-Files)
 1. [Add Helper Functions to Configure Schema Output](#Add-Helper-Functions-to-Configure-Schema-Output)
 2. [Get File Schema](#Get-File-Schema)
4. [Create External Table with Parquet Output](#Create-External-Table-with-Parquet-Output)
5. [Produce Loading Statistics](#Produce-Loading-Statistics)
6. [Final Check](#Final-Check)
---
## Introduction
In this notebook we go through the process of finding the schema for our data files, and creating Parquet output tables with set schema located in our target directory on s3. The notebook integrates Amazon services, such as Athena and S3. 

In summary, this job will perform the following functions:

**1.** Find for each file the corresponding schema 

**2.** Read the file using Athena external table

**3.** Create final tables in with parquet output format, and transform "date" fields from string into type `date`

___

### Author: AWS Professional Services Emerging Technology and Intelligent Platforms Group


---

## Set Up


In [None]:
import os
print(os.__file__)

In [None]:
%reload_ext sql

In [None]:
import os
import sys
import boto3
import json
from aws_orbit_sdk.database import get_athena
from aws_orbit_sdk.common import get_workspace,get_scratch_database
# import aws_orbit_sdk.glue as orbit_catalog_api
import matplotlib.pyplot as plt

from zipfile import ZipFile
import glob
import tempfile
import shutil
from time import time, sleep
from pathlib import Path

In [None]:
# Define constants
batch = '1'
database_name = "cms_raw_db"
source_bucket_name = "orbit-test-base-accoun-testlakebucketfa111111-1111111111"
target_bucket_name = "orbit-test-base-accoun-testlakebucketfa111111-1111111111"
schema_dir = "landing/cms/schema"
file_path = 'extracted/Beneficiary_Summary'
region = "us-east-2"

In [None]:
target_db_dir = database_name + "/"
incoming_dir = file_path
basename = Path(file_path).name
workspace = get_workspace()
team_space = workspace['team_space']
scratch_bucket = f"{workspace['ScratchBucket']}/{team_space}"
# workspace

#### Some helping python functions to work with services: S3, Athena:

In [None]:
WAIT_RETRY_DELAY = 1 # seconds. Amount of time to wait before re-checking the status of the query that is being executed ini Athena.

def getSchemas(source_bucket_name, prefix='', suffix=''):
    s3 = boto3.resource("s3")
    bucket = s3.Bucket(name=source_bucket_name) 
    schemas = []
    for o in bucket.objects.all():
        if (o.key.startswith(prefix)):
            name = os.path.basename(o.key).split(".")[0]
            schemaStr = o.get()['Body'].read().decode('utf-8') 
            schema = json.loads(schemaStr) #StructType.fromJson(json.loads(schemaStr))
            schemas.append((name, schema))
    return schemas

def getSchema(schemas, filename):
    for (schema_name, schema) in schemas:
        #print(f"{schema_name} in {filename} : {schema_name in filename}")
        if schema_name in filename:
            return schema_name, schema
    return None, None

CREATE_CSV_TABLE = """
        CREATE EXTERNAL TABLE {}.{}
        ({})
        row format delimited
        fields terminated by ',' 
        LOCATION '{}'                      
        tblproperties ("skip.header.line.count"="1")
"""
def schema_2_ddl(schema, database_name, table_name, s3_target_dir):
    if not schema or not schema.get("fields"):
        return None
    columns_def_arr = [f"{column['name']} {map_column_type(column['type'])}" for column in schema["fields"]]
    columns_def = ",".join(columns_def_arr)
    s3_target_location = "s3://{}/{}/".format(target_bucket_name, s3_target_dir)
    return CREATE_CSV_TABLE.format(database_name, table_name, columns_def, s3_target_location)

def map_column_type(schema_type):
    """Returns adopted type for athena"""
    if schema_type == 'date':
        return "string"
    if schema_type == 'integer':
        return "int"
    if schema_type == 'long':
        return "bigint"
    
    return schema_type


def delete_temp_tbl(athena_client):    
    TIMEOUT = 30 # seconds
    start = time()
    success_cd = False
    response = athena_client.start_query_execution(
            QueryString = f"DROP TABLE IF EXISTS {database_name}.{basename}_raw",
            ResultConfiguration = {'OutputLocation':  f's3://{scratch_bucket}/athena/results'}
    )
    query_id = response['QueryExecutionId']
    while (time() - start < TIMEOUT):
        response = athena_client.get_query_execution(QueryExecutionId=query_id)
        if response['QueryExecution']['Status']['State'] in ['SUCCEEDED']:
            success_cd = True
            break
        elif response['QueryExecution']['Status']['State'] in ['FAILED', 'CANCELLED']:
            break
        sleep(WAIT_RETRY_DELAY)

    if not success_cd:
        raise Exception (f"FAILED to execute DDL: DROP TABLE IF EXISTS {database_name}.{basename}_raw")

    

def createTempExternalTable(athena_client, ddl):
    TIMEOUT = 30 # seconds
    start = time()
    success_cd = False
    response = athena_client.start_query_execution(
            QueryString = ddl,
            ResultConfiguration = {'OutputLocation':  f's3://{scratch_bucket}/athena/results'}
    )
    query_id = response['QueryExecutionId']
    while (time() - start < TIMEOUT):
        response = athena_client.get_query_execution(QueryExecutionId=query_id)
        if response['QueryExecution']['Status']['State'] in ['SUCCEEDED']:
            success_cd = True
            break
        elif response['QueryExecution']['Status']['State'] in ['FAILED', 'CANCELLED']:
            break
        sleep(WAIT_RETRY_DELAY)

    if not success_cd:
        raise Exception (f"FAILED to execute DDL: {ddl}")

CREATE_PARQUET_TABLE = """
        CREATE TABLE {}.{}
        WITH (
            format = 'Parquet',
            parquet_compression = 'SNAPPY',
            external_location = '{}'
        )
        AS
        (select {} from {}.{})
"""
def loadTable(athena_client, schema_name, schema):
    """ Uses CTAS to load data from temp table into a target one"""
    if not schema or not schema.get("fields"):
        return None

    
    TIMEOUT = 30 # seconds
    start = time()
    success_cd = False
    response = athena_client.start_query_execution(
            QueryString = f"DROP TABLE IF EXISTS {database_name}.{basename}",
            ResultConfiguration = {'OutputLocation':  f's3://{scratch_bucket}/athena/results'}
    )
    query_id = response['QueryExecutionId']
    while (time() - start < TIMEOUT):
        response = athena_client.get_query_execution(QueryExecutionId=query_id)
        if response['QueryExecution']['Status']['State'] in ['SUCCEEDED']:
            success_cd = True
            break
        elif response['QueryExecution']['Status']['State'] in ['FAILED', 'CANCELLED']:
            break
        sleep(WAIT_RETRY_DELAY)

    if not success_cd:
        raise Exception (f"FAILED to execute DDL: DROP TABLE IF EXISTS {database_name}.{basename}")

    try:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(target_bucket_name)
        bucket.objects.filter(Prefix=f"{target_db_dir}{basename}/").delete()        
    except Exception as e:
        print("Failed with " + str(e))
    
    
    columns_def_arr = []
    for column in schema["fields"]:
        if column['type']!='date':
            columns_def_arr.append(column['name'])
        else:
            columns_def_arr.append(f"case when {column['name']} is not null and {column['name']}!= '' then date(parse_datetime({column['name']}, 'yyyyMMdd')) else null end {column['name']}")
    columns_def = ", \n".join(columns_def_arr)
    
    ddl = CREATE_PARQUET_TABLE.format(
                database_name, basename,
                "s3://{}/{}{}/".format(target_bucket_name, target_db_dir, basename),
                columns_def,
                database_name,
                basename+"_raw"
            )
    TIMEOUT = 30 # seconds
    start = time()
    success_cd = False
    response = athena_client.start_query_execution(
            QueryString = ddl,
            ResultConfiguration = {'OutputLocation':  f's3://{scratch_bucket}/athena/results'}
    )
    query_id = response['QueryExecutionId']
    while (time() - start < TIMEOUT):
        response = athena_client.get_query_execution(QueryExecutionId=query_id)
        if response['QueryExecution']['Status']['State'] in ['SUCCEEDED']:
            success_cd = True
            break
        elif response['QueryExecution']['Status']['State'] in ['FAILED', 'CANCELLED']:
            break
        sleep(WAIT_RETRY_DELAY)

    if not success_cd:
        raise Exception (f"FAILED to execute DDL: {ddl}")

    

In [None]:
schemas = getSchemas(source_bucket_name, schema_dir)

Time to run some code and see results:

In [None]:
# Get all schemas from the S3 location
schemas = getSchemas(source_bucket_name, schema_dir)
# Find a schema for given file
schema_name, schema = getSchema(schemas, basename)
# Convert schema to DDL for temp table
ddl = schema_2_ddl(schema, database_name, basename+"_raw", file_path) 

athena_client = boto3.client("athena")

delete_temp_tbl(athena_client)
createTempExternalTable(athena_client, ddl)

loadTable(athena_client, schema_name, schema)
# delete_temp_tbl(athena_client)


In [None]:
athena = get_athena() # now, let's use SQL magic with Athena

%config SqlMagic.autocommit=False # for engines that do not support autommit

workspace = get_workspace()
# scratch_glue_db = get_scratch_database()
team_space = workspace['team_space']
workspace

In [None]:
glue_db = "cms_raw_db"
target_db = "users"

In [None]:
%connect_to_athena -database $glue_db

* **Quick Check:** Ensuring that all of our Extracted Data is in our Database:

In [None]:
%%sql 

SELECT 1 as "Test"

In [None]:
%%sql

-- # %%spark -s spark -c sql 
SHOW TABLES

In [None]:
print(f"Loading data from table {glue_db}.{basename}")

Load the data:

In [None]:
query = f"SELECT * FROM {glue_db}.{basename} limit 13"
ds = %sql $query
df = ds.DataFrame()
df

In [None]:
print("Data load complete.")

---
## Final Check
Let's run two final checks on our loading stats:

 **1.** The count of the columns is greater than 0
 
 **2.** The count of rows is grreater than 0
 

In [None]:
assert len(df.columns) > 0

assert len(df.index) > 0 # number of rows


In [1]:
exit()