In [None]:
import logging
import requests
from pyspark.sql.functions import col, count, avg, max, datediff, current_date, when, lit
from pyspark.sql.types import IntegerType, DecimalType

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Function to fetch data from REST API
def fetch_api_data(api_url, customer_ids, api_key):
    headers = {'Authorization': f'Bearer {api_key}'}
    results = []
    for customer_id in customer_ids:
        try:
            response = requests.get(api_url.format(Customer_ID=customer_id), headers=headers)
            if response.status_code == 200:
                results.append((customer_id, response.json().get('score')))
            else:
                logger.error(f"Failed to fetch data for Customer_ID {customer_id}: {response.status_code}")
        except Exception as e:
            logger.error(f"Error fetching data for Customer_ID {customer_id}: {str(e)}")
    return results

# Load data from Unity Catalog tables
try:
    policy_df = spark.table("postgresql_catalog.demo.policydb")
    claims_df = spark.table("mysql_catalog.vsco.claimsdb")
    demographics_df = spark.table("sqlserver_catalog.dbo.demographicsdb")
except Exception as e:
    logger.error(f"Error loading data from Unity Catalog: {str(e)}")
    raise

# Join policy data with customer demographics
try:
    policy_demo_df = policy_df.join(demographics_df, "Customer_ID", "inner")
except Exception as e:
    logger.error(f"Error joining policy and demographics data: {str(e)}")
    raise

# Join the result with claims data
try:
    policy_claims_df = policy_demo_df.join(claims_df, "Policy_ID", "inner")
except Exception as e:
    logger.error(f"Error joining policy_demo and claims data: {str(e)}")
    raise

# Aggregate data at the customer level
try:
    agg_df = policy_claims_df.groupBy("Customer_ID").agg(
        count("Claim_ID").alias("Total_Claims"),
        avg("Claim_Amount").alias("Average_Claim_Amount"),
        max("Claim_Date").alias("Recent_Claim_Date"),
        count("Policy_ID").alias("Policy_Count")
    )
except Exception as e:
    logger.error(f"Error aggregating data: {str(e)}")
    raise

# Calculate derived fields
try:
    derived_df = agg_df.withColumn("Age", datediff(current_date(), col("Date_of_Birth")) / 365) \
        .withColumn("Claim_To_Premium_Ratio", when(col("Total_Premium_Paid") != 0, col("Claim_Amount") / col("Total_Premium_Paid")).otherwise(0)) \
        .withColumn("Claims_Per_Policy", when(col("Policy_Count") != 0, col("Total_Claims") / col("Policy_Count")).otherwise(0)) \
        .withColumn("Retention_Rate", lit(0.85)) \
        .withColumn("Cross_Sell_Opportunities", lit("Multi-Policy Discount, Home Coverage Add-on")) \
        .withColumn("Upsell_Potential", lit("Premium Vehicle Coverage"))
except Exception as e:
    logger.error(f"Error calculating derived fields: {str(e)}")
    raise

# Fetch fraud scores from API
try:
    fraud_api_url = "http://18.189.118.116:9010/fraudscore?Customer_ID={Customer_ID}"
    fraud_api_key = dbutils.secrets.get("api_secrets", "fraud_api_key")
    fraud_scores = fetch_api_data(fraud_api_url, derived_df.select("Customer_ID").rdd.flatMap(lambda x: x).collect(), fraud_api_key)
    fraud_df = spark.createDataFrame(fraud_scores, ["Customer_ID", "Fraud_Score"])
except Exception as e:
    logger.error(f"Error fetching fraud scores: {str(e)}")
    raise

# Fetch credit scores from API
try:
    credit_api_url = "http://18.189.118.116:9010/creditscore?Customer_ID={Customer_ID}"
    credit_api_key = dbutils.secrets.get("api_secrets", "credit_api_key")
    credit_scores = fetch_api_data(credit_api_url, derived_df.select("Customer_ID").rdd.flatMap(lambda x: x).collect(), credit_api_key)
    credit_df = spark.createDataFrame(credit_scores, ["Customer_ID", "Credit_Score"])
except Exception as e:
    logger.error(f"Error fetching credit scores: {str(e)}")
    raise

# Join derived data with fraud and credit scores
try:
    final_df = derived_df.join(fraud_df, "Customer_ID", "left").join(credit_df, "Customer_ID", "left")
except Exception as e:
    logger.error(f"Error joining derived data with API scores: {str(e)}")
    raise

# Add AI-driven insights
try:
    final_df = final_df.withColumn("Churn_Probability", lit(0.25)) \
        .withColumn("Next_Best_Offer", lit("Additional Life Coverage")) \
        .withColumn("Claims_Fraud_Probability", lit(0.10)) \
        .withColumn("Revenue_Potential", lit(12000.00))
except Exception as e:
    logger.error(f"Error adding AI-driven insights: {str(e)}")
    raise

# Write the final DataFrame to Unity Catalog table
try:
    spark.sql("DROP TABLE IF EXISTS genai_demo.cardinal_health.customer_360")
    final_df.write.format("delta").mode("overwrite").saveAsTable("genai_demo.cardinal_health.customer_360")
    logger.info("Data successfully written to Unity Catalog table: genai_demo.cardinal_health.customer_360")
except Exception as e:
    logger.error(f"Error writing data to Unity Catalog: {str(e)}")
    raise
