In [0]:
%python
# Accessing the raw file from AWS

# Set environment variables for AWS access
import os
os.environ['AWS_ACCESS_KEY_ID'] = 'YOUR_AWS_ACCESS_KEY_ID'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'YOUR_AWS_SECRET_ACCESS_KEY'
os.environ['AWS_DEFAULT_REGION'] = 'us-east-1'

print("AWS credentials set in environment variables")

# Try to load data from S3
try:
    # Load the raw EHR data from S3
    df = spark.read.csv("s3://medicodeai-ehr-data/raw/synthetic_ehr_data.csv", header=True, inferSchema=True)
    
    # Create a temporary view for SQL queries
    df.createOrReplaceTempView("ehr_data")
    
    print(f"✅ Data loaded successfully from S3!")
    print(f"Shape: {df.count()} rows, {len(df.columns)} columns")
    print(f"Columns: {df.columns}")
    
    # Show sample data
    print("\nSample data:")
    df.show(5)
    
except Exception as e:
    print(f"❌ Error loading from S3: {e}")
    print("\nCreating sample data instead...")
    
    # Fallback: Create sample data
    from pyspark.sql.types import *
    import random
    
    # Define schema for EHR data
    schema = StructType([
        StructField("patient_id", StringType(), True),
        StructField("age", IntegerType(), True),
        StructField("gender", StringType(), True),
        StructField("diagnosis_text", StringType(), True),
        StructField("icd10_code", StringType(), True),
        StructField("admission_date", StringType(), True),
        StructField("discharge_date", StringType(), True),
        StructField("length_of_stay", IntegerType(), True),
        StructField("severity_score", DoubleType(), True)
    ])
    
    # Create sample data
    sample_data = []
    for i in range(1000):
        sample_data.append((
            f"P{i:04d}",  # patient_id
            random.randint(18, 95),  # age
            random.choice(["Male", "Female"]),  # gender
            f"Patient diagnosed with {random.choice(['diabetes', 'hypertension', 'asthma', 'pneumonia', 'heart disease'])}",  # diagnosis_text
            f"{random.choice(['E11', 'I10', 'J45', 'J18', 'I25'])}.{random.randint(0, 9)}",  # icd10_code
            "2024-01-01",  # admission_date
            "2024-01-05",  # discharge_date
            random.randint(1, 30),  # length_of_stay
            round(random.uniform(1.0, 10.0), 2)  # severity_score
        ))
    
    # Create DataFrame
    df = spark.createDataFrame(sample_data, schema)
    
    # Create a temporary view for SQL queries
    df.createOrReplaceTempView("ehr_data")
    
    print(f"✅ Sample EHR data created successfully!")
    print(f"Shape: {df.count()} rows, {len(df.columns)} columns")
    print(f"Columns: {df.columns}")
    
    # Show sample data
    print("\nSample data:")
    df.show(5)

print("\n🎉 Data is ready for SQL validation!")
print("You can now run SQL queries on the 'ehr_data' table.") 

In [0]:
-- Save validation summary report as CSV
SELECT 
  'Data Quality Report' as report_type,
  CURRENT_TIMESTAMP() as report_timestamp,
  (SELECT COUNT(*) FROM ehr_data) as total_records,
  (SELECT COUNT(DISTINCT patient_id) FROM ehr_data) as unique_patients,
  (SELECT COUNT(DISTINCT icd10_code) FROM ehr_data) as unique_icd10_codes,
  (SELECT COUNT(CASE WHEN patient_id IS NOT NULL THEN 1 END) FROM ehr_data) as valid_patient_ids,
  (SELECT COUNT(CASE WHEN CAST(age AS INT) >= 0 AND CAST(age AS INT) <= 120 THEN 1 END) FROM ehr_data) as valid_ages,
  (SELECT COUNT(CASE WHEN CAST(length_of_stay AS INT) >= 0 THEN 1 END) FROM ehr_data) as valid_length_of_stay,
  (SELECT COUNT(CASE WHEN LENGTH(TRIM(diagnosis_text)) >= 10 THEN 1 END) FROM ehr_data) as valid_diagnosis_text,
  (SELECT COUNT(CASE WHEN icd10_code REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$' THEN 1 END) FROM ehr_data) as valid_icd10_codes,
  (SELECT COUNT(CASE WHEN 
    patient_id IS NOT NULL AND
    diagnosis_text IS NOT NULL AND
    icd10_code IS NOT NULL AND
    CAST(age AS INT) >= 0 AND CAST(age AS INT) <= 120 AND
    CAST(length_of_stay AS INT) >= 0 AND
    LENGTH(TRIM(diagnosis_text)) >= 10 AND
    icd10_code REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$'
  THEN 1 END) FROM ehr_data) as fully_valid_records,
  ROUND((SELECT COUNT(CASE WHEN 
    patient_id IS NOT NULL AND
    diagnosis_text IS NOT NULL AND
    icd10_code IS NOT NULL AND
    CAST(age AS INT) >= 0 AND CAST(age AS INT) <= 120 AND
    CAST(length_of_stay AS INT) >= 0 AND
    LENGTH(TRIM(diagnosis_text)) >= 10 AND
    icd10_code REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$'
  THEN 1 END) FROM ehr_data) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as overall_quality_score;

In [0]:
-- Save missing values report as CSV
SELECT 
  'Missing Values' as metric_category,
  'patient_id' as field_name,
  COUNT(*) - COUNT(patient_id) as invalid_count,
  ROUND((COUNT(*) - COUNT(patient_id)) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data

UNION ALL

SELECT 
  'Missing Values' as metric_category,
  'age' as field_name,
  COUNT(*) - COUNT(age) as invalid_count,
  ROUND((COUNT(*) - COUNT(age)) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data

UNION ALL

SELECT 
  'Missing Values' as metric_category,
  'gender' as field_name,
  COUNT(*) - COUNT(gender) as invalid_count,
  ROUND((COUNT(*) - COUNT(gender)) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data

UNION ALL

SELECT 
  'Missing Values' as metric_category,
  'diagnosis_text' as field_name,
  COUNT(*) - COUNT(diagnosis_text) as invalid_count,
  ROUND((COUNT(*) - COUNT(diagnosis_text)) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data

UNION ALL

SELECT 
  'Missing Values' as metric_category,
  'icd10_code' as field_name,
  COUNT(*) - COUNT(icd10_code) as invalid_count,
  ROUND((COUNT(*) - COUNT(icd10_code)) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data;

In [0]:
-- Save data validation report as CSV
SELECT 
  'Data Validation' as metric_category,
  'invalid_age' as field_name,
  COUNT(CASE WHEN CAST(age AS INT) < 0 OR CAST(age AS INT) > 120 THEN 1 END) as invalid_count,
  ROUND(COUNT(CASE WHEN CAST(age AS INT) < 0 OR CAST(age AS INT) > 120 THEN 1 END) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data

UNION ALL

SELECT 
  'Data Validation' as metric_category,
  'negative_length_of_stay' as field_name,
  COUNT(CASE WHEN CAST(length_of_stay AS INT) < 0 THEN 1 END) as invalid_count,
  ROUND(COUNT(CASE WHEN CAST(length_of_stay AS INT) < 0 THEN 1 END) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data

UNION ALL

SELECT 
  'Data Validation' as metric_category,
  'short_diagnosis_text' as field_name,
  COUNT(CASE WHEN LENGTH(TRIM(diagnosis_text)) < 10 THEN 1 END) as invalid_count,
  ROUND(COUNT(CASE WHEN LENGTH(TRIM(diagnosis_text)) < 10 THEN 1 END) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data

UNION ALL

SELECT 
  'Data Validation' as metric_category,
  'invalid_icd10_format' as field_name,
  COUNT(CASE WHEN icd10_code NOT REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$' THEN 1 END) as invalid_count,
  ROUND(COUNT(CASE WHEN icd10_code NOT REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$' THEN 1 END) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data;

In [0]:
-- Save gender distribution report as CSV
SELECT 
  'Gender Distribution' as report_section,
  gender,
  COUNT(*) as count,
  ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as percentage
FROM ehr_data
GROUP BY gender
ORDER BY count DESC;

In [0]:
-- Save age distribution report as CSV
SELECT 
  'Age Distribution' as report_section,
  CASE 
    WHEN CAST(age AS INT) < 18 THEN 'Under 18'
    WHEN CAST(age AS INT) BETWEEN 18 AND 30 THEN '18-30'
    WHEN CAST(age AS INT) BETWEEN 31 AND 50 THEN '31-50'
    WHEN CAST(age AS INT) BETWEEN 51 AND 70 THEN '51-70'
    WHEN CAST(age AS INT) > 70 THEN 'Over 70'
    ELSE 'Unknown'
  END as age_group,
  COUNT(*) as count,
  ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as percentage
FROM ehr_data
GROUP BY 
  CASE 
    WHEN CAST(age AS INT) < 18 THEN 'Under 18'
    WHEN CAST(age AS INT) BETWEEN 18 AND 30 THEN '18-30'
    WHEN CAST(age AS INT) BETWEEN 31 AND 50 THEN '31-50'
    WHEN CAST(age AS INT) BETWEEN 51 AND 70 THEN '51-70'
    WHEN CAST(age AS INT) > 70 THEN 'Over 70'
    ELSE 'Unknown'
  END
ORDER BY 
  CASE age_group
    WHEN 'Under 18' THEN 1
    WHEN '18-30' THEN 2
    WHEN '31-50' THEN 3
    WHEN '51-70' THEN 4
    WHEN 'Over 70' THEN 5
    ELSE 6
  END;

In [0]:
-- Save ICD-10 code distribution as CSV
SELECT 
  'ICD-10 Distribution' as report_section,
  SUBSTRING(icd10_code, 1, 3) as icd10_category,
  COUNT(*) as count,
  ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as percentage
FROM ehr_data
GROUP BY SUBSTRING(icd10_code, 1, 3)
ORDER BY count DESC;

In [0]:
-- Save length of stay distribution as CSV
SELECT 
  'Length of Stay Distribution' as report_section,
  CASE 
    WHEN CAST(length_of_stay AS INT) <= 1 THEN '1 day or less'
    WHEN CAST(length_of_stay AS INT) BETWEEN 2 AND 7 THEN '2-7 days'
    WHEN CAST(length_of_stay AS INT) BETWEEN 8 AND 14 THEN '8-14 days'
    WHEN CAST(length_of_stay AS INT) BETWEEN 15 AND 30 THEN '15-30 days'
    WHEN CAST(length_of_stay AS INT) > 30 THEN 'Over 30 days'
    ELSE 'Unknown'
  END as los_group,
  COUNT(*) as count,
  ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as percentage
FROM ehr_data
GROUP BY 
  CASE 
    WHEN CAST(length_of_stay AS INT) <= 1 THEN '1 day or less'
    WHEN CAST(length_of_stay AS INT) BETWEEN 2 AND 7 THEN '2-7 days'
    WHEN CAST(length_of_stay AS INT) BETWEEN 8 AND 14 THEN '8-14 days'
    WHEN CAST(length_of_stay AS INT) BETWEEN 15 AND 30 THEN '15-30 days'
    WHEN CAST(length_of_stay AS INT) > 30 THEN 'Over 30 days'
    ELSE 'Unknown'
  END
ORDER BY 
  CASE los_group
    WHEN '1 day or less' THEN 1
    WHEN '2-7 days' THEN 2
    WHEN '8-14 days' THEN 3
    WHEN '15-30 days' THEN 4
    WHEN 'Over 30 days' THEN 5
    ELSE 6
  END;

In [0]:
-- Save severity score distribution as CSV
SELECT 
  'Severity Score Distribution' as report_section,
  CASE 
    WHEN CAST(severity_score AS DOUBLE) <= 2.0 THEN 'Low (1-2)'
    WHEN CAST(severity_score AS DOUBLE) BETWEEN 2.1 AND 5.0 THEN 'Medium (2.1-5)'
    WHEN CAST(severity_score AS DOUBLE) BETWEEN 5.1 AND 8.0 THEN 'High (5.1-8)'
    WHEN CAST(severity_score AS DOUBLE) > 8.0 THEN 'Critical (>8)'
    ELSE 'Unknown'
  END as severity_group,
  COUNT(*) as count,
  ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as percentage
FROM ehr_data
GROUP BY 
  CASE 
    WHEN CAST(severity_score AS DOUBLE) <= 2.0 THEN 'Low (1-2)'
    WHEN CAST(severity_score AS DOUBLE) BETWEEN 2.1 AND 5.0 THEN 'Medium (2.1-5)'
    WHEN CAST(severity_score AS DOUBLE) BETWEEN 5.1 AND 8.0 THEN 'High (5.1-8)'
    WHEN CAST(severity_score AS DOUBLE) > 8.0 THEN 'Critical (>8)'
    ELSE 'Unknown'
  END
ORDER BY 
  CASE severity_group
    WHEN 'Low (1-2)' THEN 1
    WHEN 'Medium (2.1-5)' THEN 2
    WHEN 'High (5.1-8)' THEN 3
    WHEN 'Critical (>8)' THEN 4
    ELSE 5
  END;

In [0]:
-- Save details of invalid records as CSV
SELECT 
  patient_id,
  age,
  gender,
  diagnosis_text,
  icd10_code,
  length_of_stay,
  severity_score,
  CASE 
    WHEN patient_id IS NULL THEN 'Missing patient_id'
    WHEN age IS NULL THEN 'Missing age'
    WHEN CAST(age AS INT) < 0 OR CAST(age AS INT) > 120 THEN 'Invalid age'
    WHEN CAST(length_of_stay AS INT) < 0 THEN 'Negative length of stay'
    WHEN LENGTH(TRIM(diagnosis_text)) < 10 THEN 'Short diagnosis text'
    WHEN icd10_code NOT REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$' THEN 'Invalid ICD-10 format'
    ELSE 'Valid'
  END as validation_issue
FROM ehr_data
WHERE 
  patient_id IS NULL OR
  age IS NULL OR
  CAST(age AS INT) < 0 OR 
  CAST(age AS INT) > 120 OR
  CAST(length_of_stay AS INT) < 0 OR
  LENGTH(TRIM(diagnosis_text)) < 10 OR
  icd10_code NOT REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$'
ORDER BY validation_issue;

In [0]:
-- Save report metadata as CSV
SELECT 
  'MediCodeAI Data Validation Report' as report_title,
  CURRENT_TIMESTAMP() as generated_at,
  (SELECT COUNT(*) FROM ehr_data) as total_records_processed,
  (SELECT COUNT(DISTINCT patient_id) FROM ehr_data) as unique_patients,
  (SELECT COUNT(DISTINCT icd10_code) FROM ehr_data) as unique_icd10_codes,
  (SELECT COUNT(DISTINCT gender) FROM ehr_data) as unique_genders,
  ROUND((SELECT COUNT(CASE WHEN 
    patient_id IS NOT NULL AND
    diagnosis_text IS NOT NULL AND
    icd10_code IS NOT NULL AND
    CAST(age AS INT) >= 0 AND CAST(age AS INT) <= 120 AND
    CAST(length_of_stay AS INT) >= 0 AND
    LENGTH(TRIM(diagnosis_text)) >= 10 AND
    icd10_code REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$'
  THEN 1 END) FROM ehr_data) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as overall_data_quality_score;

In [0]:
%python

# Local CSV Export with S3 Upload
# This approach avoids Delta format issues in Databricks free edition

import os
import pandas as pd
import boto3
from datetime import datetime

# Set AWS credentials
os.environ['AWS_ACCESS_KEY_ID'] = 'YOUR_AWS_ACCESS_KEY_ID'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'YOUR_AWS_SECRET_ACCESS_KEY'
os.environ['AWS_DEFAULT_REGION'] = 'us-east-1'

print("AWS credentials set in environment variables")

# Initialize S3 client
s3_client = boto3.client('s3')
bucket_name = 'medicodeai-ehr-data'

def upload_to_s3(file_path, s3_key):
    """Upload a file to S3"""
    try:
        s3_client.upload_file(file_path, bucket_name, s3_key)
        print(f"✅ Uploaded {file_path} to s3://{bucket_name}/{s3_key}")
        return True
    except Exception as e:
        print(f"❌ Failed to upload {file_path}: {str(e)}")
        return False

# Create local directory for CSV files
local_dir = "/tmp/validation_reports"
os.makedirs(local_dir, exist_ok=True)

print(f"📁 Created local directory: {local_dir}")

# 1. Generate validation summary
print("📊 Generating validation summary...")
validation_summary = spark.sql("""
SELECT 
  'Data Quality Report' as report_type,
  CURRENT_TIMESTAMP() as report_timestamp,
  (SELECT COUNT(*) FROM ehr_data) as total_records,
  (SELECT COUNT(DISTINCT patient_id) FROM ehr_data) as unique_patients,
  (SELECT COUNT(DISTINCT icd10_code) FROM ehr_data) as unique_icd10_codes,
  (SELECT COUNT(CASE WHEN patient_id IS NOT NULL THEN 1 END) FROM ehr_data) as valid_patient_ids,
  (SELECT COUNT(CASE WHEN CAST(age AS INT) >= 0 AND CAST(age AS INT) <= 120 THEN 1 END) FROM ehr_data) as valid_ages,
  (SELECT COUNT(CASE WHEN CAST(length_of_stay AS INT) >= 0 THEN 1 END) FROM ehr_data) as valid_length_of_stay,
  (SELECT COUNT(CASE WHEN LENGTH(TRIM(diagnosis_text)) >= 10 THEN 1 END) FROM ehr_data) as valid_diagnosis_text,
  (SELECT COUNT(CASE WHEN icd10_code REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$' THEN 1 END) FROM ehr_data) as valid_icd10_codes,
  (SELECT COUNT(CASE WHEN 
    patient_id IS NOT NULL AND
    diagnosis_text IS NOT NULL AND
    icd10_code IS NOT NULL AND
    CAST(age AS INT) >= 0 AND CAST(age AS INT) <= 120 AND
    CAST(length_of_stay AS INT) >= 0 AND
    LENGTH(TRIM(diagnosis_text)) >= 10 AND
    icd10_code REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$'
  THEN 1 END) FROM ehr_data) as fully_valid_records,
  ROUND((SELECT COUNT(CASE WHEN 
    patient_id IS NOT NULL AND
    diagnosis_text IS NOT NULL AND
    icd10_code IS NOT NULL AND
    CAST(age AS INT) >= 0 AND CAST(age AS INT) <= 120 AND
    CAST(length_of_stay AS INT) >= 0 AND
    LENGTH(TRIM(diagnosis_text)) >= 10 AND
    icd10_code REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$'
  THEN 1 END) FROM ehr_data) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as overall_quality_score
""")

# Convert to pandas and save locally
validation_summary_pdf = validation_summary.toPandas()
local_file = f"{local_dir}/validation_summary.csv"
validation_summary_pdf.to_csv(local_file, index=False)
upload_to_s3(local_file, "raw/databricks_validation_reports/validation_summary.csv")

# 2. Generate missing values report
print("📊 Generating missing values report...")
missing_values = spark.sql("""
SELECT 
  'Missing Values' as metric_category,
  'patient_id' as field_name,
  COUNT(*) - COUNT(patient_id) as invalid_count,
  ROUND((COUNT(*) - COUNT(patient_id)) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data
UNION ALL
SELECT 
  'Missing Values' as metric_category,
  'age' as field_name,
  COUNT(*) - COUNT(age) as invalid_count,
  ROUND((COUNT(*) - COUNT(age)) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data
UNION ALL
SELECT 
  'Missing Values' as metric_category,
  'gender' as field_name,
  COUNT(*) - COUNT(gender) as invalid_count,
  ROUND((COUNT(*) - COUNT(gender)) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data
UNION ALL
SELECT 
  'Missing Values' as metric_category,
  'diagnosis_text' as field_name,
  COUNT(*) - COUNT(diagnosis_text) as invalid_count,
  ROUND((COUNT(*) - COUNT(diagnosis_text)) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data
UNION ALL
SELECT 
  'Missing Values' as metric_category,
  'icd10_code' as field_name,
  COUNT(*) - COUNT(icd10_code) as invalid_count,
  ROUND((COUNT(*) - COUNT(icd10_code)) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data
""")

missing_values_pdf = missing_values.toPandas()
local_file = f"{local_dir}/missing_values.csv"
missing_values_pdf.to_csv(local_file, index=False)
upload_to_s3(local_file, "raw/databricks_validation_reports/missing_values.csv")

# 3. Generate data validation report
print("📊 Generating data validation report...")
data_validation = spark.sql("""
SELECT 
  'Data Validation' as metric_category,
  'invalid_age' as field_name,
  COUNT(CASE WHEN CAST(age AS INT) < 0 OR CAST(age AS INT) > 120 THEN 1 END) as invalid_count,
  ROUND(COUNT(CASE WHEN CAST(age AS INT) < 0 OR CAST(age AS INT) > 120 THEN 1 END) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data
UNION ALL
SELECT 
  'Data Validation' as metric_category,
  'negative_length_of_stay' as field_name,
  COUNT(CASE WHEN CAST(length_of_stay AS INT) < 0 THEN 1 END) as invalid_count,
  ROUND(COUNT(CASE WHEN CAST(length_of_stay AS INT) < 0 THEN 1 END) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data
UNION ALL
SELECT 
  'Data Validation' as metric_category,
  'short_diagnosis_text' as field_name,
  COUNT(CASE WHEN LENGTH(TRIM(diagnosis_text)) < 10 THEN 1 END) as invalid_count,
  ROUND(COUNT(CASE WHEN LENGTH(TRIM(diagnosis_text)) < 10 THEN 1 END) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data
UNION ALL
SELECT 
  'Data Validation' as metric_category,
  'invalid_icd10_format' as field_name,
  COUNT(CASE WHEN icd10_code NOT REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$' THEN 1 END) as invalid_count,
  ROUND(COUNT(CASE WHEN icd10_code NOT REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$' THEN 1 END) * 100.0 / COUNT(*), 2) as error_percentage
FROM ehr_data
""")

data_validation_pdf = data_validation.toPandas()
local_file = f"{local_dir}/data_validation.csv"
data_validation_pdf.to_csv(local_file, index=False)
upload_to_s3(local_file, "raw/databricks_validation_reports/data_validation.csv")

# 4. Generate gender distribution
print("📊 Generating gender distribution...")
gender_dist = spark.sql("""
SELECT 
  'Gender Distribution' as report_section,
  gender,
  COUNT(*) as count,
  ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as percentage
FROM ehr_data
GROUP BY gender
ORDER BY count DESC
""")

gender_dist_pdf = gender_dist.toPandas()
local_file = f"{local_dir}/gender_distribution.csv"
gender_dist_pdf.to_csv(local_file, index=False)
upload_to_s3(local_file, "raw/databricks_validation_reports/gender_distribution.csv")

# 5. Generate age distribution
print("📊 Generating age distribution...")
age_dist = spark.sql("""
SELECT 
  'Age Distribution' as report_section,
  CASE 
    WHEN CAST(age AS INT) < 18 THEN 'Under 18'
    WHEN CAST(age AS INT) BETWEEN 18 AND 30 THEN '18-30'
    WHEN CAST(age AS INT) BETWEEN 31 AND 50 THEN '31-50'
    WHEN CAST(age AS INT) BETWEEN 51 AND 70 THEN '51-70'
    WHEN CAST(age AS INT) > 70 THEN 'Over 70'
    ELSE 'Unknown'
  END as age_group,
  COUNT(*) as count,
  ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as percentage
FROM ehr_data
GROUP BY 
  CASE 
    WHEN CAST(age AS INT) < 18 THEN 'Under 18'
    WHEN CAST(age AS INT) BETWEEN 18 AND 30 THEN '18-30'
    WHEN CAST(age AS INT) BETWEEN 31 AND 50 THEN '31-50'
    WHEN CAST(age AS INT) BETWEEN 51 AND 70 THEN '51-70'
    WHEN CAST(age AS INT) > 70 THEN 'Over 70'
    ELSE 'Unknown'
  END
ORDER BY 
  CASE age_group
    WHEN 'Under 18' THEN 1
    WHEN '18-30' THEN 2
    WHEN '31-50' THEN 3
    WHEN '51-70' THEN 4
    WHEN 'Over 70' THEN 5
    ELSE 6
  END
""")

age_dist_pdf = age_dist.toPandas()
local_file = f"{local_dir}/age_distribution.csv"
age_dist_pdf.to_csv(local_file, index=False)
upload_to_s3(local_file, "raw/databricks_validation_reports/age_distribution.csv")

# 6. Generate ICD-10 distribution
print("📊 Generating ICD-10 distribution...")
icd10_dist = spark.sql("""
SELECT 
  'ICD-10 Distribution' as report_section,
  SUBSTRING(icd10_code, 1, 3) as icd10_category,
  COUNT(*) as count,
  ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as percentage
FROM ehr_data
GROUP BY SUBSTRING(icd10_code, 1, 3)
ORDER BY count DESC
""")

icd10_dist_pdf = icd10_dist.toPandas()
local_file = f"{local_dir}/icd10_distribution.csv"
icd10_dist_pdf.to_csv(local_file, index=False)
upload_to_s3(local_file, "raw/databricks_validation_reports/icd10_distribution.csv")

# 7. Generate length of stay distribution
print("📊 Generating length of stay distribution...")
los_dist = spark.sql("""
SELECT 
  'Length of Stay Distribution' as report_section,
  CASE 
    WHEN CAST(length_of_stay AS INT) <= 1 THEN '1 day or less'
    WHEN CAST(length_of_stay AS INT) BETWEEN 2 AND 7 THEN '2-7 days'
    WHEN CAST(length_of_stay AS INT) BETWEEN 8 AND 14 THEN '8-14 days'
    WHEN CAST(length_of_stay AS INT) BETWEEN 15 AND 30 THEN '15-30 days'
    WHEN CAST(length_of_stay AS INT) > 30 THEN 'Over 30 days'
    ELSE 'Unknown'
  END as los_group,
  COUNT(*) as count,
  ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as percentage
FROM ehr_data
GROUP BY 
  CASE 
    WHEN CAST(length_of_stay AS INT) <= 1 THEN '1 day or less'
    WHEN CAST(length_of_stay AS INT) BETWEEN 2 AND 7 THEN '2-7 days'
    WHEN CAST(length_of_stay AS INT) BETWEEN 8 AND 14 THEN '8-14 days'
    WHEN CAST(length_of_stay AS INT) BETWEEN 15 AND 30 THEN '15-30 days'
    WHEN CAST(length_of_stay AS INT) > 30 THEN 'Over 30 days'
    ELSE 'Unknown'
  END
ORDER BY 
  CASE los_group
    WHEN '1 day or less' THEN 1
    WHEN '2-7 days' THEN 2
    WHEN '8-14 days' THEN 3
    WHEN '15-30 days' THEN 4
    WHEN 'Over 30 days' THEN 5
    ELSE 6
  END
""")

los_dist_pdf = los_dist.toPandas()
local_file = f"{local_dir}/length_of_stay_distribution.csv"
los_dist_pdf.to_csv(local_file, index=False)
upload_to_s3(local_file, "raw/databricks_validation_reports/length_of_stay_distribution.csv")

# 8. Generate report metadata
print("📊 Generating report metadata...")
metadata = spark.sql("""
SELECT 
  'MediCodeAI Data Validation Report' as report_title,
  CURRENT_TIMESTAMP() as generated_at,
  (SELECT COUNT(*) FROM ehr_data) as total_records_processed,
  (SELECT COUNT(DISTINCT patient_id) FROM ehr_data) as unique_patients,
  (SELECT COUNT(DISTINCT icd10_code) FROM ehr_data) as unique_icd10_codes,
  (SELECT COUNT(DISTINCT gender) FROM ehr_data) as unique_genders,
  ROUND((SELECT COUNT(CASE WHEN 
    patient_id IS NOT NULL AND
    diagnosis_text IS NOT NULL AND
    icd10_code IS NOT NULL AND
    CAST(age AS INT) >= 0 AND CAST(age AS INT) <= 120 AND
    CAST(length_of_stay AS INT) >= 0 AND
    LENGTH(TRIM(diagnosis_text)) >= 10 AND
    icd10_code REGEXP '^[A-Z][0-9]{2}\\.[0-9X]{1,2}$'
  THEN 1 END) FROM ehr_data) * 100.0 / (SELECT COUNT(*) FROM ehr_data), 2) as overall_data_quality_score
""")

metadata_pdf = metadata.toPandas()
local_file = f"{local_dir}/report_metadata.csv"
metadata_pdf.to_csv(local_file, index=False)
upload_to_s3(local_file, "raw/databricks_validation_reports/report_metadata.csv")

# Clean up local files
import shutil
shutil.rmtree(local_dir)

print("✅ All validation reports generated and uploaded to S3!")
print("📁 Location: s3://medicodeai-ehr-data/raw/databricks_validation_reports/")
print("📊 Files created:")
print("   - validation_summary.csv")
print("   - missing_values.csv")
print("   - data_validation.csv")
print("   - gender_distribution.csv")
print("   - age_distribution.csv")
print("   - icd10_distribution.csv")
print("   - length_of_stay_distribution.csv")
print("   - report_metadata.csv")
print("🧹 Local temporary files cleaned up") 