In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as _sum, lit, expr, when
from pyspark.sql.window import Window

# Initialize Spark session
spark = SparkSession.builder.appName("MetricCalculation").getOrCreate()

# Sample data
data = [
    (6482, "AA", "AA Executive", "2-5 Years", "R", "<660", "LW", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
]

columns = ["per_num", "level_2_nam", "level_3_nam", "mob_seg", "trana_ind", "fico_seg", "argus_wallet_flag", 
           "num_Accts", "sum_Balcon_resp_3M_ind", "sum_Flex_resp_3M_ind", "sum_SG_Resp_3M_ind", "sum_RS_Resp_3M_ind", 
           "sum_RS_DM_3M_ind", "sum_SG_DM_3M_ind", "sum_Balcon_DM_3M_ind", "sum_Flex_Pay_DM_3M_ind", 
           "sum_Balcon_sentcount_3M_ind", "sum_Flex_Pay_sentcount_3M_ind", "sum_SG_sentcount_3M_ind", "sum_RS_sentcount_3M_ind", 
           "sum_SG_inviews_3M_ind", "sum_RS_inviews_3M_ind", "sum_Balcon_inviews_3M_ind", "sum_Flex_loan_inviews_3M_ind", 
           "sum_Balcon_impr_3M_ind", "sum_Flex_loan_impr_3M_ind", "sum_RS_impr_3M_ind", "sum_SG_impr_3M_ind", 
           "sum_Flex_transfer_3M", "sum_Balcon_ecm_transfer_3M", "sum_Flex_durtn_3M", "sum_Balcon_ecm_durtn_3M", 
           "sum_Balcon_ecm_totl_fee_3M", "sum_Flex_pay_std_fee_3M", "sum_Flex_pay_pln_fee_3M", "sum_Flex_weighted_apr_3M", 
           "sum_Balcon_ecm_weighted_apr_3M", "sum_Flex_pay_amt_3M", "sum_Flex_pay_weighted_apr_3M", "sum_sg_promo_sales_3M", 
           "sum_rs_promo_sales_3M", "sum_Targeted"]

df_1 = spark.createDataFrame(data, columns)

# Initialize variables
timeline = 12
per_num = 6482

portfolio = "AA"
product = "AA Executive"
mob = "2-5 Years"
fico_seg = "<660"
argus_wallet_flag = "LW"
trana_ind = "R"

filter_args = {
    "level_2_nam": portfolio,
    "level_3_nam": product,
    "mob_seg": mob,
    "trana_ind": trana_ind,
    "fico_seg": fico_seg,
    "argus_wallet_flag": argus_wallet_flag
}

var_list = [
    'sum_Balcon_resp_3M_ind', 'sum_Flex_resp_3M_ind', 'sum_SG_Resp_3M_ind', 'sum_RS_Resp_3M_ind', 
    'sum_RS_DM_3M_ind', 'sum_SG_DM_3M_ind', 'sum_Balcon_DM_3M_ind', 'sum_Flex_Pay_DM_3M_ind', 
    'sum_Balcon_sentcount_3M_ind', 'sum_Flex_Pay_sentcount_3M_ind', 'sum_SG_sentcount_3M_ind', 'sum_RS_sentcount_3M_ind', 
    'sum_SG_inviews_3M_ind', 'sum_RS_inviews_3M_ind', 'sum_Balcon_inviews_3M_ind', 'sum_Flex_loan_inviews_3M_ind', 
    'sum_Balcon_impr_3M_ind', 'sum_Flex_loan_impr_3M_ind', 'sum_RS_impr_3M_ind', 'sum_SG_impr_3M_ind'
]

metric_cols = [
    'sum_Flex_transfer_3M', 'sum_Balcon_ecm_transfer_3M', 'sum_Flex_durtn_3M', 'sum_Balcon_ecm_durtn_3M', 
    'sum_Balcon_ecm_totl_fee_3M', 'sum_Flex_pay_std_fee_3M', 'sum_Flex_pay_pln_fee_3M', 'sum_Flex_weighted_apr_3M', 
    'sum_Balcon_ecm_weighted_apr_3M', 'sum_Flex_pay_amt_3M', 'sum_Flex_pay_weighted_apr_3M', 'sum_sg_promo_sales_3M', 
    'sum_rs_promo_sales_3M'
]

# Define metric calculation function
def metric_calc(df_2, per_num, timeline, **kwargs):
    filter_args_col = [key for key, value in kwargs.items() if value is not None]
    
    for key, value in kwargs.items():
        if value is not None:
            df_2 = df_2.filter(col(key) == value)
    
    df_curr_grouped = df_2.select(["per_num"] + filter_args_col + ["num_Accts"] + var_list + ["sum_Targeted"] + metric_cols)
    df_curr_grouped = df_curr_grouped.orderBy(["per_num"] + filter_args_col)
    
    df_curr_grouped_1 = df_curr_grouped.filter(col("per_num") == per_num)
    df_prev_grouped_1 = df_curr_grouped.filter(col("per_num") == (per_num - timeline))
    
    for col_name in var_list + metric_cols:
        df_curr_grouped_1 = df_curr_grouped_1.withColumn(col_name, _sum(col_name).over(Window.partitionBy("per_num", *filter_args_col)))
        df_prev_grouped_1 = df_prev_grouped_1.withColumn(col_name, _sum(col_name).over(Window.partitionBy("per_num", *filter_args_col)))
    
    df_curr_grouped_2 = df_curr_grouped_1.withColumn("Balcon_enrollment", _sum("sum_Balcon_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))
    df_curr_grouped_2 = df_curr_grouped_2.withColumn("Flex_enrollment", _sum("sum_Flex_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))
    df_curr_grouped_2 = df_curr_grouped_2.withColumn("SG_enrollment", _sum("sum_SG_Resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))
    df_curr_grouped_2 = df_curr_grouped_2.withColumn("RS_enrollment", _sum("sum_RS_Resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))
    df_curr_grouped_2 = df_curr_grouped_2.withColumn("Flex_Pay_enrollment", _sum("sum_Flex_pay_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))
    
    df_prev_grouped_2 = df_prev_grouped_1.withColumn("Total_enrollment", 
                        _sum("sum_Balcon_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                        _sum("sum_Flex_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                        _sum("sum_SG_Resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                        _sum("sum_RS_Resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                        _sum("sum_Flex_pay_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Total_targeting", _sum("sum_Targeted").over(Window.partitionBy("per_num", *filter_args_col)))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("Balcon_targeting", 
                    _sum("sum_Balcon_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("Flex_targeting", 
                    _sum("sum_Flex_pay_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_loan_incentives_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("SG_targeting", 
                    _sum("sum_SG_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("RS_targeting", 
                    _sum("sum_RS_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("Flex_Pay_targeting", 
                    _sum("sum_Flex_pay_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("Balcon_enrollment_per_targeting", 
                    col("sum_Balcon_resp_3M_ind") / 
                    (_sum("sum_Balcon_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col))))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("Flex_Loan_enrollment_per_targeting", 
                    col("sum_Flex_resp_3M_ind") / 
                    (_sum("sum_Flex_loan_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_loan_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_loan_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col))))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("SG_enrollment_per_targeting", 
                    col("sum_SG_Resp_3M_ind") / 
                    (_sum("sum_SG_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col))))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("RS_enrollment_per_targeting", 
                    col("sum_RS_Resp_3M_ind") / 
                    (_sum("sum_RS_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col))))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("Flex_pay_enrollment_per_targeting", 
                    col("sum_Flex_pay_resp_3M_ind") / 
                    (_sum("sum_Flex_pay_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col))))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("Total_enrollment_per_targeting", 
                    (col("sum_Balcon_resp_3M_ind") + col("sum_Flex_resp_3M_ind") + 
                     col("sum_SG_Resp_3M_ind") + col("sum_RS_Resp_3M_ind") + 
                     col("sum_Flex_pay_resp_3M_ind")) / 
                    _sum("sum_Targeted").over(Window.partitionBy("per_num", *filter_args_col)))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Balcon_ecm_transfer_3M", 
                    _sum("sum_Balcon_ecm_transfer_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Balcon_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Balcon_ecm_durtn_3M", 
                    _sum("sum_Balcon_ecm_durtn_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Balcon_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Balcon_ecm_totl_fee_3M", 
                    _sum("sum_Balcon_ecm_totl_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Balcon_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Balcon_ecm_weighted_apr_3M", 
                    _sum("sum_Balcon_ecm_weighted_apr_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Balcon_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Flex_transfer_3M", 
                    _sum("sum_Flex_transfer_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Flex_durtn_3M", 
                    _sum("sum_Flex_durtn_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Flex_totl_fee_3M", 
                    _sum("sum_Flex_totl_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Flex_weighted_apr_3M", 
                    _sum("sum_Flex_weighted_apr_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Flex_pay_amt_3M", 
                    _sum("sum_Flex_pay_amt_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Flex_pay_durtn_3M", 
                    _sum("sum_Flex_pay_durtn_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Flex_pay_std_fee_3M", 
                    _sum("sum_Flex_pay_std_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Flex_pay_pln_fee_3M", 
                    _sum("sum_Flex_pay_pln_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_Flex_pay_weighted_apr_3M", 
                    _sum("sum_Flex_pay_weighted_apr_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_sg_promo_sales_3M", 
                    _sum("sum_sg_promo_sales_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_SG_Resp_3M_ind"))

    df_curr_grouped_2 = df_curr_grouped_2.withColumn("avg_rs_promo_sales_3M", 
                    _sum("sum_rs_promo_sales_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_RS_Resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Balcon_enrollment", 
                    _sum("sum_Balcon_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Flex_loan_enrollment", 
                    _sum("sum_Flex_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("SG_enrollment", 
                    _sum("sum_SG_Resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("RS_enrollment", 
                    _sum("sum_RS_Resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Flex_pay_enrollment", 
                    _sum("sum_Flex_pay_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Balcon_enrollment_per_targeting", 
                    col("sum_Balcon_resp_3M_ind") / 
                    (_sum("sum_Balcon_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col))))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Flex_loan_enrollment_per_targeting", 
                    col("sum_Flex_resp_3M_ind") / 
                    (_sum("sum_Flex_loan_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_loan_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_loan_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_loan_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col))))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("SG_enrollment_per_targeting", 
                    col("sum_SG_Resp_3M_ind") / 
                    (_sum("sum_SG_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col))))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("RS_enrollment_per_targeting", 
                    col("sum_RS_Resp_3M_ind") / 
                    (_sum("sum_RS_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col))))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Flex_pay_enrollment_per_targeting", 
                    col("sum_Flex_pay_resp_3M_ind") / 
                    (_sum("sum_Flex_pay_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_pay_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col))))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Balcon_targeting", 
                    _sum("sum_Balcon_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Balcon_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Flex_targeting", 
                    _sum("sum_Flex_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("SG_targeting", 
                    _sum("sum_SG_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_SG_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("RS_targeting", 
                    _sum("sum_RS_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_RS_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Flex_Pay_targeting", 
                    _sum("sum_Flex_Pay_DM_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_Pay_sentcount_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_Pay_inviews_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)) + 
                    _sum("sum_Flex_Pay_impr_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("Total_targeting", 
                    _sum("sum_Targeted").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Balcon_ecm_transfer_3M", 
                    _sum("sum_Balcon_ecm_transfer_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Balcon_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Balcon_ecm_durtn_3M", 
                    _sum("sum_Balcon_ecm_durtn_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Balcon_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Balcon_ecm_totl_fee_3M", 
                    _sum("sum_Balcon_ecm_totl_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Balcon_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Balcon_ecm_weighted_apr_3M", 
                    _sum("sum_Balcon_ecm_weighted_apr_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Balcon_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Flex_transfer_3M", 
                    _sum("sum_Flex_transfer_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Flex_durtn_3M", 
                    _sum("sum_Flex_durtn_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Flex_totl_fee_3M", 
                    _sum("sum_Flex_totl_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Flex_weighted_apr_3M", 
                    _sum("sum_Flex_weighted_apr_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Flex_pay_amt_3M", 
                    _sum("sum_Flex_pay_amt_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Flex_pay_durtn_3M", 
                    _sum("sum_Flex_pay_durtn_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Flex_pay_std_fee_3M", 
                    _sum("sum_Flex_pay_std_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Flex_pay_pln_fee_3M", 
                    _sum("sum_Flex_pay_pln_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_Flex_pay_weighted_apr_3M", 
                    _sum("sum_Flex_pay_weighted_apr_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_Flex_pay_resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_sg_promo_sales_3M", 
                    _sum("sum_sg_promo_sales_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_SG_Resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("avg_rs_promo_sales_3M", 
                    _sum("sum_rs_promo_sales_3M").over(Window.partitionBy("per_num", *filter_args_col)) / 
                    col("sum_RS_Resp_3M_ind"))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Balcon_ecm_transfer_3M", 
                    _sum("sum_Balcon_ecm_transfer_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Balcon_ecm_durtn_3M", 
                    _sum("sum_Balcon_ecm_durtn_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Balcon_ecm_totl_fee_3M", 
                    _sum("sum_Balcon_ecm_totl_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Balcon_ecm_weighted_apr_3M", 
                    _sum("sum_Balcon_ecm_weighted_apr_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_transfer_3M", 
                    _sum("sum_Flex_transfer_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_durtn_3M", 
                    _sum("sum_Flex_durtn_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_totl_fee_3M", 
                    _sum("sum_Flex_totl_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_weighted_apr_3M", 
                    _sum("sum_Flex_weighted_apr_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_pay_amt_3M", 
                    _sum("sum_Flex_pay_amt_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_pay_durtn_3M", 
                    _sum("sum_Flex_pay_durtn_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_pay_std_fee_3M", 
                    _sum("sum_Flex_pay_std_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_pay_pln_fee_3M", 
                    _sum("sum_Flex_pay_pln_fee_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_pay_weighted_apr_3M", 
                    _sum("sum_Flex_pay_weighted_apr_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_sg_promo_sales_3M", 
                    _sum("sum_sg_promo_sales_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_rs_promo_sales_3M", 
                    _sum("sum_rs_promo_sales_3M").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Balcon_resp_3M_ind", 
                    _sum("sum_Balcon_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_resp_3M_ind", 
                    _sum("sum_Flex_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_SG_Resp_3M_ind", 
                    _sum("sum_SG_Resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_RS_Resp_3M_ind", 
                    _sum("sum_RS_Resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))

    df_prev_grouped_2 = df_prev_grouped_2.withColumn("sum_Flex_pay_resp_3M_ind", 
                    _sum("sum_Flex_pay_resp_3M_ind").over(Window.partitionBy("per_num", *filter_args_col)))
    
    df_prev_grouped_2 = df_prev_grouped_2.toPandas()
    df_curr_grouped_2 = df_curr_grouped_2.toPandas()

    df_change_grouped = pd.DataFrame()

    if len(df_curr_grouped_1.filter(col("per_num") == per_num).select("num_Accts").collect()) >= 1:
        df_prev_grouped_2["Balcon_Enrollment_Share"] = df_prev_grouped_2["Balcon_enrollment"] / (
                df_prev_grouped_2["Balcon_enrollment"] + df_prev_grouped_2["Flex_loan_enrollment"] +
                df_prev_grouped_2["SG_enrollment"] + df_prev_grouped_2["RS_enrollment"] + df_prev_grouped_2["Flex_Pay_enrollment"])

        df_prev_grouped_2["Flex_Loan_Enrollment_Share"] = df_prev_grouped_2["Flex_loan_enrollment"] / (
                df_prev_grouped_2["Balcon_enrollment"] + df_prev_grouped_2["Flex_loan_enrollment"] +
                df_prev_grouped_2["SG_enrollment"] + df_prev_grouped_2["RS_enrollment"] + df_prev_grouped_2["Flex_Pay_enrollment"])

        df_prev_grouped_2["Sales_Growth_Enrollment_Share"] = df_prev_grouped_2["SG_enrollment"] / (
                df_prev_grouped_2["Balcon_enrollment"] + df_prev_grouped_2["Flex_loan_enrollment"] +
                df_prev_grouped_2["SG_enrollment"] + df_prev_grouped_2["RS_enrollment"] + df_prev_grouped_2["Flex_Pay_enrollment"])

        df_prev_grouped_2["Rate_Sale_Enrollment_Share"] = df_prev_grouped_2["RS_enrollment"] / (
                df_prev_grouped_2["Balcon_enrollment"] + df_prev_grouped_2["Flex_loan_enrollment"] +
                df_prev_grouped_2["SG_enrollment"] + df_prev_grouped_2["RS_enrollment"] + df_prev_grouped_2["Flex_Pay_enrollment"])

        df_prev_grouped_2["Flex_Pay_Enrollment_Share"] = df_prev_grouped_2["Flex_Pay_enrollment"] / (
                df_prev_grouped_2["Balcon_enrollment"] + df_prev_grouped_2["Flex_loan_enrollment"] +
                df_prev_grouped_2["SG_enrollment"] + df_prev_grouped_2["RS_enrollment"] + df_prev_grouped_2["Flex_Pay_enrollment"])

        df_prev_grouped_2["Total_Enrollment_Share"] = (
                                                               df_prev_grouped_2["Balcon_enrollment"] +
                                                               df_prev_grouped_2["Flex_loan_enrollment"] +
                                                               df_prev_grouped_2["SG_enrollment"] +
                                                               df_prev_grouped_2["RS_enrollment"] +
                                                               df_prev_grouped_2["Flex_Pay_enrollment"]) / \
                                                       df_prev_grouped_1.groupby(["per_num"] + filter_args_col)["num_Accts"].sum().collect()[
                                                           0]["num_Accts"]

        df_curr_grouped_2["Balcon_Enrollment_Share"] = df_curr_grouped_2["Balcon_enrollment"] / (
                df_curr_grouped_2["Balcon_enrollment"] + df_curr_grouped_2["Flex_loan_enrollment"] +
                df_curr_grouped_2["SG_enrollment"] + df_curr_grouped_2["RS_enrollment"] + df_curr_grouped_2["Flex_Pay_enrollment"])

        df_curr_grouped_2["Flex_Loan_Enrollment_Share"] = df_curr_grouped_2["Flex_loan_enrollment"] / (
                df_curr_grouped_2["Balcon_enrollment"] + df_curr_grouped_2["Flex_loan_enrollment"] +
                df_curr_grouped_2["SG_enrollment"] + df_curr_grouped_2["RS_enrollment"] + df_curr_grouped_2["Flex_Pay_enrollment"])

        df_curr_grouped_2["Sales_Growth_Enrollment_Share"] = df_curr_grouped_2["SG_enrollment"] / (
                df_curr_grouped_2["Balcon_enrollment"] + df_curr_grouped_2["Flex_loan_enrollment"] +
                df_curr_grouped_2["SG_enrollment"] + df_curr_grouped_2["RS_enrollment"] + df_curr_grouped_2["Flex_Pay_enrollment"])

        df_curr_grouped_2["Rate_Sale_Enrollment_Share"] = df_curr_grouped_2["RS_enrollment"] / (
                df_curr_grouped_2["Balcon_enrollment"] + df_curr_grouped_2["Flex_loan_enrollment"] +
                df_curr_grouped_2["SG_enrollment"] + df_curr_grouped_2["RS_enrollment"] + df_curr_grouped_2["Flex_Pay_enrollment"])

        df_curr_grouped_2["Flex_Pay_Enrollment_Share"] = df_curr_grouped_2["Flex_Pay_enrollment"] / (
                df_curr_grouped_2["Balcon_enrollment"] + df_curr_grouped_2["Flex_loan_enrollment"] +
                df_curr_grouped_2["SG_enrollment"] + df_curr_grouped_2["RS_enrollment"] + df_curr_grouped_2["Flex_Pay_enrollment"])

        df_curr_grouped_2["Total_Enrollment_Share"] = (
                                                               df_curr_grouped_2["Balcon_enrollment"] +
                                                               df_curr_grouped_2["Flex_loan_enrollment"] +
                                                               df_curr_grouped_2["SG_enrollment"] +
                                                               df_curr_grouped_2["RS_enrollment"] +
                                                               df_curr_grouped_2["Flex_Pay_enrollment"]) / \
                                                       df_curr_grouped_1.groupby(["per_num"] + filter_args_col)["num_Accts"].sum().collect()[
                                                           0]["num_Accts"]

        for var in df_prev_grouped_2.columns:
            if var not in ["per_num", "yearmonth"]:
                df_change_grouped[f"{var}_Change"] = ((df_curr_grouped_2[f"{var}"]) - (df_prev_grouped_2[f"{var}"])) / abs(
                    df_prev_grouped_2[f"{var}"])

    return df_prev_grouped_2, df_curr_grouped_2, df_change_grouped

res_prev1, res_curr1, res_change1 = metric_calc(df_1, per_num, timeline, **filter_args)

# Display results
res_change1.show()
res_prev1.show()
res_curr1.show()


In [None]:
    df_prev_grouped_2 = df_prev_grouped_2.toPandas()
    df_curr_grouped_2 = df_curr_grouped_2.toPandas()

    df_change_grouped = pd.DataFrame()

    if len(df_curr_grouped_1.filter(col("per_num") == per_num).select("num_Accts").collect()) >= 1:
        df_prev_grouped_2["Balcon_Enrollment_Share"] = df_prev_grouped_2["Balcon_enrollment"] / (
                df_prev_grouped_2["Balcon_enrollment"] + df_prev_grouped_2["Flex_loan_enrollment"] +
                df_prev_grouped_2["SG_enrollment"] + df_prev_grouped_2["RS_enrollment"] + df_prev_grouped_2["Flex_Pay_enrollment"])

        df_prev_grouped_2["Flex_Loan_Enrollment_Share"] = df_prev_grouped_2["Flex_loan_enrollment"] / (
                df_prev_grouped_2["Balcon_enrollment"] + df_prev_grouped_2["Flex_loan_enrollment"] +
                df_prev_grouped_2["SG_enrollment"] + df_prev_grouped_2["RS_enrollment"] + df_prev_grouped_2["Flex_Pay_enrollment"])

        df_prev_grouped_2["Sales_Growth_Enrollment_Share"] = df_prev_grouped_2["SG_enrollment"] / (
                df_prev_grouped_2["Balcon_enrollment"] + df_prev_grouped_2["Flex_loan_enrollment"] +
                df_prev_grouped_2["SG_enrollment"] + df_prev_grouped_2["RS_enrollment"] + df_prev_grouped_2["Flex_Pay_enrollment"])

        df_prev_grouped_2["Rate_Sale_Enrollment_Share"] = df_prev_grouped_2["RS_enrollment"] / (
                df_prev_grouped_2["Balcon_enrollment"] + df_prev_grouped_2["Flex_loan_enrollment"] +
                df_prev_grouped_2["SG_enrollment"] + df_prev_grouped_2["RS_enrollment"] + df_prev_grouped_2["Flex_Pay_enrollment"])

        df_prev_grouped_2["Flex_Pay_Enrollment_Share"] = df_prev_grouped_2["Flex_Pay_enrollment"] / (
                df_prev_grouped_2["Balcon_enrollment"] + df_prev_grouped_2["Flex_loan_enrollment"] +
                df_prev_grouped_2["SG_enrollment"] + df_prev_grouped_2["RS_enrollment"] + df_prev_grouped_2["Flex_Pay_enrollment"])

        df_prev_grouped_2["Total_Enrollment_Share"] = (
                                                               df_prev_grouped_2["Balcon_enrollment"] +
                                                               df_prev_grouped_2["Flex_loan_enrollment"] +
                                                               df_prev_grouped_2["SG_enrollment"] +
                                                               df_prev_grouped_2["RS_enrollment"] +
                                                               df_prev_grouped_2["Flex_Pay_enrollment"]) / \
                                                       df_prev_grouped_2.groupby(["per_num"] + filter_args_col)["num_Accts"].sum().reset_index()[
                                                           "num_Accts"]

        df_curr_grouped_2["Balcon_Enrollment_Share"] = df_curr_grouped_2["Balcon_enrollment"] / (
                df_curr_grouped_2["Balcon_enrollment"] + df_curr_grouped_2["Flex_loan_enrollment"] +
                df_curr_grouped_2["SG_enrollment"] + df_curr_grouped_2["RS_enrollment"] + df_curr_grouped_2["Flex_Pay_enrollment"])

        df_curr_grouped_2["Flex_Loan_Enrollment_Share"] = df_curr_grouped_2["Flex_loan_enrollment"] / (
                df_curr_grouped_2["Balcon_enrollment"] + df_curr_grouped_2["Flex_loan_enrollment"] +
                df_curr_grouped_2["SG_enrollment"] + df_curr_grouped_2["RS_enrollment"] + df_curr_grouped_2["Flex_Pay_enrollment"])

        df_curr_grouped_2["Sales_Growth_Enrollment_Share"] = df_curr_grouped_2["SG_enrollment"] / (
                df_curr_grouped_2["Balcon_enrollment"] + df_curr_grouped_2["Flex_loan_enrollment"] +
                df_curr_grouped_2["SG_enrollment"] + df_curr_grouped_2["RS_enrollment"] + df_curr_grouped_2["Flex_Pay_enrollment"])

        df_curr_grouped_2["Rate_Sale_Enrollment_Share"] = df_curr_grouped_2["RS_enrollment"] / (
                df_curr_grouped_2["Balcon_enrollment"] + df_curr_grouped_2["Flex_loan_enrollment"] +
                df_curr_grouped_2["SG_enrollment"] + df_curr_grouped_2["RS_enrollment"] + df_curr_grouped_2["Flex_Pay_enrollment"])

        df_curr_grouped_2["Flex_Pay_Enrollment_Share"] = df_curr_grouped_2["Flex_Pay_enrollment"] / (
                df_curr_grouped_2["Balcon_enrollment"] + df_curr_grouped_2["Flex_loan_enrollment"] +
                df_curr_grouped_2["SG_enrollment"] + df_curr_grouped_2["RS_enrollment"] + df_curr_grouped_2["Flex_Pay_enrollment"])

        df_curr_grouped_2["Total_Enrollment_Share"] = (
                                                               df_curr_grouped_2["Balcon_enrollment"] +
                                                               df_curr_grouped_2["Flex_loan_enrollment"] +
                                                               df_curr_grouped_2["SG_enrollment"] +
                                                               df_curr_grouped_2["RS_enrollment"] +
                                                               df_curr_grouped_2["Flex_Pay_enrollment"]) / \
                                                       df_curr_grouped_2.groupby(["per_num"] + filter_args_col)["num_Accts"].sum().reset_index()[
                                                           "num_Accts"]

        for var in df_prev_grouped_2.columns:
            if var not in ["per_num", "yearmonth"]:
                df_change_grouped[f"{var}_Change"] = ((df_curr_grouped_2[f"{var}"]) - (df_prev_grouped_2[f"{var}"])) / abs(
                    df_prev_grouped_2[f"{var}"])

    return df_prev_grouped_2, df_curr_grouped_2, df_change_grouped

res_prev1, res_curr1, res_change1 = metric_calc(df_1, per_num, timeline, **filter_args)

# Display results
print(res_change1)
print(res_prev1)
print(res_curr1)