In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from pyspark.sql.functions import col, explode, split

# Define data lake paths
silver_path = "/mnt/silver/clinical_trials"
gold_path = "/mnt/gold/"

# Load clean data from the Silver layer
try:
    silver_df = spark.read.format("delta").load(silver_path)
    silver_df.createOrReplaceTempView("clinical_trials_silver")
    print("Successfully loaded silver data and created temporary view.")
except Exception as e:
    dbutils.notebook.exit(f"Error loading silver data: {e}")

# Create a cleaned temporary view for aggregation
# This handles case-insensitivity and null values for phases and funders.
spark.sql("""
CREATE OR REPLACE TEMPORARY VIEW trials_cleaned AS
WITH silver_data AS (
  SELECT
    *,
    LOWER(COALESCE(phases, 'Unknown')) as phases_lower,
    COALESCE(funder_type, 'Unknown') as funder_type_cleaned
  FROM clinical_trials_silver
)
SELECT
  nct_number,
  study_title,
  study_status,
  conditions,
  interventions,
  sponsor,
  funder_type_cleaned AS funder_type,
  study_type,
  enrollment,
  start_date,
  completion_date,
  CASE
    WHEN phases_lower LIKE '%phase 4%' THEN 'Phase 4'
    WHEN phases_lower LIKE '%phase 3%' THEN 'Phase 3'
    WHEN phases_lower LIKE '%phase 2%' THEN 'Phase 2'
    WHEN phases_lower LIKE '%phase 1%' THEN 'Phase 1'
    WHEN phases_lower = 'not applicable' THEN 'Not Applicable'
    ELSE 'Other/Unknown'
  END AS phase_cleaned,
  TRIM(SPLIT(COALESCE(locations, 'Unknown, Unknown'), ',')[SIZE(SPLIT(COALESCE(locations, 'Unknown, Unknown'), ',')) - 1]) AS country
FROM
  silver_data
""")
print("Successfully created cleaned temporary view 'trials_cleaned'.")

# Create and save Gold aggregated tables
trials_by_phase = spark.sql("""
    SELECT phase_cleaned, COUNT(DISTINCT nct_number) as number_of_trials
    FROM trials_cleaned GROUP BY phase_cleaned ORDER BY phase_cleaned
""")
trials_by_phase.write.format("delta").mode("overwrite").save(f"{gold_path}trials_by_phase")
print("Successfully created Gold table: trials_by_phase")

trials_by_status = spark.sql("""
    SELECT study_status, COUNT(DISTINCT nct_number) as number_of_trials
    FROM trials_cleaned GROUP BY study_status ORDER BY number_of_trials DESC
""")
trials_by_status.write.format("delta").mode("overwrite").save(f"{gold_path}trials_by_status")
print("Successfully created Gold table: trials_by_status")

trials_by_funder = spark.sql("""
    SELECT funder_type, COUNT(DISTINCT nct_number) as number_of_trials
    FROM trials_cleaned GROUP BY funder_type ORDER BY number_of_trials DESC
""")
trials_by_funder.write.format("delta").mode("overwrite").save(f"{gold_path}trials_by_funder")
print("Successfully created Gold table: trials_by_funder")

# Explode conditions to analyze the most studied medical conditions
trials_df = spark.table("trials_cleaned")
conditions_df = trials_df.withColumn("condition", explode(split(trials_df.conditions, "\\|"))) \
    .filter("condition IS NOT NULL AND TRIM(condition) != ''")

top_20_conditions = conditions_df.groupBy("condition") \
    .count().orderBy("count", ascending=False).limit(20) \
    .withColumnRenamed("count", "number_of_trials")

top_20_conditions.write.format("delta").mode("overwrite").save(f"{gold_path}top_20_conditions")
print("Successfully created Gold table: top_20_conditions")


# --- Gold Layer Visualizations ---
sns.set_theme(style="whitegrid")

# Top 10 Clinical Trial Statuses by Count
trials_by_status_pd = trials_by_status.limit(10).toPandas()
plt.figure(figsize=(12, 7))
status_plot = sns.barplot(x="number_of_trials", y="study_status", data=trials_by_status_pd, palette="plasma", orient="h")
status_plot.set_title("Top 10 Clinical Trial Statuses by Count", fontsize=16)
status_plot.set_xlabel("Number of Trials")
status_plot.set_ylabel("Study Status")
plt.show()

# Top 10 Countries Hosting Clinical Trials (Excluding USA)
trials_by_country = spark.sql("""
    SELECT country, COUNT(DISTINCT nct_number) as number_of_trials
    FROM trials_cleaned
    WHERE country != 'Unknown' AND country != 'United States'
    GROUP BY country ORDER BY number_of_trials DESC LIMIT 10
""")
trials_by_country_pd = trials_by_country.toPandas()
plt.figure(figsize=(12, 8))
country_plot = sns.barplot(x="number_of_trials", y="country", data=trials_by_country_pd, palette="cubehelix", orient="h")
country_plot.set_title("Top 10 Countries by Number of Trials (Excluding USA)", fontsize=16)
country_plot.set_xlabel("Number of Trials")
country_plot.set_ylabel("Country")
plt.show()

# Status Distribution by Phase
status_by_phase = spark.table("trials_cleaned") \
    .filter(col("phase_cleaned").isin(['Phase 1', 'Phase 2', 'Phase 3', 'Phase 4'])) \
    .filter(col("study_status").isin(['COMPLETED', 'RECRUITING', 'TERMINATED', 'ACTIVE_NOT_RECRUITING', 'WITHDRAWN'])) \
    .groupBy("phase_cleaned", "study_status") \
    .agg({"nct_number": "count"}) \
    .withColumnRenamed("count(nct_number)", "number_of_trials")

if not status_by_phase.isEmpty():
    status_by_phase_pd = status_by_phase.toPandas()
    plt.figure(figsize=(14, 8))
    status_phase_plot = sns.barplot(x="phase_cleaned", y="number_of_trials", hue="study_status", data=status_by_phase_pd, order=["Phase 1", "Phase 2", "Phase 3", "Phase 4"])
    status_phase_plot.set_title("Distribution of Trial Statuses within Each Phase", fontsize=16)
    status_phase_plot.set_xlabel("Trial Phase")
    status_phase_plot.set_ylabel("Number of Trials")
    plt.legend(title='Study Status')
    plt.show()
else:
    print("No data available to plot for 'Status Distribution by Phase'. The filtered DataFrame was empty.")

# Top 20 Most Studied Medical Conditions
top_20_conditions_pd = top_20_conditions.toPandas()
plt.figure(figsize=(12, 10))
conditions_plot = sns.barplot(x="number_of_trials", y="condition", data=top_20_conditions_pd, palette="coolwarm", orient="h")
conditions_plot.set_title("Top 20 Most Studied Medical Conditions", fontsize=16)
conditions_plot.set_xlabel("Number of Trials")
conditions_plot.set_ylabel("Condition")
plt.show()

