In [None]:
import logging
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:
    # Load data from Unity Catalog tables
    hospital_stats_df = spark.table("genai_demo.cardinal_health.hospital_stats_north_america")
    sales_assignments_df = spark.table("genai_demo.cardinal_health.hospital_sales_assignments")
    employment_details_df = spark.table("genai_demo.cardinal_health.sales_associates_employment_details")
    compensation_guidelines_df = spark.table("genai_demo.cardinal_health.compensation_guidelines")
    logistics_channels_df = spark.table("genai_demo.cardinal_health.logistics_channels")
    growth_opportunities_df = spark.table("genai_demo.cardinal_health.growth_opportunities")
    company_goals_df = spark.table("genai_demo.cardinal_health.company_goals")
    historical_sales_df = spark.table("genai_demo.cardinal_health.historical_sales")
    third_party_trends_df = spark.table("genai_demo.cardinal_health.third_party_sales_trends")

    # Join hospital statistics with sales assignments
    hospital_sales_df = hospital_stats_df.join(
        sales_assignments_df,
        on=["Hospital_ID", "Hospital_Name"],
        how="inner"
    )

    # Join employment details with compensation guidelines
    employment_compensation_df = employment_details_df.join(
        compensation_guidelines_df,
        on="Associate_ID",
        how="inner"
    )

    # Join the results of the previous two joins
    combined_df = hospital_sales_df.join(
        employment_compensation_df,
        on=["Associate_ID", "Associate_Name"],
        how="inner"
    )

    # Calculate total compensation
    combined_df = combined_df.withColumn(
        "Compensation",
        F.col("Base_Salary") + (F.col("Commission_Percentage") * F.col("Base_Salary")) + F.col("Bonus")
    )

    # Select specific fields for further processing
    selected_df = combined_df.select(
        "Associate_ID", "Associate_Name", "Compensation", "Director_Name", "Hospital_ID", "Manager_Name"
    )

    # Join logistics channels with growth opportunities
    logistics_growth_df = logistics_channels_df.join(
        growth_opportunities_df,
        on=["Channel_ID", "Channel_Type", "Hospital_ID"],
        how="inner"
    )

    # Join the results of the previous join with selected fields
    enriched_df = selected_df.join(
        logistics_growth_df,
        on="Hospital_ID",
        how="inner"
    )

    # Join selected fields with company goals
    goals_df = enriched_df.join(
        company_goals_df,
        on=["Hospital_ID", "Channel_Type"],
        how="inner"
    )

    # Ensure unique records based on specific fields
    unique_goals_df = goals_df.dropDuplicates(["Hospital_ID", "Channel_Type", "Projected_Growth_Rate", "Investment_Planned"])

    # Join historical sales with third-party sales trends
    sales_trends_df = historical_sales_df.join(
        third_party_trends_df,
        on="Channel_Type",
        how="inner"
    )

    # Ensure unique records based on specific fields
    unique_sales_trends_df = sales_trends_df.dropDuplicates(["Year", "Channel_Type", "Sales_Revenue"])

    # Join the results of the previous unique tool with unique records from another source
    final_df = unique_sales_trends_df.join(
        unique_goals_df,
        on=["Hospital_ID", "Channel_ID", "Channel_Type"],
        how="inner"
    )

    # Generate rows for target years
    target_years_df = final_df.withColumn("Target Year", F.explode(F.array([F.lit(year) for year in range(2023, 2027)])))

    # Calculate projected sales growth rate
    target_years_df = target_years_df.withColumn(
        "projected_sales_growth_rate",
        F.when(F.col("Target Year") == 2024, F.col("Projected_Growth_Rate") + (F.col("Projected_Growth_Rate") / 100))
         .when(F.col("Target Year") == 2025, (F.col("Projected_Growth_Rate") + (F.col("Projected_Growth_Rate") / 100)) + (F.col("Projected_Growth_Rate") / 100))
         .when(F.col("Target Year") == 2026, ((F.col("Projected_Growth_Rate") + (F.col("Projected_Growth_Rate") / 100)) + (F.col("Projected_Growth_Rate") / 100)) + (F.col("Projected_Growth_Rate") / 100))
         .otherwise(F.col("Projected_Growth_Rate"))
    )

    # Calculate projected investments
    target_years_df = target_years_df.withColumn(
        "projected_investments",
        F.when(F.col("Target Year") == 2024, F.col("Investment_Planned") * (F.col("projected_sales_growth_rate") / 100))
         .when(F.col("Target Year") == 2025, F.col("Investment_Planned") * (1 + F.col("projected_sales_growth_rate") / 100))
         .when(F.col("Target Year") == 2026, F.col("Investment_Planned") * (1 + F.col("projected_sales_growth_rate") / 100))
         .otherwise(F.col("Investment_Planned"))
    )

    # Calculate projected revenue
    target_years_df = target_years_df.withColumn(
        "Projected Revenue",
        F.when(F.col("Target Year") == 2024, F.col("Sales_Revenue") * (F.col("projected_sales_growth_rate") / 100))
         .when(F.col("Target Year") == 2025, F.col("Sales_Revenue") * (1 + F.col("projected_sales_growth_rate") / 100))
         .when(F.col("Target Year") == 2026, F.col("Sales_Revenue") * (1 + F.col("projected_sales_growth_rate") / 100))
         .otherwise(F.col("Sales_Revenue"))
    )

    # Filter records based on target year
    filtered_df = target_years_df.filter(F.col("Target Year") > 2023)

    # Select specific fields for output
    output_df = filtered_df.select(
        "Channel_Type", "Hospital_ID", "Market_Trend", "Political_Impact", "Economic_Impact",
        "Target Year", "projected_sales_growth_rate", "projected_investments", "Projected Revenue"
    )

    # Sort records by target year
    sorted_output_df = output_df.orderBy("Target Year")

    # Write to a Unity Catalog target table
    spark.sql("DROP TABLE IF EXISTS genai_demo.cardinal_health.sales_prediction_output")
    sorted_output_df.write.format("delta").mode("overwrite").saveAsTable("genai_demo.cardinal_health.sales_prediction_output")

    logger.info("ETL process completed successfully.")

except Exception as e:
    logger.error(f"An error occurred during the ETL process: {e}")
