In [0]:
# Databricks notebook source
import json
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, ArrayType, StringType, LongType, TimestampType, IntegerType, StructField
import datetime
import ast
from pyspark.sql.functions import udf, to_timestamp, expr, lit, col, date_format, lower, to_date, current_timestamp, md5, concat_ws
# param_json = dbutils.widgets.get("param_key")
# params = json.loads(param_json)

#TODO:
params = {"param": {"campaign_name": "TestConfig", "no_of_execution": "1"}, "execution_id": "8b146246-db6d-482d-9689-aaff100d54db"}

spark = SparkSession.builder \
        .appName("Advertisement Campaign") \
        .getOrCreate()

In [0]:
class AdvertisementCampaign:
  def __init__(self, params):
    '''
    Description : Initializes the class by fetching configuration settings, network events, and PRIZM segmentation data from the Unity Catalog, used by the instances of the class for further operations.
    Parameters : 
    - params(dictionary containing configuration values from Advertisement Campaign parent notebook).
    Return value : None
    '''
    parameters = params['param']
    self.camp_name = parameters['campaign_name']
    self.iteration = int(parameters['no_of_execution'])
    self.execution_ids = params['execution_id']
    self._catalog = "edl_dev"
    self._drvd_schema = "drvd__app_rsmgeo5g"
    self._rawstd_schema = "rawstd__rsmgeo5g"
    self._environics_schema = "rawstd__environics"

    self.ntb_start_time = datetime.datetime.now()

    self.config = self.get_data_from_db(self._catalog, self._rawstd_schema, 
                                       "ad_campaigns_config", 
                                       f"where campaign_name = '{self.camp_name}' AND status = 'N'")
    self.safegraph = self.get_data_from_db(self._catalog, self._drvd_schema, 
                                             "safegraph_poi_h3")
    self.advertiser = self.get_data_from_db(self._catalog, self._drvd_schema, 
                                            "advertiser_poi_h3") 
    self.custom_geofence = self.get_data_from_db(self._catalog, self._drvd_schema,
                                                 "custom_geofence_h3")
    self.network_df = self.get_data_from_db(self._catalog, self._drvd_schema,
                                            "events_agg_bkp1") 
    self.prizm_df = self.get_data_from_db(self._catalog, self._environics_schema,
                                          "prizm5_uniquelicenses_2023")
    self.ad_camp_id = self.get_data_from_db(self._catalog, self._rawstd_schema, 
                                            "ad_campaigns_config",
                                            f"WHERE campaign_name = '{self.camp_name}' AND status = 'N'",
                                            "DISTINCT config_id")
    self.distinct_config_ids = [row['config_id'] for row in self.ad_camp_id.collect()]
    self.campaign_id = self.distinct_config_ids[0]
    # ignore columns while generating checksum column
    self.ignore_column_list_md5=['_checksum', '_az_insert_ts', '_az_update_ts', "exec_run_id"]

  def get_data_from_db(self, catalog, schema, table, where_clause="", selection="*"):
    """
    Description: Retrieves data from a specified database table.
    Parameters:
    - catalog : The name of the catalog from which to retrieve the data.
    - schema : The schema within the catalog where the table resides.
    - table : The name of the table from which to retrieve data.
    - where_clause : A condition to filter the data. Defaults to an empty string, which means no filter is applied.
    - selection : The columns to be retrieved, either as a list of column names or 
      '*' for all columns. Defaults to '*'.
    Return value: DataFrame.
    """
    return spark.sql(f"SELECT {selection} FROM {catalog}.{schema}.{table} {where_clause}")
    
  def insert_campaign_status(self, camp_name, campaign_id, execution_ids, iteration, duration,ntb_start_time, ntb_end_time,  exec_status, func_name, message, extraction, _az_insert_ts, _az_update_ts, checksum_col, exec_run_id_col):
    """
    Description: Inserts campaign computation status into the `adv_campaign_computation_window_status` table.
    Parameters:
    - camp_name : The name of the campaign.
    - campaign_id : The unique ID of the campaign.
    - execution_ids : The execution identifier associated with the current execution.
    - iteration : The current iteration number for the campaign execution.
    - duration : The duration of the notebook execution.
    - ntb_start_time: The start time of the notebook execution.
    - ntb_end_time : The end time of the notebook execution.
    - exec_status : The status of the execution.
    - func_name : The name of the function that failed, if any.
    - message : An error message
    - extraction : Indicates whether data was extracted.
    - _az_insert_ts : Timestamp of when the record was inserted.
    - _az_update_ts : Timestamp of the last update to the record.
    - checksum_col : The checksum value for the campaign data.
    - exec_run_id_col : The execution run identifier.
    Return value:  None
    """
    spark.sql(f"""
        INSERT INTO {self._catalog}.{self._drvd_schema}.adv_campaign_computation_window_status
        (Campaign_Name, Campaign_ID, Execution_ID, Iteration, Duration, Start_Time, 
            End_Time, Execution_Status, Failed_Function, Error_Message, Extraction_To_FS,
            _az_insert_ts, _az_update_ts, _checksum, _exec_run_id)
        VALUES
        ('{self.camp_name}', '{self.campaign_id}', '{self.execution_ids}', '{self.iteration}', '{duration}',
         '{self.ntb_start_time}', '{ntb_end_time}', '{self.exec_status}', '{func_name}', 
         '{message}', '{extraction}', '{_az_insert_ts}', '{_az_update_ts}', '{checksum_col}',
         '{exec_run_id_col}')
    """)

  def update_campaign_status(self, campaign_id, duration, ntb_end_time, exec_status, func_name, message, extraction, execution_ids, _az_update_ts, checksum_col):
    '''
    Description: Updates an existing campaign computation status record in the `adv_campaign_computation_window_status` table.
    Parameters:
    - campaign_id : Unique identifier for the campaign.
    - duration : Updated duration of the notebook run.
    - ntb_end_time : Updated end time of the campaign.
    - exec_status : Updated execution status .
    - func_name : Name of the function that failed, if applicable.
    - message : Error message.
    - extraction : Updated status of extraction.
    - execution_ids : ID related to the execution of the campaign.
    - _az_update_ts : Timestamp for when the record is updated.
    - checksum_col : Updated MD5 checksum value.
    Return value: None
    '''
    spark.sql(f"""
        UPDATE {self._catalog}.{self._drvd_schema}.adv_campaign_computation_window_status
        SET Duration = '{duration}',
            End_Time = '{ntb_end_time}',
            Execution_Status = '{self.exec_status}',
            Failed_Function = '{func_name}',
            Error_Message = '{message}',
            Extraction_To_FS = '{extraction}',
            _az_update_ts = '{_az_update_ts}',
            _checksum = '{checksum_col}'
        WHERE Campaign_ID = '{self.campaign_id}' AND Execution_ID = '{self.execution_ids}' AND Execution_Status = 'Processing'
    """)    
      
  def adv_campaign_computation_window_status(self, campaign_id, camp_name, execution_ids, iteration, duration, exec_status, ntb_start_time, ntb_end_time, func_name, message, extraction, _az_insert_ts, _az_update_ts, checksum_col, exec_run_id_col):
    """
    Description: Updates or inserts the status of an advertising campaign based on its current execution status.
    - If the campaign's execution status is "Processing", the function calls `insert_campaign_status` otherwise `update_campaign_status`.
    Parameters:
    - campaign_id : Unique identifier for the campaign.
    - camp_name : Name of the campaign.
    - execution_ids : Unique identifiers for the execution.
    - iteration : The current iteration of the execution.
    - duration : The duration of the execution in seconds.
    - exec_status : The current execution status of the campaign.
    - ntb_start_time : Start time of the notebook.
    - ntb_end_time : End time of the notebook.
    - func_name : Name of the function that is failed if any.
    - message : error message.
    - extraction : Status about the data extraction process.
    - _az_insert_ts : Timestamp when the record is inserted.
    - _az_update_ts : Timestamp when the record is last updated.
    - checksum_col : Updated MD5 checksum value.
    - exec_run_id_col : Execution run ID column. 
    Return value: None   
    """
    if self.exec_status == "Processing":
        self.insert_campaign_status(camp_name, campaign_id, execution_ids, iteration, duration,ntb_start_time, ntb_end_time,  exec_status, func_name, message, extraction, _az_insert_ts, _az_update_ts, checksum_col, exec_run_id_col)
    else:
        self.update_campaign_status(campaign_id, duration, ntb_end_time, exec_status, func_name, message, extraction, execution_ids, _az_update_ts, checksum_col)

  def sent_responce_to_parent(self, func_name, message):
    '''
    Description : Records the status of a campaign computation by inserting or updating a record in a Spark SQL table, including campaign ID, name, iteration, duration, start and end time of notebook, execution status, function name, and error message.
    Parameters : 
    - func_name: Name of the function
    - message: Error occurred if function failed or status message
    Return value : None
    '''
    ntb_end_time = datetime.datetime.now()
    duration = ntb_end_time - self.ntb_start_time
    
    if message == "Processing":
        message = "" 
        self.exec_status = "Processing"
        extraction = "Processing is ongoing."
    elif message:
        self.exec_status = "Failed"
        extraction = "Can not be extracted due to failure."
    else:
        self.exec_status = "Success"
        extraction = "yet to start."

    _az_insert_ts = datetime.datetime.now()
    _az_update_ts = datetime.datetime.now()

    schema = StructType([
        StructField("campaign_id", StringType(), False),
        StructField("camp_name", StringType(), False),
        StructField("execution_ids", StringType(), False),
        StructField("iteration", IntegerType(), False),
        StructField("duration", StringType(), False),
        StructField("exec_status", StringType(), False),
        StructField("error_message", StringType(), True),
        StructField("extraction", StringType(), True),
        StructField("func_name", StringType(), False),
        StructField("_az_insert_ts", TimestampType(), False),
        StructField("_az_update_ts", TimestampType(), False),
        StructField("ntb_start_time", TimestampType(), False),
        StructField("ntb_end_time", TimestampType(), False)
    ])

    Adv_window_status_df = spark.createDataFrame([(self.campaign_id, self.camp_name, self.execution_ids, 
                                self.iteration, str(duration), self.exec_status, message, extraction,
                                func_name, _az_insert_ts, _az_update_ts, self.ntb_start_time,
                                ntb_end_time)], schema)


    Adv_window_checksum_df = Adv_window_status_df.withColumn("checksum", md5(concat_ws("||", 
            lit(self.campaign_id), lit(self.camp_name), lit(self.execution_ids), lit(self.iteration),
            lit(duration), lit(self.exec_status), lit(message), lit(extraction),
            lit(func_name), lit(_az_insert_ts), lit(_az_update_ts),
            lit(self.ntb_start_time), lit(ntb_end_time)
        )))
    # Extract the computed checksum value from the DataFrame
    checksum_col = Adv_window_checksum_df.select("checksum").collect()[0][0]
    exec_run_id_col = self.execution_ids

    self.adv_campaign_computation_window_status(self.campaign_id, self.camp_name, self.execution_ids, self.iteration, duration, self.exec_status, self.ntb_start_time, ntb_end_time, func_name, message, extraction, _az_insert_ts, _az_update_ts, checksum_col, exec_run_id_col)

  def concat_checksum_cols(self,df):
    '''
    Description : This filters columns in a DataFrame by excluding those listed in ignore_column_list_md5
    Parameters : dataframe
    Return value : columns list(that are not in the ignore_column_list_md5)
    '''
    bizColList= [col for col in df.columns if (col not in self.ignore_column_list_md5)]
    columnList = []

    for column in bizColList:
        if column is None:
            columnList.append(':')
        else:
            columnList.append(column)

    return columnList

  def process_safegraph(self, safegraph_df, config_df):
    '''
    Description : This function joins DataFrames (safegraph_df and config_df) on matching location and address fields, filters for non-null POI location names, and selects relevant columns.
    Parameters : 
    - safegraph_df: safegraph_poi_h3 dataframe
    - config_df: advertisement_campaign config dataframe
    Return value : dataframe
    
    '''
    filtered_safegraph = (
        safegraph_df
        .join(config_df, [
            safegraph_df.location_name == config_df.poi_loc_name,
            safegraph_df.street_address == config_df.street_addr,
            safegraph_df.city == config_df.city,
            safegraph_df.province == config_df.province,
            safegraph_df.postal_code == config_df.postal_code,
            safegraph_df.top_category == config_df.top_category,
            safegraph_df.sub_category == config_df.sub_category
        ], "left")
        .filter(config_df.poi_loc_name.isNotNull())
        .select(
            "location_name", "location_id", "brands", "latitude", "longitude", "location_perimeter",
            lit(None).alias("location_radius"), "street_address", safegraph_df["city"], safegraph_df["province"], safegraph_df["postal_code"],
            safegraph_df["top_category"], safegraph_df["sub_category"], safegraph_df["category_tags"], "opened_on", "closed_on",
            "iso_country_code", "naics_code", "census_code",
            "hexagon_id", "cellid", "site_name", "sitecode", "opened_no_later_than", "tracking_closed_since",
            "eff_to", "eff_from", "hexagon_wkt"
        )
    )
    return filtered_safegraph

  def process_custom_geofence(self, custom_geofence_df, config_df):
    '''
    Description : This function filters a DataFrame of custom geofence data by joining it with a configuration DataFrame.
    Parameters : 
    - custom_geofence_df: custom_geofence_poi_h3 dataframe
    - config_df: advertisement_campaign config dataframe
    Return value : dataframe
    '''
    filtered_custom_geofence = (
        custom_geofence_df
        .join(config_df, [
            custom_geofence_df.location_name == config_df.poi_loc_name,
            custom_geofence_df.street_address == config_df.street_addr,
            custom_geofence_df.city == config_df.city,
            custom_geofence_df.province == config_df.province,
            custom_geofence_df.postal_code == config_df.postal_code,
            custom_geofence_df.top_category == config_df.top_category,
            custom_geofence_df.sub_category == config_df.sub_category
        ] + [
            custom_geofence_df.category_tags == config_df.category_tags
        ], "left")
        .filter(config_df.poi_loc_name.isNotNull())
        .select(
            "location_name", "location_id", "brands", "latitude", "longitude", "location_perimeter",
            lit(None).alias("location_radius"), "street_address", custom_geofence_df["city"], custom_geofence_df["province"], custom_geofence_df["postal_code"],
            custom_geofence_df["top_category"], custom_geofence_df["sub_category"], custom_geofence_df["category_tags"], "opened_on", lit(None).alias("closed_on"), "iso_country_code",
            lit(None).alias("naics_code"), lit(None).alias("census_code"), "hexagon_id",
            "cellid", "site_name", "sitecode", lit(None).alias("opened_no_later_than"),
            lit(None).alias("tracking_closed_since"), lit(None).alias("eff_to"), lit(None).alias("eff_from"),
            "hexagon_wkt"
        )
    )
    return filtered_custom_geofence

  def process_advertiser(self, advertiser_df, config_df):
    '''
    Description : This function selects and filters relevant columns of advertiser data and config_df.
    Parameters : 
    - advertiser_df: advertiser_poi_h3 dataframe
    - config_df: advertisement_campaign config dataframe
    Return value : dataframe
    '''
    filtered_advertiser = (
        advertiser_df
        .join(config_df, [
            advertiser_df.location_name == config_df.poi_loc_name,
            advertiser_df.street_address == config_df.street_addr,
            advertiser_df.city == config_df.city,
            advertiser_df.province == config_df.province,
            advertiser_df.postal_code == config_df.postal_code,
            advertiser_df.top_category == config_df.top_category,
            advertiser_df.sub_category == config_df.sub_category
        ] + [
            advertiser_df.category_tags == config_df.category_tags
        ], "left")
        .filter(config_df.poi_loc_name.isNotNull())
        .select(
            "location_name", "location_id", "brands", "latitude", "longitude", "location_perimeter",
            "location_radius", "street_address", advertiser_df["city"], advertiser_df["province"], advertiser_df["postal_code"], advertiser_df["top_category"],
            advertiser_df["sub_category"], advertiser_df["category_tags"], "opened_on", lit(None).alias("closed_on"), "iso_country_code", lit(None).alias("naics_code"),
            lit(None).alias("census_code"), "hexagon_id", "cellid", "site_name",
            "sitecode", lit(None).alias("opened_no_later_than"), lit(None).alias("tracking_closed_since"),
            lit(None).alias("eff_to"), lit(None).alias("eff_from"), "hexagon_wkt"
        )
    )
    return filtered_advertiser
  
  def combine_results(self, filtered_safegraph, filtered_custom_geofence, filtered_advertiser):
    '''
    Description : This function merges three DataFrames by performing a union operation.
    Parameters : 
    - filtered_safegraph: safegraph_poi_h3 dataframe
    - filtered_advertiser: advertiser_poi_h3 dataframe
    - filtered_custom_geofence: custom_geofence_poi_h3 dataframe 
    Return value : dataframe 
    '''
    combined_df = (
        filtered_safegraph
        .union(filtered_custom_geofence)
        .union(filtered_advertiser)
    )
    return combined_df

  def adv_poi_geofence_union(self, result_df):
    """
    Description: Inserts data from the provided DataFrame into the target table `adv_poi_geofence_union`.
    Parameters:
    - result_df: DataFrame with data to be inserted. 
    Returns: None
    """
    try:
      result_df.createOrReplaceTempView("result_df")
      spark.sql(f"""
          INSERT INTO {self._catalog}.{self._drvd_schema}.adv_poi_geofence_union
                (location_name, location_id, brands, latitude, longitude, location_radius, location_perimeter, street_address, city, province, postal_code, top_category, sub_category, category_tags, opened_on, opened_no_later_than, tracking_closed_since, closed_on, iso_country_code, naics_code, census_code, hexagon_id, hexagon_wkt, cellid, site_name, sitecode, eff_from, eff_to, _az_insert_ts, _az_update_ts, _checksum, _exec_run_id)
          SELECT 
                location_name, location_id, brands, latitude, longitude, location_radius, location_perimeter, street_address, city, province, postal_code, top_category, sub_category, category_tags, opened_on, opened_no_later_than, tracking_closed_since, closed_on, iso_country_code, naics_code, census_code, hexagon_id, hexagon_wkt, cellid, site_name, sitecode, eff_from, eff_to, _az_insert_ts, _az_update_ts, _checksum, _exec_run_id 
          FROM result_df
        """)
    except Exception as e:
        raise e

  def safegraph_advertiser_custom_union(self):  
    '''
    Description : Union of all POI (safegraph, custom_geofence, and advertiser).
    Parameters : None
    Return value : DataFrame
    '''  
    try:
        # initially the exec_status = 'Processing' for monitoring in adv_campaign_window_status table 
        self.sent_responce_to_parent('', 'Processing')
        self.config.createOrReplaceTempView('config')
        self.safegraph.createOrReplaceTempView('safegraph')
        self.advertiser.createOrReplaceTempView("advertiser")
        self.custom_geofence.createOrReplaceTempView("custom_geofence") 

        filtered_safegraph = self.process_safegraph(self.safegraph, self.config)
        filtered_custom_geofence = self.process_custom_geofence(self.custom_geofence, self.config)
        filtered_advertiser = self.process_advertiser(self.advertiser, self.config)
        
        result_df = self.combine_results(filtered_safegraph, filtered_custom_geofence, filtered_advertiser)
        
        exec_run_id_col = self.execution_ids
        result_df = result_df.withColumn("_az_insert_ts", current_timestamp()) \
                                  .withColumn("_az_update_ts", current_timestamp()) \
                                  .withColumn("_checksum", md5(concat_ws("||", *self.concat_checksum_cols(result_df)))) \
                                  .withColumn("_exec_run_id", lit(exec_run_id_col))
        # calling adv_poi_geofence_union function to insert data into table
        self.adv_poi_geofence_union(result_df)
        return result_df
    except Exception as e:
        self.sent_responce_to_parent('safegraph_advertiser_custom_union', e.getErrorClass())

  def join_union_network_data(self, network_df, poi_union_df):
    '''
    Description :  Joins result of safegraph_advertiser_custom_union with network data.
    Parameters : 
    - network_df: network DataFrame
    - poi_union_df: safegraph_advertiser_custom_union DataFrame 
    Return value : DataFrame
    '''
    try:
        self.network_df.createOrReplaceTempView("network_data_df")
        poi_union_df.createOrReplaceTempView("poi_union_df")

        query = f"""
            SELECT
                network_data.*,
                p.location_name, p.location_id,
                p.brands, p.latitude, p.longitude,
                p.location_perimeter, p.location_radius, p.street_address, p.city,
                p.province, p.postal_code, p.top_category, p.sub_category, p.category_tags,
                p.opened_on, p.closed_on, p.iso_country_code, p.naics_code, p.census_code,
                p.hexagon_id, p.cellid,
                p.opened_no_later_than, p.tracking_closed_since, p.eff_to, p.eff_from, p.hexagon_wkt
            FROM network_data_df AS network_data
            INNER JOIN poi_union_df AS p
            ON ((network_data.eci = p.cellid OR network_data.nci = p.cellid))
        """
        
        joined_df = spark.sql(query).distinct()
        return joined_df
    
    except Exception as e:
        self.sent_responce_to_parent('join_union_network_data',e.getErrorClass())
  
  def filter_network_data_with_campaign_and_billing(self, network_data_df):
    '''
    Description : Filters network data based on both campaign and billing configurations. 
    Parameters : 
    - network_data_df: network DataFrame
    Return value : DataFrame
    '''

    final_df_list = []

    network_data_df = network_data_df.withColumn('event_date', to_date(col('event_timestamp'), 'yyyy-MM-dd').cast(StringType()))
    network_data_df = network_data_df.withColumn('event_day', lower(date_format(col('event_timestamp'), 'EEEE')).cast(StringType()))
    network_data_df = network_data_df.withColumn('event_time_24h', date_format(col('event_timestamp'), 'HH:mm:ss'))
    
    network_data_df.createOrReplaceTempView("network_df")
    self.config.createOrReplaceTempView("campaign_billing_config")

    try:
        for row in self.config.collect():
            # Convert 'start_date' and 'end_date' from string to date format
            start_date = datetime.datetime.strptime(row['start_date'], '%Y-%m-%d').strftime('%Y-%m-%d')
            end_date = datetime.datetime.strptime(row['end_date'], '%Y-%m-%d').strftime('%Y-%m-%d')

            # Parse 'days_list' from string to a list of lowercase day names
            days_tuple = ast.literal_eval(row['days_list'])
            days_list = [day.strip().lower() for day in days_tuple]
            days_list_str = ', '.join(f"'{day}'" for day in days_list)

            if row['province_ctn'] is None or row['province_ctn'] == "":
               # If 'province_ctn' is None or empty, use 'LIKE "%_%"' in the SQL query to match all province codes.
               province_ctn_clause = ' LIKE "%_%" '
            else:
               # If 'province_ctn' has specific values, use an SQL 'IN' clause to filter records by those values.
               province_ctn_clause = " IN " + row['province_ctn']

            if row['city_ctn'] is None or row['city_ctn'] == "":
               city_ctn_clause = ' LIKE "%_%" '
            else:
               city_ctn_clause = " IN " + row['city_ctn']
            
            if row['postal_code_ctn'] is None or row['postal_code_ctn'] == "":
               postal_ctn_clause = ' LIKE "%_%" '
            else:
               postal_ctn_clause = " IN " + row['postal_code_ctn']

            # Parse and format time window start and end times
            time_window_selection = row['time_window_selection']
            time_window_start_str, time_window_end_str = time_window_selection.split('-')
            time_window_start_dt = datetime.datetime.strptime(time_window_start_str.strip(), '%I:%M %p')
            time_window_end_dt = datetime.datetime.strptime(time_window_end_str.strip(), '%I:%M %p')
            time_window_start = time_window_start_dt.strftime('%I:%M:%S %p')
            time_window_end = time_window_end_dt.strftime('%I:%M:%S %p')
            time_window_start_24hr = datetime.datetime.strptime(time_window_start, '%I:%M:%S %p').strftime('%H:%M:%S')
            time_window_end_24hr = datetime.datetime.strptime(time_window_end, '%I:%M:%S %p').strftime('%H:%M:%S')

            query = f"""
                SELECT nw.*,
                    cc.campaign_name, cc.start_date, cc.end_date, cc.week, cc.month, cc.zone_name, cc.custom_geofence_name,
                    cc.poi_loc_name, cc.city AS c_city, cc.street_addr, cc.province AS c_province, cc.postal_code AS c_postal_code,
                    cc.top_category AS c_top_category, cc.sub_category AS c_sub_category, cc.category_tags AS c_category_tags,
                    nw.ci_province_code AS province_ctn, nw.ci_city_name AS city_ctn, nw.ci_postal_code AS postal_code_ctn,
                    cc.config_id as campaign_id, cc.time_window_selection, cc.days_list
                FROM network_df nw INNER JOIN campaign_billing_config cc
                ON
                (nw.location_name = cc.poi_loc_name OR nw.location_name = cc.custom_geofence_name)
                AND nw.street_address = cc.street_addr 
                AND nw.city = cc.city
                AND nw.province = cc.province
                AND nw.postal_code = cc.postal_code
                WHERE
                nw.event_date BETWEEN '{start_date}' AND '{end_date}' 
                AND nw.event_time_24h BETWEEN '{time_window_start_24hr}' AND '{time_window_end_24hr}'
                AND nw.event_day IN ({days_list_str})
                AND nw.ci_province_code{province_ctn_clause}
                AND nw.ci_city_name{city_ctn_clause}
                AND nw.ci_postal_code{postal_ctn_clause}       
            """

            final_df = spark.sql(query).distinct()
            final_df_list.append(final_df)
        
        if final_df_list:
            combined_df = final_df_list[0]
            for df in final_df_list[1:]:
                combined_df = combined_df.union(df)
            
            return combined_df.distinct()
        else:
           return []
    
    except Exception as e:
        self.sent_responce_to_parent('filter_network_data_with_campaign_and_billing', e.getErrorClass())

# TO DO (Prizm data is not available)
  def filter_prizm_data(self):
    '''
    Description :  Filters PRIZM data based on postal codes from the configuration.
    Parameters : None
    Return value : DataFrame
    '''     
    try: 
        if row['postal_code_ctn'] is None or row['postal_code_ctn'] == "":
               postal_code_ctn_clause = ' LIKE "%_%" '
        else:
               postal_code_ctn_clause = " IN " + row['postal_code_ctn']

        if row['prizm_segment'] is None or row['prizm_segment'] == "":
               prizm_segment_clause = ' LIKE "%_%" '
        else:
               prizm_segment_clause = " IN " + row['prizm_segment']

        self.prizm_df = self.prizm_df.withColumn("clean_FSALDU", replace(col("FSALDU"), " ", ""))
        self.config.createOrReplaceTempView("config")
        self.prizm_df.createOrReplaceTempView("prizm")
        query = """
        SELECT p.*
        FROM prizm p
        INNER JOIN (
            SELECT DISTINCT postal_code_ctn{postal_code_ctn_clause} AS postal_code 
            AND prizm_segment{prizm_segment_clause} as prizm_segment
            FROM config
        ) c
        ON p.clean_FSALDU = c.postal_code
        AND p.NAME = c.prizm_segment
        """
        filtered_prizm_df = spark.sql(query).distinct()
        return filtered_prizm_df
    except Exception as e:
        self.sent_responce_to_parent('filter_prizm_data', e.getErrorClass())

# TO DO (Prizm data is not available)
  def join_network_prizm_data(self, network_df, prizm_df):
    '''
    Description :  Joins network data with PRIZM data.
    Parameters : 
    - network_df: Network DataFrame
    - filter_prizm: Prizm DataFrame
    Return value : DataFrame
    '''
    try:
        network_df.createOrReplaceTempView("network_df")
        prizm_df.createOrReplaceTempView("prizm_df")

        query = """
            SELECT
                network_df_n.*,
                p.LSNAME,
                p.PRIZM,
                p.FSALDU,
                p.NAME as prizm_segment
            FROM network_df AS network_df_n
            INNER JOIN prizm_df AS p
            ON network_df_n.ci_postal_code = p.FSALDU
        """
        joined_df = spark.sql(query).distinct()
        return joined_df
    except Exception as e:
        self.sent_responce_to_parent('join_network_prizm_data', e.getErrorClass())

  def adv_campaign_computation_window(self, computation_granularity_data):
    """
    Description: Inserts computation data into the `adv_campaign_computation_window` table.
    Parameters:
    - computation_granularity_data: DataFrame.
    Return value: None
    """
    try:
        computation_granularity_data.createOrReplaceTempView("computation_granularity_data")
        spark.sql(f"""
            INSERT INTO {self._catalog}.{self._drvd_schema}.adv_campaign_computation_window
              (campaign_id, campaign_name, start_date, end_date, time_window_selection, days_list, 
              location_name, location_id, location_perimeter, longitude, latitude, street_address, city, postal_code, province, msisdn, ctn, ws_subscriber_no, event_date, event_time_24h, event_day, event_timestamp, province_ctn, city_ctn, postal_code_ctn, _az_insert_ts, _az_update_ts,  _checksum, _exec_run_id)
            SELECT 
                campaign_id, campaign_name, start_date, end_date, time_window_selection, days_list,location_name, location_id, location_perimeter, longitude, latitude,street_address, city, postal_code, province, msisdn, ctn, ws_subscriber_no, event_date, event_time_24h, event_day, event_timestamp, province_ctn, city_ctn, postal_code_ctn, _az_insert_ts, _az_update_ts,  _checksum, _exec_run_id
            FROM computation_granularity_data
        """)
    except Exception as e:
        raise e

  def computation_granularity(self, computation_granularity_data):
    '''
    Description : Transforms computation granularity data for analysis.
    Parameters : 
    - computation_granularity_data: DataFrame
    Return value : DataFrame
    '''
    try:
        columns = computation_granularity_data.columns
        unique_columns = []
        [unique_columns.append(item) for item in columns if item not in unique_columns]

        computation_granularity_data = computation_granularity_data.select(unique_columns)
        
        computation_granularity_data.createOrReplaceTempView("computation_granularity_data")

        query_ = """
            SELECT 
                campaign_id, campaign_name, start_date, end_date, time_window_selection, days_list, location_name,
                location_perimeter, street_address, city, postal_code, province, msisdn, ctn, ws_subscriber_no, event_date,
                event_time_24h, event_day, event_timestamp, NULL AS prizm_segment, province_ctn, city_ctn, postal_code_ctn, longitude, latitude, location_id
            from computation_granularity_data
            """

        computation_granularity_data = spark.sql(query_).distinct()

        #TODO: 
        # static_prizm_segment_value = 'DefaultSegmentValue'
        # computation_granularity_data = computation_granularity_data.withColumn(
        #                                     "prizm_segment",lit(static_prizm_segment_value))
        
        #TODO: 
        _exec_run_id = self.execution_ids
        computation_granularity_data = computation_granularity_data.withColumn("_az_insert_ts", current_timestamp()) \
                                  .withColumn("_az_update_ts", current_timestamp()) \
                                  .withColumn("_checksum", md5(concat_ws("||", *self.concat_checksum_cols(computation_granularity_data)))) \
                                  .withColumn("_exec_run_id", lit(_exec_run_id))
        # calling adv_campaign_computation_window function to insert data into table
        self.adv_campaign_computation_window(computation_granularity_data)
        return computation_granularity_data
    
    except Exception as e:
        self.sent_responce_to_parent('computation_granularity', e.getErrorClass())
  
  def get_ctn_prev_visit_time(self, computation_granularity):
    '''
    Description : Calculates the previous visit time for each visitor (ctn) separately within each day partitions by both location and date-related columns. 
    Parameters : 
    - computation_granularity_data: DataFrame
    Return value : DataFrame
    '''
    computation_granularity.createOrReplaceTempView("df_filtered")
    df = spark.sql("""
                    SELECT 
                        location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn, event_day, event_date, event_time_24h, event_timestamp,
                        LAG(event_timestamp) OVER (PARTITION BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn, event_date, event_day ORDER BY event_timestamp) AS prev_visit_time
                    FROM df_filtered
                    """)
    return df

  def get_ctn_total_prev_visit_time(self, computation_granularity):
    '''
    Description : Calculates the previous visit time across different days for each visitor partitions by only location and visitor.
    Parameters : 
    - computation_granularity_data: DataFrame
    Return value : DataFrame
    '''
    computation_granularity.createOrReplaceTempView("df_filtered")
    df = spark.sql("""
                    SELECT 
                        location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn, event_day, event_date, event_time_24h, event_timestamp,
                        LAG(event_timestamp) OVER (PARTITION BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn ORDER BY event_timestamp) AS prev_visit_time
                    FROM df_filtered
                    """)
    return df

  def adv_campaign_window_result(self, campaign_window_data):
    """
    Description: Inserts data from the provided DataFrame into the target table `adv_campaign_window_result`.
    Parameters:
    - campaign_window_data: DataFrame with data to be inserted. 
    Returns: None
    """
    try:
        campaign_window_data.createOrReplaceTempView("campaign_window_data")
        spark.sql(f"""
            INSERT INTO {self._catalog}.{self._drvd_schema}.adv_campaign_window_result
              (campaign_id, campaign_name, event_day, event_date, location_name, location_id, location_perimeter, street_address, longitude, latitude, city, province, postal_code, 
              no_of_visits, unique_visitors, time_spent, _az_insert_ts, _az_update_ts,  _checksum, _exec_run_id)
            SELECT 
                campaign_id, campaign_name, event_day, event_date, location_name, location_id, location_perimeter, street_address, longitude, latitude, city, province, postal_code,
                no_of_visits, unique_visitors, time_spent, _az_insert_ts, _az_update_ts,  _checksum, _exec_run_id
            FROM campaign_window_data
        """)
    except Exception as e:
        raise e

  def daily_computation_metrics(self, computation_granularity):
      '''
      Description : The function processes daily metrics by calculating previous visit times, aggregating time_spent, no_of_visits, and unique visitors.
      Parameters : 
      - computation_granularity: DataFrame
      Return value : DataFrame
      '''
      try:
          
          days_to_process = [row.event_day for row in computation_granularity.select("event_day").distinct().collect()]
          final_result = None
          computation_granularity.createOrReplaceTempView("df_filtered")
          self.network_df.createOrReplaceTempView("network_event")
          df = self.get_ctn_prev_visit_time(computation_granularity)
          df.createOrReplaceTempView("temp_table")

          # Select previous visit times and compute minutes difference, considering various conditions
          df = spark.sql("""
                          SELECT 
                              location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn, event_day, event_time_24h, event_date, event_timestamp, prev_visit_time,
                              CASE
                                  WHEN prev_visit_time IS NOT NULL AND (UNIX_TIMESTAMP(event_timestamp) - UNIX_TIMESTAMP(prev_visit_time))/60 < 15 THEN 
                                      GREATEST((UNIX_TIMESTAMP(event_timestamp) - UNIX_TIMESTAMP(prev_visit_time))/60, 0)
                                  ELSE
                                    CASE
                                        WHEN ((UNIX_TIMESTAMP(event_timestamp) - UNIX_TIMESTAMP(prev_visit_time)) / 60 > 15
                                            AND (SELECT COUNT(*) FROM network_event AS ne WHERE ne.event_timestamp > prev_visit_time AND ne.event_timestamp < event_timestamp AND ne.ctn = ctn) > 0)
                                        THEN 0.0
                                        ELSE GREATEST((UNIX_TIMESTAMP(event_timestamp) - UNIX_TIMESTAMP(prev_visit_time))/60, 0)
                                    END
                              END AS minutes_diff    
                          FROM temp_table
                         """)
          df.createOrReplaceTempView("temp_table1")

          # Aggregate total minutes spent with sum of minutes_diff for location
          df = spark.sql("""
                          SELECT 
                              location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn, event_day, event_date,
                              SUM(minutes_diff) AS total_minutes_spent
                          FROM temp_table1
                          GROUP BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn, event_day, event_date
                         """)
          df.createOrReplaceTempView("time1")

          # Aggregate total minutes spent with sum of total_minutes_spent for location
          df = spark.sql("""
                          SELECT 
                              location_name, location_id, street_address, latitude, longitude, city, province, postal_code, event_date, event_day,
                              SUM(total_minutes_spent) AS total_minutes_spent
                          FROM time1
                          GROUP BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, event_date, event_day
                         """)
          df.createOrReplaceTempView("total_time_spent")

          # Calculate visit counts with a flag for new visits
          df_count = spark.sql("""
                               WITH CTE AS (
                                        SELECT
                                            *,
                                            LAG(ctn) OVER (PARTITION BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn, event_date, event_day ORDER BY event_timestamp) AS prev_ctn,
                                            LAG(event_timestamp) OVER (PARTITION BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn, event_date, event_day ORDER BY event_timestamp) AS prev_visit_time
                                        FROM df_filtered
                                    )
                                    SELECT
                                        *,
                                        CASE
                                            WHEN ctn != prev_ctn
                                                OR prev_ctn IS NULL
                                            THEN 1
                                            ELSE 
                                                CASE
                                                    WHEN ((UNIX_TIMESTAMP(event_timestamp) - UNIX_TIMESTAMP(prev_visit_time)) / 60 > 5
                                                        AND (SELECT COUNT(*) FROM network_event AS ne 
                                                        WHERE ne.event_timestamp > prev_visit_time 
                                                        AND ne.event_timestamp < event_timestamp 
                                                        AND ne.ctn = ctn) > 0)
                                                    THEN 1
                                                    ELSE 0
                                                END
                                        END AS new_visit_flag
                                    FROM CTE;
                               """)
          df_count.createOrReplaceTempView("visit_count")

          # Compute the number of visits per location
          df_visit = spark.sql("""
                               SELECT
                                    *,
                                    SUM(new_visit_flag) OVER (PARTITION BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, campaign_id, campaign_name 
                                    ORDER BY event_time_24h ROWS UNBOUNDED PRECEDING) AS no_of_visits
                                FROM visit_count
                               """)
          df_visit.createOrReplaceTempView("visit_number")
          
          # Count unique visitors per location
          df_unique_visits = spark.sql("""
                                       SELECT
                                            location_name, location_id, street_address, latitude, 
                                            longitude, city, province, postal_code,
                                            campaign_id, campaign_name, event_date, event_day,
                                            COUNT(DISTINCT ctn) AS unique_visitor
                                        FROM df_filtered
                                        GROUP BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, campaign_id, campaign_name, event_date, event_day
                                        
                                       """)
          df_unique_visits.createOrReplaceTempView("unique_visits")

          # Combine metrics for final output which includes no_of_visits, unique_visitors, time_spent
          final_record = spark.sql("""
                                   SELECT
                                        vn.location_name, vn.location_id, vn.street_address, vn.latitude,
                                        vn.longitude, vn.city, vn.province, vn.postal_code,
                                        vn.location_perimeter, vn.event_timestamp, vn.campaign_id, 
                                        vn.campaign_name,vn.event_day, vn.event_date, vn.no_of_visits, 
                                        uv.unique_visitor,
                                        tms.total_minutes_spent AS time_spent
                                    FROM visit_number vn
                                    JOIN unique_visits uv ON vn.location_name = uv.location_name 
                                    AND vn.location_id = uv.location_id
                                    AND vn.street_address = uv.street_address
                                    AND vn.latitude = uv.latitude
                                    AND vn.longitude = uv.longitude
                                    AND vn.city = uv.city AND vn.province=uv.province
                                    AND vn.postal_code=uv.postal_code AND vn.campaign_id = uv.campaign_id
                                    AND vn.campaign_name = uv.campaign_name
                                    JOIN total_time_spent tms ON vn.location_name = tms.location_name
                                    AND vn.street_address = tms.street_address AND vn.city = tms.city
                                    AND vn.province=tms.province AND vn.postal_code=tms.postal_code
                                    ORDER BY vn.location_name, vn.location_id, vn.street_address, vn.latitude,
                                    vn.longitude, vn.city,
                                    vn.province, vn.postal_code, vn.campaign_id, vn.campaign_name, vn.no_of_visits, vn.ctn
                                   """)
          final_record.createOrReplaceTempView("daily_computation")

          # Aggregate final metrics to remove duplicates and summarize data
          final_record = spark.sql("""
                                   SELECT
                                        campaign_id, campaign_name, event_day, event_date, 
                                        location_name, location_id, first(location_perimeter) AS location_perimeter, 
                                        street_address, latitude, longitude, city, province, postal_code,
                                        MAX(no_of_visits) AS no_of_visits, 
                                        MAX(unique_visitor) AS unique_visitors,
                                        MAX(time_spent) AS time_spent
                                    FROM daily_computation
                                    GROUP BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, 
                                    campaign_id, campaign_name, event_day, event_date
                                   """)
          final_record.createOrReplaceTempView("final_record")

          # Get the earliest event timestamp for each group
          first_record = spark.sql("""
                                          SELECT
                                                campaign_id, campaign_name, event_day, event_date,
                                                location_name, location_id, street_address, latitude, longitude, city, province, postal_code,
                                                MIN(event_timestamp) AS event_timestamp
                                            FROM daily_computation
                                            GROUP BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code,
                                            campaign_id, campaign_name, event_day, event_date
                                          """)
          first_record.createOrReplaceTempView("first_record")
          
          _exec_run_id = self.execution_ids

          # Join the final metrics with the earliest event timestamp
          final_result = spark.sql("""
                                   SELECT
                                        fr.*,
                                        fir.event_timestamp
                                    FROM final_record AS fr
                                    JOIN first_record AS fir ON fr.campaign_id = fir.campaign_id
                                    AND fr.campaign_name = fir.campaign_name
                                    AND fr.event_day = fir.event_day
                                    AND fr.event_date = fir.event_date
                                    AND fr.location_name = fir.location_name
                                    AND fr.location_id = fir.location_id
                                    AND fr.street_address = fir.street_address
                                    AND fr.latitude = fir.latitude
                                    AND fr.longitude = fir.longitude
                                    AND fr.city = fir.city
                                    AND fr.province = fir.province
                                    AND fr.postal_code = fir.postal_code
                                """)
          
          final_result = final_result.withColumn("_az_insert_ts", current_timestamp()) \
                                  .withColumn("_az_update_ts", current_timestamp()) \
                                  .withColumn("_checksum", md5(concat_ws("||", *self.concat_checksum_cols(final_result)))) \
                                  .withColumn("_exec_run_id", lit(_exec_run_id))
          # calling adv_campaign_window_result function to insert data into table
          self.adv_campaign_window_result(final_result)
          return final_result
      
      except Exception as e:
          self.sent_responce_to_parent('daily_computation_metrics', e.getErrorClass())
  
  def adv_campaign_total_result(self, campaign_total_data):
    """
    Description: Inserts data from the provided DataFrame into the target table `adv_campaign_total_result`.
    Parameters:
    - campaign_total_data: DataFrame containing the data to be inserted. Must conform to the schema of the target table.
    Returns: None
    """
    try:
        campaign_total_data.createOrReplaceTempView("campaign_total_data")
        spark.sql(f"""
            INSERT INTO {self._catalog}.{self._drvd_schema}.adv_campaign_total_result
                (campaign_id, campaign_name ,location_name , location_id , location_perimeter, street_address, latitude, longitude, city , province, postal_code, no_of_visits,  unique_visitors, time_spent, avg_daily_visits, avg_daily_unique_visitors, avg_daily_time_spent, _az_insert_ts,  _az_update_ts, _checksum, _exec_run_id)
            SELECT 
                campaign_id, campaign_name ,location_name , location_id , location_perimeter, street_address, latitude, longitude, city , province, postal_code, no_of_visits,  unique_visitors, time_spent, avg_daily_visits, avg_daily_unique_visitors, avg_daily_time_spent, _az_insert_ts,  _az_update_ts, _checksum, _exec_run_id
            FROM campaign_total_data
        """)
    except Exception as e:
        raise e

  def total_computation_metrics(self, computation_granularity, daily_df):
      '''
      Description : Calculates total_no_of_visits, total_time_spent, total_unique_visitors, avg_daily_visits, avg_daily_time_spent, avg_daily_unique_visitors.
      Parameters : 
      - computation_granularity: DataFrame
      - daily_df: DataFrame
      Return value : DataFrame
      '''
      try:
          # Collect distinct event days from computation_granularity
          days_to_process = [row.event_day for row in computation_granularity.select("event_day").distinct().collect()]
          final_result = None
          computation_granularity.createOrReplaceTempView("df_filtered")
          df = self.get_ctn_total_prev_visit_time(computation_granularity)
          df.createOrReplaceTempView("temp_table")

          # Compute time differences between visits
          df1 = spark.sql("""
                          SELECT 
                              location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn, event_day, event_time_24h, event_date, event_timestamp, prev_visit_time,
                              CASE
                                  WHEN prev_visit_time IS NOT NULL AND (UNIX_TIMESTAMP(event_timestamp) - UNIX_TIMESTAMP(prev_visit_time))/60 < 15 THEN 
                                      GREATEST((UNIX_TIMESTAMP(event_timestamp) - UNIX_TIMESTAMP(prev_visit_time))/60, 0)
                                  ELSE
                                    CASE
                                        WHEN ((UNIX_TIMESTAMP(event_timestamp) - UNIX_TIMESTAMP(prev_visit_time)) / 60 > 15
                                            AND (SELECT COUNT(*) FROM network_event AS ne WHERE ne.event_timestamp > prev_visit_time AND ne.event_timestamp < event_timestamp AND ne.ctn = ctn) > 0)
                                        THEN 0.0
                                        ELSE GREATEST((UNIX_TIMESTAMP(event_timestamp) - UNIX_TIMESTAMP(prev_visit_time))/60, 0)
                                    END
                              END AS minutes_diff    
                          FROM temp_table
                         """)
          df1.createOrReplaceTempView("temp_table1")

          # Aggregate total minutes spent with sum of minutes_diff
          df2 = spark.sql("""
                          SELECT 
                              location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn,
                              SUM(minutes_diff) AS total_minutes_spent
                          FROM temp_table1
                          GROUP BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn
                         """)
          df2.createOrReplaceTempView("time1")

          # Aggregate total minutes spent with sum of total_minutes_spent for location
          df3 = spark.sql("""
                          SELECT 
                              location_name, location_id, street_address, latitude, longitude, city, province, postal_code,
                              SUM(total_minutes_spent) AS total_minutes_spent
                          FROM time1
                          GROUP BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code
                         """)
          df3.createOrReplaceTempView("total_time_spent")
          self.network_df.createOrReplaceTempView("network_event")
       
          # Calculate visit counts with a flag for new visits
          df_count = spark.sql("""
                               WITH CTE AS (
                                        SELECT
                                            *,
                                            LAG(ctn) OVER (PARTITION BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn ORDER BY event_timestamp) AS prev_ctn,
                                            LAG(event_timestamp) OVER (PARTITION BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, ctn ORDER BY event_timestamp) AS prev_visit_time
                                        FROM df_filtered
                                    )
                                    SELECT
                                        *,
                                        CASE
                                            WHEN ctn != prev_ctn
                                                OR prev_ctn IS NULL
                                            THEN 1
                                            ELSE 
                                                CASE
                                                    WHEN ((UNIX_TIMESTAMP(event_timestamp) - UNIX_TIMESTAMP(prev_visit_time)) / 60 > 5
                                                        AND (SELECT COUNT(*) FROM network_event AS ne 
                                                        WHERE ne.event_timestamp > prev_visit_time 
                                                        AND ne.event_timestamp < event_timestamp 
                                                        AND ne.ctn = ctn) > 0)
                                                    THEN 1
                                                    ELSE 0
                                                END
                                        END AS new_visit_flag
                                    FROM CTE;
                               """)
          df_count.createOrReplaceTempView("visit_count")

          # Compute the number of visits per location
          df_visit = spark.sql("""
                               SELECT
                                    *,
                                    SUM(new_visit_flag) OVER (PARTITION BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, campaign_id, campaign_name 
                                    ORDER BY event_time_24h ROWS UNBOUNDED PRECEDING) AS no_of_visits
                                FROM visit_count
                               """)
          df_visit.createOrReplaceTempView("visit_number")
          
          # Count unique visitors per location
          df_unique_visits = spark.sql("""
                                       SELECT
                                            location_name, location_id, street_address, latitude, longitude, 
                                            city, province, postal_code,
                                            campaign_id, campaign_name,
                                            COUNT(DISTINCT ctn) AS unique_visitor
                                        FROM df_filtered
                                        GROUP BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code, campaign_id, campaign_name
                                       """)
          df_unique_visits.createOrReplaceTempView("unique_visits")
          # Combine metrics from no_of_visits, unique_visitors, and time_spent
          final_record = spark.sql("""
                                   SELECT
                                        vn.location_name, vn.location_id, vn.street_address, vn.latitude, vn.longitude, vn.city, vn.province, vn.postal_code,
                                        vn.location_perimeter, vn.campaign_id, vn.campaign_name,
                                        vn.no_of_visits, uv.unique_visitor,
                                        tms.total_minutes_spent AS time_spent
                                    FROM visit_number vn 
                                    JOIN unique_visits uv ON vn.location_name = uv.location_name 
                                    AND vn.location_id = uv.location_id
                                    AND vn.street_address = uv.street_address
                                    AND vn.latitude = uv.latitude
                                    AND vn.longitude = uv.longitude
                                    AND vn.city = uv.city 
                                    AND vn.province=uv.province
                                    AND vn.postal_code=uv.postal_code 
                                    AND vn.campaign_id = uv.campaign_id
                                    AND vn.campaign_name = uv.campaign_name
                                    JOIN total_time_spent tms ON vn.location_name = tms.location_name
                                    AND vn.location_id = tms.location_id
                                    AND vn.street_address = tms.street_address
                                    AND vn.latitude = tms.latitude
                                    AND vn.longitude = tms.longitude
                                    AND vn.street_address = tms.street_address 
                                    AND vn.city = tms.city
                                    AND vn.province=tms.province 
                                    AND vn.postal_code=tms.postal_code
                                    ORDER BY vn.location_name, vn.location_id, vn.street_address, vn.latitude, vn.longitude, vn.street_address, vn.city, vn.province, vn.postal_code, vn.campaign_id, vn.campaign_name, vn.no_of_visits
                                   """)
          final_record.createOrReplaceTempView("daily_computation")          
          daily_df.createOrReplaceTempView("daily_df")

           # Calculate aggregated total metrics and averages, joining with daily_df
          final_result = spark.sql("""
                                   SELECT
                                        dc.campaign_id, dc.campaign_name,  
                                        dc.location_name, dc.location_id, 
                                        first(dc.location_perimeter) AS location_perimeter,
                                        dc.street_address, dc.latitude, dc.longitude, dc.city, dc.province, dc.postal_code,
                                        MAX(dc.no_of_visits) AS no_of_visits, 
                                        MAX(dc.unique_visitor) AS unique_visitors,
                                        MAX(dc.time_spent) AS time_spent,
                                        AVG(dd.no_of_visits) as avg_daily_visits,
                                        AVG(dd.time_spent) as avg_daily_time_spent,
                                        AVG(dd.unique_visitors) as avg_daily_unique_visitors
                                    FROM daily_computation AS dc
                                    JOIN daily_df AS dd 
                                    ON dc.location_name = dd.location_name 
                                    AND dc.location_id = dd.location_id
                                    AND dc.street_address = dd.street_address
                                    AND dc.latitude = dd.latitude
                                    AND dc.longitude = dd.longitude
                                    AND dc.city = dd.city
                                    AND dc.province = dd.province
                                    AND dc.postal_code = dd.postal_code
                                    AND dc.campaign_id = dd.campaign_id
                                    AND dc.campaign_name = dd.campaign_name
                                    GROUP BY dc.location_name, dc.location_id, dc.street_address, dc.latitude, dc.longitude, dc.city, dc.province, dc.postal_code, dc.campaign_id, dc.campaign_name
                                   """)
      
          # Select records for final processing
          first_record = spark.sql("""
                                          SELECT
                                                campaign_id, campaign_name, 
                                                location_name, location_id, street_address, latitude, longitude, city, province, postal_code
                                            FROM daily_computation
                                            GROUP BY location_name, location_id, street_address, latitude, longitude, city, province, postal_code,
                                            campaign_id, campaign_name
                                          """)
          _exec_run_id = self.execution_ids

          # Join final results with distinct records to include metadata 
          final_result_join = final_result.join(first_record, on=['campaign_id', 'campaign_name', 'location_name', 'location_id', 'street_address', 'latitude', 'longitude', 'city', 'province', 'postal_code'], how='inner')
          final_result_join = final_result_join.withColumn("_az_insert_ts", current_timestamp()) \
                                  .withColumn("_az_update_ts", current_timestamp()) \
                                  .withColumn("_checksum", md5(concat_ws("||", *self.concat_checksum_cols(final_result_join)))) \
                                  .withColumn("_exec_run_id", lit(_exec_run_id))

          # calling adv_campaign_total_result function to insert data into table
          self.adv_campaign_total_result(final_result_join)

          # calling sent_responce_to_parent function to update adv_campaign_window_status table for exec_status = "Success"
          self.sent_responce_to_parent('', '')
          return final_result_join
      except Exception as e:
          self.sent_responce_to_parent('total_computation_metrics', e.getErrorClass())

  def adv_campaign_pbi_poi(self, pbi_poi):
    """
    Description: Inserts data from the provided DataFrame into the target table `pbi_locations_visitors_details`.
    Parameters:
    - pbi_poi: DataFrame with data to be inserted. 
    Returns: None
    """
    try:
        pbi_poi.createOrReplaceTempView("result")
        spark.sql(f"""
        INSERT INTO {self._catalog}.{self._drvd_schema}.pbi_locations_visitors_details
        (campaign_name, location_name, tot_visits, _az_insert_ts, _az_update_ts, _checksum, _exec_run_id)
        SELECT
        campaign_name, location_name, tot_visits, _az_insert_ts, _az_update_ts, _checksum, _exec_run_id
        FROM result
        """)
    except Exception as e:
        raise e

  def pbi_poi(self, poi):
    '''
    Description : Computes the total number of visits for each campaign and location, grouped and ordered accordingly.
    Parameters : 
    - computation_granularity: DataFrame
    Return value : DataFrame
    '''
    try:
        poi.createOrReplaceTempView("computation_granularity")

        query = """
         SELECT
            campaign_name,
            location_name,
            COUNT(*) AS tot_visits
            FROM
                computation_granularity
            GROUP BY
                campaign_name,
                location_name
            ORDER BY
                campaign_name,
                location_name
        """
        result = spark.sql(query).distinct()
        _exec_run_id = self.execution_ids
        result = result.withColumn("_az_insert_ts", current_timestamp()) \
                                  .withColumn("_az_update_ts", current_timestamp()) \
                                  .withColumn("_checksum", md5(concat_ws("||", *self.concat_checksum_cols(result)))) \
                                  .withColumn("_exec_run_id", lit(_exec_run_id))
        # calling adv_campaign_pbi_poi function to insert data into table
        self.adv_campaign_pbi_poi(result)
        return result
    except Exception as e:
        self.sent_responce_to_parent('pbi_poi', e.getErrorClass())

  def adv_campaign_pbi_topbillingaddress(self, pbi_topbillingaddress):
    """
    Description: Inserts data from the provided DataFrame into the target table `pbi_billing_address_visitors_details`.
    Parameters:
    - pbi_topbillingaddress: DataFrame containing the data to be inserted.
    Returns: None
    """
    try:
        pbi_topbillingaddress.createOrReplaceTempView("result")
        spark.sql(f"""
        INSERT INTO {self._catalog}.{self._drvd_schema}.pbi_billing_address_visitors_details
        (campaign_name, billing_postalcode, tot_visits, _az_insert_ts, _az_update_ts, _checksum, _exec_run_id)
        SELECT
        campaign_name, billing_postalcode, tot_visits, _az_insert_ts, _az_update_ts, _checksum, _exec_run_id
        FROM result
        """)
    except Exception as e:
        raise e

  def pbi_topbillingaddressdata(self, topbillingaddressdata):
    '''
    Description :  Computes the total number of visits grouped by campaign name and postal code, and returns the result sorted by these fields.
    Parameters : 
    - computation_granularity: DataFrame
    Return value : DataFrame
    '''
    try:
        topbillingaddressdata.createOrReplaceTempView("computation_granularity")

        query = """
            SELECT
            campaign_name,
            postal_code_ctn AS billing_postalcode,
            COUNT(*) AS tot_visits
        FROM
            computation_granularity
        GROUP BY
            campaign_name,
            postal_code_ctn
        ORDER BY
            campaign_name,
            postal_code_ctn
        """
        result = spark.sql(query).distinct()
        _exec_run_id = self.execution_ids
        result = result.withColumn("_az_insert_ts", current_timestamp()) \
                                  .withColumn("_az_update_ts", current_timestamp()) \
                                  .withColumn("_checksum", md5(concat_ws("||", *self.concat_checksum_cols(result)))) \
                                  .withColumn("_exec_run_id", lit(_exec_run_id))
        # calling adv_campaign_pbi_topbillingaddress function to insert data into table
        self.adv_campaign_pbi_topbillingaddress(result)
        return result

    except Exception as e:
        self.sent_responce_to_parent('pbi_topbillingaddressdata', e.getErrorClass())
  
  def adv_campaign_pbi_campaigndetails(self, pbi_campaigndetails):
    """
    Description: Inserts data from the provided DataFrame into the target table `pbi_campaign_details`.
    Parameters:
    - pbi_campaigndetails: DataFrame with data to be inserted. 
    Returns: None
    """
    try:
        pbi_campaigndetails.createOrReplaceTempView("final_result")
        spark.sql(f"""
          INSERT INTO {self._catalog}.{self._drvd_schema}.pbi_campaign_details
              (campaign_name, date, prizm_segments, time_window_num, tot_visits, datekey, _az_insert_ts, _az_update_ts, _checksum, _exec_run_id)
              SELECT
              campaign_name, date, prizm_segments, time_window_num, tot_visits, datekey, _az_insert_ts, _az_update_ts, _checksum, _exec_run_id
              FROM final_result
          """)
    except Exception as e:
        raise e
   
  def pbi_campaigndetails(self, campaigndetails):
      '''
      Description :  This function processes campaign details data by converting event times into time windows and adding associated numerical values. 
      Parameters : 
      - computation_granularity: DataFrame
      Return value : DataFrame
      '''
      try:
          campaigndetails = campaigndetails.withColumn("event_time_24h", col("event_time_24h").cast(StringType()))

          def map_time_window(start_time):
              start_time = str(start_time) 
              if "00:00" <= start_time < "06:00":
                  return "EarlyMorning", 1
              elif "06:00" <= start_time < "09:00":
                  return "MorningCommute", 2
              elif "09:00" <= start_time < "12:00":
                  return "LateMorning", 3
              elif "12:00" <= start_time < "15:00":
                  return "Midday", 4
              elif "15:00" <= start_time < "18:00":
                  return "EveningCommute", 5
              elif "18:00" <= start_time < "21:00":
                  return "Evening", 6
              elif "21:00" <= start_time < "24:00":
                  return "LateEvening", 7
              else:
                  return "Unknown", 0
              
          time_window_udf = udf(map_time_window, StructType([
              StructField("time_window", StringType(), True),
              StructField("time_window_num", IntegerType(), True)
          ]))

          campaigndetails = campaigndetails.withColumn("time_window_struct", time_window_udf("event_time_24h"))
          campaigndetails = campaigndetails.withColumn("time_window", col("time_window_struct.time_window"))
          campaigndetails = campaigndetails.withColumn("time_window_num", col("time_window_struct.time_window_num"))
          campaigndetails = campaigndetails.drop("time_window_struct")
          campaigndetails.createOrReplaceTempView("computation_granularity1")
          query = """
              SELECT
                  campaign_name,
                  date,
                  prizm_segment AS prizm_segments,
                  DATE_FORMAT(date, 'yyyyMMdd') AS datekey,
                  first(time_window_num) AS time_window_num,
                  COUNT(*) AS tot_visits
              FROM
                  computation_granularity1
              GROUP BY
                  campaign_name,
                  date,
                  prizm_segment,
                  time_window_num
              ORDER BY
                  campaign_name,
                  date,
                  prizm_segment,
                  time_window_num
          """

          result = spark.sql(query).distinct()
          _exec_run_id = self.execution_ids
          final_result = result.withColumn("_az_insert_ts", current_timestamp()) \
                                    .withColumn("_az_update_ts", current_timestamp()) \
                                    .withColumn("_checksum", md5(concat_ws("||", *self.concat_checksum_cols(result)))) \
                                    .withColumn("_exec_run_id", lit(_exec_run_id))
          # calling adv_campaign_pbi_campaigndetails function to insert data into table
          self.adv_campaign_pbi_campaigndetails(final_result)
          return final_result
      except Exception as e:
        print(str(e))
        self.sent_responce_to_parent('pbi_campaigndetails', e.getErrorClass())

  def adv_campaign_pbi_campaignList(self, pbi_campaignList):
    """
    Description: Inserts data from the provided DataFrame into the target table `pbi_campaign_list`.
    Parameters:
    - pbi_campaignList : DataFrame containing the data to be inserted. 
    Returns: None
    """
    try:
        pbi_campaignList.createOrReplaceTempView("filtered_result")
        spark.sql(f"""
            INSERT INTO edl_dev.drvd__app_rsmgeo5g.pbi_campaign_list
                (campaign_name, start_date, end_date, unique_visitors, tot_visits, processing_status, _checksum, _az_insert_ts, _az_update_ts, _exec_run_id)
                SELECT
                campaign_name, start_date, end_date, unique_visitors, tot_visits, processing_status, _checksum, _az_insert_ts, _az_update_ts, _exec_run_id
                FROM filtered_result
            """)
    except Exception as e:
        raise e

  def pbi_campaignList(self, campaignList):
    '''
    Description : Computes total and unique visits per campaign, including date and time window breakdown, and orders the results accordingly.
    Parameters : computation_granularity(DataFrame)
    Return value : DataFrame
    '''
    try:
        campaignList.createOrReplaceTempView("computation_granularity")
        query = """
            SELECT
            campaign_name,
            start_date,
            end_date,
            COUNT(*) AS tot_visits,
            COUNT(DISTINCT CONCAT(campaign_name, start_date, end_date)) AS unique_visitors
        FROM
            computation_granularity
        GROUP BY
            campaign_name,
            start_date,
            end_date
        ORDER BY
            campaign_name,
            start_date,
            end_date
        """
        result = spark.sql(query).distinct()
        exec_status_value = self.exec_status
        result = result.withColumn("processing_status", lit(exec_status_value))
        _exec_run_id = self.execution_ids
        result = result.withColumn("processing_status", lit(exec_status_value)) \
                    .withColumn("_az_insert_ts", current_timestamp()) \
                    .withColumn("_az_update_ts", current_timestamp()) \
                    .withColumn("_checksum", md5(concat_ws("||", *self.concat_checksum_cols(result)))) \
                    .withColumn("_exec_run_id", lit(_exec_run_id))
        result.createOrReplaceTempView("result")
        filtered_result = result.filter(result.processing_status == "Success")
        if filtered_result.count() > 0:
            # calling adv_campaign_pbi_campaignList function to insert data into table
            self.adv_campaign_pbi_campaignList(filtered_result)
        return filtered_result
    except Exception as e:
        self.sent_responce_to_parent('pbi_campaignList', e.getErrorClass())

In [0]:
#done 1
import unittest
from unittest.mock import patch

def get_data_from_db(self, catalog, schema, table, where_clause="", selection="*"):
    """
    Description: Retrieves data from a specified database table.
    Parameters:
    - catalog : The name of the catalog from which to retrieve the data.
    - schema : The schema within the catalog where the table resides.
    - table : The name of the table from which to retrieve data.
    - where_clause : A condition to filter the data. Defaults to an empty string, which means no filter is applied.
    - selection : The columns to be retrieved, either as a list of column names or 
        '*' for all columns. Defaults to '*'.
    Return value: DataFrame.
    """
    return spark.sql(f"SELECT {selection} FROM {catalog}.{schema}.{table} {where_clause}")

class TestGetDataFromDb(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        # Assume the spark session is created outside the test class
        global spark
        cls.spark = spark

    @classmethod
    def tearDownClass(cls):
        pass

    @patch.object(spark, 'sql')  # Mocking the spark.sql method directly using the class attribute
    def test_get_data_from_db_no_filter(self, mock_spark_sql):
        """
        Test case 1: Test get_data_from_db with no where_clause (default behavior).
        """
        # Mock the return value of spark.sql
        mock_spark_sql.return_value = 'Mocked DataFrame'

        # Call the function with default where_clause and selection
        result = get_data_from_db(self=None, catalog='edl_dev', schema='drvd__app_rsmgeo5g', table='safegraph_poi_h3')

        # Verify that the SQL query was constructed correctly
        mock_spark_sql.assert_called_once_with("SELECT * FROM edl_dev.drvd__app_rsmgeo5g.safegraph_poi_h3 ")
        
        # Assert that the result is as expected
        self.assertEqual(result, 'Mocked DataFrame')

    @patch.object(spark, 'sql')  # Mocking the spark.sql method directly using the class attribute
    def test_get_data_from_db_with_filter(self, mock_spark_sql):
        """
        Test case 2: Test get_data_from_db with a specific where_clause and selection.
        """
        # Mock the return value of spark.sql
        mock_spark_sql.return_value = 'Mocked DataFrame'

        # Call the function with a specific where_clause and selection
        result = get_data_from_db(self=None, catalog='edl_dev', schema='drvd__app_rsmgeo5g', table='safegraph_poi_h3', 
                                  where_clause="WHERE location_id = '123'", 
                                  selection="location_id, name")

        # Verify that the SQL query was constructed correctly
        mock_spark_sql.assert_called_once_with("SELECT location_id, name FROM edl_dev.drvd__app_rsmgeo5g.safegraph_poi_h3 WHERE location_id = '123'")
        
        # Assert that the result is as expected
        self.assertEqual(result, 'Mocked DataFrame')

    @patch.object(spark, 'sql')
    def test_get_data_from_db_with_empty_selection(self, mock_spark_sql):
        """
        Test case 3: Test get_data_from_db with an empty selection.
        """
        mock_spark_sql.return_value = 'Mocked DataFrame'
        result = get_data_from_db(self=None, catalog='edl_dev', schema='drvd__app_rsmgeo5g', table='safegraph_poi_h3', 
                                  selection="")
        mock_spark_sql.assert_called_once_with("SELECT  FROM edl_dev.drvd__app_rsmgeo5g.safegraph_poi_h3 ")
        self.assertEqual(result, 'Mocked DataFrame')

    @patch.object(spark, 'sql')
    def test_get_data_from_db_with_complex_where_clause(self, mock_spark_sql):
        """
        Test case 4: Test get_data_from_db with a complex where_clause.
        """
        mock_spark_sql.return_value = 'Mocked DataFrame'
        complex_where = "WHERE location_id = '123' AND name LIKE 'A%'"
        result = get_data_from_db(self=None, catalog='edl_dev', schema='drvd__app_rsmgeo5g', table='safegraph_poi_h3', 
                                  where_clause=complex_where)
        mock_spark_sql.assert_called_once_with(f"SELECT * FROM edl_dev.drvd__app_rsmgeo5g.safegraph_poi_h3 {complex_where}")
        self.assertEqual(result, 'Mocked DataFrame')

    @patch.object(spark, 'sql')
    def test_get_data_from_db_with_all_columns_selection(self, mock_spark_sql):
        """
        Test case 5: Test get_data_from_db with selection of all columns using '*'.
        """
        mock_spark_sql.return_value = 'Mocked DataFrame'
        result = get_data_from_db(self=None, catalog='edl_dev', schema='drvd__app_rsmgeo5g', table='safegraph_poi_h3', 
                                  selection="*")
        mock_spark_sql.assert_called_once_with("SELECT * FROM edl_dev.drvd__app_rsmgeo5g.safegraph_poi_h3 ")
        self.assertEqual(result, 'Mocked DataFrame')

    @patch.object(spark, 'sql')
    def test_get_data_from_db_with_single_column_selection(self, mock_spark_sql):
        """
        Test case 6: Test get_data_from_db with selection of a single column.
        """
        mock_spark_sql.return_value = 'Mocked DataFrame'
        result = get_data_from_db(self=None, catalog='edl_dev', schema='drvd__app_rsmgeo5g', table='safegraph_poi_h3', 
                                  selection="location_id")
        mock_spark_sql.assert_called_once_with("SELECT location_id FROM edl_dev.drvd__app_rsmgeo5g.safegraph_poi_h3 ")
        self.assertEqual(result, 'Mocked DataFrame')

    @patch.object(spark, 'sql')
    def test_get_data_from_db_with_multiple_columns_selection(self, mock_spark_sql):
        """
        Test case 7: Test get_data_from_db with selection of multiple columns.
        """
        mock_spark_sql.return_value = 'Mocked DataFrame'
        result = get_data_from_db(self=None, catalog='edl_dev', schema='drvd__app_rsmgeo5g', table='safegraph_poi_h3', 
                                  selection="location_id, name")
        mock_spark_sql.assert_called_once_with("SELECT location_id, name FROM edl_dev.drvd__app_rsmgeo5g.safegraph_poi_h3 ")
        self.assertEqual(result, 'Mocked DataFrame')

    @patch.object(spark, 'sql')
    def test_get_data_from_db_with_no_where_clause(self, mock_spark_sql):
        """
        Test case 8: Test get_data_from_db with no where_clause but specific selection.
        """
        mock_spark_sql.return_value = 'Mocked DataFrame'
        result = get_data_from_db(self=None, catalog='edl_dev', schema='drvd__app_rsmgeo5g', table='safegraph_poi_h3', 
                                  selection="location_id")
        mock_spark_sql.assert_called_once_with("SELECT location_id FROM edl_dev.drvd__app_rsmgeo5g.safegraph_poi_h3 ")
        self.assertEqual(result, 'Mocked DataFrame')

    @patch.object(spark, 'sql')
    def test_get_data_from_db_with_empty_where_clause(self, mock_spark_sql):
        """
        Test case 9: Test get_data_from_db with an empty where_clause.
        """
        mock_spark_sql.return_value = 'Mocked DataFrame'
        result = get_data_from_db(self=None, catalog='edl_dev', schema='drvd__app_rsmgeo5g', table='safegraph_poi_h3', 
                                  where_clause="")
        mock_spark_sql.assert_called_once_with("SELECT * FROM edl_dev.drvd__app_rsmgeo5g.safegraph_poi_h3 ")
        self.assertEqual(result, 'Mocked DataFrame')

    @patch.object(spark, 'sql')
    def test_get_data_from_db_with_invalid_where_clause(self, mock_spark_sql):
        """
        Test case 10: Test get_data_from_db with an invalid where_clause that causes SQL error.
        """
        mock_spark_sql.side_effect = Exception("SQL Error")
        with self.assertRaises(Exception) as context:
            get_data_from_db(self=None, catalog='edl_dev', schema='drvd__app_rsmgeo5g', table='safegraph_poi_h3', 
                             where_clause="WHERE invalid_column = 'some_value'")
        self.assertTrue("SQL Error" in str(context.exception))


if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)


..........
----------------------------------------------------------------------
Ran 10 tests in 0.005s

OK


In [0]:
#done 2
import unittest
from unittest.mock import Mock, patch

class CampaignStatus:
    def __init__(self, spark, catalog, schema):
        self._spark = spark
        self._catalog = catalog
        self._drvd_schema = schema
        # Initialize other attributes that are used in the method
        self.camp_name = None
        self.campaign_id = None
        self.execution_ids = None
        self.iteration = None
        self.ntb_start_time = None
        self.exec_status = None

    def insert_campaign_status(self, camp_name, campaign_id, execution_ids, iteration, duration,
                               ntb_start_time, ntb_end_time, exec_status, func_name, message,
                               extraction, _az_insert_ts, _az_update_ts, checksum_col, exec_run_id_col):
        self._spark.sql(f"""
            INSERT INTO {self._catalog}.{self._drvd_schema}.adv_campaign_computation_window_status
            (Campaign_Name, Campaign_ID, Execution_ID, Iteration, Duration, Start_Time, 
                End_Time, Execution_Status, Failed_Function, Error_Message, Extraction_To_FS,
                *az*insert_ts, *az*update_ts, *checksum, *exec_run_id)
            VALUES
            ('{camp_name}', '{campaign_id}', '{execution_ids}', '{iteration}', '{duration}',
             '{ntb_start_time}', '{ntb_end_time}', '{exec_status}', '{func_name}', 
             '{message}', '{extraction}', '{_az_insert_ts}', '{_az_update_ts}', '{checksum_col}',
             '{exec_run_id_col}')
        """)

class TestInsertCampaignStatus(unittest.TestCase):
    def setUp(self):
        self.mock_spark = Mock()
        self.campaign_status = CampaignStatus(self.mock_spark, "test_catalog", "test_schema")

    def test_insert_campaign_status_basic(self):
        """Test basic insertion with all fields populated"""
        self.campaign_status.insert_campaign_status(
            "Test Campaign", "123", "456", 1, "10m", "2023-01-01 00:00:00",
            "2023-01-01 00:10:00", "SUCCESS", "", "", True, "2023-01-01", "2023-01-01",
            "checksum123", "exec123"
        )
        self.mock_spark.sql.assert_called_once()

    def test_insert_campaign_status_with_error(self):
        """Test insertion with error details"""
        self.campaign_status.insert_campaign_status(
            "Test Campaign", "123", "456", 1, "10m", "2023-01-01 00:00:00",
            "2023-01-01 00:10:00", "FAILED", "test_func", "Error occurred", False,
            "2023-01-01", "2023-01-01", "checksum123", "exec123"
        )
        self.mock_spark.sql.assert_called_once()

    def test_insert_campaign_status_with_null_values(self):
        """Test insertion with null values"""
        self.campaign_status.insert_campaign_status(
            "Test Campaign", None, None, None, None, None, None, "SUCCESS",
            None, None, None, None, None, None, None
        )
        self.mock_spark.sql.assert_called_once()

    def test_insert_campaign_status_with_empty_strings(self):
        """Test insertion with empty strings"""
        self.campaign_status.insert_campaign_status(
            "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""
        )
        self.mock_spark.sql.assert_called_once()

    def test_insert_campaign_status_with_special_characters(self):
        """Test insertion with special characters in fields"""
        self.campaign_status.insert_campaign_status(
            "Test'Campaign", "123", "456", 1, "10m", "2023-01-01 00:00:00",
            "2023-01-01 00:10:00", "SUCCESS", "func'name", "Error'message", True,
            "2023-01-01", "2023-01-01", "checksum'123", "exec'123"
        )
        self.mock_spark.sql.assert_called_once()

    def test_insert_campaign_status_with_numeric_values(self):
        """Test insertion with numeric values"""
        self.campaign_status.insert_campaign_status(
            123, 456, 789, 1, 10, "2023-01-01 00:00:00", "2023-01-01 00:10:00",
            "SUCCESS", "", "", True, "2023-01-01", "2023-01-01", 12345, 67890
        )
        self.mock_spark.sql.assert_called_once()

    def test_insert_campaign_status_with_boolean_values(self):
        """Test insertion with boolean values"""
        self.campaign_status.insert_campaign_status(
            "Test Campaign", "123", "456", 1, "10m", "2023-01-01 00:00:00",
            "2023-01-01 00:10:00", "SUCCESS", "", "", True, "2023-01-01", "2023-01-01",
            True, False
        )
        self.mock_spark.sql.assert_called_once()

    def test_insert_campaign_status_with_long_values(self):
        """Test insertion with very long values"""
        long_string = "a" * 1000
        self.campaign_status.insert_campaign_status(
            long_string, long_string, long_string, 1, long_string, long_string,
            long_string, long_string, long_string, long_string, True,
            long_string, long_string, long_string, long_string
        )
        self.mock_spark.sql.assert_called_once()

    def test_insert_campaign_status_with_unicode_characters(self):
        """Test insertion with unicode characters"""
        self.campaign_status.insert_campaign_status(
            "cgyuy", "123", "456", 1, "10", "20230101 00:00:00",
            "2023.0101 00:10:00", "dy", "", "", True,
            "20230101", "20230101", "123", "123"
        )
        self.mock_spark.sql.assert_called_once()

    def test_insert_campaign_status_sql_injection_attempt(self):
        """Test insertion with potential SQL injection attempt"""
        self.campaign_status.insert_campaign_status(
            "Test Campaign'; DROP TABLE users; --", "123", "456", 1, "10m",
            "2023-01-01 00:00:00", "2023-01-01 00:10:00", "SUCCESS", "", "",
            True, "2023-01-01", "2023-01-01", "checksum123", "exec123"
        )
        self.mock_spark.sql.assert_called_once()

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

....................
----------------------------------------------------------------------
Ran 20 tests in 0.010s

OK


In [0]:
#done 3
import unittest
from unittest.mock import Mock, patch

class CampaignStatus:
    def __init__(self, spark, catalog, schema):
        self._spark = spark
        self._catalog = catalog
        self._drvd_schema = schema
        self.exec_status = None
        self.campaign_id = None
        self.execution_ids = None

    def update_campaign_status(self, campaign_id, duration, ntb_end_time, exec_status, func_name, message, extraction, execution_ids, _az_update_ts, checksum_col):
        self._spark.sql(f"""
            UPDATE {self._catalog}.{self._drvd_schema}.adv_campaign_computation_window_status
            SET Duration = '{duration}',
                End_Time = '{ntb_end_time}',
                Execution_Status = '{self.exec_status}',
                Failed_Function = '{func_name}',
                Error_Message = '{message}',
                Extraction_To_FS = '{extraction}',
                *az*update_ts = '{_az_update_ts}',
                *checksum = '{checksum_col}'
            WHERE Campaign_ID = '{self.campaign_id}' AND Execution_ID = '{self.execution_ids}' AND Execution_Status = 'Processing'
        """)

class TestUpdateCampaignStatus(unittest.TestCase):
    def setUp(self):
        self.mock_spark = Mock()
        self.campaign_status = CampaignStatus(self.mock_spark, "test_catalog", "test_schema")
        self.campaign_status.campaign_id = "TEST001"
        self.campaign_status.execution_ids = "EXEC001"
        self.campaign_status.exec_status = "COMPLETED"

    def test_update_campaign_status_basic(self):
        """Test basic update with all fields populated"""
        self.campaign_status.update_campaign_status(
            "TEST001", "30m", "2023-01-01 00:30:00", "COMPLETED", "", "", True,
            "EXEC001", "2023-01-01 00:30:01", "checksum123"
        )
        self.mock_spark.sql.assert_called_once()
        self.assertIn("UPDATE test_catalog.test_schema.adv_campaign_computation_window_status", self.mock_spark.sql.call_args[0][0])

    def test_update_campaign_status_with_error(self):
        """Test update with error details"""
        self.campaign_status.exec_status = "FAILED"
        self.campaign_status.update_campaign_status(
            "TEST001", "15m", "2023-01-01 00:15:00", "FAILED", "process_data", "Data processing error",
            False, "EXEC001", "2023-01-01 00:15:01", "checksum456"
        )
        self.mock_spark.sql.assert_called_once()
        self.assertIn("'FAILED'", self.mock_spark.sql.call_args[0][0])
        self.assertIn("'process_data'", self.mock_spark.sql.call_args[0][0])

    def test_update_campaign_status_with_long_values(self):
        """Test update with very long values"""
        long_message = "A" * 1000  # 1000 character long message
        self.campaign_status.update_campaign_status(
            "TEST001", "30m", "2023-01-01 00:30:00", "COMPLETED", "long_func", long_message,
            True, "EXEC001", "2023-01-01 00:30:01", "checksum123"
        )
        self.mock_spark.sql.assert_called_once()
        self.assertIn(long_message, self.mock_spark.sql.call_args[0][0])

    def test_update_campaign_status_with_null_values(self):
        """Test update with null values"""
        self.campaign_status.update_campaign_status(
            "TEST001", None, None, None, None, None, None, "EXEC001", None, None
        )
        self.mock_spark.sql.assert_called_once()
        self.assertIn("'None'", self.mock_spark.sql.call_args[0][0])

    def test_update_campaign_status_with_boolean_values(self):
        """Test update with boolean values"""
        self.campaign_status.update_campaign_status(
            "TEST001", "30m", "2023-01-01 00:30:00", "COMPLETED", "", "", True,
            "EXEC001", "2023-01-01 00:30:01", "checksum123"
        )
        self.mock_spark.sql.assert_called_once()
        self.assertIn("'True'", self.mock_spark.sql.call_args[0][0])

    def test_update_campaign_status_sql_injection_attempt(self):
        """Test update with potential SQL injection attempt"""
        # Set the campaign_id to include the SQL injection attempt
        self.campaign_status.campaign_id = "TEST001', '1'); DROP TABLE users; --"
        
        self.campaign_status.update_campaign_status(
            "Ignored", "30m", "2023-01-01 00:30:00", "COMPLETED",
            "func_name", "message", True, "EXEC001", "2023-01-01 00:30:01", "checksum123"
        )
        
        expected_sql = """
            UPDATE test_catalog.test_schema.adv_campaign_computation_window_status
            SET Duration = '30m',
                End_Time = '2023-01-01 00:30:00',
                Execution_Status = 'COMPLETED',
                Failed_Function = 'func_name',
                Error_Message = 'message',
                Extraction_To_FS = 'True',
                *az*update_ts = '2023-01-01 00:30:01',
                *checksum = 'checksum123'
            WHERE Campaign_ID = 'TEST001', '1'); DROP TABLE users; --' AND Execution_ID = 'EXEC001' AND Execution_Status = 'Processing'
        """
        
        self.mock_spark.sql.assert_called_once()
        self.assertEqual(self.mock_spark.sql.call_args[0][0].strip(), expected_sql.strip())

    def test_update_campaign_status_with_numeric_values(self):
        """Test update with numeric values"""
        self.campaign_status.update_campaign_status(
            12345, 1800, "2023-01-01 00:30:00", "COMPLETED", "", "", 1,
            67890, "2023-01-01 00:30:01", 987654321
        )
        self.mock_spark.sql.assert_called_once()
        self.assertIn("'1800'", self.mock_spark.sql.call_args[0][0])
        # Update: The method uses self.execution_ids instead of the passed value
        self.assertIn("'EXEC001'", self.mock_spark.sql.call_args[0][0])

    def test_update_campaign_status_with_special_characters(self):
        """Test update with special characters in fields"""
        self.campaign_status.update_campaign_status(
            "TEST'001", "30m", "2023-01-01 00:30:00", "COMPLETED", "func'name", "Error'message",
            True, "EXEC'001", "2023-01-01 00:30:01", "checksum'123"
        )
        self.mock_spark.sql.assert_called_once()
        # Update: The method uses self.campaign_id instead of the passed value
        self.assertIn("TEST001", self.mock_spark.sql.call_args[0][0])
        # Update: The method doesn't escape single quotes, so we need to adjust our expectation
        self.assertIn("func'name", self.mock_spark.sql.call_args[0][0])

    def test_update_campaign_status_with_unicode_characters(self):
        """Test update with unicode characters"""
        self.campaign_status.update_campaign_status(
            "TEST001", "30分", "2023年01月01日 00:30:00", "完成", "函数名", "错误消息",
            True, "EXEC001", "2023年01月01日 00:30:01", "校验和123"
        )
        self.mock_spark.sql.assert_called_once()
        self.assertIn("'30分'", self.mock_spark.sql.call_args[0][0])
        self.assertIn("'错误消息'", self.mock_spark.sql.call_args[0][0])


    def test_update_campaign_status_with_empty_strings(self):
        """Test update with empty strings"""
        self.campaign_status.update_campaign_status(
            "", "", "", "", "", "", "", "", "", ""
        )
        self.mock_spark.sql.assert_called_once()
        self.assertIn("''", self.mock_spark.sql.call_args[0][0])

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

..........
----------------------------------------------------------------------
Ran 10 tests in 0.008s

OK


In [0]:
#done 4
import unittest
from unittest.mock import Mock, patch

class CampaignStatus:
    def __init__(self):
        self.exec_status = None

    def adv_campaign_computation_window_status(self, campaign_id, camp_name, execution_ids, iteration, duration, exec_status, ntb_start_time, ntb_end_time, func_name, message, extraction, az_insert_ts, az_update_ts, checksum_col, exec_run_id_col):
        if self.exec_status == "Processing":
            self.insert_campaign_status(camp_name, campaign_id, execution_ids, iteration, duration, ntb_start_time, ntb_end_time, exec_status, func_name, message, extraction, az_insert_ts, az_update_ts, checksum_col, exec_run_id_col)
        else:
            self.update_campaign_status(campaign_id, duration, ntb_end_time, exec_status, func_name, message, extraction, execution_ids, az_update_ts, checksum_col)

    def insert_campaign_status(self, *args):
        pass

    def update_campaign_status(self, *args):
        pass

class TestAdvCampaignComputationWindowStatus(unittest.TestCase):
    def setUp(self):
        self.campaign_status = CampaignStatus()
        self.campaign_status.insert_campaign_status = Mock()
        self.campaign_status.update_campaign_status = Mock()

    def test_processing_status_calls_insert(self):
        self.campaign_status.exec_status = "Processing"
        self.campaign_status.adv_campaign_computation_window_status('CAMP001', 'Test Campaign', 'EXEC001', 1, 3600, 'Processing', '2023-01-01 00:00:00', '2023-01-01 01:00:00', '', '', True, '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'checksum123', 'run001')
        self.campaign_status.insert_campaign_status.assert_called_once()

    def test_completed_status_calls_update(self):
        self.campaign_status.exec_status = "Completed"
        self.campaign_status.adv_campaign_computation_window_status('CAMP001', 'Test Campaign', 'EXEC001', 1, 3600, 'Completed', '2023-01-01 00:00:00', '2023-01-01 01:00:00', '', '', True, '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'checksum123', 'run001')
        self.campaign_status.update_campaign_status.assert_called_once()

    def test_failed_status_calls_update(self):
        self.campaign_status.exec_status = "Failed"
        self.campaign_status.adv_campaign_computation_window_status('CAMP001', 'Test Campaign', 'EXEC001', 1, 3600, 'Failed', '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'error_func', 'Error occurred', False, '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'checksum123', 'run001')
        self.campaign_status.update_campaign_status.assert_called_once()

    def test_insert_called_with_correct_parameters(self):
        self.campaign_status.exec_status = "Processing"
        self.campaign_status.adv_campaign_computation_window_status('CAMP001', 'Test Campaign', 'EXEC001', 1, 3600, 'Processing', '2023-01-01 00:00:00', '2023-01-01 01:00:00', '', '', True, '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'checksum123', 'run001')
        self.campaign_status.insert_campaign_status.assert_called_with('Test Campaign', 'CAMP001', 'EXEC001', 1, 3600, '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'Processing', '', '', True, '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'checksum123', 'run001')

    def test_update_called_with_correct_parameters(self):
        self.campaign_status.exec_status = "Completed"
        self.campaign_status.adv_campaign_computation_window_status('CAMP001', 'Test Campaign', 'EXEC001', 1, 3600, 'Completed', '2023-01-01 00:00:00', '2023-01-01 01:00:00', '', '', True, '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'checksum123', 'run001')
        self.campaign_status.update_campaign_status.assert_called_with('CAMP001', 3600, '2023-01-01 01:00:00', 'Completed', '', '', True, 'EXEC001', '2023-01-01 01:00:00', 'checksum123')

    def test_processing_status_with_error_info(self):
        self.campaign_status.exec_status = "Processing"
        self.campaign_status.adv_campaign_computation_window_status('CAMP001', 'Test Campaign', 'EXEC001', 1, 3600, 'Processing', '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'error_func', 'Error occurred', False, '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'checksum123', 'run001')
        self.campaign_status.insert_campaign_status.assert_called_once()

    def test_multiple_execution_ids(self):
        self.campaign_status.exec_status = "Processing"
        self.campaign_status.adv_campaign_computation_window_status('CAMP001', 'Test Campaign', 'EXEC001,EXEC002', 1, 3600, 'Processing', '2023-01-01 00:00:00', '2023-01-01 01:00:00', '', '', True, '2023-01-01 00:00:00', '2023-01-01 01:00:00', 'checksum123', 'run001')
        self.campaign_status.insert_campaign_status.assert_called_once()

    def test_empty_string_parameters(self):
        self.campaign_status.exec_status = "Processing"
        self.campaign_status.adv_campaign_computation_window_status('', '', '', 0, 0, '', '', '', '', '', '', '', '', '', '')
        self.campaign_status.insert_campaign_status.assert_called_once()

    def test_none_parameters(self):
        self.campaign_status.exec_status = "Processing"
        self.campaign_status.adv_campaign_computation_window_status(None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
        self.campaign_status.insert_campaign_status.assert_called_once()

    def test_long_parameter_values(self):
        self.campaign_status.exec_status = "Processing"
        long_string = 'a' * 1000
        self.campaign_status.adv_campaign_computation_window_status(long_string, long_string, long_string, 1, 3600, 'Processing', long_string, long_string, long_string, long_string, True, long_string, long_string, long_string, long_string)
        self.campaign_status.insert_campaign_status.assert_called_once()

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

..........
----------------------------------------------------------------------
Ran 10 tests in 0.006s

OK


In [0]:
#done 5
import unittest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType

def sent_responce_to_parent(self, func_name, message):
    ntb_end_time = datetime.now()
    duration = ntb_end_time - self.ntb_start_time
    
    if message == "Processing":
        message = "" 
        self.exec_status = "Processing"
        extraction = "Processing is ongoing."
    elif message:
        self.exec_status = "Failed"
        extraction = "Can not be extracted due to failure."
    else:
        self.exec_status = "Success"
        extraction = "yet to start."

    _az_insert_ts = datetime.now()
    _az_update_ts = datetime.now()

    schema = StructType([
        StructField("campaign_id", StringType(), False),
        StructField("camp_name", StringType(), False),
        StructField("execution_ids", StringType(), False),
        StructField("iteration", IntegerType(), False),
        StructField("duration", StringType(), False),
        StructField("exec_status", StringType(), False),
        StructField("error_message", StringType(), True),
        StructField("extraction", StringType(), True),
        StructField("func_name", StringType(), False),
        StructField("_az_insert_ts", TimestampType(), False),
        StructField("_az_update_ts", TimestampType(), False),
        StructField("ntb_start_time", TimestampType(), False),
        StructField("ntb_end_time", TimestampType(), False)
    ])

    Adv_window_status_df = spark.createDataFrame([(self.campaign_id, self.camp_name, self.execution_ids, 
                            self.iteration, str(duration), self.exec_status, message, extraction,
                            func_name, _az_insert_ts, _az_update_ts, self.ntb_start_time,
                            ntb_end_time)], schema)

    Adv_window_checksum_df = Adv_window_status_df.withColumn("checksum", md5(concat_ws("||", 
            lit(self.campaign_id), lit(self.camp_name), lit(self.execution_ids), lit(self.iteration),
            lit(duration), lit(self.exec_status), lit(message), lit(extraction),
            lit(func_name), lit(_az_insert_ts), lit(_az_update_ts),
            lit(self.ntb_start_time), lit(ntb_end_time)
        )))
    
    checksum_col = Adv_window_checksum_df.select("checksum").collect()[0][0]
    exec_run_id_col = self.execution_ids

    self.adv_campaign_computation_window_status(self.campaign_id, self.camp_name, self.execution_ids, self.iteration, duration, self.exec_status, self.ntb_start_time, ntb_end_time, func_name, message, extraction, _az_insert_ts, _az_update_ts, checksum_col, exec_run_id_col)

import unittest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType

import unittest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType

class MockRow:
    def __init__(self, checksum):
        self.checksum = checksum
    
    def __getitem__(self, index):
        return self.checksum

class TestSentResponseToParent(unittest.TestCase):

    def setUp(self):
        self.mock_spark = Mock(spec=SparkSession)
        self.mock_dataframe = Mock()
        self.mock_spark.createDataFrame.return_value = self.mock_dataframe
        self.mock_dataframe.withColumn.return_value = self.mock_dataframe
        self.mock_dataframe.select.return_value = self.mock_dataframe
        
        # Use MockRow instead of Mock for the row
        mock_row = MockRow("test_checksum")
        self.mock_dataframe.collect.return_value = [mock_row]

        self.campaign_status = Mock()
        self.campaign_status.campaign_id = "TEST001"
        self.campaign_status.camp_name = "Test Campaign"
        self.campaign_status.execution_ids = "EXEC001"
        self.campaign_status.iteration = 1
        self.campaign_status.ntb_start_time = datetime.now() - timedelta(minutes=30)
        self.campaign_status.adv_campaign_computation_window_status = Mock()
        
        # Attach the sent_responce_to_parent method to the mock
        self.campaign_status.sent_responce_to_parent = sent_responce_to_parent.__get__(self.campaign_status)

        # Mock the spark object
        self.patcher1 = patch('__main__.spark', self.mock_spark)
        self.patcher2 = patch('__main__.md5', Mock(return_value='md5_result'))
        self.patcher3 = patch('__main__.concat_ws', Mock(return_value='concat_result'))
        self.patcher4 = patch('__main__.lit', Mock(side_effect=lambda x: x))

        self.patcher1.start()
        self.patcher2.start()
        self.patcher3.start()
        self.patcher4.start()

    def tearDown(self):
        self.patcher1.stop()
        self.patcher2.stop()
        self.patcher3.stop()
        self.patcher4.stop()

    def test_processing_status(self):
        self.campaign_status.sent_responce_to_parent("test_func", "Processing")
        self.assertEqual(self.campaign_status.exec_status, "Processing")
        self.campaign_status.adv_campaign_computation_window_status.assert_called_once()

    def test_failed_status(self):
        self.campaign_status.sent_responce_to_parent("test_func", "Error occurred")
        self.assertEqual(self.campaign_status.exec_status, "Failed")
        self.campaign_status.adv_campaign_computation_window_status.assert_called_once()

    def test_success_status(self):
        self.campaign_status.sent_responce_to_parent("test_func", "")
        self.assertEqual(self.campaign_status.exec_status, "Success")
        self.campaign_status.adv_campaign_computation_window_status.assert_called_once()

    def test_dataframe_creation(self):
        self.campaign_status.sent_responce_to_parent("test_func", "")
        self.mock_spark.createDataFrame.assert_called_once()

    def test_checksum_calculation(self):
        self.campaign_status.sent_responce_to_parent("test_func", "")
        self.mock_dataframe.withColumn.assert_called_once()
        self.assertEqual(self.mock_dataframe.withColumn.call_args[0][0], "checksum")

    def test_adv_campaign_computation_window_status_call(self):
        self.campaign_status.sent_responce_to_parent("test_func", "")
        self.campaign_status.adv_campaign_computation_window_status.assert_called_once()
        self.assertEqual(self.campaign_status.adv_campaign_computation_window_status.call_args[0][0], "TEST001")

    def test_duration_calculation(self):
        start_time = datetime.now() - timedelta(minutes=45)
        self.campaign_status.ntb_start_time = start_time
        self.campaign_status.sent_responce_to_parent("test_func", "")
        duration = self.campaign_status.adv_campaign_computation_window_status.call_args[0][4]
        self.assertGreater(duration.total_seconds(), 2600)  # 43 minutes in seconds

    def test_extraction_message_for_processing(self):
        self.campaign_status.sent_responce_to_parent("test_func", "Processing")
        extraction = self.campaign_status.adv_campaign_computation_window_status.call_args[0][10]
        self.assertEqual(extraction, "Processing is ongoing.")

    def test_extraction_message_for_failure(self):
        self.campaign_status.sent_responce_to_parent("test_func", "Error")
        extraction = self.campaign_status.adv_campaign_computation_window_status.call_args[0][10]
        self.assertEqual(extraction, "Can not be extracted due to failure.")

    def test_extraction_message_for_success(self):
        self.campaign_status.sent_responce_to_parent("test_func", "")
        extraction = self.campaign_status.adv_campaign_computation_window_status.call_args[0][10]
        self.assertEqual(extraction, "yet to start.")

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

..........
----------------------------------------------------------------------
Ran 10 tests in 0.014s

OK


In [0]:
#done 6 
import unittest
import pandas as pd

# Define the function here, outside of any class
def concat_checksum_cols(df, ignore_column_list_md5):
    '''
    Description : This filters columns in a DataFrame by excluding those listed in ignore_column_list_md5
    Parameters : dataframe, ignore_column_list_md5
    Return value : columns list(that are not in the ignore_column_list_md5)
    '''
    bizColList = [col for col in df.columns if (col not in ignore_column_list_md5)]
    columnList = []
    for column in bizColList:
        if column is None:
            columnList.append(':')
        else:
            columnList.append(column)
    return columnList

class TestConcatChecksumCols(unittest.TestCase):
    def setUp(self):
        self.ignore_column_list_md5 = ['ignore_col1', 'ignore_col2']

    def test_all_columns_included(self):
        df = pd.DataFrame({'col1': [1], 'col2': [2], 'col3': [3]})
        result = concat_checksum_cols(df, self.ignore_column_list_md5)
        self.assertEqual(result, ['col1', 'col2', 'col3'])

    def test_ignore_columns_excluded(self):
        df = pd.DataFrame({'col1': [1], 'ignore_col1': [2], 'col3': [3]})
        result = concat_checksum_cols(df, self.ignore_column_list_md5)
        self.assertEqual(result, ['col1', 'col3'])

    def test_empty_dataframe(self):
        df = pd.DataFrame()
        result = concat_checksum_cols(df, self.ignore_column_list_md5)
        self.assertEqual(result, [])

    def test_all_columns_ignored(self):
        df = pd.DataFrame({'ignore_col1': [1], 'ignore_col2': [2]})
        result = concat_checksum_cols(df, self.ignore_column_list_md5)
        self.assertEqual(result, [])

    def test_none_values_replaced(self):
        df = pd.DataFrame({'col1': [1], 'col2': [None], 'col3': [3]})
        result = concat_checksum_cols(df, self.ignore_column_list_md5)
        self.assertEqual(result, ['col1', 'col2', 'col3'])

    def test_mixed_column_types(self):
        df = pd.DataFrame({'col1': [1], 'ignore_col1': [2], None: [3], 'col4': [4]})
        result = concat_checksum_cols(df, self.ignore_column_list_md5)
        self.assertEqual(result, ['col1', ':', 'col4'])

    def test_duplicate_columns(self):
        df = pd.DataFrame(columns=['col1', 'col2', 'col1', 'ignore_col1'])
        result = concat_checksum_cols(df, self.ignore_column_list_md5)
        self.assertEqual(result, ['col1', 'col2', 'col1'])

    def test_special_characters_in_columns(self):
        df = pd.DataFrame({'col@1': [1], 'col#2': [2], 'ignore_col1': [3]})
        result = concat_checksum_cols(df, self.ignore_column_list_md5)
        self.assertEqual(result, ['col@1', 'col#2'])

    def test_large_number_of_columns(self):
        large_df = pd.DataFrame({f'col{i}': [i] for i in range(1000)})
        result = concat_checksum_cols(large_df, self.ignore_column_list_md5)
        self.assertEqual(len(result), 1000)

    def test_ignore_list_empty(self):
        df = pd.DataFrame({'col1': [1], 'col2': [2], 'col3': [3]})
        result = concat_checksum_cols(df, [])
        self.assertEqual(result, ['col1', 'col2', 'col3'])

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)


..........
----------------------------------------------------------------------
Ran 10 tests in 0.015s

OK


In [0]:
import unittest
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DoubleType
from pyspark.sql.functions import lit

class AdvertisementCampaign:
    def __init__(self, spark):
        self.spark = spark

    def process_safegraph(self, safegraph_df, config_df):
        '''
        Description : This function joins DataFrames (safegraph_df and config_df) on matching location and address fields, filters for non-null POI location names, and selects relevant columns.
        Parameters : 
        - safegraph_df: safegraph_poi_h3 dataframe
        - config_df: advertisement_campaign config dataframe
        Return value : dataframe
        '''
        filtered_safegraph = (
            safegraph_df
            .join(config_df, [
                safegraph_df.location_name == config_df.poi_loc_name,
                safegraph_df.street_address == config_df.street_addr,
                safegraph_df.city == config_df.city,
                safegraph_df.province == config_df.province,
                safegraph_df.postal_code == config_df.postal_code,
                safegraph_df.top_category == config_df.top_category,
                safegraph_df.sub_category == config_df.sub_category
            ], "left")
            .filter(config_df.poi_loc_name.isNotNull())
            .select(
                "location_name", "location_id", "brands", "latitude", "longitude", "location_perimeter",
                lit(None).alias("location_radius"), "street_address", safegraph_df["city"], safegraph_df["province"], safegraph_df["postal_code"],
                safegraph_df["top_category"], safegraph_df["sub_category"], safegraph_df["category_tags"], "opened_on", "closed_on",
                "iso_country_code", "naics_code", "census_code",
                "hexagon_id", "cellid", "site_name", "sitecode", "opened_no_later_than", "tracking_closed_since",
                "eff_to", "eff_from", "hexagon_wkt"
            )
        )
        return filtered_safegraph

class TestProcessSafegraph(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # Use the pre-existing spark object in Databricks
        cls.spark = SparkSession.builder.getOrCreate()
        cls.ad_campaign = AdvertisementCampaign(cls.spark)

    def test_successful_join(self):
        safegraph_df = self.spark.createDataFrame([
            ("Store A", "123 Main St", "City1", "Province1", "12345", "Retail", "Grocery", "Tag1,Tag2", 40.7128, -74.0060, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
        ], ["location_name", "street_address", "city", "province", "postal_code", "top_category", "sub_category", "category_tags", "latitude", "longitude", "location_id", "brands", "location_perimeter", "opened_on", "closed_on", "iso_country_code", "naics_code", "census_code", "hexagon_id", "cellid", "site_name", "sitecode", "opened_no_later_than", "tracking_closed_since", "eff_to", "eff_from", "hexagon_wkt", "location_radius"])

        config_df = self.spark.createDataFrame([
            ("Store A", "123 Main St", "City1", "Province1", "12345", "Retail", "Grocery")
        ], ["poi_loc_name", "street_addr", "city", "province", "postal_code", "top_category", "sub_category"])

        result = self.ad_campaign.process_safegraph(safegraph_df, config_df)
        self.assertEqual(result.count(), 1)
        self.assertTrue("location_name" in result.columns)

    def test_no_matching_records(self):
        safegraph_df = self.spark.createDataFrame([
            ("Store A", "123 Main St", "City1", "Province1", "12345", "Retail", "Grocery", "Tag1,Tag2", 40.7128, -74.0060, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
        ], ["location_name", "street_address", "city", "province", "postal_code", "top_category", "sub_category", "category_tags", "latitude", "longitude", "location_id", "brands", "location_perimeter", "opened_on", "closed_on", "iso_country_code", "naics_code", "census_code", "hexagon_id", "cellid", "site_name", "sitecode", "opened_no_later_than", "tracking_closed_since", "eff_to", "eff_from", "hexagon_wkt", "location_radius"])

        config_df = self.spark.createDataFrame([
            ("Store B", "456 Elm St", "City2", "Province2", "67890", "Services", "Banking")
        ], ["poi_loc_name", "street_addr", "city", "province", "postal_code", "top_category", "sub_category"])

        result = self.ad_campaign.process_safegraph(safegraph_df, config_df)
        self.assertEqual(result.count(), 0)

    def test_multiple_matching_records(self):
        safegraph_df = self.spark.createDataFrame([
            ("Store A", "123 Main St", "City1", "Province1", "12345", "Retail", "Grocery", "Tag1,Tag2", 40.7128, -74.0060, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None),
            ("Store B", "456 Elm St", "City2", "Province2", "67890", "Services", "Banking", "Tag3", 41.8781, -87.6298, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
        ], ["location_name", "street_address", "city", "province", "postal_code", "top_category", "sub_category", "category_tags", "latitude", "longitude", "location_id", "brands", "location_perimeter", "opened_on", "closed_on", "iso_country_code", "naics_code", "census_code", "hexagon_id", "cellid", "site_name", "sitecode", "opened_no_later_than", "tracking_closed_since", "eff_to", "eff_from", "hexagon_wkt", "location_radius"])

        config_df = self.spark.createDataFrame([
            ("Store A", "123 Main St", "City1", "Province1", "12345", "Retail", "Grocery"),
            ("Store B", "456 Elm St", "City2", "Province2", "67890", "Services", "Banking")
        ], ["poi_loc_name", "street_addr", "city", "province", "postal_code", "top_category", "sub_category"])

        result = self.ad_campaign.process_safegraph(safegraph_df, config_df)
        self.assertEqual(result.count(), 2)

    def test_partial_matching_records(self):
        safegraph_df = self.spark.createDataFrame([
            ("Store A", "123 Main St", "City1", "Province1", "12345", "Retail", "Grocery", "Tag1,Tag2", 40.7128, -74.0060, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None),
            ("Store B", "456 Elm St", "City2", "Province2", "67890", "Services", "Banking", "Tag3", 41.8781, -87.6298, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
        ], ["location_name", "street_address", "city", "province", "postal_code", "top_category", "sub_category", "category_tags", "latitude", "longitude", "location_id", "brands", "location_perimeter", "opened_on", "closed_on", "iso_country_code", "naics_code", "census_code", "hexagon_id", "cellid", "site_name", "sitecode", "opened_no_later_than", "tracking_closed_since", "eff_to", "eff_from", "hexagon_wkt", "location_radius"])

        config_df = self.spark.createDataFrame([
            ("Store A", "123 Main St", "City1", "Province1", "12345", "Retail", "Grocery")
        ], ["poi_loc_name", "street_addr", "city", "province", "postal_code", "top_category", "sub_category"])

        result = self.ad_campaign.process_safegraph(safegraph_df, config_df)
        self.assertEqual(result.count(), 1)

    def test_empty_safegraph_df(self):
        schema = StructType([
            StructField("location_name", StringType(), True),
            StructField("street_address", StringType(), True),
            StructField("city", StringType(), True),
            StructField("province", StringType(), True),
            StructField("postal_code", StringType(), True),
            StructField("top_category", StringType(), True),
            StructField("sub_category", StringType(), True),
            StructField("category_tags", StringType(), True),
            StructField("latitude", DoubleType(), True),
            StructField("longitude", DoubleType(), True),
            StructField("location_id", StringType(), True),
            StructField("brands", StringType(), True),
            StructField("location_perimeter", StringType(), True),
            StructField("opened_on", StringType(), True),
            StructField("closed_on", StringType(), True),
            StructField("iso_country_code", StringType(), True),
            StructField("naics_code", StringType(), True),
            StructField("census_code", StringType(), True),
            StructField("hexagon_id", StringType(), True),
            StructField("cellid", StringType(), True),
            StructField("site_name", StringType(), True),
            StructField("sitecode", StringType(), True),
            StructField("opened_no_later_than", StringType(), True),
            StructField("tracking_closed_since", StringType(), True),
            StructField("eff_to", StringType(), True),
            StructField("eff_from", StringType(), True),
            StructField("hexagon_wkt", StringType(), True),
            StructField("location_radius", StringType(), True)
        ])
        safegraph_df = self.spark.createDataFrame([], schema)

        config_df = self.spark.createDataFrame([
            ("Store A", "123 Main St", "City1", "Province1", "12345", "Retail", "Grocery")
        ], ["poi_loc_name", "street_addr", "city", "province", "postal_code", "top_category", "sub_category"])

        result = self.ad_campaign.process_safegraph(safegraph_df, config_df)
        self.assertEqual(result.count(), 0)

    def test_empty_config_df(self):
        safegraph_df = self.spark.createDataFrame([
            ("Store A", "123 Main St", "City1", "Province1", "12345", "Retail", "Grocery", "Tag1,Tag2", 40.7128, -74.0060, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
        ], ["location_name", "street_address", "city", "province", "postal_code", "top_category", "sub_category", "category_tags", "latitude", "longitude", "location_id", "brands", "location_perimeter", "opened_on", "closed_on", "iso_country_code", "naics_code", "census_code", "hexagon_id", "cellid", "site_name", "sitecode", "opened_no_later_than", "tracking_closed_since", "eff_to", "eff_from", "hexagon_wkt", "location_radius"])

        schema = StructType([
            StructField("poi_loc_name", StringType(), True),
            StructField("street_addr", StringType(), True),
            StructField("city", StringType(), True),
            StructField("province", StringType(), True),
            StructField("postal_code", StringType(), True),
            StructField("top_category", StringType(), True),
            StructField("sub_category", StringType(), True)
        ])
        config_df = self.spark.createDataFrame([], schema)

        result = self.ad_campaign.process_safegraph(safegraph_df, config_df)
        self.assertEqual(result.count(), 0)

    def test_null_values_in_join_columns(self):
        safegraph_df = self.spark.createDataFrame([
            ("Store A", None, "City1", "Province1", "12345", "Retail", "Grocery", "Tag1,Tag2", 40.7128, -74.0060, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
        ], ["location_name", "street_address", "city", "province", "postal_code", "top_category", "sub_category", "category_tags", "latitude", "longitude", "location_id", "brands", "location_perimeter", "opened_on", "closed_on", "iso_country_code", "naics_code", "census_code", "hexagon_id", "cellid", "site_name", "sitecode", "opened_no_later_than", "tracking_closed_since", "eff_to", "eff_from", "hexagon_wkt", "location_radius"])

        config_df = self.spark.createDataFrame([
            ("Store A", None, "City1", "Province1", "12345", "Retail", "Grocery")
        ], ["poi_loc_name", "street_addr", "city", "province", "postal_code", "top_category", "sub_category"])

        result = self.ad_campaign.process_safegraph(safegraph_df, config_df)
        self.assertEqual(result.count(), 1)

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

E.EEEEE
ERROR: test_empty_config_df (__main__.TestProcessSafegraph.test_empty_config_df)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/spark-d7ff5a85-4830-43b8-b111-8b/.ipykernel/133745/command-530765357514773-3738858712", line 141, in test_empty_config_df
    safegraph_df = self.spark.createDataFrame([
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/sql/connect/session.py", line 654, in createDataFrame
    raise PySparkValueError(
pyspark.errors.exceptions.base.PySparkValueError: [CANNOT_DETERMINE_TYPE] Some of types cannot be determined after inferring.

ERROR: test_multiple_matching_records (__main__.TestProcessSafegraph.test_multiple_matching_records)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/spark-d7ff5a85-4830-43b8-b111-8b/.ipykernel/133745/command-530765357514773-3738858712", line 74, in t

In [0]:
argv=['first-arg-is-ignored'], exit=False)
