In [None]:
import logging
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:
    # Load data from Unity Catalog tables
    logger.info("Loading data from Unity Catalog tables...")
    hospital_stats_df = spark.table("genai_demo.cardinal_health.hospital_stats_north_america")
    sales_assignments_df = spark.table("genai_demo.cardinal_health.hospital_sales_assignments")
    employment_details_df = spark.table("genai_demo.cardinal_health.sales_associates_employment_details")
    compensation_guidelines_df = spark.table("genai_demo.cardinal_health.compensation_guidelines")
    logistics_channels_df = spark.table("genai_demo.cardinal_health.logistics_channels")
    growth_opportunities_df = spark.table("genai_demo.cardinal_health.growth_opportunities")
    company_goals_df = spark.table("genai_demo.cardinal_health.company_goals")
    historical_sales_df = spark.table("genai_demo.cardinal_health.historical_sales")
    third_party_trends_df = spark.table("genai_demo.cardinal_health.third_party_sales_trends")

    # Transformation: Join hospital statistics with sales assignments
    logger.info("Joining hospital statistics with sales assignments...")
    hospital_sales_df = hospital_stats_df.join(
        sales_assignments_df,
        on=["Hospital_ID", "Hospital_Name"],
        how="inner"
    )

    # Transformation: Join employment details with compensation guidelines
    logger.info("Joining employment details with compensation guidelines...")
    employment_compensation_df = employment_details_df.join(
        compensation_guidelines_df,
        on="Associate_ID",
        how="inner"
    )

    # Transformation: Calculate total compensation
    logger.info("Calculating total compensation...")
    employment_compensation_df = employment_compensation_df.withColumn(
        "Compensation",
        F.col("Base_Salary") + (F.col("Commission_Percentage") / 100) * F.col("Base_Salary") + F.col("Bonus")
    )

    # Transformation: Join logistics channels with growth opportunities
    logger.info("Joining logistics channels with growth opportunities...")
    logistics_growth_df = logistics_channels_df.join(
        growth_opportunities_df,
        on=["Channel_ID", "Channel_Type", "Hospital_ID"],
        how="inner"
    )

    # Transformation: Join with company goals
    logger.info("Joining with company goals...")
    combined_df = logistics_growth_df.join(
        company_goals_df,
        on=["Hospital_ID", "Channel_Type"],
        how="inner"
    )

    # Transformation: Join historical sales with third-party sales trends
    logger.info("Joining historical sales with third-party sales trends...")
    sales_trends_df = historical_sales_df.join(
        third_party_trends_df,
        on="Channel_Type",
        how="inner"
    )

    # Transformation: Generate rows for target years
    logger.info("Generating rows for target years...")
    target_years_df = sales_trends_df.withColumn("TargetYear", F.explode(F.array([2024, 2025, 2026])))

    # Transformation: Calculate projected sales growth rate
    logger.info("Calculating projected sales growth rate...")
    target_years_df = target_years_df.withColumn(
        "projected_sales_growth_rate",
        F.when(F.col("TargetYear") == 2024, F.col("Projected_Growth_Rate") + (F.col("Projected_Growth_Rate") / 100))
        .when(F.col("TargetYear") == 2025, (F.col("Projected_Growth_Rate") + (F.col("Projected_Growth_Rate") / 100)) + (F.col("Projected_Growth_Rate") / 100))
        .when(F.col("TargetYear") == 2026, ((F.col("Projected_Growth_Rate") + (F.col("Projected_Growth_Rate") / 100)) + (F.col("Projected_Growth_Rate") / 100)) + (F.col("Projected_Growth_Rate") / 100))
        .otherwise(F.col("Projected_Growth_Rate"))
    )

    # Transformation: Calculate projected investments
    logger.info("Calculating projected investments...")
    target_years_df = target_years_df.withColumn(
        "projected_investments",
        F.when(F.col("TargetYear") == 2024, F.col("Investment_Planned") * (F.col("projected_sales_growth_rate") / 100))
        .when(F.col("TargetYear") == 2025, F.col("Investment_Planned") * (1 + F.col("projected_sales_growth_rate") / 100))
        .when(F.col("TargetYear") == 2026, F.col("Investment_Planned") * (1 + F.col("projected_sales_growth_rate") / 100))
        .otherwise(F.col("Investment_Planned"))
    )

    # Transformation: Calculate projected revenue
    logger.info("Calculating projected revenue...")
    target_years_df = target_years_df.withColumn(
        "Projected_Revenue",
        F.when(F.col("TargetYear") == 2024, F.col("Sales_Revenue") * (F.col("projected_sales_growth_rate") / 100))
        .when(F.col("TargetYear") == 2025, F.col("Sales_Revenue") * (1 + F.col("projected_sales_growth_rate") / 100))
        .when(F.col("TargetYear") == 2026, F.col("Sales_Revenue") * (1 + F.col("projected_sales_growth_rate") / 100))
        .otherwise(F.col("Sales_Revenue"))
    )

    # Filter records where TargetYear > 2023
    logger.info("Filtering records where TargetYear > 2023...")
    filtered_df = target_years_df.filter(F.col("TargetYear") > 2023)

    # Select specific fields for output
    logger.info("Selecting specific fields for output...")
    final_df = filtered_df.select(
        "Channel_Type", "Hospital_ID", "Market_Trend", "Political_Impact", "Economic_Impact",
        "TargetYear", "projected_sales_growth_rate", "projected_investments", "Projected_Revenue"
    )

    # Sort records by TargetYear in ascending order
    logger.info("Sorting records by TargetYear in ascending order...")
    final_df = final_df.orderBy("TargetYear")

    # Write the processed data to Unity Catalog table
    logger.info("Writing the processed data to Unity Catalog table...")
    spark.sql("DROP TABLE IF EXISTS genai_demo.cardinal_health.sales_prediction_output")
    final_df.write.format("delta").mode("overwrite").saveAsTable("genai_demo.cardinal_health.sales_prediction_output")

    logger.info("ETL process completed successfully.")

except Exception as e:
    logger.error(f"An error occurred during the ETL process: {e}")
