In [0]:
# def write_read(spark, sdf, table_path):
#     """
#     Writes a Spark DataFrame to a Hive table at the specified path and then reads it back.

#     Parameters:
#     - spark: The Spark Session
#     - sdf (DataFrame): The Spark DataFrame to be written and read.
#     - table_path (str): The Hive table path where the DataFrame is to be written.

#     Returns:
#     - DataFrame: The Spark DataFrame that has been read from the Hive table.
#     Note:
#     This function overwrites any existing data at the table_path.
#     """

#     sdf.write.mode("overwrite").saveAsTable(table_path)
#     sdf = spark.read.table(table_path)

#     return sdf

# # def fill_with_zero(df, columns):
# #     """
# #     Fill specified columns in a DataFrame with zeros.

# #     Parameters:
# #         df: DataFrame - The input DataFrame
# #         columns: list - A list of column names to fill with zero

# #     Returns:
# #         DataFrame - The DataFrame with specified columns filled with zero
# #     """
# #     fill_dict = {col: 0 for col in columns}
# #     return df.na.fill(fill_dict)



In [0]:
# List to hold date type features to curate days_since features
selected_date_columns = [
 'AccountStatusChangedAt',
 'AutoRenewChangeDate',
 'GroupJoinedAt',
 'InactivityDate',
 'InfoRequestAt',
 'LastRenewalDate',
 'LostSimAt',
 #'PACRequestAt',
 #'STACRequestAt',
 'SalesDate',
 'LastPlanChangeAt'
]

# List to hold the features for curation of volatility features
selected_volatility_features = [
# Data
"PlanDataUKGB_LastWeek",
"PlanDataUKGB_LastMonth",
"PlanDataUKGB_Last6Month",

# Text
"PlanTextUKCount_LastWeek",
"PlanTextUKCount_LastMonth",
"PlanTextUKCount_Last6Month",

# Minutes
"PlanVoiceUKMinute_LastWeek",
"PlanVoiceUKMinute_LastMonth",
"PlanVoiceUKMinute_Last6Month"
]

In [0]:
def write_read(spark, sdf, table_path):
    """
    Writes a Spark DataFrame to a Hive table at the specified path and then reads it back.

    Parameters:
    - spark: The Spark Session
    - sdf (DataFrame): The Spark DataFrame to be written and read.
    - table_path (str): The Hive table path where the DataFrame is to be written.

    Returns:
    - DataFrame: The Spark DataFrame that has been read from the Hive table.
    Note:
    This function overwrites any existing data at the table_path.
    """

    sdf.write.mode("overwrite").saveAsTable(table_path)
    sdf = spark.read.table(table_path)

    return sdf

def filter_permanent_live_customers(lead_time_sdf: DataFrame) -> DataFrame:
    """
    Filters the input DataFrame to only include customers who have been a customer for more than 3 months and have a live account status.

    :param lead_time_sdf: Input DataFrame with customer data.
    :return: A DataFrame with customers who have been a customer for more than 3 months and have a live account status.
    """
    # Filter to include only customers who have been a customer for more than 3 months and have a live account status
    permanent_customers_df = lead_time_sdf.filter(
        (col("ActivationDate") <= date_sub(current_date(), 90)) &
        (col('AccountStatus') == 'Live')
    )
    
    return permanent_customers_df

def add_days_since_features(df: DataFrame, date_columns: list, snapshot_date_col: str) -> DataFrame:
    """
    Adds new columns to the DataFrame representing the number of days since the events in date_columns occurred.
    Assumes that the snapshot date column is of date type.

    :param df: Input DataFrame with date columns.
    :param date_columns: List of column names that contain date information.
    :param snapshot_date_col: The name of the column containing the snapshot date.
    :return: DataFrame with new features.
    """
    # Ensure the snapshot_date_col is of date type
    df = df.withColumn(snapshot_date_col, to_date(col(snapshot_date_col)))

    # Calculate the number of days since each date column
    for date_col in date_columns:
        days_since_col_name = f"Days_Since_{date_col}"
        event_occurred_col_name = f"Event_Occurred_{date_col}"

        # Create a binary column indicating whether the event occurred
        df = df.withColumn(event_occurred_col_name, when(col(date_col).isNotNull(), 1).otherwise(0))

        # Calculate days since the event (or set to -1 if no event occurred)
        df = df.withColumn(days_since_col_name, when(col(event_occurred_col_name) == 1,
                                                      datediff(col(snapshot_date_col), to_date(col(date_col))))
                                                .otherwise(-1))

    return df


def create_volatility_features(df: DataFrame, selected_features: list) -> DataFrame:
    # Define suffixes for new column names
    abs_change_suffix = '_AbsChange'
    pct_change_suffix = '_PctChange'
    ratio_suffix = '_Ratio'
    
    # Loop through the selected features and calculate the new volatility features
    for feature_type in ['Data', 'Text', 'Minute']:
        # Extract relevant features for the current type (Data, Text, Minute)
        relevant_features = [feature for feature in selected_features if feature_type in feature]
        
        # Ensure there are three features for each type (weekly, monthly, six-monthly)
        if len(relevant_features) == 3:
            weekly_feature, monthly_feature, six_monthly_feature = relevant_features
            
            # Calculate Absolute Changes
            df = df.withColumn(feature_type + '_Week_Month' + abs_change_suffix, 
                               abs(col(weekly_feature) - col(monthly_feature)))
            df = df.withColumn(feature_type + '_Month_6Month' + abs_change_suffix, 
                               abs(col(monthly_feature) - col(six_monthly_feature)))
            
            # Calculate Relative Changes
            df = df.withColumn(feature_type + '_Week_Month' + pct_change_suffix, 
                               when(col(monthly_feature) != 0, 
                                    col(feature_type + '_Week_Month' + abs_change_suffix) / col(monthly_feature))
                               .otherwise(None))
            df = df.withColumn(feature_type + '_Month_6Month' + pct_change_suffix, 
                               when(col(six_monthly_feature) != 0, 
                                    col(feature_type + '_Month_6Month' + abs_change_suffix) / col(six_monthly_feature))
                               .otherwise(None))
            
            # Calculate Ratios
            df = df.withColumn(feature_type + '_Week_Month' + ratio_suffix, 
                               when(col(monthly_feature) != 0, 
                                    col(weekly_feature) / col(monthly_feature))
                               .otherwise(None))
            df = df.withColumn(feature_type + '_Month_6Month' + ratio_suffix, 
                               when(col(six_monthly_feature) != 0, 
                                    col(monthly_feature) / col(six_monthly_feature))
                               .otherwise(None))
            
           
            df = df.fillna(0)  # Replace nulls with zeros
        else:
            print(f"Skipping {feature_type} as it does not have all three time frames (weekly, monthly, six-monthly).")
    
    return df

def get_features(data: DataFrame, snapshot_date_col: str):
    

    """
    Processes the input DataFrame to filter ported-in live customers, add days since features,
    and writes the result to a Hive table.

    Parameters:
    - data: The input Spark DataFrame to be processed.
    - snapshot_date_col: The name of the column containing the snapshot date.
    - table_output_schema: The Hive table schema where the processed DataFrame will be written.

    Returns:
    - DataFrame: The Spark DataFrame that has been written to and read from the Hive table.
    """
    # Filter ported-in live customers
    filtered_data = filter_permanent_live_customers(data)

    # Add days since features
    processed_data = add_days_since_features(filtered_data, selected_date_columns, snapshot_date_col)

    # Add curated volatility features from usage columns 
    processed_data = create_volatility_features(processed_data, selected_volatility_features)

    # Filter processed data to customers whose days since last renewal is equal to or greater than 17
    processed_data = processed_data.filter((col("Days_Since_LastRenewalDate") >= 17))

    # Write to Hive table and read back
    #final_data = write_read(spark, processed_data, f"{table_output_schema}.processed_data")

    return processed_data