In [1]:
from datetime import datetime, timedelta
import json
import logging
logging.basicConfig()
logger = logging.getLogger('MTA')
logger.setLevel(logging.INFO)

import boto3
import pandas as pd
import psycopg2
import base64


def get_redshift_secret():
    global AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, REGION_NAME, SECRET_NAME
    # connects to AWS secrets manager to get credentials
    session = boto3.session.Session()
    
    client = session.client(
        service_name='secretsmanager', region_name=REGION_NAME, aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY)

    try:
        get_secret_value_response = client.get_secret_value(SecretId=SECRET_NAME)
    except Exception as e:
        logger.error("ERROR in getting redshift secrets")
        raise
    else:
        if 'SecretString' in get_secret_value_response:
            secret = get_secret_value_response['SecretString']
        else:
            decoded_binary_secret = base64.b64decode(
                get_secret_value_response['SecretBinary'])
    return json.loads(secret)


def get_redshift_connection():
    # creates data base connection and passes it back to caller
    try:
        secrets = get_redshift_secret()
        conn = psycopg2.connect(dbname=secrets['DATABASE'],
                                host=secrets['HOST'],
                                port=secrets['PORT'],
                                user=secrets['USER'],
                                password=secrets['PASSWORD'])
        return conn

    except Exception as e:
        logger.error(f'Error getting redshift connection details: Exception: {str(e)}')
        raise

        
def repair_table(client, date):
    global ATHENA_DATABASE, ATHENA_TABLE, ATHENA_STATUS_OUTPUT_LOCATION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, REGION_NAME
    # Create an Athena client
    athena = boto3.client('athena', region_name=REGION_NAME, aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY)

    # Specify the Athena database and table name
    database_name = ATHENA_DATABASE
    table_name = ATHENA_TABLE
    status_output_location = ATHENA_STATUS_OUTPUT_LOCATION

    # Define the SQL query to run the MSCK repair operation
    table_repair_query = f"ALTER TABLE {database_name}.{table_name} ADD IF NOT EXISTS PARTITION (advertiser_name = '{client}', event_date = '{date}')"

    # Start the query execution
    response = athena.start_query_execution(
        QueryString=table_repair_query,
        QueryExecutionContext={
            'Database': database_name
        },
        ResultConfiguration={
            'OutputLocation': status_output_location,  # Specify the S3 bucket for query results
        }
    )
    # Get the query execution ID
    query_execution_id = response['QueryExecutionId']
    logger.info(f"Started table repair for table with query execution ID: {query_execution_id}")
    
    # Wait for the query to complete
    while True:
        query_status = athena.get_query_execution(QueryExecutionId=query_execution_id)
        status = query_status['QueryExecution']['Status']['State']

        if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
            break

    # Retrieve and return the query results if the query succeeded
    if status == 'SUCCEEDED':
#         results = athena.get_query_results(QueryExecutionId=query_execution_id)
#         results = format_athena_results(results)
        logger.debug(f'Athena query ran successfully')
#         return results
    elif status == 'FAILED':
        raise Exception('Query from Athena failed')
        
        
def prepare_input(date_delta, lookback_days, media_type, branch, kpi_table, channel_table
    ):
    logger.info('Preparing client KPI input')
    pixel_date = (datetime.now() + timedelta(days=-date_delta)).date().isoformat()
    if pixel_date.lower() == "yesterday":
        new_pixel_date = datetime.now() + timedelta(days=-1)
        new_pixel_date_string = f'{new_pixel_date.year}-{str("{:02d}".format(new_pixel_date.month))}-{str("{:02d}".format(new_pixel_date.day))}'

        new_pixel_p1_date = datetime.now()
        new_pixel_p1_date_string = f'{new_pixel_p1_date.year}-{str("{:02d}".format(new_pixel_p1_date.month))}-{str("{:02d}".format(new_pixel_p1_date.day))}'

        logger.info(f"Changing Pixel Date to {new_pixel_date_string}")
        pixel_dt = datetime(
            year=int(new_pixel_date_string.split("-")[0]),
            month=int(new_pixel_date_string.split("-")[1]),
            day=int(new_pixel_date_string.split("-")[2]),
            hour=0,
        )
        pixel_p1_dt = datetime(
            year=int(new_pixel_p1_date_string.split("-")[0]),
            month=int(new_pixel_p1_date_string.split("-")[1]),
            day=int(new_pixel_p1_date_string.split("-")[2]),
            hour=0,
        )
    else:
        pixel_dt = datetime(
            year=int(pixel_date.split("-")[0]),
            month=int(pixel_date.split("-")[1]),
            day=int(pixel_date.split("-")[2]),
            hour=0,
        )
        pixel_p1_dt = pixel_dt + timedelta(days=1)

    # Get Impression window start date using pixel date and lookback days
    #lookback_days = -1 * lookback_days
    impression_dt = pixel_dt - timedelta(days=lookback_days)

    # Pixel / Impression Lookback datetime formatted strings
    pixel_dts = f'{pixel_dt.year}-{str("{:02d}".format(pixel_dt.month))}-{str("{:02d}".format(pixel_dt.day))} 00:00:00'
    pixel_p1_dts = f'{pixel_p1_dt.year}-{str("{:02d}".format(pixel_p1_dt.month))}-{str("{:02d}".format(pixel_p1_dt.day))} 00:00:00'
    impression_dts = f'{impression_dt.year}-{str("{:02d}".format(impression_dt.month))}-{str("{:02d}".format(impression_dt.day))} 00:00:00'

    client_kpi_input = kpi_table.join(channel_table.set_index('clientkey_map'), on='clientkey_map')
    client_kpi_input[['impression_dts','pixel_dts','media_type','lookback_days','branch']] = pd.DataFrame([[impression_dts, pixel_dts, media_type, lookback_days, branch]], index=client_kpi_input.index)
    client_kpi_input.loc[(client_kpi_input['enabled_mta_channel'] == False) & (client_kpi_input['enabled_mta_ott'] == True), 'media_type'] = 'ott'
    client_kpi_input.loc[(client_kpi_input['enabled_mta_channel'] == False) & (client_kpi_input['enabled_mta_audio'] == True), 'media_type'] = 'audio'
    client_kpi_input.loc[(client_kpi_input['enabled_mta_channel'] == False) & (client_kpi_input['enabled_mta_digital'] == True), 'media_type'] = 'digital'
    client_kpi_input = client_kpi_input.rename(columns={'clientkey_map':'clientkey', 'kpi_name_map':'kpi', 'enabled_mta_digital':'include_channel_digital', 'enabled_mta_ott':'include_channel_ott', 'enabled_mta_audio':'include_channel_audio'})
    return client_kpi_input


def start_func(df, aws_access_key_id, aws_secret_access_key, region_name, secret_name, 
        date_delta, lookback_days, media_type, 
        branch, athena_database, athena_status_output_location
    ):
    athena_table = 'pixels'
    global AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, REGION_NAME, SECRET_NAME, conn, cursor, ATHENA_DATABASE, ATHENA_TABLE, ATHENA_STATUS_OUTPUT_LOCATION
    AWS_ACCESS_KEY_ID = aws_access_key_id
    AWS_SECRET_ACCESS_KEY = aws_secret_access_key
    REGION_NAME = region_name
    SECRET_NAME = secret_name
    ATHENA_DATABASE = athena_database
    ATHENA_TABLE = athena_table
    ATHENA_STATUS_OUTPUT_LOCATION = athena_status_output_location
    
    conn = get_redshift_connection()
    cursor = conn.cursor() 
    logger.info('Fetching data from tables')
  
    query1 = f"""SELECT
        DISTINCT(k.clientkey_map),
        k.kpi_name_map
        FROM portal_global_settings.kpi_mappings k
        LEFT JOIN portal_global_settings.client_mappings c ON(c.clientkey_map = k.clientkey_map)
        WHERE k.media_type IN('ott')
        AND k.enabled = true
        AND c.client_status != 'inactive'
        GROUP BY 1,2
        ORDER BY 1,2;"""
    kpi_list = pd.read_sql(query1,conn)

    logger.info(f"Fetched client-kpi mappings for {len(kpi_list['clientkey_map'].unique())} clients")
    
    query2 = f"""SELECT  DISTINCT clientkey_map, enabled_mta_channel, enabled_mta_ott, enabled_mta_audio, enabled_mta_digital, enabled_mta 
        FROM portal_global_settings.client_mappings c
        WHERE c.client_status != 'inactive'
    """
    channel_results = pd.read_sql(query2,conn)
    logger.info(f"Fetched channel settings for {len(channel_results['clientkey_map'].unique())} clients")
    
    if str(branch) not in ["dev", "master"]:
        logger.error("ERROR: Invalid branch (dev, master)")
        raise ValueError('Invalid branch (dev, master)')

    results = prepare_input(int(date_delta), int(lookback_days), media_type, branch, kpi_list, channel_results)
    logger.info(f"Sending final input data for MTA processing for {len(results['clientkey'].unique())} clients")

    results = results[~results['clientkey'].isin(('demo', 'varsitytutors', 'singlecare'))].reset_index(drop=True)
    results.loc[results['clientkey'] == 'overstock', 'include_channel_digital'] = False

    pixel_date = results.loc[0, 'pixel_dts'][:10]
    for client in results['clientkey'].unique():
        repair_table(client, pixel_date)
    return results