In [1]:
import json
from collections import defaultdict
import itertools
from io import BytesIO
from datetime import datetime, timedelta
import time
from time import perf_counter
import typing
import re
import logging
logging.basicConfig()
logger = logging.getLogger('MTA')
logging_level = logging.INFO
logger.setLevel(logging_level)

import awswrangler as wr
import boto3
import numpy as np
import psycopg2
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import uuid


def func_time(func):
    logger = logging.getLogger(f'FUNC_{func.__name__}')
    if 'logging_level' in globals():
        logger.setLevel(logging_level)
    else:
        logger.setLevel(logging.INFO)

    def decorated(*args, **kwargs):
        func_start_time = perf_counter()
        results = func(*args, **kwargs)
        logger.debug(f'Finished running; {func.__name__}. Took {round(perf_counter() - func_start_time, 3)} seconds')
        return results
    return decorated


class RedshiftConfig:
    def __init__(self, REGION_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, REDSHIFT_MTA_SECRET_NAME):
        self.REGION_NAME = REGION_NAME
        self.AWS_ACCESS_KEY_ID = AWS_ACCESS_KEY_ID
        self.AWS_SECRET_ACCESS_KEY = AWS_SECRET_ACCESS_KEY
        self.REDSHIFT_MTA_SECRET_NAME = REDSHIFT_MTA_SECRET_NAME
        self.connection = self.create_redshift_connection()
        self.cursor = None
    
    def create_redshift_connection(self):
        session = boto3.session.Session()
        client = session.client(
            service_name="secretsmanager",
            region_name=self.REGION_NAME,
            aws_access_key_id=self.AWS_ACCESS_KEY_ID,
            aws_secret_access_key=self.AWS_SECRET_ACCESS_KEY
        )
        secret_value = client.get_secret_value(SecretId=self.REDSHIFT_MTA_SECRET_NAME)
        credentials = json.loads(secret_value["SecretString"])
        conn = psycopg2.connect(
            host=credentials["HOST"],
            port=credentials["PORT"],
            database=credentials["DATABASE"],
            user=credentials["USER"],
            password=credentials["PASSWORD"],
        )
        conn.autocommit = True
        return conn
    
    def get_connection(self):
        if self.connection.closed:
            self.connection = self.create_redshift_connection()
        return self.connection
    
    def get_cursor(self):
        if not self.cursor or self.cursor.closed:
            try:
                self.cursor = self.connection.cursor()
            except psycopg2.InterfaceError:
                self.connection = self.get_connection()
                self.cursor = self.connection.cursor()
        return self.cursor


@func_time
def run_redshift_query(query):
    global redshift_config
    redshift_conn = redshift_config.get_connection()
    logger.debug('Starting Redshift query execution')
    logger.debug(f'Redshift query:\n{query}')
    df = pd.read_sql(query, redshift_conn)
    return df


def get_athena_connection(REGION_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY):
    athena_client = boto3.client(
        'athena', 
        region_name=REGION_NAME, 
        aws_access_key_id=AWS_ACCESS_KEY_ID, 
        aws_secret_access_key=AWS_SECRET_ACCESS_KEY
    )
    return athena_client


def format_athena_results(query_results):
    column_names = [col['Name'] for col in query_results['ResultSet']['ResultSetMetadata']['ColumnInfo']]
    data = []
    for row in query_results['ResultSet']['Rows'][1:]:   
        values = [val['VarCharValue'] if 'VarCharValue' in val else None for val in row['Data']]     
        data.append(values)
    return pd.DataFrame(data, columns=column_names)

@func_time
def run_athena_query(query, only_metadata=False):
    global S3, ATHENA_CLIENT, ATHENA_DATABASE, ATHENA_STATUS_OUTPUT_LOCATION, S3_TARGET_BUCKET_PARQUET
    logger.debug(f'Starting Athena query execution')
    logger.debug(f'Athena query:\n{query}')
    response = ATHENA_CLIENT.start_query_execution(
        QueryString=query,
        QueryExecutionContext={'Database': ATHENA_DATABASE},
        ResultConfiguration={'OutputLocation': ATHENA_STATUS_OUTPUT_LOCATION},
        WorkGroup='juice_data'
    )
    
    # Get the query execution ID
    query_execution_id = response['QueryExecutionId']

    # Wait for the query to complete
    while True:
        query_status = ATHENA_CLIENT.get_query_execution(QueryExecutionId=query_execution_id)
        status = query_status['QueryExecution']['Status']['State']

        if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
            break

    if only_metadata:
        # Retrieve and return the query results if the query succeeded
        if status == 'SUCCEEDED':
            results = ATHENA_CLIENT.get_query_results(QueryExecutionId=query_execution_id)
            logger.debug(f'Athena query ran successfully')
        elif status == 'FAILED':
            raise Exception('Query from Athena failed') 
    else:
        # Retrieve and return the query results if the query succeeded
        if status == 'SUCCEEDED':
            s3_response_object = S3.get_object(Bucket=S3_TARGET_BUCKET_PARQUET, Key=f"{ATHENA_STATUS_OUTPUT_LOCATION.split('/')[-2]}/{query_execution_id}.csv")
            object_content = s3_response_object['Body'].read()
            results = pd.read_csv(BytesIO(object_content))
            logger.debug(f'Athena query ran successfully')
            return results
        elif status == 'FAILED':
            raise Exception('Query from Athena failed')
        

        
############################################
############### Shapley Code ###############
############################################

# Shapley Functions
def power_set(List):
    PS = [list(j) for i in range(len(List)) for j in itertools.combinations(List, i+1)]
    return PS


# return all possible subsets from the channels
def subsets(s):
    '''
    This function returns all the possible subsets of a set of channels.
    input :
            - s: a set of channels.
    '''
    if len(s) == 1:
        return s
    else:
        sub_channels = []
        for i in range(1, len(s)+1):
            sub_channels.extend(map(list, itertools.combinations(s, i)))
    return list(map(",".join, map(sorted, sub_channels)))


# compute the worth of each coalition
def v_function(A, C_values):
    '''
    This function computes the worth of each coalition.
    inputs:
            - A : a coalition of channels.
            - C_values : A dictionnary containing the number of conversions that
            each subset of channels has yielded.
    '''
    subsets_of_A = subsets(A)
    worth_of_A = 0
    for subset in subsets_of_A:
        if subset in C_values:
            worth_of_A += C_values[subset]
    return worth_of_A


FACTORIAL_DICT = {}
def factorial(n):
    global FACTORIAL_DICT
    if n == 0:
        FACTORIAL_DICT[n] = 1
        return 1
    if n not in FACTORIAL_DICT:
        FACTORIAL_DICT[n] = factorial(n-1) * n
    return FACTORIAL_DICT[n]


# calculate shapley value
def calculate_shapley(df, channel_name, conv_name):
    '''
    This function returns the shapley values
            - df: A dataframe with the two columns: ['channel_name', 'conv_name'].
            The channel_subset column is the channel(s) associated with the conversion and the
            count is the sum of the conversions.
            - channel_name: A string that is the name of the channel column
            - conv_name: A string that is the name of the column with conversions
            **Make sure that that each value in channel_subset is in alphabetical order.
            Email,PPC and PPC,Email are the same in regards to this analysis and
            should be combined under Email,PPC.
    '''
    # casting the subset into dict, and getting the unique channels
    c_values = df.set_index(channel_name).to_dict()[conv_name]
    df['channels'] = df[channel_name].apply(lambda x: x if len(x.split(",")) == 1 else np.nan)
    channels = list(df['channels'].dropna().unique())

    v_values = {}
    for A in power_set(channels):  # generate all possible channel combination
        v_values[','.join(sorted(A))] = v_function(A, c_values)
    n = len(channels)  # no. of channels
    shapley_values = defaultdict(int)

    for channel in channels:
        for A in v_values.keys():
            if channel not in A.split(","):
                cardinal_A = len(A.split(","))
                A_with_channel = A.split(",")
                A_with_channel.append(channel)
                A_with_channel = ",".join(sorted(A_with_channel))
                weight = (factorial(cardinal_A)*factorial(n-cardinal_A-1)/factorial(n))  # Weight = |S|!(n-|S|-1)!/n!
                contrib = (v_values[A_with_channel]-v_values[A])  # Marginal contribution = v(S U {i})-v(S)
                shapley_values[channel] += weight * contrib
        # Add the term corresponding to the empty set
        shapley_values[channel] += v_values[channel]/n

    return shapley_values


def get_ips_for_dimension(dimension, adserver_df, pixel_df):
    # Merge Adserver and Pixel data for dimension
    merged_df = pd.merge(
        pixel_df,
        adserver_df,
        how="left",
        left_on=["ip_address"],
        right_on=["ip_address"],
    )
    merged_df['adserver_timestamp'] = pd.to_datetime(merged_df['adserver_timestamp'])
    merged_df["keep_time"] = np.where(
        merged_df["event_timestamp"] > merged_df["adserver_timestamp"], 1, 0
    )
    merged_df1 = merged_df[merged_df["keep_time"].isin([1])]

    del pixel_df
    del adserver_df

    sorted_df = merged_df1[["ip_address", dimension]].sort_values(
        by=["ip_address", dimension]
    )
    sorted_df["converted"] = 1
    del merged_df
    del merged_df1

    sorted_df.drop_duplicates(inplace=True)
    return sorted_df


def mta_process(dimension, adserver_df, pixel_df):
    try:
        sorted_df = get_ips_for_dimension(dimension, adserver_df, pixel_df)
        grouped_ip_df = sorted_df.groupby(['ip_address'], as_index=False).agg({dimension: lambda x: ','.join(map(str, x.unique())), 'converted': max})

        #create dummies for those channels without singular factorials
        dummy_list = list(set(list(sorted_df[dimension])))
        lens_1 = len(dummy_list)
        lens_2 = range(lens_1)
        dummies_df = pd.DataFrame(list(zip(lens_2, dummy_list)), columns=('ip_address', dimension))
        dummies_df['converted'] = 0.00001
        added_dummies_df = pd.concat([grouped_ip_df, dummies_df])
        del sorted_df
        del dummies_df

        grouped_with_dummies_df = added_dummies_df.groupby([dimension], as_index=False)['converted'].sum()
        del added_dummies_df

        grouped_kpi_df = grouped_with_dummies_df.groupby([dimension], as_index=False)['converted'].sum()

        del grouped_with_dummies_df

        shapley_dict = calculate_shapley(grouped_kpi_df, dimension, 'converted')
        del grouped_kpi_df
        
        # return shapley dictionary for upload
        return shapley_dict
    except Exception as e:
        logger.error(f"ERROR: MTA.Process Failed: {e}")
        raise


def process_mta_by_channel_and_event_date(channel, adserver_df, pixel_df):
    return process_mta_by_event_date(
        adserver_df[adserver_df['channel_name'].str.lower() == channel.lower()],
        pixel_df
    )

def process_mta_by_event_date(adserver_df, pixel_df):
    date_dict = {}
    if adserver_df.empty:
        logger.error('No click impression records returned from Adserver')
    else:
        # Run MTA Attribution for Impression Dates
        shapley_dict = mta_process("event_date", adserver_df, pixel_df)
        shap_total = 0.0

        for key in shapley_dict:
            if key != "nan":
                shap_total += float(shapley_dict[key])

        for key in shapley_dict:
            if key != "nan":
                date_dict[key] = float(shapley_dict[key])/float(shap_total)

    return date_dict


############################################
############# Utility Functions ############
############################################
        
        
def upload_results(data, branch, s3_destination):
    global S3, S3_TARGET_BUCKET
    try:
        target_bucket = S3_TARGET_BUCKET
        logger.info(f'Uploading file to S3 to prefix: {s3_destination}')
        output = BytesIO()
        pd.DataFrame(data).to_csv(output, mode='wb', encoding='UTF-8', index=False)
        output.seek(0)
        S3.upload_fileobj(output, target_bucket, s3_destination)
    except Exception as e:
        logger.error(f'Error: upload_results(): {e}')
        raise

def upload_results_parquet(data, branch, s3_destination):
    global S3, S3_TARGET_BUCKET_PARQUET
    try:
        target_bucket = S3_TARGET_BUCKET_PARQUET
        logger.info(f'Uploading file to S3 to prefix: {s3_destination}')
        table = pa.Table.from_pandas(data)     
        buf = BytesIO()     
        pq.write_table(table, buf, compression='SNAPPY')
        S3.put_object(Bucket=target_bucket, Key=s3_destination, Body=buf.getvalue())
    except Exception as e:
        logger.error(f'Error: upload_results(): {e}')
        raise



############################################
############# Channel Pipeline #############
############################################

def parse_timestamp(time_stamp):
    try:
        time_stamp
        if "." in time_stamp:
            return datetime.strptime(time_stamp.split(".")[0], "%Y-%m-%d %H:%M:%S")
        else:
            return datetime.strptime(time_stamp, "%Y-%m-%d %H:%M:%S")
    except:
        return ""


def load_display_df(clientkey: str, impression_date: str, pixel_date: str) -> pd.DataFrame:
    global ATHENA_DATABASE
    channel_dict = {25968151: 'display', 26065320: 'Dart Search', 29910233: 'PPC Display'}

    if clientkey.lower() != 'constantcontact':
        return pd.DataFrame()
    else:
        table_client_alias = 'ctct'

    query = f'''SELECT
        interaction_timestamp as event_date,
        interaction_timestamp as adserver_timestamp,
        campaign_id,
        other_data
    FROM {ATHENA_DATABASE}.bigquery_export_{table_client_alias}_cm360_activity
    WHERE cast(interaction_timestamp as date) between cast('{impression_date}' as date) and cast('{pixel_date}' as date)
        AND activity_id=11247128 and conversion_id > 0 
    '''

    display = run_athena_query(query)
    display['event_date'] = [datetime.fromisoformat(i).date() for i in display['event_date']]
    display['adserver_timestamp'] = pd.to_datetime(display['adserver_timestamp'])

    display['channel_name'] = display['campaign_id'].map(channel_dict)
    display['soid_front'] = np.where(display['other_data'].str.lower().str.find('u5=') > -1, display['other_data'].str.lower().str.find('u5=')+3, -1)

    f_cut = []
    for i, i, s in list(zip(display['soid_front'], display['soid_front'], display['other_data'].str.lower())):
        f_cut.append(s[i:i+13])

    display['soid'] = f_cut
    return display   


def get_display_matches_df(clientkey: str, impression_date: str, pixel_date: str, pixel_df: pd.DataFrame) -> pd.DataFrame:
    if clientkey.lower() != 'constantcontact':
        return pd.DataFrame()

    if pixel_df.empty:
        return pixel_df

    display_df = load_display_df(clientkey, impression_date, pixel_date)
    if display_df.empty:
        return display_df

    # pull out "soid" from uri_query and append as field to later match display impressions
    pixel_df['soid_front'] = np.where(pixel_df['uri_query'].str.lower().str.find('soid=') > -1, pixel_df['uri_query'].str.lower().str.find('soid=')+5, -1)
    f_cut1 = []
    for i, i, s in list(zip(pixel_df['soid_front'], pixel_df['soid_front'], pixel_df['uri_query'].str.lower())):
        f_cut1.append(s[i:i+13])

    pixel_df['soid'] = f_cut1

    ips = pd.DataFrame(list(zip(pixel_df['ip_address'], pixel_df['soid'])), columns=('ip_address', 'soid'))

    # match display impressions on soid
    display_1 = display_df[display_df['channel_name'].isin(['display'])]
    display_2 = pd.merge(display_1, ips, how='left', left_on=['soid'], right_on=['soid'])
    display_2['keeper'] = np.where(display_2['ip_address'] != '', 1, 0)
    display1 = display_2[display_2['keeper'].isin([1])].dropna()

    return pd.DataFrame(list(zip(display1['channel_name'], display1['ip_address'], display1['adserver_timestamp'], display1['event_date'])), columns=('channel_name', 'ip_address', 'adserver_timestamp', 'event_date'))



def format_records(
        clientkey,
        dimension,
        kpi,
        pixel_date,
        lookback_days,
        shapley_dict,
        s3_destination,
        first_write
    ):
    
    first_write = True
    try:
        pix_year = datetime.fromisoformat(pixel_date).year
        pix_month = datetime.fromisoformat(pixel_date).month
        pix_day = datetime.fromisoformat(pixel_date).day
        event_date = pixel_date.split(" ")[0]
        advertiser_name = clientkey
        kpi_name = kpi
        results = dict(shapley_dict)
        process_row = 0
        csv_records = []
        csv_columns = [
            "event_date",
            "imp_date",
            "advertiser_name",
            "channel_name",
            "dimension",
            "kpi_name",
            "kpi_conversions",
            "lookback_days",
            "s3_source_file",
            "ingest_timestamp",
            "row_id",
        ]

        for item in results:
            channel_name = None

            if dimension == "channel_name" and item != "nan":
                channel_name = str(item)
            elif item == "nan":
                continue
            else:
                logger.error(f"ERROR: format_records(): Invalid Dimension: {dimension}")
                raise Exception(f'ERROR: format_records(): Invalid Dimension: {dimension}')

            for row in results[item]:
                imp_date = row
                kpi_conversions = results[item][row]
                if kpi_conversions == 0:
                    continue

                new_row = {
                    "event_date": event_date,
                    "imp_date": imp_date,
                    "advertiser_name": advertiser_name,
                    "channel_name": channel_name,
                    "dimension": dimension,
                    "kpi_name": kpi_name,
                    "kpi_conversions": kpi_conversions,
                    "lookback_days": int(lookback_days),
                    "s3_source_file": s3_destination,
                    "ingest_timestamp": parse_timestamp(str(datetime.now())),
                    "row_id": str(uuid.uuid4()),
                }
                csv_records.append(new_row)
                process_row += 1

        return csv_records
    
    except Exception as e:
        logger.error(f"ERROR: Error in format_records(). Exception: {e}")
        raise


def save_channel_conversions(
    postgres_connection, advertiser_name, event_name, event_date, adserver_df, pixel_df
):
    df = get_ips_for_dimension('channel_name', adserver_df, pixel_df)

    if df.empty:
        return

    df = df.groupby('channel_name')["converted"].sum().reset_index()
    df.rename({"channel_name": "channel", "converted": "total"}, axis=1, inplace=True)

    df["event_date"] = pd.to_datetime(event_date).date()
    df["event_name"] = event_name
    df["advertiser_name"] = advertiser_name
    df["updated_at"] = pd.Timestamp.utcnow()

    records = df[['event_date', 'event_name', 'advertiser_name']].drop_duplicates().to_dict('records')
    conditions = []
    for record in records:
        inner = []
        for key in record:
            inner.append(f"""{key} = '{record[key]}'""")
        conditions.append(' AND '.join(inner))

    if len(conditions) == 0:
        return

    # delete channel conversions for this group
    with postgres_connection.cursor() as cursor:
        cursor.execute(f"""
            DELETE FROM public.channel_events
            WHERE ({') OR ('.join(conditions)})
        """)
        postgres_connection.commit()

    wr.postgresql.to_sql(
        df=df,
        con=postgres_connection,
        table="channel_events",
        schema='public',
        use_column_names=True,
        mode="upsert",
        upsert_conflict_columns=[
            "event_date",
            "event_name",
            "advertiser_name",
            "channel",
        ],
    )

    del df

def get_pixel_data(clientkey, kpi, source_channel_pixel, pixel_date, impression_date, dataset_start_date, source_pixel_s3ingest_exclusion):
    global US_WHERE_COND
    logger.info(f"Getting pixel data for {kpi}")
    if clientkey == 'newstwelveny':
        clientkey = 'news_twelve_ny'

    if clientkey == 'constantcontact':
        columns = 'DISTINCT ip_address, event_timestamp, uri_query'
    else:
        columns = 'DISTINCT ip_address, event_timestamp'

    import_query = f"""
        SELECT {columns}
            FROM {source_channel_pixel} px 
            WHERE advertiser_name = '{clientkey}'
                AND DATE(event_date) = DATE('{pixel_date}') 
                AND DATE(event_date) >= DATE('{dataset_start_date}')               
                AND event_name = '{kpi}' 
                AND s3_source_file NOT LIKE '{source_pixel_s3ingest_exclusion}%'
                {US_WHERE_COND}
                AND ip_address <> '' AND ip_address NOT IN (
                    SELECT DISTINCT ip_address 
                        FROM {source_channel_pixel} 
                        WHERE advertiser_name = '{clientkey}'
                            AND DATE(event_date) > DATE('{impression_date}') 
                            AND s3_source_file NOT LIKE '{source_pixel_s3ingest_exclusion}%'
                            AND DATE(event_date) <= CAST('{pixel_date}' AS TIMESTAMP) - INTERVAL '1' DAY
                            AND event_name = '{kpi}'
                            {US_WHERE_COND}
                )
            
    """
    pixel_df = run_athena_query(import_query)
    return pixel_df

# parse the channel config json
def parse_channel_config(channel_config):
    channel_type = ''
    channel_keys = ["utm_source", "xtm_source", "pn"]
    channel_length = 1000
    excluded_channels = []
    labels = {}
    if channel_config is None or len(channel_config) == 0:
        return channel_type, channel_keys, channel_length, excluded_channels, labels
    try:
        config_json = json.loads(channel_config)
        if "channel_type" in config_json and len(config_json["channel_type"]) > 0:
            channel_type = config_json["channel_type"]
        if "channel_keys" in config_json and len(config_json["channel_keys"]) > 0:           
            channel_keys = [x.lower() for x in config_json["channel_keys"]]
        if "channel_length" in config_json and config_json["channel_length"] > 0:
            channel_length = config_json["channel_length"]
        if "channel_attributes" in config_json:
            channel_attributes = config_json["channel_attributes"]
            for x in channel_attributes:
                if "name" in x:
                    name = x["name"].lower()
                    label = name
                    if "label" in x:
                        label = x["label"]
                        labels[name] = label
                    if "excluded" in x and x["excluded"]:
                        excluded_channels.append(name)
    except Exception as e:
        message = f"exception while parsing channel config: {e}"
        logger.warn(message)
    return channel_type, channel_keys, channel_length, excluded_channels, labels

# returns a function that parses the channel name from url
def parse_channel_name(channel_type, channel_keys, channel_length):
    pattern = re.compile(r'[_\W]+') # matches underscore or non-alphanumeric
    def get_channel_name(site_url):
        if pd.isna(site_url):
            return ''
        query_params = site_url.lower().split('?')
        if len(query_params) == 1: # not found
            query_params = site_url.lower().split('%3f') # encoded '?'
            if len(query_params) == 1: # again not found
                return ''
        query_params = query_params[-1]  # changing query_params[1] to query_params[-1] because some site_urls are badly formatted (contain two instances of '?')
        kv = {}
        for x in query_params.split('&'):
            tmp = x.split('=')
            k = tmp[0]
            v = '='.join(tmp[1:])
            kv[k] = v
        if channel_type == 'z': # zazzle. expect channel_keys = ['utm_source', 'utm_medium']. utm_source value should begin with 0.
            if (len(channel_keys) > 1 and
                channel_keys[1] in kv and kv[channel_keys[1]] == 'email' and
                channel_keys[0] in kv and kv[channel_keys[0]][0] == '0'):
                return 'email'
        if channel_type == 'p': # pepper. expect channel_keys = ['utm_source', 'utm_medium'].
            if len(channel_keys) > 1 and channel_keys[0] in kv and kv[channel_keys[0]] == 'google':
                if channel_keys[1] in kv and kv[channel_keys[1]] == 'cpc':
                    return 'google'
                return ''
        first_key = True
        for key in channel_keys:
            if key in kv:
                name = kv[key]
                if first_key:
                    name = name[:channel_length] # truncate
                    name = pattern.sub('', name) # remove offending chars
                return name
            first_key = False
        return ''
    return get_channel_name

def get_adserver_data(clientkey, kpi, source_channel_pixel, source_ott_adserver, source_audio_adserver, source_digital_adserver, 
                      include_ott, include_audio, include_digital,
                      pixel_date, impression_date, dataset_start_date, source_pixel_s3ingest_exclusion, athena_pixel_table, channel_type, channel_keys,
                      channel_length, excluded_channels, labels):   
    global US_WHERE_COND
    import_query_ott = ""
    import_query_audio = ""
    import_query_digital = ""
    import_query_click = ""
    
    if clientkey == 'newstwelveny':
        clientkey = 'news_twelve_ny'
    
    # if clientkey in ['overstock', 'goldbelly']:
    #     cp_where_condition = "AND source = 'casualprecision'"
    # else:
    #     cp_where_condition = ''
    cp_where_condition = ''

    logger.info(f"Getting adserver data for channel_name for {kpi} ({pixel_date})")
        
    if source_ott_adserver is not None and source_ott_adserver != '' and include_ott is True:
        import_query_ott = f""" 
                            SELECT DISTINCT
                                0 as is_click,
                                'OTT' as channel_name,
                                ads.ip_address,
                                ads.event_timestamp as adserver_timestamp
                            FROM {source_ott_adserver} ads
                            WHERE LOWER(ads.media_type) = 'ott'
                                {cp_where_condition}
                                AND ads.advertiser_name = '{clientkey}'
                                AND DATE(ads.event_timestamp) > DATE('{impression_date}') AND DATE(ads.event_timestamp) <= DATE('{pixel_date}') AND DATE(ads.event_timestamp) >= DATE('{dataset_start_date}') 
                                AND ads.ip_address IN (SELECT DISTINCT ip_address FROM converted_ips)

                            UNION ALL
                        """

    if source_audio_adserver is not None and source_audio_adserver != '' and include_audio is True:
        import_query_audio = f""" 
                            SELECT DISTINCT
                                0 as is_click,
                                'Audio' as channel_name,
                                ads.ip_address,
                                ads.event_timestamp as adserver_timestamp
                            FROM {source_audio_adserver} ads
                            WHERE LOWER(ads.media_type) = 'audio'
                                {cp_where_condition}
                                AND ads.advertiser_name = '{clientkey}'
                                AND DATE(ads.event_timestamp) > DATE('{impression_date}') AND DATE(ads.event_timestamp) <= DATE('{pixel_date}') AND DATE(ads.event_timestamp) >= DATE('{dataset_start_date}') 
                                AND ads.ip_address IN (SELECT DISTINCT ip_address FROM converted_ips)

                            UNION ALL
                        """

    if source_digital_adserver is not None and source_digital_adserver != '' and include_digital is False:
        import_query_digital = f""" 
                            SELECT DISTINCT
                            0 as is_click,
                            'Display' as channel_name,
                            ads.ip_address,
                            ads.event_timestamp as adserver_timestamp
                            FROM {source_digital_adserver} ads
                            WHERE LOWER(ads.media_type) = 'digital'
                                {cp_where_condition}
                                AND ads.advertiser_name = '{clientkey}'
                                AND DATE(ads.event_timestamp) > DATE('{impression_date}') AND DATE(ads.event_timestamp) <= DATE('{pixel_date}') AND DATE(ads.event_timestamp) >= DATE('{dataset_start_date}') 
                                AND ads.ip_address IN (SELECT DISTINCT ip_address FROM converted_ips)                                             

                            UNION ALL
                        """

    import_query_click = f""" 
                        SELECT DISTINCT
                            1 as is_click,
                            click.site_url as channel_name,
                            click.ip_address,
                            COALESCE(TRY(CAST(DATE_PARSE(click.event_timestamp, '%Y-%m-%d %H:%i:%s.%f') AS TIMESTAMP) )
                            ,TRY(CAST(click.event_timestamp AS TIMESTAMP) )) as adserver_timestamp
                        FROM {source_channel_pixel} click
                        WHERE advertiser_name = '{clientkey}'
                            AND click.utm_source <> '' AND click.utm_source IS NOT NULL AND click.traffic_source!='organic'
                            AND DATE(COALESCE(TRY(CAST(DATE_PARSE(click.event_timestamp, '%Y-%m-%d %H:%i:%s.%f') AS TIMESTAMP) )
                            ,TRY(CAST(click.event_timestamp AS TIMESTAMP) ))) > DATE('{impression_date}')
                            AND DATE(COALESCE(TRY(CAST(DATE_PARSE(click.event_timestamp, '%Y-%m-%d %H:%i:%s.%f') AS TIMESTAMP) )
                            ,TRY(CAST(click.event_timestamp AS TIMESTAMP) ))) <= DATE('{pixel_date}')
                            AND DATE(COALESCE(TRY(CAST(DATE_PARSE(click.event_timestamp, '%Y-%m-%d %H:%i:%s.%f') AS TIMESTAMP) )
                            ,TRY(CAST(click.event_timestamp AS TIMESTAMP) ))) >= DATE('{dataset_start_date}') 
                            {US_WHERE_COND}
                            AND click.ip_address IN (SELECT DISTINCT ip_address FROM converted_ips)
                    """

    import_query = f"""
        WITH converted_ips AS (
            SELECT DISTINCT ip_address
            FROM {source_channel_pixel}
            WHERE advertiser_name = '{clientkey}'
            AND ip_address <> '' 
            AND s3_source_file NOT LIKE '{source_pixel_s3ingest_exclusion}%'
            AND DATE(event_date) = DATE('{pixel_date}') AND DATE(event_date) >= DATE('{dataset_start_date}') AND event_name = '{kpi}'
            {US_WHERE_COND}
            AND ip_address NOT IN (
                SELECT DISTINCT ip_address 
                FROM {source_channel_pixel} 
                WHERE advertiser_name = '{clientkey}'
                    AND DATE(event_date) > DATE('{impression_date}') 
                    AND s3_source_file NOT LIKE '{source_pixel_s3ingest_exclusion}%'
                    AND DATE(event_date) <= CAST('{pixel_date}' AS TIMESTAMP) - INTERVAL '1' DAY 
                    AND event_name = '{kpi}'
                    {US_WHERE_COND}
            )
        )
        SELECT
            is_click,
            channel_name,
            ip_address,
            adserver_timestamp,
            DATE(adserver_timestamp) as event_date
        FROM (
            {import_query_ott}
            {import_query_audio}
            {import_query_digital}
            {import_query_click}
        )
    """
    df = run_athena_query(import_query)
    df['adserver_timestamp'] = pd.to_datetime(df['adserver_timestamp'])

    # work on click data
    df_click = df[df['is_click'] == 1].copy().reset_index(drop=True)

    # remove is_click data from the frame
    df = df[df['is_click'] == 0]

    # convert site_url to channel name for click data
    df_click['channel_name'] = df_click['channel_name'].apply(parse_channel_name(channel_type, channel_keys, channel_length))

    # drop rows for excluded channels
    df_click = df_click[~(df_click['channel_name'].isin(excluded_channels))]

    # drop rows for empty channels
    df_click = df_click[~(df_click['channel_name'] == '')]

    # use label instead of channel name
    df_click['channel_name'] = df_click['channel_name'].apply(lambda name: labels[name] if name in labels else name)

    # union
    df = pd.concat((df, df_click), axis=0).reset_index(drop=True)
    return df

def pipeline_channel_run(clientkey, impression_date, pixel_date, kpi, lookback_days, 
                         include_ott, include_audio, include_digital, branch,
                         write_results=True):
    logger.debug(f'input vars are: {clientkey, impression_date, pixel_date, kpi, lookback_days, include_ott, include_audio, include_digital, branch}')
    global redshift_config, ATHENA_PIXEL_TABLE, REGION_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, ATHENA_OTT_ADSERVER, ATHENA_AUDIO_ADSERVER, ATHENA_DIGITAL_ADSERVER
    redshift_cursor = redshift_config.get_cursor()
    athena_pixel_table = ATHENA_PIXEL_TABLE
    source_ott_adserver = ATHENA_OTT_ADSERVER
    source_audio_adserver = ATHENA_AUDIO_ADSERVER
    source_digital_adserver = ATHENA_DIGITAL_ADSERVER

    global S3_DIRECTORY_PREFIX_CSV, S3_FILE_PREFIX_CSV, S3_DIRECTORY_PREFIX_PARQUET, S3_FILE_PREFIX_PARQUET
    
    pipeline_channel_run_start_time = perf_counter()
    try:        
        impression_date = datetime.fromisoformat(impression_date).date().isoformat()
        pixel_date = datetime.fromisoformat(pixel_date).date().isoformat()

        pix_year = datetime.fromisoformat(pixel_date).year
        pix_month = datetime.fromisoformat(pixel_date).month
        pix_day = datetime.fromisoformat(pixel_date).day
        pixel_date_date = datetime.fromisoformat(pixel_date).date().isoformat()
       
        # S3 PATH FOR channel
        upload_file_name_csv = f'{S3_FILE_PREFIX_CSV}_channel_{clientkey}_{kpi}_{lookback_days}day_{str(pix_year)}_{str("{:02d}".format(pix_month))}_{str("{:02d}".format(pix_day))}.csv'
        s3_destination_csv = f'{S3_DIRECTORY_PREFIX_CSV}_channel/{clientkey}/{str(pix_year)}/{str("{:02d}".format(pix_month))}/{upload_file_name_csv}'
        upload_file_name_parquet = f'{S3_FILE_PREFIX_PARQUET}_channel_{clientkey}_{kpi}_{lookback_days}day_{str(pix_year)}_{str("{:02d}".format(pix_month))}_{str("{:02d}".format(pix_day))}.parquet'
        s3_destination_parquet = f'{S3_DIRECTORY_PREFIX_PARQUET}_channel_parquet/advertiser_name={clientkey}/event_date={pixel_date_date}/{upload_file_name_parquet}'

        # GET CLIENT VARIABLES FROM REDSHIFT
        query = f"SELECT source_channel_pixel, source_ott_adserver, source_audio_adserver, source_digital_adserver FROM portal_global_settings.client_mappings WHERE LOWER(clientkey_map) = '{clientkey.lower()}' AND enabled_mta_channel = 'true'"
        redshift_cursor.execute(query)
        query_results = redshift_cursor.fetchall()

        if len(query_results) == 0:
            message = f"WARNING: No records returned from portal_global_settings.client_mappings.  Pipeline aborted"
            logger.error(message)
            raise Exception(message)

        # Channel source
        if query_results[0][0] is None or str(query_results[0][0]) == "":
            message = f"ERROR: There is no value for source_channel_pixel"
            logger.error(message)
            raise Exception(message)
        else:
#             source_channel_pixel = query_results[0][0]
            pass

        # OTT source
        if include_ott is False:
            source_ott_adserver = ""
        elif include_ott is True and (query_results[0][1] is None or str(query_results[0][1]) == ""):
            message = f"ERROR: include_ott is set to true but there is no value for source_ott_adserver"
            logger.error(message)
            raise Exception(message)
        else:
            pass
#             source_ott_adserver = query_results[0][1]

        # Audio source
        if include_audio is False:
            source_audio_adserver = ""
        elif include_audio is True and (query_results[0][2] is None or str(query_results[0][2]) == ""):
            message = f"ERROR: include_audio is set to true but there is no value for source_audio_adserver"
            logger.error(message)
            raise Exception(message)
        else:
            pass
#             source_audio_adserver = query_results[0][2]

        # Digital source
        if include_digital is False:
            source_digital_adserver = ""
        elif include_digital is True and (query_results[0][3] is None or str(query_results[0][3]) == ""):
            message = f"ERROR: include_digital is set to true but there is no value for source_digital_adserver"
            logger.error(message)
            raise Exception(message)
        else:
            pass
#             source_digital_adserver = query_results[0][3]

        # get client's dataset start/end dates and pixel s3ingest_exlusions fro global db
        query = f"SELECT dataset_start_date, dataset_end_date, NVL(source_pixel_s3ingest_exclusion,'none'), channel_config FROM portal_global_settings.client_mappings WHERE LOWER(clientkey_map) = '{clientkey.lower()}'"
        redshift_cursor.execute(query)
        query_results = redshift_cursor.fetchall()
        dataset_start_date = query_results[0][0]
        dataset_end_date = query_results[0][1]
        source_pixel_s3ingest_exclusion = query_results[0][2]
        channel_config = query_results[0][3]
        channel_type, channel_keys, channel_length, excluded_channels, labels = parse_channel_config(channel_config)

        # Get Pixel data for dimension attribution
        pixel_df = get_pixel_data(clientkey, kpi, athena_pixel_table, pixel_date, impression_date, dataset_start_date, source_pixel_s3ingest_exclusion)

        if pixel_df.empty:
            message = 'No records returned for Pixel'
            logger.error(message)
            raise Exception(message)
        else:
            logger.info(f"Loaded {len(pixel_df)} rows")

            # Process only kpi in list that had data for the lookback period
            # Loop through dimensions and run MTA. channel_name must always run first since all other compound values are based on it
            dimensions = ("channel_name",)
            first_write = True

            display_df = get_display_matches_df(clientkey, impression_date, pixel_date, pixel_df)

            for dimension in dimensions:
                # channel_name
                if dimension == "channel_name":
                    # Get Adserver data for dimension
                    adserver_df1 = get_adserver_data(clientkey, kpi, 
                        athena_pixel_table, source_ott_adserver, source_audio_adserver, source_digital_adserver, 
                        include_ott, include_audio, include_digital,
                        pixel_date, impression_date, dataset_start_date, source_pixel_s3ingest_exclusion, athena_pixel_table, channel_type, channel_keys,
                        channel_length, excluded_channels, labels)
                    
                    if not display_df.empty:
                        adserver_df1 = pd.concat([adserver_df1, display_df], ignore_index=True)
                        adserver_df1.reset_index(drop=True, inplace=True)

                    adserver_df1['event_date'] = adserver_df1['event_date'].astype(str)

                    adserver2=adserver_df1[~adserver_df1['channel_name'].isin(['OTT','Audio'])]
                    counts = adserver2['channel_name'].value_counts().rename_axis('unique_values').reset_index(name='counts')
                    counts['ratio']=counts['counts']/counts['counts'].sum()
                    counts['keeper1']=np.where(counts['ratio']>=.001,1,0)
                    counts1=counts[counts['keeper1'].isin([1])]
                    del adserver2
                    
                    cnts2=list(counts1['unique_values'])
                    cnts2.append('OTT')
                    cnts2.append('Audio')
        
                    keeps = dict(zip(cnts2,cnts2))
                    adserver_df1['dimension1']=adserver_df1['channel_name'].map(keeps)
                    adserver_df2=adserver_df1.dropna(subset=['dimension1'])
                    del adserver_df1
        
                    cnts=list(counts1['unique_values'])[:10]
                    cnts.append('OTT')
                    cnts.append('Audio')
                    cnts1=list(set(cnts))
                    keeps1 = dict(zip(cnts1,cnts1))
                    adserver_df2['channel_name']=adserver_df2['dimension1'].map(keeps1)
                    adserver_df=adserver_df2.fillna('other')
                    
                    if adserver_df.empty:
                        message = "No records returned from Pixel Data intersecting with Adserver Data"
                        logger.error(message)
                        raise Exception(message)
                    else:
                        logger.info(f"Loaded {len(adserver_df)} rows")

                        if dimension == "channel_name":
                            if write_results:
                                postgres_connection = wr.postgresql.connect(
                                    boto3_session=boto3.session.Session(aws_access_key_id=AWS_ACCESS_KEY_ID,
                                                                        aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
                                        region_name=REGION_NAME),
                                    secret_id="juice-rds-xmp-prod" if branch == 'master' else "juice-rds-xmp-dev",
                                )
                                save_channel_conversions(
                                    postgres_connection=postgres_connection,
                                    advertiser_name=clientkey,
                                    event_name=kpi,
                                    event_date=pixel_date,
                                    adserver_df=adserver_df,
                                    pixel_df=pixel_df
                                )
                                postgres_connection.close()

                        # Run MTA Attribution for all of the Channels
                        shapley_dict = mta_process(dimension, adserver_df, pixel_df)

                        # compute impression breakdown by day
                        combined_dict = {}
                        for key in shapley_dict:
                            if key != "nan":
                                combined_dict[key] = {}
                                date_dict = process_mta_by_channel_and_event_date(key, adserver_df, pixel_df)
                                for k in date_dict:
                                    combined_dict[key][k] = date_dict[k] * shapley_dict[key]

                        if not combined_dict:
                            raise Exception('MTA result for channel is empty')

                        logger.info(f"Formatting data for {dimension} for {kpi}")
                        channel_results = format_records(
                            clientkey,
                            dimension,
                            kpi,
                            pixel_date,
                            lookback_days,
                            combined_dict,
                            s3_destination_csv,
                            first_write
                        )
                        logger.info(f'Channel run for clientkey: {clientkey}, KPI: {kpi} finished in {round(perf_counter() - pipeline_channel_run_start_time, 3)} seconds')
                
                        channel_results = pd.DataFrame(channel_results)
                        logger.debug(f'channel_results:\n{channel_results}')
                        channel_results['ingest_timestamp'] = channel_results['ingest_timestamp'].astype(str)

                        if write_results:
                            upload_results_parquet(channel_results, branch, s3_destination_parquet)
                            upload_results(channel_results.to_dict(orient='records'), branch, s3_destination_csv)
                        return channel_results, shapley_dict

    except Exception as e:
        message = f"ERROR: pipeline_channel.Run Failed: {e}"
        logger.error(message)
        raise Exception(message)


############################################
########## OTT and Audio Pipeline ##########
############################################


class RunContext:
    media_type: str
    branch: str

    client: str
    kpi: str
    pixel_date: str
    lookback_days: int
    max_conversions: float

    # from client globals
    pixel_table: str
    adserver_table: str
    dataset_start_date: str
    dataset_end_date: str
    source_pixel_s3ingest_exclusion: str

    output_region: str
    output_bucket: str
    output_file_prefix: str
    output_file_name_csv: str
    output_file_path_csv: str
    output_file_name_parquet: str
    output_file_path_parquet: str


class ClientSettings:
    pixel_table: str
    adserver_table: str
    dataset_start_date: str
    dataset_end_date: str
    source_pixel_s3ingest_exclusion: str


def load_client_settings(
        client: str, media_type: str
    ) -> ClientSettings:
    global redshift_config
    redshift_cursor = redshift_config.get_cursor()

    query = f"""
        SELECT
            source_{media_type}_pixel,
            source_{media_type}_adserver,
            dataset_start_date,
            dataset_end_date,
            NVL(source_pixel_s3ingest_exclusion, 'none') as source_pixel_s3ingest_exclusion
        FROM portal_global_settings.client_mappings
        WHERE LOWER(clientkey_map) = '{client.lower()}' AND enabled_mta_{media_type} = 'true'
    """
    redshift_cursor.execute(query)
    settings = redshift_cursor.fetchone()

    if settings is None:
        return None

    return {
        "pixel_table": settings[0],
        "adserver_table": settings[1],
        "dataset_start_date": settings[2],
        "dataset_end_date": settings[3],
        "source_pixel_s3ingest_exclusion": settings[4],
    }

def converted_ips_query_athena(context: RunContext, fields: str):
    global US_WHERE_COND
    pixel_table = context['athena_pixel_table']
    if context['client'] == 'newstwelveny':
        clientkey = 'news_twelve_ny'
    else:
        clientkey = context['client']
    return f"""
        SELECT {fields}
        FROM {pixel_table}
        WHERE advertiser_name = '{clientkey}'
        AND ip_address <> '' 
        AND s3_source_file NOT LIKE '{context['source_pixel_s3ingest_exclusion']}%'
        AND date(event_date) = date('{context['pixel_date']}')
        AND date(event_date) >= date('{context['dataset_start_date']}')
        AND event_name = '{context['kpi']}'
        {US_WHERE_COND}
        AND ip_address NOT IN (
            SELECT DISTINCT ip_address 
            FROM {pixel_table} 
            WHERE advertiser_name = '{clientkey}'
                AND date(event_date) > date('{context['lookback_date']}')
                AND s3_source_file NOT LIKE '{context['source_pixel_s3ingest_exclusion']}%'
                AND date(event_date) <= CAST('{context['pixel_date']}' AS TIMESTAMP) - INTERVAL '1' DAY
                AND event_name = '{context['kpi']}'
                {US_WHERE_COND}
        )
    """


def load_pixels_df(context: RunContext):
    logger.info(f"""MTA[{context['media_type']}]: Loading pixels for KPI "{context['kpi']}"...""")
    return run_athena_query(converted_ips_query_athena(context, "DISTINCT ip_address, event_timestamp"))


def enrich_adserver(ads, media_type):
    if media_type == 'ott':
        genre_map = get_genre_mappings()
        ads = merge_adserver_genre(ads, genre_map)
    elif media_type == 'audio':
        ads['genre_name'] = 'audio'
    return ads


def merge_adserver_genre(ads, genre_map):
    # join only for media_type == 'ott'
    ads['publisher_name_lower']   = ads['publisher_name'].str.lower()
    genre_map['network_name_map_lower'] = genre_map['network_name_map'].str.lower()
    df = ads.merge(genre_map, how='left', left_on='publisher_name_lower', right_on='network_name_map_lower')
    df['genre_name'] = df['genre_name_map'].fillna('Entertainment')
    df = df[['genre_name', 'publisher_name', 'audience_name', 'creative_isci', 'device_type', 'adserver_timestamp', 'event_date', 'ip_address']]
    return df


def get_genre_mappings():
    query = """
        select network_name_map, genre_name_map from portal_global_settings.genre_mappings genres 
        where genres.media_type_map = 'ott'
    """
    df = run_redshift_query(query)
    return df


def load_impressions_df(context: RunContext):
    logger.info(
        f"""MTA[{context['media_type']}]: Loading impressions for KPI "{context['kpi']}"..."""
    )

    # if context['client'] in ['overstock', 'goldbelly']:
    #     cp_where_condition = "AND source = 'casualprecision'"
    # else:
    #     cp_where_condition = ''
    cp_where_condition = ''
        
    query = f"""
    WITH converted_ips AS (
        {converted_ips_query_athena(context, 'DISTINCT ip_address')}
    )
    SELECT DISTINCT
        COALESCE(ads.publisher_name, 'Unknown') as publisher_name,
        COALESCE(ads.audience_name, 'Unknown') as audience_name,
        COALESCE(ads.creative_isci, 'Unknown') as creative_isci,
        CASE
            WHEN LOWER(ads.device_type) = 'desktop' THEN 'DESKTOP'
            WHEN LOWER(ads.device_type) = 'tv' THEN 'TV'
            WHEN LOWER(ads.device_type) = 'tablet' THEN 'TABLET'
            WHEN LOWER(ads.device_type) = 'phablet' THEN 'TABLET'
            WHEN LOWER(ads.device_type) = 'none' THEN 'UNKNOWN'
            WHEN LOWER(ads.device_type) = 'television' THEN 'TV'
            WHEN LOWER(ads.device_type) = 'mobile' THEN 'MOBILE'
            WHEN LOWER(ads.device_type) = 'smart speaker' THEN 'MOBILE'
            WHEN LOWER(ads.device_type) = 'camera' THEN 'MOBILE'
            WHEN LOWER(ads.device_type) = 'smartphone' THEN 'MOBILE'
            WHEN LOWER(ads.device_type) = 'feature phone' THEN 'MOBILE'
            WHEN LOWER(ads.device_type) = 'console' THEN 'MOBILE'
            WHEN LOWER(ads.device_type) = 'smart display' THEN 'MOBILE'
            WHEN LOWER(ads.device_type) = 'portable media player' THEN 'MOBILE'
            WHEN LOWER(ads.device_type) = 'car' THEN 'MOBILE'
            ELSE 'UNKNOWN'
        END AS device_type,
        ads.event_timestamp as adserver_timestamp,
        DATE(ads.event_timestamp) as event_date,
        ads.ip_address
    FROM {context['athena_adserver_table']} ads
    WHERE LOWER(ads.media_type) = '{context['media_type']}'
    {cp_where_condition}
    AND ads.advertiser_name = '{context['client']}'
    AND DATE(ads.event_timestamp) > DATE('{context['lookback_date']}')
    AND DATE(ads.event_timestamp) <= DATE('{context['pixel_date']}')
    AND DATE(ads.event_timestamp) >= DATE('{context['dataset_start_date']}')
    AND ads.ip_address IN (SELECT DISTINCT ip_address FROM converted_ips)
    """    

    df_tmp = run_athena_query(query)
    df = enrich_adserver(df_tmp, context['media_type'])
    df["event_date"] = df["event_date"].astype(str)
    return df


class DimensionColumns:
    genre_name: str
    publisher_name: str
    audience_name: str
    creative_isci: str


def get_dimension_rows(
        context: RunContext,
        dimension_name: str,
        dimension_column: str,
        dimension_columns: DimensionColumns,
        date_shapley_values: dict,
    ):
    assert isinstance(dimension_columns, dict), 'invalid "dimension_columns" argument'
    assert "genre_name" in dimension_columns, 'missing "genre_name"'

    dimension_columns.setdefault("audience_name", "")
    dimension_columns.setdefault("creative_isci", "")
    dimension_columns.setdefault("device_type", "")

    rows = []

    for k in date_shapley_values:
        dimension_columns[dimension_column] = k

        for impression_date in date_shapley_values[k]:
            assert "publisher_name" in dimension_columns, 'missing "publisher_name"'

            rows.append(
                {
                    "event_date": context["pixel_date"],
                    "imp_date": impression_date,
                    "advertiser_name": context["client"],
                    "genre_name": dimension_columns["genre_name"],
                    "publisher_name": dimension_columns["publisher_name"],
                    "audience_name": dimension_columns["audience_name"],
                    "creative_isci": dimension_columns["creative_isci"],
                    "device_type": dimension_columns["device_type"],
                    "dimension": dimension_name,
                    "kpi_name": context["kpi"],
                    "kpi_conversions": date_shapley_values[k][impression_date],
                    "lookback_days": context["lookback_days"],
                    "s3_source_file": context["output_file_path_csv"],
                    "ingest_timestamp": datetime.now(),
                    "row_id": str(uuid.uuid4()),
                }
            )

    return rows


def get_kpi_responses_for_dimension(
        dimension_column: str,
        dimension_df: pd.DataFrame,
        pixel_df: pd.DataFrame,
        max_conversions: float,
    ):
    shapley_values_by_date = {}
    shapley_values = mta_process(
        dimension_column,
        dimension_df,
        pixel_df,
    )
    shapley_values_sum = sum(shapley_values.values())

    # apply proportional values based on the maximum amount of conversions.
    # we do this because the channel MTA run assigns partial credit to media type and
    # we need to not let totals exceed those.

    proportional_ratio = 0
    if max_conversions > 0 and shapley_values_sum > 0:
        proportional_ratio = max_conversions / shapley_values_sum

    for shapley_dimension in shapley_values:
        shapley_values[shapley_dimension] *= proportional_ratio

        shapley_values_by_date_for_dimension = mta_process(
            "event_date",
            dimension_df[dimension_df[dimension_column] == shapley_dimension],
            pixel_df,
        )

        shapley_values_by_date_for_dimension_total = sum(
            shapley_values_by_date_for_dimension.values()
        )

        shapley_values_by_date[shapley_dimension] = {}
        for date in shapley_values_by_date_for_dimension:
            shapley_values_by_date[shapley_dimension][date] = (
                shapley_values_by_date_for_dimension[date]
                / shapley_values_by_date_for_dimension_total
            ) * shapley_values[shapley_dimension]

    return {"totals": shapley_values, "event_dates": shapley_values_by_date}


# this is pulled from the original way MTA was ran, but should probably do shapley values
# against each genre then do proportional values against shapley values
def get_genre_totals(context: RunContext, impression_df: pd.DataFrame):
    genre_counts = (
        impression_df[["genre_name", "ip_address"]]
        .drop_duplicates()
        .groupby(["genre_name"])
        .size()
        .to_dict()
    )
    genre_total = sum(genre_counts.values())
    genre_totals = {}
    for k in genre_counts:
        genre_totals[k] = 0
        if genre_total > 0:
            genre_totals[k] = (genre_counts[k] / genre_total) * context["max_conversions"]

    return genre_totals


def get_kpi_responses(
        context: RunContext, impression_df: pd.DataFrame, pixel_df: pd.DataFrame
    ):
    genre_totals = get_genre_totals(context, impression_df)

    rows = []
    for genre_name, genre_df in impression_df.groupby("genre_name"):
        publisher_shapley_values = get_kpi_responses_for_dimension(
            dimension_column="publisher_name",
            dimension_df=genre_df,
            pixel_df=pixel_df,
            max_conversions=genre_totals[genre_name],
        )

        rows += get_dimension_rows(
            context=context,
            dimension_name="publisher_name",
            dimension_column="publisher_name",
            dimension_columns={
                "genre_name": genre_name,
            },
            date_shapley_values=publisher_shapley_values["event_dates"],
        )

        for k, publisher_df in genre_df.groupby(["genre_name", "publisher_name"]):
            genre_name, publisher_name = k

            # calculate results for audience_name and don't let it exceed publisher totals
            audience_shapley_values = get_kpi_responses_for_dimension(
                dimension_column="audience_name",
                dimension_df=publisher_df,
                pixel_df=pixel_df,
                max_conversions=publisher_shapley_values["totals"][publisher_name],
            )
            rows += get_dimension_rows(
                context=context,
                dimension_name="publisher_audience",
                dimension_column="audience_name",
                dimension_columns={
                    "publisher_name": publisher_name,
                    "genre_name": genre_name,
                },
                date_shapley_values=audience_shapley_values["event_dates"],
            )

            # calculate creative_isci results within each audience group and don't let it
            # exceed audience totals
            for audience_name, audience_df in publisher_df.groupby("audience_name"):
                creative_shapley_values = get_kpi_responses_for_dimension(
                    dimension_column="creative_isci",
                    dimension_df=audience_df,
                    pixel_df=pixel_df,
                    max_conversions=audience_shapley_values["totals"][audience_name],
                )

                rows += get_dimension_rows(
                    context=context,
                    dimension_name="publisher_audience_creative",
                    dimension_column="creative_isci",
                    dimension_columns={
                        "publisher_name": publisher_name,
                        "genre_name": genre_name,
                        "audience_name": audience_name,
                    },
                    date_shapley_values=creative_shapley_values["event_dates"],
                )

                # calculate device_type results within each creative_isci group and don't let it
                # exceed creative totals
                for creative_isci, creative_df in audience_df.groupby("creative_isci"):
                    if creative_shapley_values["totals"][creative_isci] == 0:
                        continue

                    device_shapley_values = get_kpi_responses_for_dimension(
                        dimension_column="device_type",
                        dimension_df=creative_df,
                        pixel_df=pixel_df,
                        max_conversions=creative_shapley_values["totals"][
                            creative_isci
                        ],
                    )

                    rows += get_dimension_rows(
                        context=context,
                        dimension_name="publisher_audience_creative_devicetype",
                        dimension_column="device_type",
                        dimension_columns={
                            "publisher_name": publisher_name,
                            "genre_name": genre_name,
                            "audience_name": audience_name,
                            "creative_isci": creative_isci,
                        },
                        date_shapley_values=device_shapley_values["event_dates"],
                    )

    if len(rows) == 0:
        return pd.DataFrame()

    df = pd.DataFrame.from_records(rows)
    df["ingest_timestamp"] = df["ingest_timestamp"].astype("datetime64[s]")

    return df


def get_run_context(
        media_type: str,
        clientkey: str,
        kpi: str,
        pixel_iso_date: str,
        impression_iso_date: str,
        lookback_days: int,
        max_conversions: float,
        output_region: str,
        output_bucket: str,
        output_prefix: str,
        source_pixel: str
    ) -> RunContext:
    global S3_DIRECTORY_PREFIX_CSV, S3_FILE_PREFIX_CSV, S3_DIRECTORY_PREFIX_PARQUET, S3_FILE_PREFIX_PARQUET, ATHENA_PIXEL_TABLE, ATHENA_OTT_ADSERVER
    client_settings = load_client_settings(client=clientkey, media_type=media_type)
    client_settings['pixel_table'] = source_pixel

    pixel_dt = datetime.fromisoformat(pixel_iso_date)
    impression_dt = datetime.fromisoformat(impression_iso_date)

    pix_year = pixel_dt.strftime("%Y")
    pix_month = pixel_dt.strftime("%m")
    pix_day = pixel_dt.strftime("%d")
    pixel_date_date = datetime.fromisoformat(pixel_iso_date).date().isoformat()

#     output_file_name = f"{output_prefix}_{client}_{kpi}_{lookback_days}day_{pix_year}_{pix_month}_{pix_day}.csv"
#     output_file_path = (
#         f"{output_prefix}/{client}/{pix_year}/{pix_month}/{output_file_name}"
#     )

    upload_file_name_csv = f'{S3_FILE_PREFIX_CSV}_{media_type}_{clientkey}_{kpi}_{lookback_days}day_{pix_year}_{pix_month:>02}_{pix_day:>02}.csv'
    s3_destination_csv = f'{S3_DIRECTORY_PREFIX_CSV}_{media_type}/{clientkey}/{pix_year}/{pix_month:>02}/{upload_file_name_csv}'
    upload_file_name_parquet = f'{S3_FILE_PREFIX_PARQUET}_{media_type}_{clientkey}_{kpi}_{lookback_days}day_{pix_year}_{pix_month:>02}_{pix_day:>02}.parquet'
    s3_destination_parquet = f'{S3_DIRECTORY_PREFIX_PARQUET}_{media_type}_parquet/advertiser_name={clientkey}/event_date={pixel_date_date}/{upload_file_name_parquet}'

    return {
        "media_type": media_type,
        "client": clientkey,
        "kpi": kpi,
        "pixel_date": pixel_dt.strftime("%Y-%m-%d"),
        "lookback_date": impression_dt.strftime("%Y-%m-%d"),
        "lookback_days": lookback_days,
        "max_conversions": max_conversions,
        "pixel_table": client_settings["pixel_table"],
        "athena_pixel_table": ATHENA_PIXEL_TABLE,
        "adserver_table": client_settings["adserver_table"],
        "athena_adserver_table": ATHENA_OTT_ADSERVER,
        "dataset_start_date": client_settings["dataset_start_date"],
        "dataset_end_date": client_settings["dataset_end_date"],
        "source_pixel_s3ingest_exclusion": client_settings["source_pixel_s3ingest_exclusion"],
        "output_region": output_region,
        "output_bucket": output_bucket,
        "output_prefix": output_prefix,
        "output_file_name_csv": upload_file_name_csv,
        "output_file_path_csv": s3_destination_csv,
        "output_file_name_parquet": upload_file_name_parquet,
        "output_file_path_parquet" : s3_destination_parquet
    }


def pipeline_ott_audio_run(
        media_type: str,
        clientkey: str,
        kpi: str,
        pixel_iso_date: str,
        impression_iso_date: str,
        lookback_days: int,
        max_conversions: float,
        write_s3: bool = True,
        write_local: bool = False,
        branch: str = 'dev',
        source_pixel: str = ''
    ):
    try:
        total_start_time = time.time()

        global REGION_NAME, S3_TARGET_BUCKET, ATHENA_PIXEL_TABLE
        output_prefix = f'fact_mta_{media_type}'
        context = get_run_context(
            media_type=media_type,
            clientkey=clientkey,
            kpi=kpi,
            pixel_iso_date=pixel_iso_date,
            impression_iso_date=impression_iso_date,
            lookback_days=lookback_days,
            max_conversions=max_conversions,
            output_region=REGION_NAME,
            output_bucket=S3_TARGET_BUCKET,
            output_prefix=output_prefix,
            source_pixel=source_pixel
        )
        
        pixel_df = load_pixels_df(context)

        if pixel_df.empty:
            message = f'No {type} pixels loaded for "{clientkey}" and kpi "{kpi}" for "{pixel_iso_date}'
            logger.info(f'''MTA[{context['media_type']}]: {message}''')
            raise Exception(message)

        impression_df = load_impressions_df(context)
        logger.debug(f'pixel_df: {pixel_df.shape},   impression_df: {impression_df.shape}')
        kpi_responses_df = get_kpi_responses(
            context=context, impression_df=impression_df, pixel_df=pixel_df
        )

        if kpi_responses_df.empty:
            message = f'No KPI responses for "{clientkey}" and kpi "{kpi}" for "{pixel_iso_date}"'
            logger.info(f'''MTA[{context['media_type']}]: {message}''')
            raise Exception(message)
        kpi_responses_df['ingest_timestamp'] = kpi_responses_df['ingest_timestamp'].astype(str)

        if write_s3:
            upload_results_parquet(kpi_responses_df, branch, context['output_file_path_parquet'])
            upload_results(kpi_responses_df.to_dict(orient='records'), branch, context['output_file_path_csv'])

        total_execution_time = time.time() - total_start_time
        total_minutes = float(total_execution_time) / 60.0

        logger.info(
            f"""MTA[{context['media_type']}]: Process completed at {datetime.now()} in {round(float(total_minutes),2)} minutes"""
        )
        return kpi_responses_df
    except Exception as e:
        message = f'ott_audio run failed: {e}'
        logger.error(message)
        raise Exception(e)


def run_ott_audio(media_type, max_conversions_cap, clientkey, kpi, pixel_dts, impression_dts, lookback_days, source_pixel,
                 write_results=True):
    results = pd.DataFrame()
    start_time = perf_counter()
    run_type = 'OTT' if media_type == 'ott' else 'Audio'

    logger.info(f"--------------------------------")
    logger.info(f"STARTING {run_type} PROCESSING")
    try:
        results = pipeline_ott_audio_run(
            media_type=media_type,
            clientkey=clientkey,
            kpi=kpi,
            pixel_iso_date=pixel_dts,
            impression_iso_date=impression_dts,
            lookback_days=abs(lookback_days),
            max_conversions=max_conversions_cap,
#                         output_bucket="tn-datalake-attribution-mta" if branch == 'dev' else "tn-datalake-attribution",
#                         output_prefix="mta_ott_dev" if branch == 'dev' else "mta_ott",
            write_local=False,
            write_s3=write_results,
            source_pixel=source_pixel
        )
        log = {
            f'{media_type}_success' : 'success',
            f'{media_type}_count'   : results.shape[0]
        }
    except Exception as e:
        log = {
            f'{media_type}_success' : 'failed',
            f'{media_type}_error'   : str(e)
        }
    
    run_time_minutes = round((perf_counter() - start_time) / 60, 2)
    log.update({f'{media_type}_run_time_minutes' : run_time_minutes})
    logger.info(f'{run_type} run finished for client {clientkey}, KPI {kpi}. Shape of results: {results.shape}')
    return results, log


def run_channel(clientkey,
            impression_dts,
            pixel_dts,
            kpi,
            lookback_days,
            include_channel_ott,
            include_channel_audio,
            include_channel_digital,
            branch,
            write_results=True
    ):
    channel_results = pd.DataFrame()
    shapley_dict = defaultdict(int)

    channel_start_time = perf_counter()
    try:
        channel_results_tuple = pipeline_channel_run(
            clientkey,
            impression_dts,
            pixel_dts,
            kpi,
            abs(lookback_days),
            include_channel_ott,
            include_channel_audio,
            include_channel_digital,
            branch,
            write_results=write_results
        )
        channel_results, shapley_dict = channel_results_tuple
        channel_log = {
            'channel_success' : 'success',
            'channel_count'   : channel_results.shape[0]
        }
    except Exception as e:
        channel_log = {
            'channel_success' : 'failed',
            'channel_error'   : str(e)
        }
    
    channel_time_minutes = round((perf_counter() - channel_start_time) / 60, 2)
    channel_log.update({'channel_run_time_minutes' : channel_time_minutes})
    logger.info(f'Channel run finished for client {clientkey}, KPI {kpi}. Shape of results: {channel_results.shape}')
    return channel_results, shapley_dict, channel_log


def run(data, branch, write_results=True):
    global ATHENA_DATABASE
    start_time = perf_counter()
    data = data[0]
    clientkey                 = data['clientkey']
    impression_dts            = data['impression_dts']
    pixel_dts                 = data['pixel_dts']
    kpi                       = data['kpi']
    media_type                = data['media_type']
    lookback_days             = data['lookback_days']
    include_channel_ott       = data['include_channel_ott']
    include_channel_audio     = data['include_channel_audio']
    include_channel_digital   = data['include_channel_digital']

    logger.info(f'==============================================')
    logger.info(f'MTA run started for client: {clientkey}, KPI: {kpi}, pixel date: {pixel_dts}, lookback_days: {7}')

    # Run the MTA pipeline
    channel_results = pd.DataFrame()
    ott_results = pd.DataFrame()
    audio_results = pd.DataFrame()
    final_results = defaultdict(pd.DataFrame)
    final_log = {}
    if media_type == "audio" or media_type == "ott":
        source_pixel = f'{ATHENA_DATABASE}.pixels' if media_type == 'ott' else f'{ATHENA_DATABASE}.pixels'
        ott_audio_results, ott_audio_log = run_ott_audio(media_type, 0, clientkey, kpi, pixel_dts, impression_dts, lookback_days, source_pixel, write_results=write_results)

        final_results[f'{media_type}_results'] = ott_audio_results
        final_log.update(ott_audio_log)
        logger.info(f'Shape of {media_type} results: {ott_audio_results.shape}')

    elif media_type.lower() == "channel":
        channel_results, shapley_dict, channel_log = run_channel(
                                    clientkey,
                                    impression_dts,
                                    pixel_dts,
                                    kpi,
                                    lookback_days,
                                    include_channel_ott,
                                    include_channel_audio,
                                    include_channel_digital,
                                    branch,
                                    write_results=write_results
        )
        final_log.update(channel_log)

        if include_channel_ott is True:
            max_conversions_cap = float(shapley_dict['OTT'])
            if max_conversions_cap > 0:
                ott_results, ott_log = run_ott_audio('ott', max_conversions_cap, clientkey, kpi, pixel_dts, impression_dts, lookback_days, f'{ATHENA_DATABASE}.pixels', write_results=write_results)
            else:
                logger.info(f'OTT run could not be completed: {f"max_conversions_cap = {max_conversions_cap}"}')
                ott_log = {
                f'ott_success' : 'na',
                f'ott_info'    : f"max_conversions_cap = {max_conversions_cap}"
            }
            final_log.update(ott_log)            

        if include_channel_audio is True:
            max_conversions_cap = float(shapley_dict['Audio'])
            if max_conversions_cap > 0:
                audio_results, audio_log = run_ott_audio('audio', max_conversions_cap, clientkey, kpi, pixel_dts, impression_dts, lookback_days, f'{ATHENA_DATABASE}.pixels', write_results=write_results)
            else:
                logger.info(f'Audio run could not be completed: {f"max_conversions_cap = {max_conversions_cap}"}')
                audio_log = {
                    f'audio_success' : 'na',
                    f'audio_info'   : f"max_conversions_cap = {max_conversions_cap}"
                }
            final_log.update(audio_log)

        final_results['channel_results'] = channel_results
        final_results['ott_results'] = ott_results
        final_results['audio_results'] = audio_results
        logger.info(f'Shape of channel results: {channel_results.shape};   shape of ott results: {ott_results.shape};   shape of audio results: {audio_results.shape}')

    logger.info(f'MTA run finished for client:{clientkey}, KPI:{kpi} in {round(perf_counter() - start_time, 3)} seconds')
    return final_results, final_log


def init_globals(aws_access_key_id, aws_secret_access_key, region_name, redshift_mta_secret_name, 
        athena_database, athena_status_output_location, athena_pixel_table,
        athena_ott_adserver, athena_audio_adserver, athena_digital_adserver,
        branch, s3_target_bucket, s3_target_bucket_parquet,
        s3_directory_prefix_csv, s3_file_prefix_csv,
        s3_directory_prefix_parquet, s3_file_prefix_parquet,
        use_us_condition, list_clients_no_us_condition, clientkey
    ):

    logger.debug('Initing globals')
    global REGION_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, REDSHIFT_MTA_SECRET_NAME
    REGION_NAME = region_name
    AWS_ACCESS_KEY_ID = aws_access_key_id
    AWS_SECRET_ACCESS_KEY = aws_secret_access_key
    REDSHIFT_MTA_SECRET_NAME = redshift_mta_secret_name

    global redshift_config
    redshift_config = RedshiftConfig(REGION_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, REDSHIFT_MTA_SECRET_NAME)

    global S3
    S3 = boto3.client(
        "s3",
        region_name=REGION_NAME,
        aws_access_key_id=AWS_ACCESS_KEY_ID,
        aws_secret_access_key=AWS_SECRET_ACCESS_KEY
    )

    global ATHENA_CLIENT, ATHENA_DATABASE, ATHENA_STATUS_OUTPUT_LOCATION, ATHENA_PIXEL_TABLE, ATHENA_OTT_ADSERVER, ATHENA_AUDIO_ADSERVER, ATHENA_DIGITAL_ADSERVER
    ATHENA_CLIENT = get_athena_connection(REGION_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)
    ATHENA_DATABASE = athena_database
    ATHENA_STATUS_OUTPUT_LOCATION = athena_status_output_location
    ATHENA_PIXEL_TABLE = athena_pixel_table
    ATHENA_OTT_ADSERVER = athena_ott_adserver
    ATHENA_AUDIO_ADSERVER = athena_audio_adserver
    ATHENA_DIGITAL_ADSERVER = athena_digital_adserver

    global S3_TARGET_BUCKET, S3_TARGET_BUCKET_PARQUET
    S3_TARGET_BUCKET = s3_target_bucket
    S3_TARGET_BUCKET_PARQUET = s3_target_bucket_parquet

    global S3_DIRECTORY_PREFIX_CSV, S3_FILE_PREFIX_CSV
    if s3_directory_prefix_csv[-1] == '/':
        s3_directory_prefix_csv = s3_directory_prefix_csv[:-1]
    S3_DIRECTORY_PREFIX_CSV = s3_directory_prefix_csv
    S3_FILE_PREFIX_CSV = s3_file_prefix_csv
    
    global S3_DIRECTORY_PREFIX_PARQUET, S3_FILE_PREFIX_PARQUET
    if s3_directory_prefix_parquet[-1] == '/':
        s3_directory_prefix_parquet = s3_directory_prefix_parquet[:-1]
    S3_DIRECTORY_PREFIX_PARQUET = s3_directory_prefix_parquet
    S3_FILE_PREFIX_PARQUET = s3_file_prefix_parquet
    
    global US_WHERE_COND
    if use_us_condition.lower() == 'true':
        if clientkey in [i.strip() for i in list_clients_no_us_condition.split(',')]:
            US_WHERE_COND = ''
        else:
            US_WHERE_COND = "AND UPPER(geo_country_code) = 'US'"
    elif use_us_condition.lower() == 'false':
        US_WHERE_COND = ''


def close_connections():
    global redshift_config, S3, ATHENA_CLIENT
    redshift_config.cursor.close()
    redshift_config.connection.close()
#     S3.close()
#     ATHENA_CLIENT.close()
    logger.debug(f'Closed all connections')


def repair_mta_tables(clientkey, pixel_date):
    global ATHENA_DATABASE
    table_name_prefix = 'fact_mta'
    table_suffixes = ['channel', 'ott', 'audio']

    # Define the SQL query to run the MSCK repair operation
    for table in table_suffixes:
        table_name = f'{ATHENA_DATABASE}.{table_name_prefix}_{table}'
        alter_table_query = f"ALTER TABLE {table_name} ADD IF NOT EXISTS PARTITION (advertiser_name = '{clientkey}', event_date = '{pixel_date}');"
        run_athena_query(alter_table_query, only_metadata=True)


def start_func(data, 
               aws_access_key_id, aws_secret_access_key, region_name, secret_name, 
               athena_database, athena_status_output_location, 
               branch, s3_target_bucket, s3_target_bucket_parquet,
               s3_directory_prefix_csv, s3_file_prefix_csv,
               s3_directory_prefix_parquet, s3_file_prefix_parquet,
               use_us_condition, list_clients_no_us_condition,
               debug=False, write_results=True):

    clientkey     = data[0]['clientkey']
    pixel_date    = data[0]['pixel_dts'][:10]
    kpi           = data[0]['kpi']
    lookback_days = data[0]['lookback_days']
    
    athena_pixel_table = f'{athena_database}.pixels'
    athena_ott_adserver = f'{athena_database}.adserver'
    athena_audio_adserver = f'{athena_database}.adserver'
    athena_digital_adserver = f'{athena_database}.adserver'

    init_globals(aws_access_key_id, aws_secret_access_key, region_name, secret_name, 
                 athena_database, athena_status_output_location, athena_pixel_table,
                 athena_ott_adserver, athena_audio_adserver, athena_digital_adserver,
                 branch, s3_target_bucket, s3_target_bucket_parquet,
                 s3_directory_prefix_csv, s3_file_prefix_csv,
                 s3_directory_prefix_parquet, s3_file_prefix_parquet,
                 use_us_condition, list_clients_no_us_condition, clientkey)

    start_time = perf_counter()
    
    results = run(data, branch, write_results)
    repair_mta_tables(clientkey, pixel_date)
    close_connections()
    mta_results, results_log = results
    
    logger.info(f'MTA process finished in {round(perf_counter() - start_time, 3)} seconds')
    final_log = {
        'clientkey'       : clientkey,
        'kpi'             : kpi,
        'event_date'      : pixel_date,
        'lookback_days'   : lookback_days,
    }
    final_log.update(results_log)
    final_log = [final_log]

    if debug:
        return mta_results['channel_results'], mta_results['ott_results'], mta_results['audio_results'], final_log
    return final_log