In [None]:
import logging
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, DoubleType

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:
    # Load data from Unity Catalog tables
    hospital_stats_df = spark.table("catalog.db.hospital_stats_north_america")
    sales_assignments_df = spark.table("catalog.db.hospital_sales_assignments")
    employment_details_df = spark.table("catalog.db.sales_associates_employment_details")
    compensation_guidelines_df = spark.table("catalog.db.compensation_guidelines")
    logistics_channels_df = spark.table("catalog.db.logistics_channels")
    growth_opportunities_df = spark.table("catalog.db.growth_opportunities")
    company_goals_df = spark.table("catalog.db.company_goals")
    historical_sales_df = spark.table("catalog.db.historical_sales")
    third_party_sales_trends_df = spark.table("catalog.db.third_party_sales_trends")

    # Join hospital statistics with sales assignments
    hospital_sales_df = hospital_stats_df.join(
        sales_assignments_df,
        on=["Hospital_ID", "Hospital_Name"],
        how="inner"
    )

    # Join employment details with compensation guidelines
    employment_compensation_df = employment_details_df.join(
        compensation_guidelines_df,
        on="Associate_ID",
        how="inner"
    )

    # Combine the results of previous joins
    combined_df = hospital_sales_df.join(
        employment_compensation_df,
        on=["Associate_ID", "Associate_Name"],
        how="inner"
    )

    # Calculate total compensation
    combined_df = combined_df.withColumn(
        "Compensation",
        F.col("Base_Salary") + (F.col("Commission_Percentage") * F.col("Base_Salary")) + F.col("Bonus")
    )

    # Select relevant fields
    selected_df = combined_df.select(
        "Associate_ID", "Associate_Name", "Compensation", "Director_Name", "Hospital_ID", "Manager_Name"
    )

    # Join logistics channels with growth opportunities
    logistics_growth_df = logistics_channels_df.join(
        growth_opportunities_df,
        on=["Channel_ID", "Channel_Type", "Hospital_ID"],
        how="inner"
    )

    # Join compensation data with logistics and growth data
    compensation_logistics_growth_df = selected_df.join(
        logistics_growth_df,
        on="Hospital_ID",
        how="inner"
    )

    # Select fields for further processing
    selected_logistics_df = compensation_logistics_growth_df.select(
        "Hospital_ID", "Channel_Type", "Growth_Opportunities", "Projected_Growth_Rate"
    )

    # Join selected data with company goals
    goals_df = selected_logistics_df.join(
        company_goals_df,
        on=["Hospital_ID", "Channel_Type"],
        how="inner"
    )

    # Ensure unique records
    unique_goals_df = goals_df.dropDuplicates(["Hospital_ID", "Channel_Type", "Projected_Growth_Rate", "Investment_Planned"])

    # Join historical sales data with third-party sales trends
    sales_trends_df = historical_sales_df.join(
        third_party_sales_trends_df,
        on="Channel_Type",
        how="inner"
    )

    # Ensure unique sales records
    unique_sales_df = sales_trends_df.dropDuplicates(["Year", "Channel_Type", "Sales_Revenue"])

    # Join unique sales data with growth and investment data
    final_join_df = unique_sales_df.join(
        unique_goals_df,
        on=["Hospital_ID", "Channel_ID", "Channel_Type"],
        how="inner"
    )

    # Generate rows for target years
    target_years_df = spark.range(2023, 2027).withColumnRenamed("id", "TargetYear")

    # Calculate projected sales growth rate
    final_df = final_join_df.crossJoin(target_years_df).withColumn(
        "projected_sales_growth_rate",
        F.when(F.col("TargetYear") == 2024, F.col("Projected_Growth_Rate") + (F.col("Projected_Growth_Rate") / 100))
         .when(F.col("TargetYear") == 2025, (F.col("Projected_Growth_Rate") + (F.col("Projected_Growth_Rate") / 100)) + (F.col("Projected_Growth_Rate") / 100))
         .when(F.col("TargetYear") == 2026, ((F.col("Projected_Growth_Rate") + (F.col("Projected_Growth_Rate") / 100)) + (F.col("Projected_Growth_Rate") / 100)) + (F.col("Projected_Growth_Rate") / 100))
         .otherwise(F.col("Projected_Growth_Rate"))
    )

    # Calculate projected investments
    final_df = final_df.withColumn(
        "projected_investments",
        F.when(F.col("TargetYear") == 2024, F.col("Investment_Planned") * (F.col("projected_sales_growth_rate") / 100))
         .when(F.col("TargetYear") == 2025, F.col("Investment_Planned") * (1 + F.col("projected_sales_growth_rate") / 100))
         .when(F.col("TargetYear") == 2026, F.col("Investment_Planned") * (1 + F.col("projected_sales_growth_rate") / 100))
         .otherwise(F.col("Investment_Planned"))
    )

    # Calculate projected revenue
    final_df = final_df.withColumn(
        "Projected_Revenue",
        F.when(F.col("TargetYear") == 2024, F.col("Sales_Revenue") * (F.col("projected_sales_growth_rate") / 100))
         .when(F.col("TargetYear") == 2025, F.col("Sales_Revenue") * (1 + F.col("projected_sales_growth_rate") / 100))
         .when(F.col("TargetYear") == 2026, F.col("Sales_Revenue") * (1 + F.col("projected_sales_growth_rate") / 100))
         .otherwise(F.col("Sales_Revenue"))
    )

    # Filter records based on target year
    filtered_df = final_df.filter(F.col("TargetYear") > 2023)

    # Select fields for final output
    output_df = filtered_df.select(
        "Channel_Type", "Hospital_ID", "Market_Trend", "Political_Impact", "Economic_Impact",
        "TargetYear", "projected_sales_growth_rate", "projected_investments", "Projected_Revenue"
    )

    # Sort records by target year
    sorted_output_df = output_df.orderBy("TargetYear")

    # Write to Unity Catalog target table
    spark.sql("DROP TABLE IF EXISTS catalog.db.sales_prediction_output")
    sorted_output_df.write.format("delta").mode("overwrite").saveAsTable("catalog.db.sales_prediction_output")

    logger.info("ETL process completed successfully.")

except Exception as e:
    logger.error(f"Error during ETL process: {e}")
    raise
