In [None]:
!pip install pyathena --quiet

In [None]:
# AWS Imports
import boto3
from botocore.client import ClientError
import sagemaker
from pyathena import connect
import awswrangler as wr

# Data Transformation Imports
import pandas as pd
from io import StringIO

# Misc Imports
from IPython.display import display, HTML

## Sagemaker Details and Variable Init

In [None]:
# Initializing variables for reproducibility
FILE_NAME="data.csv"
DATA_SOURCE="db_source"
DATA_FOLDER =f"s3://{bucket}/aai-540-group-3-final-project/data/"
FILE_LOCATION=f"{DATA_FOLDER}{FILE_NAME}"
DATA_PATH = f"{DATA_FOLDER}{DATA_SOURCE}/"
DATABASE = "retainAI"
PROD_DIR = f"s3://{bucket}/athena/prod"
STAGE_DIR = f"s3://{bucket}/athena/staging"
EMPLOYEE_TABLE = "employee_table"

# Making sure all variables are correct

print(f"File location with all the data: {FILE_LOCATION}")
print(f"Data Path for database creation: {DATA_PATH}")
print(f"Production and Staging Database Directories: {PROD_DIR},{STAGE_DIR}")
print(f"Database Name, Training Table and Testing Table: {DATABASE}, {EMPLOYEE_TABLE}")

In [None]:
# check what is in DATA_FOLDER
!aws s3 ls $DATA_FOLDER --recursive

In [None]:
# Create a SageMaker session object, which is used to manage interactions with SageMaker resources.
sess = sagemaker.Session()

# Retrieve the default Amazon S3 bucket associated with the SageMaker session.
bucket = sess.default_bucket()

# Get the IAM role associated with the current SageMaker notebook or environment.
role = sagemaker.get_execution_role()

# Get the AWS region name for the current session.
region = boto3.Session().region_name

# Retrieve the AWS account ID of the caller using the Security Token Service (STS) client.
account_id = boto3.client("sts").get_caller_identity().get("Account")

# Create a Boto3 client for the SageMaker service, specifying the AWS region.
sm = boto3.Session().client(service_name="sagemaker", region_name=region)

## Creating Athena Schema

In [None]:
# Establish a connection to the AWS Athena service, specifying the region and an S3 staging directory
# where query results will be stored.
conn = connect(region_name=region, s3_staging_dir=STAGE_DIR)

# Define a SQL statement to create a database in Athena if it doesn't already exist.
# The database name is dynamically determined by the variable `DATABASE`.
db_create_statement = f"CREATE DATABASE IF NOT EXISTS {DATABASE}"

# Execute the SQL statement using the established connection and Pandas, 
# which sends the query to Athena and ensures the database is created.
pd.read_sql(db_create_statement, conn)

In [None]:
# Define a SQL statement to list all existing databases in the AWS Athena service.
show_db_statement = "SHOW DATABASES"

# Execute the SQL query using the established connection, and store the results in a Pandas DataFrame.
df_show = pd.read_sql(show_db_statement, conn)

# Display the first 5 rows of the DataFrame to view the list of databases.
df_show.head(5)

## Registering Data with Athena

### Creating Employee Data Table

In [None]:
CREATE_STATEMENT = """
CREATE EXTERNAL TABLE IF NOT EXISTS {}.{} (
    employee_id INT,
    age INT,
    gender STRING,
    years_at_company INT,
    job_role STRING,
    monthly_income INT,
    work_life_balance STRING,
    job_satisfaction STRING,
    performance_rating STRING,
    number_of_promotions INT,
    distance_from_home INT,
    education_level STRING,
    marital_status STRING,
    number_of_dependents INT,
    job_level STRING,
    company_size STRING,
    company_tenure INT,
    remote_work STRING,
    leadership_opportunities STRING,
    innovation_opportunities STRING,
    company_reputation STRING,
    employee_recognition STRING,
    attrition STRING
)
ROW FORMAT DELIMITED
FIELDS TERMINATED BY ','
LINES TERMINATED BY '\\n'
LOCATION '{}'
TBLPROPERTIES ('skip.header.line.count'='1')
""".format(DATABASE, EMPLOYEE_TABLE, DATA_PATH)

# Ensuring everything is formatted correctly
print(CREATE_STATEMENT)

In [None]:
# Execute the SQL query defined in CREATE_STATEMENT to create the employee data table in Athena. 
pd.read_sql(CREATE_STATEMENT, conn)

### Check Tables and Ensure Correctness

If everything runs correctly, you should see tab_name = employee_table

In [None]:
check_table_statement = "SHOW TABLES in {}".format(DATABASE)

df_show = pd.read_sql(check_table_statement, conn)
df_show

In [None]:
# If everything runs as expected, you should get a return value of 74,498 instances in our table
count_records = f"""SELECT COUNT(*) FROM {DATABASE}.{EMPLOYEE_TABLE}"""
count_show = pd.read_sql(count_records,conn)
count_show

## Release Resources

In [None]:
%%html

<p><b>Shutting down your kernel for this notebook to release resources.</b></p>
<button class="sm-command-button" data-commandlinker-command="kernelmenu:shutdown" style="display:none;">Shutdown Kernel</button>
        
<script>
try {
    els = document.getElementsByClassName("sm-command-button");
    els[0].click();
}
catch(err) {
    // NoOp
}    
</script>

In [None]:
%%javascript

try {
    Jupyter.notebook.save_checkpoint();
    Jupyter.notebook.session.delete();
}
catch(err) {
    // NoOp
}