# Step by step heriarcy of notebook-

`import_libraries` -> `input_selection` -> `user_input` -> `utils` -> `arr_mapping_fields` -> `product_bundling` -> `round_off` -> `dimension_date_dim` -> `{ufr_logics}` -> `flows` -> `main` -> `main_ufr`-> `qc_mechanism`-> `test`

**Note-** Do not alter the cell position.

# Import libraries for data processing and snowpark

**Key libraries**

- `Snowflake snowpark` libraries
    - Session, Mathematical Operations
    - Snowpark data types
- `Pandas` for data frame processing
- `functools.reduce`: For functional programming utilities
- `ThreadPoolExecutor`: For parallel processing of function.

In [None]:
# STANDARD LIBRARY IMPORTS
import time
import traceback
import logging
import json
import glob
from functools import reduce
from concurrent.futures import ThreadPoolExecutor, as_completed

# CORE SNOWPARK IMPORTS
import snowflake.snowpark as snowpark
from snowflake.snowpark import Session, DataFrame
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.exceptions import SnowparkSQLException

# WINDOW FUNCTIONS
from snowflake.snowpark import Window
from snowflake.snowpark.window import Window  # Note: This is duplicate - keep one
from snowflake.snowpark.functions import row_number, lag

# SNOWPARK FUNCTIONS MODULE
from snowflake.snowpark import functions as F

# COLUMN OPERATIONS
from snowflake.snowpark.functions import (
    col,
    when,
    lit,
    coalesce,
    is_null
)

# AGGREGATION FUNCTIONS
from snowflake.snowpark.functions import (
    sum,
    # Multiple aliases for sum
    sum as sum_sf,
    sum as spark_sum,
    sum as snowflake_sum,
    sum as snowpark_sum,
    count,
    max,
    max as spark_max,
    min,
    min as spark_min
)

# DATE/TIME FUNCTIONS
from snowflake.snowpark.functions import (
    to_date,
    dateadd,
    add_months,
    date_trunc,
    datediff,
    dayofmonth
)

# STRING FUNCTIONS
from snowflake.snowpark.functions import (
    upper,
    trim,
    regexp_replace,
    regexp_count,
    concat,
    length
)

# MATHEMATICAL FUNCTIONS
from snowflake.snowpark.functions import (
    abs,
    round
)

# SEQUENCE FUNCTIONS
from snowflake.snowpark.functions import (
    seq1,
    seq8
)

# DATA TYPES
from snowflake.snowpark.types import *
from snowflake.snowpark.types import (
    DecimalType,
    DateType,
    StringType
)

# SESSION INITIALIZATION
session = get_active_session()

In [None]:
import streamlit as st
# --- Title ---
st.title("TMT Enabler Input Parameters")

# --- User Inputs ---
# Lookback list selection (allow multiple)
lookback_list = st.multiselect(
    "Select lookback period(s):",
    options=[1, 3, 12],
    help="Choose one or more of Month (1), Quarter (3), or Year (12)."
)

# Input amount selection
input_amount = st.selectbox(
    "Select input amount type:",
    options=["ARR", "MRR"]
)

# Run at levels
run_at_levels = st.multiselect(
    "Select run at levels:",
    options=["ARR", "MRR", "T3M", "TTM", "T3M (Annualized)"],
    default=["ARR", "MRR"]
)

# Retention levels
retention_levels = st.multiselect(
    "Select retention level(s):",
    options=[
        "Customer_level",
        "Customer_Product_level", 
        "Customer_Product_RetentionType_level",
        "Level4"
    ],
    default=[
        "Customer_level"
    ]
)

# --- Check if all required inputs are provided ---
if (
    lookback_list and 
    input_amount and 
    run_at_levels and 
    retention_levels
):
    st.success("All required inputs selected. Ready to run!")
    
    # --- Add your business logic or function call below ---
    st.write("### Selected Parameters:")
    st.write(f"Lookback List: {lookback_list}")
    st.write(f"Input Amount: {input_amount}")
    st.write(f"Run At Levels: {run_at_levels}")
    st.write(f"Retention Levels: {retention_levels}")
    
    # Debug info to show the actual values
    # st.write("### Debug Info:")
    # st.write(f"Use UFR Logic: {use_ufr_logic}")
    # st.write(f"lookback_list_UFR type: {type(lookback_list_UFR)}")
    # st.write(f"lookback_list_UFR value: {lookback_list_UFR}")
    
    # Code representation (like your comment example)
    # st.write("### Code Representation:")
    # if use_ufr_logic and lookback_list_UFR:
    #     st.code(f"# lookback_list_UFR = None  # Commented out\nlookback_list_UFR = {lookback_list_UFR}")
    # else:
    #     st.code(f"lookback_list_UFR = None\n# lookback_list_UFR = [1,3,12]  # Commented out")
    
    # You can now call your main function here
    # result = your_main_function(
    #     lookback_list=lookback_list,
    #     lookback_list_UFR=lookback_list_UFR,
    #     input_amount=input_amount,
    #     run_at_levels=run_at_levels,
    #     retention_levels=retention_levels
    # )
    # st.write(result)
else:
    st.warning("Please select all required parameters to proceed.")

In [None]:
# Step 1: Find the JSON file matching *.config.json in current directory
config_files = glob.glob("*.config.json")

if not config_files:
    raise FileNotFoundError("No config file matching '*.config.json' found in the current directory.")
elif len(config_files) > 1:
    raise RuntimeError("Multiple '*.config.json' files found. Expected only one.")
else:
    config_file = config_files[0]
    print(f"Loading config from: {config_file}")

# Step 2: Load the config file
with open(config_file, "r") as f:
    config = json.load(f)

# Step 3: Extract input and pbi table names
try:
    input_table = config["ODBC config"]["tables"]["input_table"]
    pbi_table = config["ODBC config"]["tables"]["pbi_table"]
    excel_table = config["ODBC config"]["tables"]["excel_table"]
    fact_table = config["ODBC config"]["tables"]["fact_table"]
except KeyError as e:
    raise KeyError(f"Missing expected key in config: {e}")

# User input variables & levers

**Input Variables**

- `lookback_list`
- `input_amount`
- `run_at_levels`
- `retention_levels`
- `input_file_path`
- `input_table_templogic`
- `input_table_product_bundle`
- `pbi_retention_output_path`
- `excel_retention_output_path`
- `fact_table_output_path`

**Mapping columns**
- `column_mapping_file`

**Filter Condition**
- Can only be done on base columns like  `CUSTOMERID, PRODUCT, REVENUETYPE, CURRENTPERIOD, ARR` 
- SAMPLE: col("CustomerID") == "99999_TEST"

In [None]:
# #####################################################################
# ### defining lookbacks function, default is 3 for quarter lookbacks
# ### for monthly ARR data, YoY = 12, QoQ = 3, MoM = 1
# ### for quarterly ARR data, YoY = 4, QoQ = 1
# #####################################################################
# # Period Comparison. 1 = Month, 3 = Quarter, 12 = Year
# lookback_list= [1,3,12]

# # Input amount - "ARR" or "MRR"
# input_amount = "ARR"

# # run_at_levels - ["ARR", "MRR", "T3M", "TTM", "T3M (Annualized)"]
# run_at_levels = ["ARR", "MRR", "T3M", "TTM", "T3M (Annualized)"]

# # Retention_level | 
# # Valid values ["Customer_level","Customer_Product_level","Customer_Product_RetentionType_level","Level4"]
# retention_levels = ["Customer_level","Customer_Product_level","Customer_Product_RetentionType_level"]

# Input Tables
input_file_path = input_table
 
# Input is a small sample of SUBSCRIPTION_ACCEL.PYTHON_TESTING.INPUT_BOOKING_TEST
# Output Tables Path
## Note: Do not add _C, _CP and _CPR at the end of the paths. This is handled internally!
pbi_retention_output_path = pbi_table
excel_retention_output_path = excel_table
fact_table_output_path = fact_table

# Input Table for Temp logic. 
# NOTE: If not to use Temp logic, make it "None". 
# input_table_templogic = pbi_retention_output_path + "_C_NOTEBOOK"
input_table_templogic = None

# Input Table for product bundling
input_table_product_bundle = pbi_retention_output_path + "_C_NOTEBOOK"

# Mapping File with respect to columns available in input file
column_mapping_file = {
    "CUSTOMERID":'"CUSTOMERID"'
    ,"PRODUCT": '"PRODUCT"'
    ,"REVENUETYPE": '"REVENUETYPE"'
    ,"CURRENTPERIOD":'"CURRENTPERIOD"'
    ,"VALUE":'"VALUE"'
    ,"Account Size": None
    ,"Region": None
    ,"Industry": None
    ,"Channel": None}


####################### Filter condition ############################
# Can only be done on base columns like 
# CUSTOMERID, PRODUCT, REVENUETYPE, CURRENTPERIOD, ARR. 
# SAMPLE: col("CustomerID") == "99999_TEST"

filter_condition: None # default
# filter_condition = col("CUSTOMERID") == "Customer 002651"
# filter_condition = col("CURRENTPERIOD").between("2020-01-01", "2020-06-31")

# Add column names here to run the all_qa_check function
qa_columns = ['ARR_ROLLCHECK', 'Count_RollCheck', 'Cohort_Max_Dates_Check']

# Thresholds
qa_check_thresholds = 0.001

**This pipeline offers a comprehensive framework for customer retention analysis, valuable for data analysts and business intelligence professionals.**

## Key Functions

1. **`determine_account_size(arr_value)`**: Classifies account sizes based on ARR values.
2. **`add_rename_cols(df, mapping_file)`**: Adds and renames columns in a DataFrame based on a mapping.
3. **`add_fact_table_cols(df)`**: Adds dummy columns for reporting in a fact table.
4. **`data_loading(session, path, mapping_file, filter_condition, type, input_amount, retention_level)`**: Loads and processes data based on specified analysis type.
5. **`generate_months_snowpark(session, input_df)`**: Generates monthly records for all customers.
6. **`credit_df_prepare(df)`**: Prepares DataFrame for credit analysis by adjusting ARR values.
7. **`window_cols(df)`**: Adds windowed calculations for customer metrics.
8. **`base_table_creation(session, df)`**: Creates a base table for further analysis.
9. **`retention(df, l)`**: Applies retention logic to calculate various metrics.
10. **`save_results(result, path, type)`**: Saves processed results to a specified Snowflake table.

## Outputs
- DataFrames with retention metrics.
- Aggregated results saved in Snowflake.


In [None]:
# Function to determine account size based on ARR value
def determine_account_size(arr_value):
    if arr_value < 10000:
        return "1. <10K"
    elif 10000 <= arr_value <= 50000:
        return "2. 10K - 50K"
    elif 50000 < arr_value <= 100000:
        return "3. 50K - 100K"
    elif 100000 < arr_value <= 250000:
        return "4. 100K - 250K"
    elif 250000 < arr_value <= 500000:
        return "5. 250K - 500K"
    elif 500000 < arr_value < 1000000:
        return "6. 500K-1M"
    else:
        return "7. >1M"

# To add missing columns and rename them is required format
def add_rename_cols(df: DataFrame, mapping_file: dict) -> DataFrame:
    for key, value in mapping_file.items():
        if key == "Account Size" and "Account Size" in df.columns:
            # Skip this to preserve our calculated Account Size
            continue
        elif value is None:
            df = df.with_column(key,lit("NA"))
        else:
            df = df.with_column_renamed(value,key)
    return df


# Add dummy columns to fact table for PBI template 
def add_fact_table_cols(df: DataFrame) -> DataFrame:
    # dummy colsr
    df = df.withColumn('"Boomerang flag"', lit('"No Boomerang"'))
    df = df.withColumn("Cust+Prdt", concat(col("CUSTOMERID"), lit("_"), col('"Product"')))
    df = df.withColumn('"Cust+Prdt+RevType"', concat(col("CUSTOMERID"), lit("_"), col('"Product"'), lit("_"), col('"Revenue Type"')))
    df = df.withColumn('"Cust+Prdt+RevType+Lv4"', concat(col("CUSTOMERID"), lit("_"), col('"Product"'), lit("_"), col('"Revenue Type"'), lit("_"), col('"LEVEL4"')))
    df = df.withColumn('"Product Category"', lit("NA"))
    df = df.withColumn('"Product Level 2"', lit("NA"))

    # index col
    window_spec = Window.order_by(lit(1))
    df = df.withColumn('"Index"', row_number().over(window_spec))

    # cohort cols
    df = df.withColumn("CM", lit(1))
    
    return df

# Generic method to loading data from table
def data_loading(session: Session, path: str, mapping_file: dict, filter_condition:str = None, 
                 type:str ="retention", input_amount:str ="MRR", retention_level:str ="Customer_Product_RetentionType_level"):
    
    df = session.table(path)
    
    df = add_rename_cols(df, mapping_file)

    if type == "retention":

        # df = df.with_column("CURRENTPERIOD", to_date(col("CURRENTPERIOD"), "yyyy-MM-dd"))
        if retention_level == "Customer_level":
            df = df.group_by(col("CUSTOMERID")
                , date_trunc('month', col('CURRENTPERIOD')).alias('CURRENTPERIOD')
            ).agg(sum(col("VALUE")).alias("VALUE"))
            
            df = df.with_column("PRODUCT", lit("NA"))
            df = df.with_column("REVENUETYPE", lit("NA"))
            df = df.with_column("LEVEL4", lit("NA"))
            
        elif retention_level == "Customer_Product_level":
            df = df.group_by(col("CUSTOMERID")
                , col("PRODUCT")
                , date_trunc('month', col('CURRENTPERIOD')).alias('CURRENTPERIOD')
            ).agg(sum(col("VALUE")).alias("VALUE"))
            
            df = df.with_column("REVENUETYPE", lit("NA"))
            df = df.with_column("LEVEL4", lit("NA"))
            
        elif retention_level == "Customer_Product_RetentionType_level":
            df = df.group_by(col("CUSTOMERID")
                , col("PRODUCT")
                , col("REVENUETYPE")
                , date_trunc('month', col('CURRENTPERIOD')).alias('CURRENTPERIOD')
            ).agg(sum(col("VALUE")).alias("VALUE"))
            
            df = df.with_column("LEVEL4", lit("NA"))
            
        elif retention_level == "Level4":
            df = df.group_by(col("CUSTOMERID")
                , col("PRODUCT")
                , col("REVENUETYPE")
                , col("LEVEL4")
                , date_trunc('month', col('CURRENTPERIOD')).alias('CURRENTPERIOD')
            ).agg(sum(col("VALUE")).alias("VALUE"))       
    
        if not filter_condition is None:
            df = df.filter(filter_condition)
            
    elif type == "fact":
        # df = df.with_column("CURRENTPERIOD", to_date(col("CURRENTPERIOD"), "yyyy-MM-dd"))
        
        if retention_level == "Customer_level":
            df = df.group_by(
                col("CUSTOMERID")
                , col("Account Size")
                , col("Region")
                , col("Industry")
                , col("Channel")
                , date_trunc('month', col('CURRENTPERIOD')).alias('"Date"')
            ).agg(
                count(col("CUSTOMERID")).alias("Count of rows"),
                sum(col("VALUE")).alias("VALUE") 
            )
            
            df = df.with_column('"Product"', lit("NA"))            
            df = df.with_column('"Revenue Type"', lit("NA"))
            df = df.with_column("LEVEL4", lit("NA"))
            
        elif retention_level == "Customer_Product_level":
            df = df.group_by(
                col("CUSTOMERID")
                , col("Account Size")
                , col("Region")
                , col("Industry")
                , col("Channel")
                , date_trunc('month', col('CURRENTPERIOD')).alias('"Date"')
                , col("PRODUCT" ).alias('"Product"')
            ).agg(
                count(col("CUSTOMERID")).alias("Count of rows"),
                sum(col("VALUE")).alias("VALUE") 
            )
            
            df = df.with_column('"Revenue Type"', lit("NA"))
            df = df.with_column("LEVEL4", lit("NA"))
            
        elif retention_level == "Customer_Product_RetentionType_level":
            df = df.group_by(
                col("CUSTOMERID")
                , col("Account Size")
                , col("Region")
                , col("Industry")
                , col("Channel")
                , date_trunc('month', col('CURRENTPERIOD')).alias('"Date"')
                , col("PRODUCT" ).alias('"Product"')
                , col("REVENUETYPE").alias('"Revenue Type"')
            ).agg(
                count(col("CUSTOMERID")).alias("Count of rows"),
                sum(col("VALUE")).alias("VALUE") 
            )
            
            df = df.with_column("LEVEL4", lit("NA"))
            
        elif retention_level == "Level4":
            df = df.group_by(
                col("CUSTOMERID")
                , col("Account Size")
                , col("Region")
                , col("Industry")
                , col("Channel")
                , date_trunc('month', col('CURRENTPERIOD')).alias('"Date"')
                , col("PRODUCT" ).alias('"Product"')
                , col("REVENUETYPE").alias('"Revenue Type"')
                , col("LEVEL4")
            ).agg(
                count(col("CUSTOMERID")).alias("Count of rows"),
                sum(col("VALUE")).alias("VALUE") 
            )

        df = add_fact_table_cols(df) # adding new cols

        if not filter_condition is None:
            df = df.filter(filter_condition)

        df = df.with_column_renamed("CUSTOMERID",'"CustomerID"')
  
    return df


# Generating all the months rows for all the customers
def generate_months_snowpark(session: Session, input_df: DataFrame) -> DataFrame:    
    print("generate_months_snowpark started...")
    cache_s_time = time.time()
    # Get date boundaries
    date_boundaries = input_df.agg(min(col("CURRENTPERIOD")), max(col("CURRENTPERIOD"))).collect()
    min_date = date_boundaries[0][0]
    max_date = date_boundaries[0][1]

    # Generate month-end dates using Snowpark
    date_range_df = session.range((max_date - min_date).days + 1).select(
        (to_date(lit(min_date)) + col("id")).alias("date")
    ).filter(
        dayofmonth(col("date")) == 1
    ).select(
        col("date").alias("month_date")
    ).distinct()

    
    # Cross join
    cross_join = input_df.cross_join(date_range_df)
    
    # Update ARR column
    cross_join = cross_join.with_column(
        "VALUE",
        when(col("CURRENTPERIOD") == col("month_date"), col("VALUE")).otherwise(lit(0))
    )
    
    # Group by and sum ARR
    final_table = cross_join.group_by(
        *[col for col in cross_join.columns if col not in ['CURRENTPERIOD', "VALUE"]]
    ).agg(sum(col("VALUE")).alias("VALUE"))
    
    # Rename column
    final_table = final_table.with_column_renamed("month_date", "CURRENTPERIOD")
    final_table = final_table.with_column("Max Date", lit(max_date))
    cache_e_time = time.time()
    print(f"⏱️ Caching df generate_months_snowpark done: {cache_e_time - cache_s_time:.2f} seconds")
    return final_table
    

# Generating rows for credit
def credit_df_prepare(df: DataFrame) -> DataFrame:
    
    # Split into positive and negative ARR
    positive_arr = df.filter(col("VALUE") >= 0)
    negative_arr = df.filter(col("VALUE") < 0)
    
    # Add "-CR$" to CUSTOMERID
    neg_arr_w_name_change = negative_arr.with_column(
        "CUSTOMERID", concat(col("CUSTOMERID"), lit("-CR$"))
    )
    
    # Adding back the row with ARR as 0
    zero_arr = negative_arr.with_column("VALUE", lit(0.0))

    df_final = positive_arr.union_by_name(neg_arr_w_name_change).union_by_name(zero_arr)
    
    df_final = df_final.sort(col('CURRENTPERIOD'))

    return df_final

# Creating date columns
def window_cols(df: DataFrame) -> DataFrame:

    ### Define the window specifications for each set of operations (Group By Summarizations)
    customer_window = Window.partitionBy("CustomerID")
    customer_date_window = Window.partitionBy("CustomerID","CurrentPeriod")
    
    customer_product_window = Window.partitionBy("CustomerID", "Product")
    
    customer_product_revenue_window = Window.partitionBy("CustomerID", "Product", "RevenueType")
    
    customer_product_revenue_level4_window = Window.partitionBy("CustomerID", "Product", "RevenueType", "Level4")

    min_date_w_non_zero_arr = min(when(col('VALUE') != 0, col("CurrentPeriod")))
    max_date_w_non_zero_arr = max(when(col('VALUE') != 0, col("CurrentPeriod")))

    # Calculate max ARR for each customer (use ARR*12 if the input is MRR)
    max_value = max(col("VALUE")).over(customer_window)
    df = df.withColumn("Customer_Max_ARR", max_value)
    
    # Determine account size based on max ARR
    df = df.withColumn("Account Size", 
        when(col("Customer_Max_ARR") < 10000, "1. <10K")
        .when((col("Customer_Max_ARR") >= 10000) & (col("Customer_Max_ARR") <= 50000), "2. 10K - 50K")
        .when((col("Customer_Max_ARR") > 50000) & (col("Customer_Max_ARR") <= 100000), "3. 50K - 100K")
        .when((col("Customer_Max_ARR") > 100000) & (col("Customer_Max_ARR") <= 250000), "4. 100K - 250K")
        .when((col("Customer_Max_ARR") > 250000) & (col("Customer_Max_ARR") <= 500000), "5. 250K - 500K")
        .when((col("Customer_Max_ARR") > 500000) & (col("Customer_Max_ARR") < 1000000), "6. 500K-1M")
        .otherwise("7. >1M")
    )
    
    ## Add minimum and maximum dates for each customer with ARR > 0 (Cohort and MaxARR Dates)
    df = df.withColumn("Cust_MinDate", min_date_w_non_zero_arr.over(customer_window))
    df = df.withColumn("Cust_MaxDate", max_date_w_non_zero_arr.over(customer_window))

    ### Add minimum and maximum dates for each customer-product combination with ARR > 0 (Cohort and MaxARR Dates, 2nd lvl)
    df = df.withColumn("Cust_Prod_MinDate", min_date_w_non_zero_arr.over(customer_product_window))
    df = df.withColumn("Cust_Prod_MaxDate", max_date_w_non_zero_arr.over(customer_product_window))

    ### Add minimum and maximum dates for each customer-product-revtype combination with ARR > 0 (Cohort and MaxARR Dates, 3rd lvl)
    df = df.withColumn("Cust_Prod_Rev_MinDate", min_date_w_non_zero_arr.over(customer_product_revenue_window))
    df = df.withColumn("Cust_Prod_Rev_MaxDate", max_date_w_non_zero_arr.over(customer_product_revenue_window))
    
    ### Add minimum and maximum dates for each customer-product-revtype combination with ARR > 0 (Cohort and MaxARR Dates, 3rd lvl)
    df = df.withColumn("Cust_Prod_Rev_Lv4_MinDate", min_date_w_non_zero_arr.over(customer_product_revenue_level4_window))
    df = df.withColumn("Cust_Prod_Rev_Lv4_MaxDate", max_date_w_non_zero_arr.over(customer_product_revenue_level4_window))

    return df

# Base table creation: calling function for all row generating logic
def base_table_creation(session, df: DataFrame) -> DataFrame:
    
    df = window_cols(df)
    
    ### Credit treatment 
    df = credit_df_prepare(df)
    
    df = generate_months_snowpark(session, df)

    ### add column for total revenue for a customer in a specific period
    customer_date_window = Window.partitionBy("CustomerID","CurrentPeriod")
    df = df.withColumn("Total_ARR", sum(col("VALUE")).over(customer_date_window))

    return df

# Creating credit related columns
def credit_cols(df: DataFrame) -> DataFrame:
    ### Define Credit ARR
    df = df.withColumn("Credit",
        when(
            col("VALUE") < 0
            ,col("ARR_Variance")
        ).otherwise(0)
    )

    ### Define Credit Reversal ARR
    df = df.withColumn("Credit_Reversal",
        when(
            (col("Prior_ARR") < 0) & (col("VALUE") >= 0)
            ,col("ARR_Variance")
        ).otherwise(0)
    )
        
    ### Define Net Credit  ARR
    df = df.withColumn("Net_Credit",
                       col('Credit') + col('Credit_Reversal') 
                      )
    
    return df

# Creating ARR related columns
def arr_cols(df: DataFrame,l) -> DataFrame:
    customer_product_revenue_date_order_window = Window.partitionBy("CustomerID", "Product", "RevenueType").orderBy(col("CurrentPeriod").asc())
    ARR_positive_check = (col("Prior_ARR") == 0) & (col("VALUE") > 0)
    Prior_ARR_positive_check = (col("VALUE") == 0) & (col("Prior_ARR") > 0)
    
    ### Assign Lookback by customer by product by revenue type for ARR (Prior ARR)
    df = df.withColumn("Prior_ARR"
        , lag(col("VALUE"), offset=l, default_value=0)
        .over(customer_product_revenue_date_order_window))
    
    ### Calculate ARR Change Period of Period (PoP Variance)
    df = df.withColumn("ARR_Variance", col("VALUE") - col("Prior_ARR"))

    ### Define New Customer ARR
    df = df.withColumn("NewCust_ARR",
        when(
            ARR_positive_check 
            & (col("CurrentPeriod") < add_months(col("Cust_MinDate"), l)) 
            ,col("ARR_Variance")
        ).otherwise(0)
    )
    
    ### Define New Product ARR (Cross-sell)
    df = df.withColumn("NewProd_ARR",
        when(
            ARR_positive_check
            & (col("NewCust_ARR") == 0)
            & (col("CurrentPeriod") < add_months(col("Cust_Prod_MinDate"), l))
            , col("ARR_Variance")
        ).otherwise(0)
    )
    
    ### Define New Revenue Type ARR (Cross-sell Rev Type)
    df = df.withColumn("NewRev_ARR",
        when(
            ARR_positive_check
            & (col("NewCust_ARR") == 0)
            & (col("NewProd_ARR") == 0)
            & (col("CurrentPeriod") < add_months(col("Cust_Prod_Rev_MinDate"), l))
            , col("ARR_Variance")
        ).otherwise(0)
    )
    
    ### Define New Lv4 ARR
    df = df.withColumn("NewLv4_ARR",
        when(
            ARR_positive_check
            & (col("NewCust_ARR") == 0)
            & (col("NewProd_ARR") == 0)
            & (col("NewRev_ARR") == 0)
            & (col("CurrentPeriod") < add_months(col("Cust_Prod_Rev_Lv4_MinDate"), l))
            , col("ARR_Variance")
        ).otherwise(0)
    )
        
    ### Define Temp New ARR     
    df = df.withColumn("TempNew_ARR",
        when(
            ARR_positive_check
            & (col("NewCust_ARR") == 0)
            & (col("NewProd_ARR") == 0)
            & (col("NewRev_ARR") == 0)
            & (col("NewLv4_ARR") == 0)
            ,col("ARR_Variance")
        ).otherwise(0)
    )
    ### Define Upsell ARR       
    df = df.withColumn("Upsell_ARR",
        when(
            (col("VALUE") > col("Prior_ARR"))
            & (col("Prior_ARR") > 0)
            ,col("ARR_Variance")
        ).otherwise(0)
    )
   
    ### Define Customer Churn ARR        
    df = df.withColumn("ChurnCust_ARR",
        when(
            Prior_ARR_positive_check
            & (col("CurrentPeriod") > col("Cust_MaxDate"))
            ,col("ARR_Variance")
        ).otherwise(0)
    )
    ### Define Product Churn ARR
    df = df.withColumn("ChurnProd_ARR",
        when(
            Prior_ARR_positive_check
            & (col("ChurnCust_ARR") == 0)
            & (col("CurrentPeriod") > col("Cust_Prod_MaxDate"))
            ,col("ARR_Variance")
        ).otherwise(0)
    )
    
    ### Define Revenue Type Churn ARR
    df = df.withColumn("ChurnRev_ARR",
        when(
            Prior_ARR_positive_check
            & (col("ChurnCust_ARR") == 0)
            & (col("ChurnProd_ARR") == 0)
            & (col("CurrentPeriod") > col("Cust_Prod_Rev_MaxDate"))
            ,col("ARR_Variance")
        ).otherwise(0)
    )
    
    ### Define Lv4 Churn ARR
    df = df.withColumn("ChurnLv4_ARR",
        when(
            Prior_ARR_positive_check
            & (col("ChurnCust_ARR") == 0)
            & (col("ChurnProd_ARR") == 0)
            & (col("ChurnRev_ARR") == 0)
            & (col("CurrentPeriod") > col("Cust_Prod_Rev_Lv4_MaxDate"))
            ,col("ARR_Variance")
        ).otherwise(0)
    )
    
    ### Define Temp Loss ARR
    df = df.withColumn("TempLoss_ARR",
        when(
            Prior_ARR_positive_check
            & (col("ChurnCust_ARR") == 0)
            & (col("ChurnProd_ARR") == 0)
            & (col("ChurnRev_ARR") == 0)
            & (col("ChurnLv4_ARR") == 0)
            ,col("ARR_Variance")
        ).otherwise(0)
    )

    ### Define Downsell ARR       
    df = df.withColumn("Downsell_ARR",
        when(
            (col("VALUE") < col("Prior_ARR"))
            & (col("VALUE") > 0)
            ,col("ARR_Variance")
        ).otherwise(0)
    ) 
    return df

# Creating QA checks related columns   
def qa_cols(df: DataFrame) -> DataFrame:
    
    ### Define Roll Forward Check      
    df = df.withColumn(
        "ARR_RollCheck",
        (col('NewCust_ARR') + col('NewProd_ARR') 
        + col('NewRev_ARR') + col('TempNew_ARR')
        + col('Upsell_ARR') + col('ChurnCust_ARR')
        + col('ChurnProd_ARR') + col('ChurnRev_ARR') 
        + col('TempLoss_ARR') + col('Downsell_ARR') 
        + col('Prior_ARR') + col('Net_Credit')
        - col("VALUE")).cast("float")
    )

    df = df.withColumn(
        "Count_RollCheck",
        col('BoP_count')
        + col('NewCust_count') + col('NewProd_count')
        + col('NewRev_count') + col('TempNew_count') 
        + col('ChurnCust_count') + col('ChurnProd_count') 
        + col('ChurnRev_count') + col('TempLoss_count') 
        - col('EoP_count')
    )
    

    df = df.withColumn(
        "Cohort_Max_Dates_Check",
        when( (col("VALUE")!=0 ) &
             (
                 (col('CurrentPeriod') > col('Cust_MaxDate')) | 
                 (col('CurrentPeriod') > col('Cust_Prod_MaxDate')) | 
                 (col('CurrentPeriod') > col('Cust_Prod_Rev_MaxDate'))
             ),1).otherwise(0)
    )
    
    return df

# Creating count related columns
def count_cols(df: DataFrame) -> DataFrame:

    ### Define New Customer Count
    df = df.withColumn("NewCust_count", when(col("NewCust_ARR") != 0, 1).otherwise(0))

    ### Define New Product Count     
    df = df.withColumn("NewProd_count", when(col("NewProd_ARR") != 0, 1).otherwise(0))

    ### Define New Revenue Type Count       
    df = df.withColumn("NewRev_count", when(col("NewRev_ARR") > 0, 1).otherwise(0))

    ### Define Temp New Customer Count          
    df = df.withColumn("TempNew_count", when(col("TempNew_ARR") > 0, 1).otherwise(0)) 

    ### Define Upsell Customer Count        
    df = df.withColumn("Upsell_count", when(col("Upsell_ARR") > 0, 1).otherwise(0))

    ### Define Churn Customer Count            
    df = df.withColumn("ChurnCust_count", when(col("ChurnCust_ARR") != 0, -1).otherwise(0))
        
    ### Define Churn Product Count        
    df = df.withColumn("ChurnProd_count", when(col("ChurnProd_ARR") != 0, -1).otherwise(0))

    ### Define Churn Revenue Type Count        
    df = df.withColumn("ChurnRev_count", when(col("ChurnRev_ARR") < 0, -1).otherwise(0))

    ### Define Downsell Customer Count
    df = df.withColumn("Downsell_count", when(col("Downsell_ARR") < 0, -1).otherwise(0))

    ### Define Temp Loss Customer Count     
    df = df.withColumn("TempLoss_count", when(col("TempLoss_ARR") < 0, -1).otherwise(0))

    ### EoP Customer Count
    df = df.withColumn("EoP_count", when(col("VALUE") > 0, 1).otherwise(0))

    ### BoP Customer Count
    df = df.with_column("BoP_count", when(col("Prior_ARR") > 0, 1).otherwise(0))

    return df

# Creating columns not falling in any category for example, RetentionCategory and Period
def extra_cols(df: DataFrame,l) -> DataFrame:

    #### Define Retention Categories ####  
    df = df.with_column("RetentionCategory",
        when(col("Credit") != 0,'Credit')
        .when(col("Credit_Reversal") != 0,'Credit Reversal')
        .when(col("NewCust_ARR") != 0, 'New Cust')
        .when(col("NewProd_ARR") != 0, 'New Prod')
        .when(col("NewRev_ARR") != 0, 'New Rev Type')
        .when(col("NewLv4_ARR") != 0, 'New Lv4')
        .when(col("TempNew_ARR") != 0, 'Increase_TempNew')
        .when(col("Upsell_ARR") != 0, 'Increase')
        .when(col("ChurnCust_ARR") != 0, 'Churn Cust')
        .when(col("ChurnProd_ARR") != 0, 'Churn Prod')
        .when(col("ChurnRev_ARR") != 0, 'Churn Rev Type')
        .when(col("ChurnLv4_ARR") != 0, 'Churn Lv4')
        .when(col("TempLoss_ARR") != 0, 'Decrease_TempLost')
        .when(col("Downsell_ARR") != 0, 'Decrease')
        .otherwise('NoChange')
    )

    ### add lookback column as a tag
    df = df.with_column("Period", lit(l).cast("string"))
    
    df = df.with_column("Period",
        when(col("Period") == "1", "Month")
        .when(col("Period") == "3", "Quarter")
        .when(col("Period") == "12", "Year")
        .otherwise(col("Period")) 
    )

    df = df.with_column(
        "Retained100%$",
        when(col("RetentionCategory") == 'NoChange'
             ,col("ARR_Variance")
            )
    )

    df = df.with_column(
        "Date SOM", date_trunc('month', col('CURRENTPERIOD'))
    )

    df = df.with_column(
        "TTM Date",when(col('Date SOM') > col('Max Date'), dateadd('year',lit(1),'Date SOM'))
    .otherwise(col('Date SOM'))
    )

    ## Placeholder columns
    df = df.with_column("UFR Amount", lit(0.0))
    df = df.with_column("UFR Tag", lit("Not UFR"))
    df = df.with_column("Current Winback Tag", lit(""))
    df = df.with_column("UFR Date", lit(""))
    df = df.with_column("WinbackTag", lit("No Winback"))
    
    return df

# Calling function for Retention logic 
def retention(df: DataFrame,l) -> DataFrame: 
    new_df = arr_cols(df,l)
    new_df = credit_cols(new_df)
    new_df = count_cols(new_df)
    new_df = extra_cols(new_df,l)
    new_df = qa_cols(new_df)
    return new_df

# Syncing columns name to Alteryx version
def rename_cols(df: DataFrame) -> DataFrame:
    alyterx_col_mapping = {
        "ARR_VARIANCE": '"YoY Variance"',
        "VALUE": '"EoP$"',
        "PRIOR_ARR": '"BoP$"',
        "NEWCUST_ARR": '"NewCust$"',
        "NEWPROD_ARR": '"NewProd$"',
        "NEWREV_ARR": '"NewRevType$"',
        "TEMPNEW_ARR": '"Increase_TempNew"',
        "UPSELL_ARR": '"Increase$"',
        "CHURNCUST_ARR": '"ChurnCust$"',
        "CHURNPROD_ARR": '"ChurnProd$"',
        "CHURNREV_ARR": '"ChurnRevType$"',
        "TEMPLOSS_ARR": '"Decrease_TempLost"',
        "DOWNSELL_ARR": '"Decrease$"',
        "NET_CREDIT": '"Net Credit"',
        "NEWCUST_COUNT": '"NewCust_count"',
        "NEWPROD_COUNT": '"NewProd_count"',
        "NEWREV_COUNT": '"NewRevType_count"',
        "TEMPNEW_COUNT": '"Increases_TempNew_count"', #Added extra 's'
        "UPSELL_COUNT": '"Increase_count"',
        "CHURNCUST_COUNT": '"ChurnCust_count"',
        "CHURNPROD_COUNT": '"ChurnProd_count"',
        "CHURNREV_COUNT": '"ChurnRevType_count"',
        "DOWNSELL_COUNT": '"Decreases_count"',
        "TEMPLOSS_COUNT": '"Decreases_TempLost_count"',
        "CURRENTPERIOD": '"Current Period"',
        "CUSTOMERID": '"CustName"',
        "PERIOD": '"Period"',
        "PRODUCT": '"Product"',
        "RETENTIONCATEGORY": '"RetentionCategory"',
        "REVENUETYPE": '"Revenue Type"',
        "CUST_MINDATE": '"Cohort Date"',
        "CUST_MAXDATE": '"Max Date w ARR"',
        "CUST_PROD_MAXDATE": '"Max Date w ARR_2nd lv"',
        "CUST_PROD_MINDATE": '"Cohort Date_2nd lv"',
        "CUST_PROD_REV_MAXDATE": '"Max Date w ARR_3rd lv"',
        "CUST_PROD_REV_MINDATE": '"Cohort Date_3rd lv"',
        "CUST_PROD_REV_Lv4_MAXDATE": '"Max Date w ARR_4th lv"',
        "CUST_PROD_REV_Lv4_MINDATE": '"Cohort Date_4th lv"',
        "Credit": '"Credit_$"',
        "CREDIT_REVERSAL":'"Credit_Reversal$"',
        "EOP_COUNT": '"EoP_count"',
        "BOP_COUNT": '"BoP_count"',
        "ARR_ROLLCHECK": '"Check"',
        "COUNT_ROLLCHECK": '"Check_Count"'
    }

    for python_name, alyterx_name in alyterx_col_mapping.items():
        df = df.withColumnRenamed(python_name, alyterx_name)

    return df
    
# Cleaning step to make output equivalent to Alteryx version
def post_cleaning(df_retention: DataFrame) -> DataFrame:
    
    ## Returning to original customer name/ remove '-CR$' from end
    df_retention = df_retention.with_column(
        "CUSTOMERID", regexp_replace(col("CUSTOMERID"), r"-CR\$$", "")
    )

    ## Removing row with ARR = 0 and PRIOR_ARR = 0
    non_zero_arr_prior_arr = ~((col("ARR") == 0) & (col("PRIOR_ARR") == 0))
    df_retention = df_retention.filter(non_zero_arr_prior_arr).sort(col('CURRENTPERIOD'))

    df_retention = rename_cols(df_retention)
    
    return df_retention

# Custom sum function to add list of columns
def sum_columns(columns):
    return reduce(lambda a, b: a + b, columns)

# Creating amount columns
def add_amount_columns(df: DataFrame, input_amount) -> DataFrame:
    customer_product_revenue_date_order_window = Window.partitionBy("CustomerID", "Product", "RevenueType").orderBy(col("CurrentPeriod").asc())

    ## Adding new amounts
    if input_amount == "MRR":
        df = df.withColumn("ARR",col("VALUE")*12)
        df = df.withColumn("MRR",col("VALUE"))
    elif input_amount == "ARR":
        df = df.withColumn("MRR",col("VALUE")/12)
        df = df.withColumn("ARR",col("VALUE"))
    else:
        raise ValueError("Input amount not valid!")

    df = df.withColumn(
        "T3M", sum_columns([lag(col("MRR"), offset=i, default_value=0).over(customer_product_revenue_date_order_window) for i in range(3)]))
    
    df = df.withColumn(
        "TTM", sum_columns([lag(col("MRR"), offset=i, default_value=0).over(customer_product_revenue_date_order_window) for i in range(12)]))

    df = df.withColumn("T3M (Annualized)", col("T3M")*4)
    return df

# Main pipeline function
def retention_pipeline_v2(session, df: DataFrame, lb_periods = [3], input_amount="MRR", run_at_levels=["ARR"]) -> DataFrame:
    print("retention_pipeline_v2 started...")
    cache_s_time = time.time()
    base_table = base_table_creation(session, df)
    base_table_W_amount = add_amount_columns(base_table, input_amount)
    base_table_W_amount_cols = base_table_W_amount.cache_result()
    
    # Process each run_at_level and collect results
    run_at_level_dfs = []
    
    for run_at_level in run_at_levels:
        if run_at_level not in ["ARR", "MRR", "T3M", "TTM", "T3M (Annualized)"]:
            raise ValueError("Incorrect Run_at_level. Check user inputs.")
        
        df_with_value = base_table_W_amount_cols.withColumn("VALUE", col(run_at_level))
        
        # Process each lookback period within this run_at_level
        retention_dfs = []
        for lb in lb_periods:
            if lb not in [1, 3, 12, 4]:
                raise ValueError("Incorrect lookback period. Check user inputs.")
            
            retention_df = retention(df_with_value, lb)
            retention_dfs.append(retention_df.select(*retention_df.columns))
        
        # Union all retention_dfs for this run_at_level
        if retention_dfs:
            df_retention = retention_dfs[0]
            for df in retention_dfs[1:]:
                df_retention = df_retention.union_all(df)
            
            df_retention = post_cleaning(df_retention)
            df_retention = df_retention.withColumn('"Amount"', lit(run_at_level))
            run_at_level_dfs.append(df_retention.select(*df_retention.columns))
    
    # Union all run_at_level results
    if run_at_level_dfs:
        df_retention_all_amounts = run_at_level_dfs[0]
        for df in run_at_level_dfs[1:]:
            df_retention_all_amounts = df_retention_all_amounts.union_all(df)
    else:
        # Handle empty case to match original behavior
        df_retention_all_amounts = None
    cache_e_time = time.time()
    print(f"⏱️ retention_pipeline_v2 done: {cache_e_time - cache_s_time:.2f} seconds")
    return df_retention_all_amounts

  
# Data prep for period on period qa check
def prep_data_pop_test_v2(df: DataFrame) -> DataFrame:
    agg_df_2 = df.group_by("Current Period").agg(
        sum('"EoP$"').alias("EoP$"),
        sum('"BoP$"').alias("BoP$")
    )
    
    melted_df = agg_df_2.select(
        col("Current Period"),
        col("EoP$").alias("Value"),
        lit("EoP$").alias("Metric")
    ).union_all(
        agg_df_2.select(
            col("Current Period"),
            col("BoP$").alias("Value"),
            lit("BoP$").alias("Metric")
        )
    )
    
    pivot_df_2 = melted_df.group_by("Metric").pivot("Current Period").agg(sum("Value"))
    pivot_df_2 = pivot_df_2.rename({col("Metric"):"Source"})
    pivot_df_2 = pivot_df_2.sort(col("Source"))

    return pivot_df_2
    
# Verifying QA check
def verify_qa_check(input_df: DataFrame, columns_to_check: list, qa_check_thresholds: float) -> DataFrame:
    """Fliters the dataframe based on given column and threshold 

    Args:
        input_df (DataFrame): _description_
        columns_to_check (list): _description_
        qa_check_thresholds (float): _description_

    Returns:
        DataFrame: _description_
    """
    filter_condition = None
    for column in columns_to_check:
        if filter_condition is None:
            filter_condition = abs(col(column)) >= qa_check_thresholds
        else:
            filter_condition |= abs(col(column)) >= qa_check_thresholds
    
    # Apply the filter condition to the DataFrame
    filtered_df = input_df.filter(filter_condition)

    return filtered_df

# Save result in table
def save_results(result: DataFrame, path: str, type:str = 'overwrite'):
    """Saves results in a table on the given patj

    Args:
        result (DataFrame): _description_
        path (str): _description_
        type (str): _description_
    """
    #print("__Writing to table started...")
    start_time = time.time()
    print("___Saving on path: ", path)
    result.write.mode(type).save_as_table(path)
    end_time = time.time()
    print("_save_results.")
    print(f"⏱️ Save result: {end_time - start_time:.2f} seconds")
    #print("__Writing to table completed!")


# Clear console output
def clear_console():
    """ Clears console
    """
    if os.name == 'nt':  # For Windows
        os.system('cls')
    else:  # For macOS and Linux
        os.system('clear')

# To convert seconds into hours, mins and seconds
def cal_time(start_time: int, end_time: int):
    """ Take start and end time. Calculate hours, minutes, seconds

    Args:
        start_time: start time in secs
        end_time: end time in secs

    Returns:
        hours, minutes, seconds
    """
    elapsed_time = (end_time - start_time)
    hours = elapsed_time // 3600
    minutes = (elapsed_time % 3600) // 60
    seconds = elapsed_time % 60
    
    return hours, minutes, seconds

# To add suffix to file path based on retention level
def get_file_path(retention_level: str, output_path: str) -> str:
    if retention_level == "Customer_level":
        updated_output_path = output_path + "_C_NOTEBOOK"
    elif retention_level == "Customer_Product_level":
        updated_output_path = output_path + "_CP"
    elif retention_level == "Customer_Product_RetentionType_level":
        updated_output_path = output_path + "_CPR"
    elif retention_level == "Level4":
        updated_output_path = output_path + "_L4"
    return updated_output_path



def transform_retention_data(session, df1: DataFrame, table2: str,  retention_level: str) -> DataFrame:
    """
    Transform retention data using Snowpark
    
    Args:
        session: Snowpark session object
        table1: Name of the first table (CUST_PROD_RETENTION_OLD)
        table2: Name of the second table (CUST_RETENTION)
    
    Returns:
        Snowpark DataFrame with transformed data
    """
    # Load tables as DataFrames
    print("**********In Tranform Retention Data***********")
    df2 = session.table(table2)

    # Filter rows where RetentionCategory is in the list
    filtered_df = df2.filter(
        col('"RetentionCategory"').isin(["Decrease_TempLost", "Increase_TempNew"])
    )

    # Select the required columns
    selected_df = filtered_df.select(
        col('"Amount"'),
        col('"Current Period"'),
        col('"CustName"'),
        col('"RetentionCategory"'),
        col('"Period"')
    )
    # Group by all selected columns to mimic GROUP BY ALL
    grouped_df = selected_df.group_by(
        '"CustName"', '"Current Period"', '"RetentionCategory"', '"Amount"', '"Period"'
    ).agg()

    df2 = grouped_df
    print(df2.schema)
    # grouped_df now contains distinct rows matching the filter
    # Columns to exclude
    exclude_cols = [
        '"RetentionCategory"', "Credit_Reversal$", "Credit_$", "NewCust$", "ChurnCust$", "Increase$", 
        "Decrease$", "Decrease_TempLost$", "Increase_TempNew$", "Retained100%$", "NewProd$", 
        "ChurnProd$", "BoP_count", "ChurnCust_count", "ChurnProd_count", "Decreases_count", 
        "Decreases_TempLost_count", "Increases_TempNew_count", "Increase_count", "NewCust_count", 
        "NewProd_count", "EoP_count",'"Current Period"','"CustName"','"Period"','"Amount"'
    ]
    new_cols = ['"NewRevType$"', '"ChurnRevType$"', '"NewRevType_count"', '"ChurnRevType_count"']

    if retention_level == 'Customer_Product_RetentionType_level':
        exclude_cols.extend(new_cols)

    # Select all columns from T1 except excluded ones
    select_cols = [col(c) for c in df1.columns if c not in exclude_cols]
    print(df1.columns)
    print(select_cols)
    # Define new RetentionCategory logic (Tool ID 3438)
    retention_category_expr = when(
        (df1['"RetentionCategory"'] == "Increase_TempNew") & (df2['"RetentionCategory"'].is_null()),
        lit("New Prod")
    ).when(
        (df1['"RetentionCategory"'] == "Decrease_TempLost") & (df2['"RetentionCategory"'].is_null()),
        lit("Churn Prod")
    ).when(
        df2['"RetentionCategory"'].is_not_null(),
        df2['"RetentionCategory"']
    ).otherwise(df1['"RetentionCategory"'])

    # Define YOY_VARIANCE-based columns (Tool ID 3439)
    credit_reversal = when(col('"RetentionCategory"') == "Credit Reversal", df1['"YoY Variance"']).otherwise(lit(0))
    credit = when(col('"RetentionCategory"') == "Credit", df1['"YoY Variance"']).otherwise(lit(0))
    new_cust = when(col('"RetentionCategory"') == "New Cust", df1['"YoY Variance"']).otherwise(lit(0))
    churn_cust = when(col('"RetentionCategory"') == "Churn Cust", df1['"YoY Variance"']).otherwise(lit(0))
    increase = when(col('"RetentionCategory"') == "Increase", df1['"YoY Variance"']).otherwise(lit(0))
    decrease = when(col('"RetentionCategory"') == "Decrease", df1['"YoY Variance"']).otherwise(lit(0))
    decrease_temp_lost = when(col('"RetentionCategory"') == "Decrease_TempLost", df1['"YoY Variance"']).otherwise(lit(0))
    increase_temp_new = when(col('"RetentionCategory"') == "Increase_TempNew", df1['"YoY Variance"']).otherwise(lit(0))
    retained_100 = when(col('"RetentionCategory"') == "NoChange", df1['"YoY Variance"']).otherwise(lit(0))
    new_prod = when(col('"RetentionCategory"') == "New Prod", df1['"YoY Variance"']).otherwise(lit(0))
    churn_prod = when(col('"RetentionCategory"') == "Churn Prod", df1['"YoY Variance"']).otherwise(lit(0))
    if retention_level == 'Customer_Product_RetentionType_level':
        new_rev_type = when(col('"RetentionCategory"') == "New Rev Type", df1['"YoY Variance"']).otherwise(lit(0))
        churn_rev_type = when(col('"RetentionCategory"') == "Churn Rev Type", df1['"YoY Variance"']).otherwise(lit(0))
    # Compute CHECK column
    check_col = (
        df1['"BoP$"']+ 
        new_cust + 
        churn_cust + 
        increase + 
        increase_temp_new + 
        decrease + 
        decrease_temp_lost + 
        new_prod + 
        churn_prod + 
        retained_100 - 
        df1['"EoP$"']
    )

    # Define count columns (Tool ID 3441)
    bop_count = when(df1['"BoP$"'] != 0, lit(1)).otherwise(lit(0))
    churn_cust_count = when(churn_cust != 0, lit(-1)).otherwise(lit(0))
    churn_prod_count = when(churn_prod != 0, lit(-1)).otherwise(lit(0))
    decreases_count = when(decrease != 0, lit(-1)).otherwise(lit(0))
    decreases_temp_lost_count = when(decrease_temp_lost != 0, lit(-1)).otherwise(lit(0))
    increases_temp_new_count = when(increase_temp_new != 0, lit(1)).otherwise(lit(0))
    increase_count = when(increase != 0, lit(1)).otherwise(lit(0))
    new_cust_count = when(new_cust != 0, lit(1)).otherwise(lit(0))
    new_prod_count = when(new_prod != 0, lit(1)).otherwise(lit(0))
    eop_count = when(df1['"EoP$"'] != 0, lit(1)).otherwise(lit(0))
    if retention_level == 'Customer_Product_RetentionType_level':
        new_rev_type_count = when(new_rev_type != 0, lit(1)).otherwise(lit(0))
        churn_rev_type_count = when(churn_rev_type != 0, lit(1)).otherwise(lit(0))
    

    # Perform the full outer join
    result_df = df1.join(
        df2,
        (df1['"Current Period"'] == df2['"Current Period"']) &
        (df1['"CustName"'] == df2['"CustName"']) &
        (df1['"Period"'] == df2['"Period"']) &
        (df1['"Amount"'] == df2['"Amount"']),
        "full_outer"
    )

    # Select final columns
    if retention_level == 'Customer_Product_RetentionType_level':
        result_df = result_df.select(
            *select_cols,
            df1['"Current Period"'].alias('"Current Period"'),df1['"Period"'].alias('"Period"'),
            df1['"CustName"'].alias('"CustName"'),df1['"Amount"'].alias('"Amount"'),
            retention_category_expr.alias('"RetentionCategory"'), credit_reversal.alias('"Credit_Reversal$"'),
            credit.alias('"Credit_$"'),new_cust.alias('"NewCust$"'),
            churn_cust.alias('"ChurnCust$"'),increase.alias('"Increase$"'),
            decrease.alias('"Decrease$"'),decrease_temp_lost.alias('"Decrease_TempLost$"'),
            increase_temp_new.alias('"Increase_TempNew$"'),retained_100.alias('"Retained100%$"'),
            new_prod.alias('"NewProd$"'),churn_prod.alias('"ChurnProd$"'),
            new_rev_type.alias('"NewRevType$"'),churn_rev_type.alias('"ChurnRevType$"'),
            check_col.alias('"Check"'),bop_count.alias('"BoP_count"'),
            churn_cust_count.alias('"ChurnCust_count"'),churn_prod_count.alias('"ChurnProd_count"'),
            decreases_count.alias('"Decreases_count"'),decreases_temp_lost_count.alias('"Decreases_TempLost_count"'),
            increases_temp_new_count.alias('"Increases_TempNew_count"'),increase_count.alias('"Increase_count"'),
            new_cust_count.alias('"NewCust_count"'),new_prod_count.alias('"NewProd_count"'),
            eop_count.alias('"EoP_count"'), new_rev_type_count.alias('"NewRevType_count"'),
            churn_rev_type_count.alias('"ChurnRevType_count"')
        )
        result_df = result_df.drop(df1['"ChurnCust$"'],df1['"Credit_Reversal$"'],
                                   df1['"Credit_$"'],df1['"NewCust$"'],df1['"NewProd$"'],
                                   df1['"Increase$"'],df1['"Decrease$"'],df1['"Retained100%$"'],
                                   df1['"NewProd$"'],df1['"ChurnProd$"'],df1['"Check"'],
                                   df1['"BoP_count"'],df1['"ChurnCust_count"'],df1['"ChurnProd_count"'],
                                   df1['"Decreases_count"'],df1['"Decreases_TempLost_count"'],
                                   df1['"Increases_TempNew_count"'],df1['"Increase_count"'],
                                   df1['"NewCust_count"'],df1['"NewProd_count"'],df1['"EoP_count"'])
    else:
        result_df = result_df.select(
            *select_cols,
            df1['"Current Period"'].alias('"Current Period"'),df1['"Period"'].alias('"Period"'),
            df1['"CustName"'].alias('"CustName"'),df1['"Amount"'].alias('"Amount"'),
            retention_category_expr.alias('"RetentionCategory"'), credit_reversal.alias('"Credit_Reversal$"'),
            credit.alias('"Credit_$"'),new_cust.alias('"NewCust$"'),
            churn_cust.alias('"ChurnCust$"'),increase.alias('"Increase$"'),
            decrease.alias('"Decrease$"'),decrease_temp_lost.alias('"Decrease_TempLost$"'),
            increase_temp_new.alias('"Increase_TempNew$"'),retained_100.alias('"Retained100%$"'),
            new_prod.alias('"NewProd$"'),churn_prod.alias('"ChurnProd$"'),
            check_col.alias('"Check"'),bop_count.alias('"BoP_count"'),
            churn_cust_count.alias('"ChurnCust_count"'),churn_prod_count.alias('"ChurnProd_count"'),
            decreases_count.alias('"Decreases_count"'),decreases_temp_lost_count.alias('"Decreases_TempLost_count"'),
            increases_temp_new_count.alias('"Increases_TempNew_count"'),increase_count.alias('"Increase_count"'),
            new_cust_count.alias('"NewCust_count"'),new_prod_count.alias('"NewProd_count"'),
            eop_count.alias('"EoP_count"')
        )
        result_df = result_df.drop(df1['"ChurnCust$"'],df1['"Credit_Reversal$"'],
                                   df1['"Credit_$"'],df1['"NewCust$"'],df1['"NewProd$"'],
                                   df1['"Increase$"'],df1['"Decrease$"'],df1['"Retained100%$"'],
                                   df1['"NewProd$"'],df1['"ChurnProd$"'],df1['"Check"'],
                                   df1['"BoP_count"'],df1['"ChurnCust_count"'],df1['"ChurnProd_count"'],
                                   df1['"Decreases_count"'],df1['"Decreases_TempLost_count"'],
                                   df1['"Increases_TempNew_count"'],df1['"Increase_count"'],
                                   df1['"NewCust_count"'],df1['"NewProd_count"'],df1['"EoP_count"'])
    
    print(result_df.schema)
    print("**********Out Tranform Retention Data***********")
    return result_df

# Fields mapped according to highest ARR
**Logic force fields to be 1:1 with customer based on highest ARR. Logic also outputs customers that were affected by the logic so that users can quantify it.**

**Mapped Fields**
- `REGION`
- `INDUSTRY`
- `CHANNEL`

**IF you do not want to use it, kindly ignore "ARR_mapping_fields" block of code and use "optimized_main" instead of "optimized_main_with_mapping_bundling" to genearte output.**

In [None]:
def standardize_customer_fields(session: Session, input_table_path: str, column_mapping_file: dict) -> DataFrame:
    """
    Standardizes fields to be 1:1 with customer based on the highest ARR value.
    Also creates a report of affected customers.
    
    Parameters:
    session (Session): The Snowpark session object.
    input_table_path (str): Path to the input table.
    column_mapping_file (dict): Dictionary mapping standard column names to input table column names.
    
    Returns:
    tuple: (Standardized DataFrame, DataFrame of affected customers)
    """
    print("\n_Standardizing customer fields started...")
    start_time = time.time()
    
    # Read the input table
    df = session.table(input_table_path)
    
    # Determine which fields to standardize based on the column mapping
    fields_to_standardize = []
    reverse_mapping = {}
    
    # Create a reverse mapping for easy reference and identify fields to standardize
    for standard_col, input_col in column_mapping_file.items():
        if input_col is not None:
            # Remove quotes if present in the column name
            clean_input_col = input_col.replace('"', '')
            reverse_mapping[clean_input_col] = standard_col
            
            # Add standard fields that should be standardized and exist in the input
            if standard_col in ['Region', 'Industry', 'Channel'] and clean_input_col in df.columns:
                fields_to_standardize.append(clean_input_col)
    
    # Get the customer ID and value (ARR) column names from the mapping
    customer_id_col = column_mapping_file["CUSTOMERID"].replace('"', '') if column_mapping_file["CUSTOMERID"] else "CUSTOMERID"
    value_col = column_mapping_file["VALUE"].replace('"', '') if column_mapping_file["VALUE"] else "VALUE"
    
    affected_customers = []
    
    # Create a standardized dataframe
    standardized_df = df
    
    for field in fields_to_standardize:
        # Check if we need to standardize this field (if customers have multiple values)
        field_count_per_customer = standardized_df.group_by(customer_id_col, field).count()
        customers_with_multiple_values = field_count_per_customer.group_by(customer_id_col).count().filter(col('COUNT') > 1)
        
        # Only proceed if there are customers that need standardization
        if customers_with_multiple_values.count() > 0:
            # Find dominant value for each customer based on highest total ARR
            dominant_values = standardized_df.group_by(customer_id_col, field).agg(
                sum(value_col).alias('TOTAL_ARR')
            ).with_column(
                'RANK', row_number().over(
                    Window.partition_by(customer_id_col).order_by(col('TOTAL_ARR').desc())
                )
            ).filter(col('RANK') == 1).select(customer_id_col, field, 'TOTAL_ARR')
            
            # Get list of affected customers with their original and new values
            affected_field_df = standardized_df.filter(
                col(customer_id_col).in_(customers_with_multiple_values.select(customer_id_col))
            ).select(customer_id_col, field, value_col)
            
            affected_field_df = affected_field_df.join(
                dominant_values,
                [customer_id_col],
                suffix='_DOMINANT'
            ).filter(
                col(field) != col(f'{field}_DOMINANT')
            ).select(
                customer_id_col, 
                col(field).alias(f'ORIGINAL_{field}'),
                col(f'{field}_DOMINANT').alias(f'NEW_{field}'),
                value_col
            ).distinct()
            
            # Add to affected customers report
            if affected_field_df.count() > 0:
                affected_customers.append(affected_field_df)
            
            # Create a map of customerid to dominant value
            dominant_map = dominant_values.select(customer_id_col, field).to_pandas().set_index(customer_id_col)[field].to_dict()
            
            # Update values in the original dataframe
            standardized_df = standardized_df.with_column(
                field,
                when(
                    col(customer_id_col).in_(customers_with_multiple_values.select(customer_id_col)),
                    # Map function to replace values for affected customers
                    when(
                        lit(True), 
                        lit(None)  # Placeholder that will be replaced
                    ).otherwise(col(field))
                ).otherwise(col(field))
            )
            
            # Use pandas to efficiently apply the mapping
            # Convert to pandas, apply the mapping, and convert back to Snowpark DataFrame
            pdf = standardized_df.to_pandas()
            affected_indices = pdf[customer_id_col].isin(dominant_map.keys())
            pdf.loc[affected_indices, field] = pdf.loc[affected_indices, customer_id_col].map(dominant_map)
            standardized_df = session.create_dataframe(pdf)
            
            print(f"Standardized '{field}' for {customers_with_multiple_values.count()} customers")
    
    # Combine all affected customers into a single report
    affected_customers_df = None
    if affected_customers:
        affected_customers_df = reduce(
            lambda df1, df2: df1.union(df2), 
            affected_customers
        )
        
        # Save affected customers to a table
        affected_customers_report_path = input_file_path + "_AFFECTED_CUSTOMERS_REPORT"
        save_results(affected_customers_df, affected_customers_report_path)
        print(f"Saved affected customers report to {affected_customers_report_path}")
    else:
        print("No customers were affected by standardization")
    
    # Save standardized dataframe
    output_path = input_file_path + "_ARR_Mapped"
    save_results(standardized_df, output_path)
    print(f"___Saving on path: {output_path}")
    
    end_time = time.time()
    print(f"⏱️ Field standardization runtime: {end_time - start_time:.2f} seconds")
    
    return standardized_df, affected_customers_df

# Product Bundle Analysis
**Goal: Analyze customer-product bundles using Snowpark to generate ARR, MRR, and cross-sell insights.**

`Load Data`

- Read input Snowflake table: input_table_product_bundle.

`Map & Validate Columns`

- Rename columns using column_mapping_bundling.

- Check for required fields like CUSTOMERID, PRODUCT, VALUE, etc.

`Customer-Level Metrics`

- Group by period & customer to calculate:

- ARR, CROSSSELL

- PRODUCT_COMBO (distinct products list)

`Bundle Calculation`

- Compute PRODUCT_BUNDLE as number of products (via comma count in PRODUCT_COMBO).

`Final Aggregation`

- Group by period, cohort date, bundle info.

In [None]:
def product_bundle_analysis(session: Session):
    """
    Convert SQL query to Snowpark DataFrame for product bundle analysis
    Integrates with existing notebook flow
    """
    print("\n" + "="*50)
    print("STEP 3: GENERATING PRODUCT BUNDLE ANALYSIS")
    print("="*50)
    
    # Mapping File with respect to columns available in input_table_product_bundle
    column_mapping_bundling = {
    "CURRENTPERIOD": '"Current Period"',
    "CUSTOMERID": '"CustName"',
    "TRANSACTION_COHORT_DATE": '"Cohort Date"',
    "VALUE": '"EoP$"',
    "NEWPROD_ARR": '"NewRevType$"',
    "PRODUCT": '"Product"',
    "PERIOD": '"Period"',
    "AMOUNT": '"Amount"'}
    
    # Read the input data from the standardized output
    #input_table = get_file_path("Customer_Product_level", pbi_retention_output_path)
    print(f"Reading data from: {input_table_product_bundle}")
    
    # Use Snowpark DataFrame
    df = session.table(input_table_product_bundle)
    
    # Debug: Print column names to see what's available
    # print("Available columns in source data:")
    # print(df.columns)
    
    # Create a working copy with mapped columns
    work_df = df
    
    # Rename columns according to mapping
    for sql_col, df_col in column_mapping_bundling.items():
        if df_col in df.columns:
            work_df = work_df.rename(df_col, sql_col)
    
    # print("Columns after mapping:")
    # print(work_df.columns)
    
    # Check for missing columns
    required_columns = ['CURRENTPERIOD', 'CUSTOMERID', 'TRANSACTION_COHORT_DATE', 'VALUE', 'NEWPROD_ARR', 'PRODUCT']
    missing_cols = [col for col in required_columns if col not in work_df.columns]
    
    if missing_cols:
        # print(f"Missing columns: {missing_cols}")
        raise ValueError(f"Required columns are missing: {missing_cols}")
    
    try:
        # STEP 1: Equivalent to the customer_level CTE
        print("Creating customer_level aggregation...")
        
        # Using listagg with WITHIN GROUP to get distinct, sorted products
        customer_level = work_df.group_by(
            'CURRENTPERIOD', 'CUSTOMERID', 'TRANSACTION_COHORT_DATE', 'AMOUNT', 'PERIOD'
        ).agg(
            sum_sf('VALUE').alias('ARR'),
            sum_sf('NEWPROD_ARR').alias('CROSSSELL'),
            F.call_builtin("LISTAGG", F.call_builtin("DISTINCT", col('PRODUCT')), lit(',')).alias('PRODUCT_COMBO')
        )
        
        # Calculate MRR from ARR using Snowpark's column operations
        customer_level = customer_level.with_column('MRR', col('ARR') / 12)
        
        # STEP 2: Equivalent to the product_bundle CTE
        print("Calculating product bundles...")
        
        # Count commas in PRODUCT_COMBO to determine bundle size and add 1
        # We need to handle NULL values with COALESCE in Snowpark
        customer_level = customer_level.with_column(
            'PRODUCT_BUNDLE', 
            regexp_count(col('PRODUCT_COMBO'), ',') + 1
        )
        
        # STEP 3: Equivalent to the final SELECT query
        print("Performing final aggregation...")
        
        # Group by the required fields for sum aggregations using Snowpark
        final_result = customer_level.group_by(
            'CURRENTPERIOD', 'TRANSACTION_COHORT_DATE', 'PRODUCT_COMBO', 'PRODUCT_BUNDLE'
        ).agg(
            sum_sf('ARR').alias('ARR'),
            sum_sf('MRR').alias('MRR'),
            sum_sf('CROSSSELL').alias('CROSSSELL'),
            count(when(col('ARR') > 0, col('CUSTOMERID'))).alias('DISTINCT_CUSTOMERS_WITH_ARR')
        )
        
        # Write results to output table using Snowpark's native save method
        output_table = input_table_product_bundle + "_PRODUCT_BUNDLE_ANALYSIS"
        # print(f"Writing results to: {output_table}")
        
        # Save to Snowflake table with mode overwrite
        final_result.write.mode("overwrite").save_as_table(output_table)
        
        print(f"___Saving on path: {output_table}")
        print("=" * 50)
        
        return final_result
    
    except Exception as e:
        print(f"Error in product bundle analysis: {str(e)}")
        print("Detailed error information:")
        import traceback
        traceback.print_exc()
        return None

# Float Column Rounding
**Goal: Ensure consistency and control over decimal precision by rounding all float-type columns in a Snowflake table to a specified number of decimal places (default: 3).**

`Read Input Table`

- Loads table from the given table_path.

`Identify Float Columns`

- Uses DESCRIBE TABLE to detect columns of type: FLOAT, DOUBLE, REAL, DECIMAL, or NUMERIC.

`Round Float Values`

- Applies rounding to each float column using Snowpark's round() function.

In [None]:
def round_float_columns(session: Session, table_path, decimal_places=3):
    """
    Rounds all float columns in a Snowflake table to the specified number of decimal places.
    
    Parameters:
    session (Session): The Snowpark session object.
    table_path (str): Path to the input table.
    decimal_places (int): Number of decimal places to round to (default: 3).
    
    Returns:
    DataFrame: DataFrame with float columns rounded to specified decimal places.
    """
    from snowflake.snowpark.functions import round, col
    import time
    
    print(f"\n_Rounding float columns to {decimal_places} decimal places...")
    start_time = time.time()
    
    # Read the table first
    df = session.table(table_path)
    
    # Get column information using Snowflake's DESCRIBE TABLE command
    table_columns = session.sql(f"DESCRIBE TABLE {table_path}").collect()
    
    # Identify float columns
    float_columns = []
    for column_info in table_columns:
        column_type = str(column_info['type']).upper()
        column_name = column_info['name']
        if any(float_type in column_type for float_type in ['FLOAT', 'DOUBLE', 'REAL', 'DECIMAL', 'NUMERIC']):
            float_columns.append(column_name)
    
    if not float_columns:
        print("No float columns found in the table.")
        return df
    
    print(f"Found {len(float_columns)} float columns to round: {', '.join(float_columns)}")
    
    # Apply rounding to each float column using Snowpark DataFrame API
    for column_name in float_columns:
        df = df.withColumn(column_name, round(col(column_name), decimal_places))
    
    # Save the rounded data back to a new table
    rounded_table_path = table_path + "_ROUNDED"
    df.write.mode("overwrite").save_as_table(rounded_table_path)
    print(f"___Saving on path: {rounded_table_path}")
    
    end_time = time.time()
    print(f"⏱️ Float column rounding runtime: {end_time - start_time:.2f} seconds")
    
    return df

In [None]:
def create_master_dimension_table(session: Session, source_table_name: str, target_table_name: str):
    """
    Args:
        session: Snowpark session
        source_table_name: Source table name (e.g., "SUBSCRIPTION_ACCEL.PYTHON_TESTING.FACT_TABLE_AMSTERDAM_INPUT_100K_24JUNE_MASTER")
        target_table_name: Target table name (e.g., "SUBSCRIPTION_ACCEL.PYTHON_TESTING.AMSTERDAM_INPUT_100K_24JUNE_MASTER_DIMENSION_DIM")
    
    Returns:
        DataFrame: The aggregated result that can be saved to a table
    """
    
    # Helper function to get column reference
    def get_col(df, col_name):
        """Get column reference, handling both quoted and unquoted identifiers"""
        columns = df.columns
        
        # Try exact match first
        if col_name in columns:
            return col(col_name)
        
        # Try with quotes
        quoted_name = f'"{col_name}"'
        if quoted_name in columns:
            return col(quoted_name)
        
        # Try case variations
        for c in columns:
            if c.upper() == col_name.upper():
                return col(c)
        
        # If not found, return the original and let Snowflake handle the error
        return col(col_name)
    
    # print(f"Reading source table: {source_table_name}")
    
    # Read the source table
    df = session.table(source_table_name)
    
    # print("Available columns:", df.columns)
    
    # Group by all columns except "Boomerang flag" and aggregate
    result_df = df.groupBy(
        get_col(df, "CustomerID"),
        get_col(df, "Account Size"),
        get_col(df, "Revenue Type"),
        get_col(df, "Region"),
        get_col(df, "Industry"),
        get_col(df, "Channel"),
        get_col(df, "Cohort Date"),
        get_col(df, "Product Category"),
        get_col(df, "Product")
    ).agg(
        spark_max(get_col(df, "Boomerang flag")).alias("Boomerang flag")
    ).select(
        # Explicitly select all columns to match the original SQL order
        get_col(df, "CustomerID"),
        get_col(df, "Account Size"),
        get_col(df, "Revenue Type"),
        get_col(df, "Region"),
        get_col(df, "Industry"),
        get_col(df, "Channel"),
        col("Boomerang flag"),  # This comes from the aggregation
        get_col(df, "Cohort Date"),
        get_col(df, "Product Category"),
        get_col(df, "Product")
    )
    
    # print(f"Creating/replacing table: {target_table_name}")
    
    # Write the result to the target table (equivalent to CREATE OR REPLACE TABLE)
    result_df.write.mode("overwrite").saveAsTable(target_table_name)
    
    print(f"___Saving on path:  {target_table_name}")
    
    return result_df

# Example usage:
# Assuming you have a session already created

# Source and target table names
# SOURCE_TABLE = fact_table_output_path
# TARGET_TABLE = fact_table_output_path + "_DIMENSION_DIM"

# Call the function
# result_df = create_master_dimension_table(session, SOURCE_TABLE, TARGET_TABLE)

# Display the results (optional)
# result_df.show()

# To get row count
# print(f"Total rows in result: {result_df.count()}")


def create_date_dimension_table(session: Session, source_table_name: str, target_table_name: str, row_count: int = 10000):
    """
    Convert the SQL CREATE TABLE query to Snowpark using Snowflake DataFrames
    Creates a date dimension table with sequential dates from min to max date in source table
    
    Args:
        session: Snowpark session
        source_table_name: Source table name (e.g., "SUBSCRIPTION_ACCEL.PYTHON_TESTING.FACT_TABLE_AMSTERDAM_INPUT_100K_24JUNE_MASTER")
        target_table_name: Target table name (e.g., "SUBSCRIPTION_ACCEL.PYTHON_TESTING.AMSTERDAM_INPUT_100K_24JUNE_MASTER_DATE_DIM")
        row_count: Number of rows to generate in the sequence (default: 10000)
    
    Returns:
        DataFrame: The date dimension result that can be saved to a table
    """
    
    # Helper function to get column reference
    def get_col(df, col_name):
        """Get column reference, handling both quoted and unquoted identifiers"""
        columns = df.columns
        
        # Try exact match first
        if col_name in columns:
            return col(col_name)
        
        # Try with quotes
        quoted_name = f'"{col_name}"'
        if quoted_name in columns:
            return col(quoted_name)
        
        # Try case variations
        for c in columns:
            if c.upper() == col_name.upper():
                return col(c)
        
        # If not found, return the original and let Snowflake handle the error
        return col(col_name)
    
    # print(f"Reading source table: {source_table_name}")
    
    # Read the source table
    df = session.table(source_table_name)
    
    # print("Available columns:", df.columns)
    
    # Step 1: Min_Date_Summary CTE - Get minimum date from source table
    # print("Step 1: Getting minimum date...")
    min_date_df = df.select(spark_min(get_col(df, "Date")).alias("Min_Date"))
    min_date_value = min_date_df.collect()[0]["MIN_DATE"]
    # print(f"Minimum date found: {min_date_value}")
    
    # Step 2 & 3: CTE_MY_DATE - Generate date sequence using GENERATOR and SEQ4
    # print(f"Step 2-3: Generating date sequence with {row_count} rows...")
    
    # Create generator table with specified row count
    # In Snowpark, we need to use the table function approach
    generator_df = session.sql(f"SELECT SEQ4() as seq_num FROM TABLE(GENERATOR(ROWCOUNT=>{row_count}))")
    
    # Add months to min_date using seq_num values
    date_sequence_df = generator_df.select(
        dateadd("month", col("seq_num"), lit(min_date_value)).alias("DATE")
    )
    
    # Step 4: Get max date from source where TTM <> 0 for filtering
    # print("Step 4: Getting maximum TTM date for filtering...")
    max_ttm_date_df = df.filter(get_col(df, "TTM") != 0).select(
        spark_max(get_col(df, "Date")).alias("Max_TTM_Date")
    )
    max_ttm_date_value = max_ttm_date_df.collect()[0]["MAX_TTM_DATE"]
    # print(f"Maximum TTM date found: {max_ttm_date_value}")
    
    # Step 4: Date_Dimension CTE - Filter dates and add row numbers
    # print("Step 4: Creating date dimension with filtering and row numbers...")
    
    # Filter dates <= max TTM date
    filtered_dates_df = date_sequence_df.filter(col("DATE") <= lit(max_ttm_date_value))
    
    # Add row number and TTM Date columns
    window_spec = Window.orderBy(col("DATE"))
    
    date_dimension_df = filtered_dates_df.select(
        col("DATE").alias("Date"),
        row_number().over(window_spec).alias("INDEX"),
        col("DATE").alias("TTM Date")
    )
    
    # Final step: Order by Date ASC (equivalent to ORDER BY "Date" ASC)
    # print("Final step: Ordering results by Date...")
    final_result_df = date_dimension_df.orderBy(col("Date").asc())
    
    # print(f"Creating/replacing table: {target_table_name}")
    
    # Write the result to the target table (equivalent to CREATE TABLE)
    final_result_df.write.mode("overwrite").saveAsTable(target_table_name)
    
    print(f"___Saving on path:  {target_table_name}")
    
    return final_result_df

# Example usage:
# Source and target table names
# SOURCE_TABLE = fact_table_output_path
# TARGET_TABLE = fact_table_output_path + "_DATE_DIM"

# Call the function to create the table
# result_df = create_date_dimension_table(session, SOURCE_TABLE, TARGET_TABLE, row_count=10000)

# Display the results (optional)
# result_df.show(20)  # Show first 20 rows

# To get row count
# print(f"Total rows in date dimension: {result_df.count()}")

# You can also customize the row count if needed
# result_df = create_date_dimension_table(session, SOURCE_TABLE, TARGET_TABLE, row_count=5000)

In [None]:
def rename_fixed_tables(session):
    tables_and_renames = {
        fact_table_output_path: [("REGION", "Region"), ("INDUSTRY", "Industry"), ("CHANNEL", "Channel")],
        fact_table_output_path + "_DATE_DIM": [("DATE", "Date")],
        fact_table_output_path + "_DIMENSION_DIM": [("REGION", "Region"), ("INDUSTRY", "Industry"), ("CHANNEL", "Channel")]
    }

    for table_name, rename_list in tables_and_renames.items():
        print(f"\n🔧 Renaming columns in: {table_name}")
        df = session.table(table_name)
        actual_cols = df.columns
        col_map = {}

        for old_col, new_col in rename_list:
            match = next((ac for ac in actual_cols if ac.upper() == old_col.upper()), None)
            if match:
                col_map[match] = new_col
            else:
                print(f"⚠️ Column '{old_col}' not found in {table_name}")

        for src, tgt in col_map.items():
            df = df.with_column_renamed(src, f'"{tgt}"')  # Quote to preserve casing

        df.write.mode("overwrite").save_as_table(table_name)
        print(f"✅ Updated: {table_name}")


def rename_pbi_tables(session, retention_levels):
    for level in retention_levels:
        print(f"\n🔄 Processing PBI table for: {level}")
        
        suffix = {
            "Customer_level": "_C_NOTEBOOK",
            "Customer_Product_level": "_CP",
            "Customer_Product_RetentionType_level": "_CPR"
        }.get(level)
        
        if not suffix:
            print(f"⚠️ Unknown retention level: {level}")
            continue

        pbi_table = pbi_retention_output_path + suffix
        df = session.table(pbi_table)
        actual_cols = df.columns
        col_map = {}

        # Rename: "Cohort Date" → "Cohort_Date", WINBACKTAG → WinbackTag
        rename_targets = [("Cohort Date", "Cohort_Date"), ("WINBACKTAG", "WinbackTag")]

        for old_col, new_col in rename_targets:
            match = next((ac for ac in actual_cols if ac.replace('"', '').upper() == old_col.upper()), None)
            if match:
                col_map[match] = new_col
            else:
                print(f"⚠️ Column '{old_col}' not found in {pbi_table}, skipping.")

        for src, tgt in col_map.items():
            df = df.with_column_renamed(src, f'"{tgt}"')  # Use quoted column names to preserve case

        df.write.mode("overwrite").save_as_table(pbi_table)
        print(f"✅ Updated PBI table: {pbi_table}")

# Flow block
**Retention Flow (retention_flow())**

`Purpose: Builds retention tables based on the given level.`

Steps:

- Loads input data via data_loading().

- Runs retention_pipeline_v2() with session, lookbacks, and config.

- Caches the result and saves it to pbi_path.

- Returns: Processed retention DataFrame.

**Fact Table Flow (fact_table_flow())**

`Purpose: Builds a summarized fact table using highest priority retention level.`

Steps:

- Determines the highest-level granularity from retention_levels.

- Loads retention table and base fact data.

- Filters retention data for selected amount/period.

- Aggregates and joins with fact data on customer and product keys.

- Calculates CM (Cohort Month) and drops redundant columns.

- Saves final fact table to fact_path.

**Databook Table Flow (databook_table_flow())**

`Purpose: Creates an Excel-aligned grouped version of retention data.`

Steps:

- Groups by key attributes (CustName, Product, Period, etc.).

- Aggregates financial metrics (EoP$, BoP$, YoY Variance, etc.).

- Truncates date to month granularity.

- Saves grouped results to the Excel output path.

In [None]:
def retention_flow(session: Session, retention_level: str, pbi_path: str, input_file_path=input_file_path) -> DataFrame:
    """
    Processes the input data to build a retention table.

    Parameters:
    session (Session): The Snowpark session object.
    retention_level (str): The level at which retention is calculated.
    pbi_path (str): The file path where the resulting retention table will be saved.

    Returns:
    DataFrame: The resulting DataFrame after processing the retention pipeline.
    """
    start_time = time.time()
    print(f"\n_Building Retention Table Started at {retention_level}.")
    loaded_df = data_loading(
        session,
        input_file_path,
        column_mapping_file,
        #filter_condition=filter_condition,
        type="retention",
        input_amount=input_amount,
        retention_level=retention_level,
    )

    result = retention_pipeline_v2(
        session,
        loaded_df,
        lookback_list,
        input_amount,
        run_at_levels,
    )

    # Cache the result before saving to prevent recomputation

    result = result.cache_result()
    
    # print("Printing retention level oooooooooooooo:   ",retention_level)

    # if retention_level == "Customer_Product_level" or retention_level == "Customer_Product_RetentionType_level":
    #     result = transform_retention_data(session, result, input_table_templogic, retention_level)

    if input_table_templogic is not None and (
    retention_level == "Customer_Product_level" or 
    retention_level == "Customer_Product_RetentionType_level"):
        result = transform_retention_data(session, result, input_table_templogic, retention_level) 
        
    # result = result.drop('"TOTAL_ARR"','"ARR_ROLLCHECK"','"COUNT_ROLLCHECK"','"COHORT_MAX_DATES_CHECK"')
    save_results(result, pbi_path)
    print(f"_Building Retention Table Completed at {retention_level}.")
    end_time = time.time()
    print(f"⏱️ Retention Flow Runtime: {end_time - start_time:.2f} seconds")

    return result


def fact_table_flow(session: Session, fact_path: str, input_file_path=input_file_path) -> None:
    """
    Processes the retention result DataFrame to build a fact table.

    Parameters:
    session (Session): The Snowpark session object.
    fact_path (str): The file path where the resulting fact table will be saved.

    Returns:
    None
    """

    start_time = time.time()
    print("\n_Building Fact Table Started...")
    level_map = {"Customer_level": 1
                , "Customer_Product_level": 2
                , "Customer_Product_RetentionType_level": 3
                , "Level4": 4}

    highest_level = None
    highest_value = -1

    for level in retention_levels:
        value = level_map.get(level, 0)
        if value > highest_value:
           highest_value = value
           highest_level = level

    print(f"Manual approach result: {highest_level}")

    #highest_level = max(retention_levels, key=lambda x: level_map.get(x, 0), default=None)
    pbi_path = get_file_path(highest_level, pbi_retention_output_path)
    retention_result = session.table(pbi_path)

    
    # print("_Building Fact Table Started...")
    fact_df = data_loading(
        session,
        input_file_path,
        column_mapping_file,
        #filter_condition=filter_condition,
        type="fact",
        retention_level = highest_level
    )

    unique_amount = run_at_levels[0]

    lb_value = lookback_list[0]
    if lb_value == 1:
        unique_period = "Month"
    elif lb_value == 3:
        unique_period = "Quarter"
    elif lb_value == 12:
        unique_period = "Year"
    else:
        unique_period = "Year"

    result_temp = retention_result.filter(
        (col('"Amount"') == unique_amount) & (col('"Period"') == unique_period)
    )

    result_short = (
        result_temp.group_by(
            col('"CustName"'),
            col('"Cohort Date"'),
            col('"Current Period"'),
            col('"Product"'),
            col('"Revenue Type"'),
            col("LEVEL4"),
            col("Account Size")  # Add Account Size to the group by
        )
        .agg(
            count(col('"CustName"')).alias("Count of rows"),
            sum(col("T3M")).alias("T3M"),
            sum(col("TTM")).alias("TTM"),
            sum(col("MRR")).alias("MRR"),
            sum(col("ARR")).alias("ARR"),
            sum(col('"T3M (Annualized)"')).alias('"T3M (Annualized)"'),
        )
        .drop("Count of rows")
    )

    fact_df_joined = fact_df.join(
        result_short,
        (fact_df['"CustomerID"'] == result_short['"CustName"'])
        & (fact_df['"Date"'] == result_short['"Current Period"'])
        & (fact_df['"Product"'] == result_short['"Product"'])
        & (fact_df['"Revenue Type"'] == result_short['"Revenue Type"'])
        & (fact_df['"LEVEL4"'] == result_short['"LEVEL4"']),
        lsuffix="_left",
    )

    # Updating fact_df values
    fact_df_joined = fact_df_joined.withColumn(
        "CM", datediff("day", col('"Cohort Date"'), col('"Date"'))
    )

    fact_df_joined = fact_df_joined.drop(
        ['"CustName"', '"Current Period"', '"Product_left"', '"Revenue Type_left"']
    )

    save_results(fact_df_joined, fact_path)

    end_time = time.time()

    print("_Building Fact Table Completed.")
    print(f"⏱️ Fact Table Flow Runtime: {end_time - start_time:.2f} seconds")


def databook_table_flow(result: DataFrame, excel_path: str) -> None:
    """
    Generates a databook table from the given DataFrame and saves it to an Excel file.

    Parameters:
    result (DataFrame): Input DataFrame to process.
    excel_path (str): Path to save the Excel file.

    Returns:
    None
    """
    start_time = time.time()
    print("_Building Databook Table Started...")

    databook_gp = [
        '"Current Period"',
        '"CustName"',
        '"Product"',
        '"Revenue Type"',
        '"RetentionCategory"',
        '"Period"',
        '"Cohort Date"',
        '"Amount"',
        '"UFR Tag"',
        '"Account Size"',  # Add Account Size to the group by
    ]

    result_databook = result.groupBy(databook_gp).agg(
        sum('"EoP$"').alias('"EoP$"'),
        sum('"BoP$"').alias('"BoP$"'),
        sum('"YoY Variance"').alias('"YoY Variance"'),
        sum('"UFR Amount"').alias('"UFR Amount"'),
        sum('"Net Credit"').alias('"Net Credit"'),
    )

    result_databook = result_databook.with_column(
        "Current Period", date_trunc("MONTH", col("Current Period"))
    )

    #### result grouping outputs the full retention table GROUPED to align to the Excel template
    print("Saving Results...")
    save_results(result_databook, excel_path)
    print("_Building Databook Table Completed.")
    end_time = time.time()
    print(f"⏱️ Databook Flow Runtime: {end_time - start_time:.2f} seconds")

# Optimized_main_with_mapping_bundling
This pipeline prepares, analyzes, and summarizes customer data to support retention reporting and revenue insights.

**Step 0: Standardize Data**

`standardize_customer_fields()`
Cleans and maps raw customer data using column mapping.

**Step 0.5: Round Float Columns**

`round_float_columns()`
Rounds all numeric fields to 3 decimal places for consistency.

**Step 1: Retention Analysis**

`retention_flow() (via retention_worker)`
Runs retention logic at multiple levels (Customer, Customer_Product, CPR).
Outputs to Power BI and Excel formats.
Parallelized using ThreadPoolExecutor.

**Step 2: Fact Table**

`fact_table_flow()`
Aggregates retention data into a summarized fact table for dashboarding.

**Step 3: Product Bundle Analysis**

`product_bundle_analysis()`
Derives ARR/MRR, cross-sell, and bundle size insights.

In [None]:
# Original start timing code
start_time = time.time()

# Add standardization step first - before any retention processing
print("\n" + "="*50)
print("STEP 0: STANDARDIZING CUSTOMER FIELDS")
print("="*50)
standardized_df, affected_customers_report = standardize_customer_fields(session, input_file_path, column_mapping_file)

# If you want to use the standardized data for all subsequent operations,
# update the input path to the processed output
processed_input_path = input_file_path + "_ARR_Mapped"

# Round all float columns to 3 decimal places
print("\n" + "="*50)
print("STEP 0.5: ROUNDING FLOAT COLUMNS TO 3 DECIMAL PLACES")
print("="*50)
# Use the properly implemented Snowpark decimal function
rounded_df = round_float_columns(session, processed_input_path, decimal_places=3)

# Now proceed with the original workflow but using the standardized data
print("\n" + "="*50)
print("STEP 1: RUNNING RETENTION ANALYSIS")
print("="*50)

def retention_worker(retention_level):
    print("Running for retention level:", retention_level)
    if retention_level == "Customer_level":
        pbi_path = pbi_retention_output_path + "_C_NOTEBOOK"
        excel_path = excel_retention_output_path + "_C"
    elif retention_level == "Customer_Product_level":
        pbi_path = pbi_retention_output_path + "_CP"
        excel_path = excel_retention_output_path + "_CP"
    elif retention_level == "Customer_Product_RetentionType_level":
        pbi_path = pbi_retention_output_path + "_CPR"
        excel_path = excel_retention_output_path + "_CPR"
    else:
        print(f"Unknown retention level: {retention_level}")
        return None
    
    # Use the standardized data instead of the original input
    retention_result = retention_flow(session, retention_level, pbi_path, input_file_path=processed_input_path)
    print("-" * 50)
    databook_table_flow(retention_result, excel_path)
    print("=" * 50)
    return retention_level

# Execute in parallel
with ThreadPoolExecutor(max_workers=3) as executor:
    futures = {executor.submit(retention_worker, level): level for level in retention_levels}
    for future in as_completed(futures):
        try:
            result = future.result()
        except Exception as e:
            print(f"Error processing {futures[future]}: {e}")

# Run fact table flow after all retention levels
print("\n" + "="*50)
print("STEP 2.1: GENERATING FACT TABLE")
print("="*50)
fact_table_flow(session, fact_table_output_path, input_file_path=processed_input_path)

# DIMENSION_DIM table creation
print("\n" + "="*50)
print("STEP 2.2: GENERATING DIMENSION_DIM TABLE")
output_dimension = fact_table_output_path + "_DIMENSION_DIM"
result_df = create_master_dimension_table(session, fact_table_output_path, output_dimension)
print("="*50)

# DATE_DIM table creation
print("\n" + "="*50)
print("STEP 2.3: GENERATING DATE_DIM TABLE")
output_date = fact_table_output_path + "_DATE_DIM"
result_df = create_date_dimension_table(session, fact_table_output_path, output_date, row_count=10000)
print("="*50)

#Add product bundle analysis
cache_s_time = time.time()
try:
    product_bundle_result = product_bundle_analysis(session)
except Exception as e:
    print(f"Error in product bundle analysis step: {str(e)}")
    traceback.print_exc()
    product_bundle_result = None
cache_e_time = time.time()
print(f"⏱️ Caching bundling: {cache_e_time - cache_s_time:.2f} seconds")

print("-" * 50)

# Rename columns
rename_fixed_tables(session)
rename_pbi_tables(session, retention_levels)

# Print execution time
end_time = time.time()
elapsed_seconds = end_time - start_time
hours = int(elapsed_seconds // 3600)
minutes = int((elapsed_seconds % 3600) // 60)
seconds = int(elapsed_seconds % 60)
print(f"Total execution time: {hours} hours {minutes} minutes {seconds} seconds")

# QC_Mechanism
Generates two kind of outputs-       
(1) `from input table`- Two columns generates- `Current Period` and `Total_Value`.   
(2) `from PBI outputs`- Four columns generates- `Current Period`, `Period`, `SUM_EOP$` and `SUM_BOP$`.  
**Note**- Filter second output wrt `Period` and then sort by `Current Period` and then compare it with first output.

In [None]:
# Read the input table
print("QC table generation started for input file...")
input_df = session.table(input_file_path)

# Group by Date and sum the ARR values
monthly_arr_summary = input_df.group_by(col(column_mapping_file["CURRENTPERIOD"])) \
                             .agg(snowpark_sum(col(column_mapping_file["VALUE"])).alias("TOTAL_VALUE")) \
                             .order_by(col(column_mapping_file["CURRENTPERIOD"]))

output_qc= input_file_path + "_QC"
# Save the output to a new table
monthly_arr_summary.write.mode('overwrite').save_as_table(output_qc)

print(f"✅ Monthly value summary table created and saved to {output_qc} for raw input file.")
print("="*50)
print("QC table generation started for processed input file...")

# List of 6 input file paths
input_files = [
    pbi_retention_output_path + "_C_NOTEBOOK",
    pbi_retention_output_path + "_CP",
    pbi_retention_output_path + "_CPR",
]

# Iterate through each input file
for input_table in input_files:
    try:
        # Attempt to load the table
        df = session.table(input_table)

        # Filter the DataFrame based on input_amount value
        df = df.filter(F.col('"Amount"') == input_amount)

        # Perform SUM aggregations grouped by "Current Period"
        aggregated_df = (
            df.group_by("Current Period", '"Period"')
              .agg(
                  
                  snowflake_sum('"EoP$"').alias("Sum_EoP$"),
                  snowflake_sum('"BoP$"').alias("Sum_BoP$")
              )
        )

        # Define output table name
        output_table = input_table + "_QC"

        # Save the result as a new table (overwrite if exists)
        aggregated_df.write.mode("overwrite").save_as_table(output_table)

        print(f"✅ Output saved with EoP$, BoP$ and Current Period to: {output_table} ")

    except SnowparkSQLException as e:
        print(f"⚠️ Skipping unavailable file: {input_table} (Reason: {e.message})")
print("="*50)
print("You can now compare the outputs of each file to verify the results for input and output files.")

# Retention QA & PoP Validation
`Input`
- Load retention table (e.g., "_C_NOTEBOOK")

- Extract distinct Period & Amount values

`Loop Logic`

- Filter data

- all_checks() – Validates QA columns vs. thresholds

- period_on_period_check_v2() – Compares BoP vs. lagged EoP

`Functions`

- all_checks(df, cols, thresholds)
→ Marks Pass/Fail, shows failed rows

- period_on_period_check_v2(period, df)
→ Ensures EoP → BoP continuity

In [None]:
# User inputs
pbi_retention_output_path = pbi_retention_output_path + "_C_NOTEBOOK" # Customer_level
# pbi_retention_output_path = pbi_retention_output_path + "_CP" # Customer_Product_level
# pbi_retention_output_path = pbi_retention_output_path + "_CPR" # Customer_Product_RetentionType_level

retention_result = session.table(pbi_retention_output_path)
periods = retention_result.select('"Period"').distinct().collect()
amounts = retention_result.select('"Amount"').distinct().collect()

def period_on_period_check_v2(period, filtered_df):
 
    if period[0] == 'Month':
        criteria = col("Current Period")
    elif period[0]  == 'Year':
        criteria = year(col("Current Period"))
        window_spec = Window.orderBy()
    # elif period[0]  == 'Quarter':
    #     criteria = (year(col("Current Period")), quarter(col("Current Period")))
    
    result_df = filtered_df.group_by(criteria).agg(
        sum(col('"EoP$"')).alias("sum_EoP$"),
        sum(col('"BoP$"')).alias("sum_BoP$")
    )
    window_spec = Window.orderBy(criteria)
    result_df = result_df.sort(criteria)
        
    df_with_lag = result_df.withColumn("sum_EoP$_pr", lag(col("sum_EoP$"), 1).over(window_spec))
        
    df_final = df_with_lag.withColumn(
            "Check_value",
            when(col("sum_BoP$") == col("sum_EoP$_pr"), 'True').otherwise('False')
        )
        
    chk_result = df_final.filter(col("Check_value") != True)
        
    if chk_result.count() == 1:
        print("Period on Period Check", "-->", "Check Passed!")
        print("=" * 50)
    else:
        print("Period on Period Check", "-->", "Check Failed!")
        print("=" * 50)
        print("Snapshot view")
        df_final.show()
        
period_on_period_mapping={
    "Year":"ARR",
    "Month":"MRR",
    # "Quarter":"T3M"
}

def all_checks(tmp, qa_columns, qa_check_thresholds):
    for qa_column in qa_columns:
        check_df = verify_qa_check(tmp, [qa_column], qa_check_thresholds)
        # Try Catch block has been added to handle empty Dataframes after filter conditions in verify_qa_check
        try:
            sample = check_df.limit(1).collect()
            if len(sample) == 0:
                print(qa_column,"-->", "Check Passed!")
                continue
        except Exception as e:
            print(qa_column,"-->", "Check if column exists in retention result: Specify correct coulumn in Input!")
            continue
        result_str = "Check Failed!" if check_df.count() > 0 else "Check Passed!"
        print(qa_column,"-->", result_str)
        if result_str == "Check Failed!":
            print("Snapshot view")
            check_df.show()
    print("-"*50)

for amount in amounts: 
    for period in periods:
        print(f"Running for amount: {amount[0]} and period: {period[0]}")
        print("="*50)
        tmp = retention_result.filter((col('"Period"') == period[0]) & (col('"Amount"')== amount[0]))
        
        all_checks(tmp, qa_columns, qa_check_thresholds)
        
        # if period_on_period_mapping.get(period[0]) == amount[0]:
        #     print("Period on period check!", period[0], amount[0])
        #     tmp = retention_result.filter((col('"Period"') == period[0]) & (col('"Amount"')== amount[0]))
        #     period_on_period_check_v2(period, tmp)