In [0]:
import dlt
from pyspark.sql.functions import *
from pyspark.sql.functions import max as spark_max
from pyspark.sql.window import Window
from datetime import datetime
from pyspark.sql.utils import AnalysisException
from pyspark.sql.types import *
from pyspark.sql import functions as F
from pyspark.sql.functions import sum as sql_sum, min as sql_min, max as sql_max

In [0]:


# Define OMOP CDM schema for location table
location_schema = StructType([
    StructField("location_id", LongType(), False, 
                metadata={"comment": "A unique identifier for each geographic location."}),
    StructField("address_1", StringType(), True,
                metadata={"comment": "The first line of the address."}),
    StructField("address_2", StringType(), True,
                metadata={"comment": "The second line of the address"}),
    StructField("city", StringType(), True,
                metadata={"comment": "The city field is the text name of the city."}),
    StructField("state", StringType(), True,
                metadata={"comment": "The state field contains the state name. For addresses outside the US, this field can be used for provinces or other administrative regions."}),
    StructField("zip", StringType(), True,
                metadata={"comment": "The zip or postal code. For US addresses, valid formats are 3-digit, 5-digit or 9-digit ZIP codes. For non-US addresses, the postal code should be stored in the same field."}),
    StructField("county", StringType(), True,
                metadata={"comment": "The county, if available. The county field can also be used to store other regional information."}),
    StructField("location_source_value", StringType(), True,
                metadata={"comment": "The verbatim value for the location as it appears in the source data."}),
    StructField("country_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined Concept table for the country concept id, representing the country portion of the address."}),
    StructField("country_source_value", StringType(), True,
                metadata={"comment": "The source code for the country as it appears in the source data."}),
    StructField("latitude", FloatType(), True,
                metadata={"comment": "The latitude of the location. Must be between -90 and 90."}),
    StructField("longitude", FloatType(), True,
                metadata={"comment": "The longitude of the location. Must be between -180 and 180."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_location_rules = {
    "valid_location_id": "location_id IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_location_rules = {
    # Field format validations
    "valid_zip_format": "zip IS NULL OR LENGTH(zip) <= 9",
    "valid_address_length": "address_1 IS NULL OR LENGTH(address_1) <= 50",
    "valid_address2_length": "address_2 IS NULL OR LENGTH(address_2) <= 50",
    "valid_city_length": "city IS NULL OR LENGTH(city) <= 50",
    "valid_state_length": "state IS NULL OR LENGTH(state) <= 2",
    "valid_county_length": "county IS NULL OR LENGTH(county) <= 20",
    "valid_location_source_length": "location_source_value IS NULL OR LENGTH(location_source_value) <= 50",
    "valid_country_source_length": "country_source_value IS NULL OR LENGTH(country_source_value) <= 80",
    
    # Geographical constraints
    "valid_latitude": "latitude IS NULL OR (latitude >= -90 AND latitude <= 90)",
    "valid_longitude": "longitude IS NULL OR (longitude >= -180 AND longitude <= 180)",
    
    # Concept ID validation
    "valid_country_concept": "country_concept_id IS NULL OR country_concept_id >= 0"
}

@dlt.table(
    name="country_concepts",
    comment="Geography concepts from OMOP vocabulary"
)
@dlt.expect_or_fail("valid_domain", "domain_id = 'Geography'")
def create_country_concepts():
    """Creates a table of geography concepts for country mapping"""
    return (spark.table("3_lookup.omop.concept")
           .filter(col("domain_id") == "Geography")
           .select(
               lower(col("concept_name")).alias("country_name_lower"),
               col("concept_id").alias("country_concept_id")
           ))

@dlt.table(
    name="omop_location",
    comment="OMOP CDM Location table - Represents a generic way to capture physical location or address information",
    schema=location_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_location_rules)
@dlt.expect_all(advisory_location_rules)
def create_omop_location():
    """
    Creates the OMOP Location table from source address data.
    Implements incremental processing and data quality checks.
    """
    # Read source addresses
    addresses = spark.table("4_prod.bronze.map_address")
    
    # Get country concepts from previous step
    country_concepts = dlt.read("country_concepts")
    
    # Join and transform data
    joined_data = (addresses
        .join(country_concepts, 
              lower(addresses.country_cd) == col("country_name_lower"), 
              "left"))

    # Add row number for deduplication
    window_spec = Window.partitionBy("ADDRESS_ID").orderBy("country_concept_id")
    joined_data_with_rn = joined_data.withColumn("row_number", 
                                                F.row_number().over(window_spec))

    # Transform to final format with all necessary type casting
    return (joined_data_with_rn
        .filter(col("row_number") == 1)
        .select(
            col("ADDRESS_ID").cast("bigint").alias("location_id"),
            
            when(col("full_street_address").contains(","),
                 substring(split(col("full_street_address"), ",").getItem(0), 1, 50))
            .otherwise(substring(col("full_street_address"), 1, 50))
            .cast("string").alias("address_1"),
            
            when(col("full_street_address").contains(","),
                 substring(split(col("full_street_address"), ",").getItem(1), 1, 50))
            .otherwise(lit(None))
            .cast("string").alias("address_2"),
            
            substring(col("CITY"), 1, 50).cast("string").alias("city"),
            lit(None).cast("string").alias("state"),
            substring(col("masked_zipcode"), 1, 9).cast("string").alias("zip"),
            lit(None).cast("string").alias("county"),
            substring(col("full_street_address"), 1, 50).cast("string")
                .alias("location_source_value"),
            coalesce(col("country_concept_id"), lit(0)).cast("integer")
                .alias("country_concept_id"),
            substring(col("country_cd"), 1, 20).cast("string")
                .alias("country_source_value"),
            lit(None).cast("float").alias("latitude"),
            lit(None).cast("float").alias("longitude")
        ))

@dlt.table(
    name="qual_omop_location",
    comment="Quality metrics for the location table"
)
def qual_omop_location():
    """Tracks quality metrics for the location table"""
    location_data = dlt.read("omop_location")
    return (location_data
            .agg(
                F.count("*").alias("total_records"),
                F.count("country_concept_id").alias("records_with_country"),
                (F.count("*") - F.count("country_concept_id")).alias("records_missing_country"),
                F.sum(when(col("country_concept_id") == 0, 1).otherwise(0))
                    .alias("unmapped_countries")
            ))

In [0]:
# Define OMOP CDM schema for care_site table
care_site_schema = StructType([
    StructField("care_site_id", LongType(), False, 
                metadata={"comment": "A unique identifier for each Care Site"}),
    StructField("care_site_name", StringType(), True,
                metadata={"comment": "The name of the care_site as it appears in the source data"}),
    StructField("place_of_service_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined Concept table for the place of service concept id"}),
    StructField("location_id", LongType(), True,
                metadata={"comment": "A foreign key to the Location table, where the detailed address information is stored"}),
    StructField("care_site_source_value", StringType(), True,
                metadata={"comment": "The identifier of the care_site as it appears in the source data"}),
    StructField("place_of_service_source_value", StringType(), True,
                metadata={"comment": "The source code for the place of service as it appears in the source data"})
])

# Data quality rules for care site table
care_site_rules = {
    # Required field checks based on OMOP CDM specification
    "valid_care_site_id": "care_site_id IS NOT NULL",
    # Additional data quality checks
    "valid_concept_id_format": "place_of_service_concept_id IS NULL OR place_of_service_concept_id >= 0",
    "valid_location_id_format": "location_id IS NULL OR location_id >= 0"
}

@dlt.table(
    name="omop_care_site",
    comment="OMOP CDM Care Site table - Contains a list of institutional (physical or organizational) units where healthcare delivery is practiced (offices, wards, hospitals, clinics, etc.)",
    schema=care_site_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(care_site_rules)
def create_omop_care_site():
    """
    Creates the OMOP Care Site table from source care site data
    """
    # Load source care site data
    care_sites = spark.table("4_prod.bronze.map_care_site")
    
    # Transform to OMOP format
    return (care_sites
        .select(
            # Use care_site_cd as the primary identifier
            col("care_site_cd").cast("bigint").alias("care_site_id"),
            
            # Use care_site_name directly
            col("care_site_name").alias("care_site_name"),
            
            # Default place_of_service_concept_id to 0 since no mapping exists
            lit(0).cast("integer").alias("place_of_service_concept_id"),
            
            col("address_id").cast("bigint").alias("location_id"),
            
            # Use care_site_cd as source value
            col("care_site_cd").cast("string").alias("care_site_source_value"),
            
            # Use facility_name as place_of_service_source_value
            col("facility_name").alias("place_of_service_source_value")
        ))

@dlt.table(
    name="qual_omop_care_site",
    comment="Quality metrics for the care site table"
)
def qual_omop_care_site():
    """
    Tracks various quality metrics for the care site table
    """
    care_site_data = dlt.read("omop_care_site")
    
    # Calculate quality metrics
    return (care_site_data
            .agg(
                count("*").alias("total_records"),
                count("location_id").alias("records_with_location"),
                (count("*") - count("location_id")).alias("records_missing_location"),
                count(when(col("location_id").isNull(), 1)).alias("null_location_count"),
                avg(when(col("location_id").isNotNull(), 1).otherwise(0)).alias("location_fill_rate")
            ))

@dlt.table(
    name="valid_care_site_locations",
    comment="Validates location references in care site table"
)
@dlt.expect_or_fail("valid_location_references", "invalid_location_count = 0")
def valid_care_site_locations():
    """
    Validates that all location_ids in care_site table exist in location table
    """
    care_sites = dlt.read("omop_care_site")
    locations = dlt.read("omop_location")
    
    # Count invalid location references
    return (care_sites
            .join(locations, 
                  care_sites.location_id == locations.location_id, 
                  "left")
            .agg(
                count(
                    when(
                        (care_sites.location_id.isNotNull()) & 
                        (locations.location_id.isNull()), 
                        1
                    )
                ).alias("invalid_location_count")
            ))

In [0]:


# Define OMOP CDM schema for provider table
provider_schema = StructType([
    StructField("provider_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Provider."}),
    StructField("provider_name", StringType(), True,
                metadata={"comment": "A description of the Provider, typically the name of the physician or facility."}),
    StructField("npi", StringType(), True,
                metadata={"comment": "The National Provider Identifier (NPI) of the provider."}),
    StructField("dea", StringType(), True,
                metadata={"comment": "The Drug Enforcement Administration (DEA) number of the provider."}),
    StructField("specialty_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Standard Specialty Concept ID in the Standardized Vocabularies."}),
    StructField("care_site_id", LongType(), True,
                metadata={"comment": "A foreign key to the main Care Site where the provider is practicing."}),
    StructField("year_of_birth", IntegerType(), True,
                metadata={"comment": "The year of birth of the Provider."}),
    StructField("gender_concept_id", IntegerType(), True,
                metadata={"comment": "The gender of the Provider."}),
    StructField("provider_source_value", StringType(), True,
                metadata={"comment": "The identifier used for the Provider in the source data."}),
    StructField("specialty_source_value", StringType(), True,
                metadata={"comment": "The source code for the Provider specialty as it appears in the source data."}),
    StructField("specialty_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Concept that refers to the code used in the source."}),
    StructField("gender_source_value", StringType(), True,
                metadata={"comment": "The source value for the Provider gender."}),
    StructField("gender_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Concept that refers to the code used in the source."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_provider_rules = {
    "valid_provider_id": "provider_id IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_provider_rules = {
    "valid_provider_name_length": "provider_name IS NULL OR LENGTH(provider_name) <= 255",
    "valid_npi_length": "npi IS NULL OR LENGTH(npi) <= 20",
    "valid_dea_length": "dea IS NULL OR LENGTH(dea) <= 20",
    "valid_specialty_concept": "specialty_concept_id IS NULL OR specialty_concept_id >= 0",
    "valid_gender_concept": "gender_concept_id IS NULL OR gender_concept_id >= 0",
    "valid_source_value_length": "provider_source_value IS NULL OR LENGTH(provider_source_value) <= 50",
    "valid_specialty_source_length": "specialty_source_value IS NULL OR LENGTH(specialty_source_value) <= 50",
    "valid_gender_source_length": "gender_source_value IS NULL OR LENGTH(gender_source_value) <= 50"
}

@dlt.table(
    name="omop_provider_base",
    comment="Initial provider table before care site validation",
    schema=provider_schema,
    temporary=True
)
@dlt.expect_all_or_drop(mandatory_provider_rules)
@dlt.expect_all(advisory_provider_rules)
def create_provider_base():
    """Creates base provider records from source medical personnel data"""
    
    # Load source medical personnel data
    medical_personnel = spark.table("4_prod.bronze.map_medical_personnel")
    
    return (medical_personnel
        .select(
            # Use PERSON_ID as provider_id
            col("PERSON_ID").cast("bigint").alias("provider_id"),
            
            # Combine position and service info for provider name
            concat_ws(" - ", 
                     col("position_name"),
                     col("MEDSERVICE")
            ).alias("provider_name"),
            
            # NPI and DEA not available in source
            lit(None).cast("string").alias("npi"),
            lit(None).cast("string").alias("dea"),
            
            # No specialty mapping available, set to 0
            lit(0).cast("integer").alias("specialty_concept_id"),
            
            # Link to care site - convert 0 to null
            when(col("primary_care_site_cd").cast("bigint") == 0, lit(None))
            .otherwise(col("primary_care_site_cd").cast("bigint"))
            .alias("care_site_id"),
            
            # Demographics not available in source
            lit(None).cast("integer").alias("year_of_birth"),
            lit(0).cast("integer").alias("gender_concept_id"),
            
            # Source values
            col("PERSON_ID").cast("string").alias("provider_source_value"),
            
            # Store specialty source values
            coalesce(
                col("MEDSERVICE"),
                col("SRVCATEGORY"),
                col("SURGSPEC")
            ).alias("specialty_source_value"),
            
            lit(0).cast("integer").alias("specialty_source_concept_id"),
            
            # Gender information not available
            lit(None).cast("string").alias("gender_source_value"),
            lit(0).cast("integer").alias("gender_source_concept_id")
        ))

@dlt.table(
    name="omop_provider",
    comment="OMOP CDM Provider table - Contains a list of uniquely identified healthcare providers",
    schema=provider_schema,
    table_properties={"quality": "gold"}
)
def create_omop_provider():
    """
    Creates the final provider table with validated care site references
    """
    providers = dlt.read("omop_provider_base")
    care_sites = dlt.read("omop_care_site") \
        .select("care_site_id") \
        .distinct()
        
    # Join with care_sites to validate care_site_id and set to null if not found
    return (providers
        .join(care_sites, providers.care_site_id == care_sites.care_site_id, "left")
        .select(
            providers["*"],
            when(care_sites.care_site_id.isNotNull(), providers.care_site_id)
            .otherwise(lit(None)).alias("valid_care_site_id")
        )
        .drop("care_site_id")
        .withColumnRenamed("valid_care_site_id", "care_site_id"))

@dlt.table(
    name="qual_omop_provider", 
    comment="Quality metrics for the provider table"
)
def qual_omop_provider():
    """
    Creates a quality metrics table for providers
    """
    provider_data = dlt.read("omop_provider")
    
    return (provider_data
            .agg(
                count("*").alias("total_providers"),
                count("care_site_id").alias("providers_with_care_site"),
                count(when(col("specialty_concept_id") == 0, 1)).alias("unmapped_specialties"),
                count(when(col("gender_concept_id") == 0, 1)).alias("unmapped_genders"),
                (count("*") - count("care_site_id")).alias("providers_without_care_site")
            ))

@dlt.view(
    name="qual_omop_provider_duplicates",
    comment="Identifies potential duplicate provider records"
)
def qual_omop_provider_duplicates():
    """
    Identifies potential duplicate providers based on matching source values
    """
    provider_data = dlt.read("omop_provider")
    
    return (provider_data
            .groupBy("provider_source_value")
            .agg(
                count("*").alias("record_count"),
                collect_set("provider_id").alias("provider_ids"),
                collect_set("provider_name").alias("provider_names")
            )
            .where("record_count > 1"))

In [0]:


person_schema = StructType([
    StructField("person_id", LongType(), False,
                metadata={"comment": "A unique identifier for each person."}),
    StructField("gender_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key that refers to a standard concept identifier in the Vocabulary for the gender of the person."}),
    StructField("year_of_birth", IntegerType(), True,
                metadata={"comment": "The year of birth of the person."}),
    StructField("month_of_birth", IntegerType(), True,
                metadata={"comment": "The month of birth of the person."}),
    StructField("day_of_birth", IntegerType(), True,
                metadata={"comment": "The day of birth of the person."}),
    StructField("birth_datetime", TimestampType(), True,
                metadata={"comment": "The date and time of birth of the person."}),
    StructField("race_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key that refers to a standard concept identifier in the Vocabulary for the race of the person."}),
    StructField("ethnicity_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key that refers to the standard concept identifier in the Vocabulary for the ethnicity of the person."}),
    StructField("location_id", LongType(), True,
                metadata={"comment": "A foreign key to the location table that indicates where the person is located."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider table that indicates the primary care provider of the person."}),
    StructField("care_site_id", LongType(), True,
                metadata={"comment": "A foreign key to the care site table that indicates the primary care site of the person."}),
    StructField("person_source_value", StringType(), True,
                metadata={"comment": "The source code for the person as it appears in the source data."}),
    StructField("gender_source_value", StringType(), True,
                metadata={"comment": "The source code for the gender of the person as it appears in the source data."}),
    StructField("gender_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."}),
    StructField("race_source_value", StringType(), True,
                metadata={"comment": "The source code for the race of the person as it appears in the source data."}),
    StructField("race_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."}),
    StructField("ethnicity_source_value", StringType(), True,
                metadata={"comment": "The source code for the ethnicity of the person as it appears in the source data."}),
    StructField("ethnicity_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_person_rules = {
    "valid_person_id": "person_id IS NOT NULL",
    "valid_gender_concept": "gender_concept_id IS NOT NULL",
    "valid_year_of_birth": "year_of_birth IS NOT NULL",
    "valid_race_concept": "race_concept_id IS NOT NULL",
    "valid_ethnicity_concept": "ethnicity_concept_id IS NOT NULL",
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_person_rules = {
    "reasonable_birth_year": "year_of_birth >= 1901",
    "valid_birth_month": "month_of_birth IS NULL OR (month_of_birth >= 1 AND month_of_birth <= 12)",
    "valid_birth_day": "day_of_birth IS NULL OR (day_of_birth >= 1 AND day_of_birth <= 31)",
    "valid_gender_value": "gender_concept_id > 0",
    "valid_race_value": "race_concept_id >= 0",
    "valid_ethnicity_value": "ethnicity_concept_id >= 0",
    "valid_source_values": """
        person_source_value IS NULL OR LENGTH(person_source_value) <= 50
    """
}

@dlt.table(
    name="valid_encounters",
    comment="Valid person IDs from encounters",
    temporary=True
)
def get_valid_encounters():
    """Get distinct valid person IDs from encounters"""
    return (spark.table("4_prod.bronze.map_encounter")
            .select("person_id")
            .distinct())


@dlt.table(
    name="gender_concept_maps",
    comment="Gender concept mappings",
    temporary=True
)
def get_gender_maps():
    """Get gender concept mappings"""
    return (spark.table("3_lookup.omop.barts_new_maps")
            .filter(col("OMOPField") == "gender_concept_id")
            .select(
                col("SourceValue").alias("gender_source"),
                col("OmopConceptId").alias("gender_omop_id")
            ))

@dlt.table(
    name="race_concept_maps",
    comment="Race concept mappings",
    temporary=True
)
def get_race_maps():
    """Get race concept mappings"""
    return (spark.table("3_lookup.omop.barts_new_maps")
            .filter(col("OMOPField") == "race_concept_id")
            .select(
                col("SourceValue").alias("race_source"),
                col("OmopConceptId").alias("race_omop_id")
            ))

@dlt.table(
    name="omop_person",
    comment="OMOP CDM Person table - Contains records that uniquely identify each person in the database",
    schema=person_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_person_rules)
@dlt.expect_all(advisory_person_rules)
def create_omop_person():
    """
    Creates the OMOP Person table with necessary data quality validations.
    Invalid records are dropped based on business rules.
    """
    # Load and filter source person data
    person_df = (spark.table("4_prod.bronze.map_person")
        .filter(
            (col("birth_year").isNotNull()) & 
            (col("gender_cd").isNotNull()) &
            (col("gender_cd") != "0") &
            (col("birth_year") >= 1901)
        ))
    
    # Get reference data
    valid_persons = dlt.read("valid_encounters")
    gender_maps = dlt.read("gender_concept_maps")
    race_maps = dlt.read("race_concept_maps")
    
    # Transform to OMOP format with all validations
    return (person_df
        .join(valid_persons, "person_id", "inner")
        .join(gender_maps, 
              person_df.gender_cd == gender_maps.gender_source, 
              "left")
        .join(race_maps,
              person_df.ethnicity_cd == race_maps.race_source,
              "left")
        .select(
            col("person_id").cast("bigint"),
            coalesce(col("gender_omop_id"), lit(0)).alias("gender_concept_id")
                .cast("integer"),
            col("birth_year").alias("year_of_birth").cast("integer"),
            lit(None).cast("integer").alias("month_of_birth"),
            lit(None).cast("integer").alias("day_of_birth"),
            lit(None).cast("timestamp").alias("birth_datetime"),
            coalesce(col("race_omop_id"), lit(0)).alias("race_concept_id")
                .cast("integer"),
            lit(0).cast("integer").alias("ethnicity_concept_id"),
            col("address_id").alias("location_id").cast("bigint"),
            lit(None).cast("bigint").alias("provider_id"),
            lit(None).cast("bigint").alias("care_site_id"),
            col("person_id").cast("string").alias("person_source_value"),
            col("gender_cd").cast("string").alias("gender_source_value"),
            lit(0).cast("integer").alias("gender_source_concept_id"),
            col("ethnicity_cd").cast("string").alias("race_source_value"),
            lit(0).cast("integer").alias("race_source_concept_id"),
            lit(None).cast("string").alias("ethnicity_source_value"),
            lit(0).cast("integer").alias("ethnicity_source_concept_id")
        ))

@dlt.table(
    name="valid_omop_person_location",
    comment="Validates location references in person table"
)
@dlt.expect_or_fail("valid_location_references", "invalid_location_count = 0")
def valid_omop_person_location():
    """
    Validates that all location_ids in person table exist in location table
    """
    persons = dlt.read("omop_person")
    locations = dlt.read("omop_location")
    
    # Count invalid location references
    return (persons
            .join(locations, 
                  persons.location_id == locations.location_id, 
                  "left")
            .agg(
                count(
                    when(
                        (persons.location_id.isNotNull()) & 
                        (locations.location_id.isNull()), 
                        1
                    )
                ).alias("invalid_location_count")
            ))

@dlt.table(
    name="qual_omop_person",
    comment="Quality metrics for person table"
)
def qual_omop_person():
    """Tracks quality metrics for the person table"""
    person_data = dlt.read("omop_person")
    
    return (person_data.agg(
        count("*").alias("total_persons"),
        count("location_id").alias("persons_with_location"),
        avg(when(col("gender_concept_id") > 0, 1).otherwise(0))
            .alias("gender_mapping_rate"),
        avg(when(col("race_concept_id") > 0, 1).otherwise(0))
            .alias("race_mapping_rate"),
        count(when(col("year_of_birth") < 1900, 1)).alias("invalid_birth_years")
    ))

@dlt.table(
    name="summ_omop_person_demographics",
    comment="Demographic summary statistics"
)
def summ_omop_person_demographics():
    """Calculates demographic summary statistics"""
    person_data = dlt.read("omop_person")
    
    return (person_data.groupBy("gender_concept_id", "race_concept_id")
            .agg(
                count("*").alias("person_count"),
                avg("year_of_birth").alias("avg_birth_year"),
                min("year_of_birth").alias("min_birth_year"),
                max("year_of_birth").alias("max_birth_year")
            ))

In [0]:


visit_schema = StructType([
    StructField("visit_occurrence_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Person's visit or encounter at a healthcare provider."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person who is having the visit."}),
    StructField("visit_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a visit concept identifier in the Standardized Vocabularies."}),
    StructField("visit_start_date", DateType(), True,
                metadata={"comment": "The start date of the visit."}),
    StructField("visit_start_datetime", TimestampType(), True,
                metadata={"comment": "The start date and time of the visit."}),
    StructField("visit_end_date", DateType(), True,
                metadata={"comment": "The end date of the visit."}),
    StructField("visit_end_datetime", TimestampType(), True,
                metadata={"comment": "The end date and time of the visit."}),
    StructField("visit_type_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the type of source data from which the visit record is derived."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider in the provider table who was associated with the visit."}),
    StructField("care_site_id", LongType(), True,
                metadata={"comment": "A foreign key to the care site in the care site table that was visited."}),
    StructField("visit_source_value", StringType(), True,
                metadata={"comment": "The source code for the visit as it appears in the source data."}),
    StructField("visit_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."}),
    StructField("admitted_from_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept in the Place of Service vocabulary indicating where the person was admitted from."}),
    StructField("admitted_from_source_value", StringType(), True,
                metadata={"comment": "The source code for the admitted from concept as it appears in the source data."}),
    StructField("discharged_to_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept in the Place of Service vocabulary indicating where the person was discharged to."}),
    StructField("discharged_to_source_value", StringType(), True,
                metadata={"comment": "The source code for the discharged to concept as it appears in the source data."}),
    StructField("preceding_visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit occurrence that immediately preceded this visit."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_visit_rules = {
    "valid_visit_id": "visit_occurrence_id IS NOT NULL",
    "valid_person_id": "person_id IS NOT NULL",
    "valid_visit_concept": "visit_concept_id IS NOT NULL",
    "valid_start_date": "visit_start_date IS NOT NULL",
    "valid_end_date": "visit_end_date IS NOT NULL",
    "valid_type_concept": "visit_type_concept_id IS NOT NULL",
    "valid_dates": "visit_start_date <= visit_end_date"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_visit_rules = {
    "valid_visit_concept_value": "visit_concept_id > 0",
    "valid_type_concept_value": "visit_type_concept_id > 0",
    "valid_source_values": """
        visit_source_value IS NULL OR LENGTH(visit_source_value) <= 50
    """
}

@dlt.table(
    name="valid_persons_with_birth",
    comment="Valid person IDs and birth dates",
    temporary=True
)
def get_valid_persons_with_birth():
    """Gets valid person IDs with birth dates for date validation"""
    return (dlt.read("omop_person")
           .select("person_id",
                  date_format(
                      concat(
                          col("year_of_birth").cast("string"), 
                          lit("-01-01")
                      ),
                      "yyyy-MM-dd"
                  ).alias("birth_date")))

@dlt.table(
    name="death_dates",
    comment="Death dates for visit validation",
    temporary=True
)
def get_death_dates():
    """Gets death dates for visit end date validation"""
    return (spark.table("4_prod.bronze.map_death")
           .select("person_id", 
                  coalesce(
                      col("DECEASED_DT_TM"),
                      col("CALC_DEATH_DATE")
                  ).alias("death_date")))

@dlt.table(
    name="visit_type_mapping",
    comment="Visit type concept mappings",
    temporary=True
)
def get_visit_type_mapping():
    """Creates visit type concept mapping logic"""
    return (spark.createDataFrame([
        ("Inpatient Pre-Admission", 9201),
        ("Outpatient Pre-registration", 9202),
        ("Day Case", 8883),
        ("Day Case Waiting List", 8883),
        ("Outpatient Referral", 9202),
        ("Regular Day Admission", 581476),
        ("Outpatient Services", 9202),
        ("Results Only", 32036),
        ("Research", 38004259),
        ("Emergency Department", 9203),
        ("Outpatient", 9202),
        ("Inpatient Waiting List", 9201),
        ("Maternity", 8650),
        ("Day Care", 38004210),
        ("Inpatient", 9201),
        ("Newborn", 581384),
        ("Regular Night Admission", 9201)
    ], ["encntr_type_desc", "concept_id"]))

@dlt.table(
    name="base_visit_occurrence",
    comment="Initial visit records before reference validation",
    temporary=True
)
def create_base_visit_occurrence():
    """Creates initial visit occurrence records with date validations"""
    # Load source data
    encounters = spark.table("4_prod.bronze.map_encounter")
    
    # Get reference data
    valid_persons = dlt.read("valid_persons_with_birth")
    death_dates = dlt.read("death_dates")
    visit_types = dlt.read("visit_type_mapping")
    
    # Join and transform
    visit_df = (encounters.alias("e")
        .join(valid_persons.alias("p"), "person_id")
        .join(death_dates.alias("d"), "person_id", "left")
        .join(visit_types.alias("v"), 
              col("e.encntr_type_desc") == col("v.encntr_type_desc"), 
              "left")
        .select(
            col("e.encntr_id").alias("visit_occurrence_id").cast("bigint"),
            col("e.person_id").cast("bigint"),
            coalesce(col("v.concept_id"), lit(0)).alias("visit_concept_id")
                .cast("integer"),  # Changed from bigint to integer
            # Rest of the selections...
            when(col("e.arrive_dt_tm") < col("p.birth_date"), 
                 col("p.birth_date"))
            .otherwise(col("e.arrive_dt_tm"))
            .alias("visit_start_date").cast("date"),
            when(col("e.arrive_dt_tm") < col("p.birth_date"), 
                 col("p.birth_date"))
            .otherwise(col("e.arrive_dt_tm"))
            .alias("visit_start_datetime").cast("timestamp"),
            when(col("d.death_date").isNotNull(), 
                 least(col("e.depart_dt_tm"), col("d.death_date")))
            .otherwise(col("e.depart_dt_tm"))
            .alias("visit_end_date").cast("date"),
            when(col("d.death_date").isNotNull(), 
                 least(col("e.depart_dt_tm"), col("d.death_date")))
            .otherwise(col("e.depart_dt_tm"))
            .alias("visit_end_datetime").cast("timestamp"),
            lit(32817).cast("integer").alias("visit_type_concept_id"),
            lit(None).cast("integer").alias("admitted_from_concept_id"),
            lit(None).cast("string").alias("admitted_from_source_value"),
            lit(None).cast("integer").alias("discharged_to_concept_id"),
            lit(None).cast("string").alias("discharged_to_source_value"),
            when(col("e.reg_prsnl_id") == 0, None)
            .otherwise(col("e.reg_prsnl_id")).alias("provider_id").cast("bigint"),
            when(col("e.loc_nurse_unit_cd") == 0, None)
            .otherwise(col("e.loc_nurse_unit_cd")).alias("care_site_id")
                .cast("bigint"),
            col("e.encntr_type_desc").alias("visit_source_value"),
            lit(0).cast("integer").alias("visit_source_concept_id")
        ))
    
    min_date = lit("1901-01-01").cast("date")
    
    # Additional pass to fix any remaining date inconsistencies
    return (visit_df
        .select(
            "*",
            # Apply all date constraints
            when(col("visit_start_date") < min_date, min_date)
            .when(col("visit_start_date") > current_date(), current_date())
            .otherwise(col("visit_start_date"))
            .alias("valid_start_date"),
            
            when(col("visit_start_datetime") < to_timestamp(min_date), 
                 to_timestamp(min_date))
            .when(col("visit_start_datetime") > current_timestamp(), 
                  current_timestamp())
            .otherwise(col("visit_start_datetime"))
            .alias("valid_start_datetime"),
            
            when(col("visit_end_date") < min_date, min_date)
            .when(col("visit_end_date") > current_date(), current_date())
            .when(col("visit_end_date") < col("visit_start_date"), 
                  col("visit_start_date"))
            .otherwise(col("visit_end_date"))
            .alias("valid_end_date"),
            
            when(col("visit_end_datetime") < to_timestamp(min_date), 
                 to_timestamp(min_date))
            .when(col("visit_end_datetime") > current_timestamp(), 
                  current_timestamp())
            .when(col("visit_end_datetime") < col("visit_start_datetime"), 
                  col("visit_start_datetime"))
            .otherwise(col("visit_end_datetime"))
            .alias("valid_end_datetime")
        )
        .drop("visit_start_date", "visit_start_datetime", 
              "visit_end_date", "visit_end_datetime")
        .withColumnRenamed("valid_start_date", "visit_start_date")
        .withColumnRenamed("valid_start_datetime", "visit_start_datetime")
        .withColumnRenamed("valid_end_date", "visit_end_date")
        .withColumnRenamed("valid_end_datetime", "visit_end_datetime"))

@dlt.table(
    name="omop_visit_occurrence",
    comment="OMOP CDM Visit Occurrence table - Contains records of Events where Persons engage with the healthcare system for a duration of time",
    schema=visit_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_visit_rules)
@dlt.expect_all(advisory_visit_rules)
def create_omop_visit_occurrence():
    """
    Creates the final visit occurrence table with all validations
    """
    base_visits = dlt.read("base_visit_occurrence")
    
    # Get valid references
    valid_providers = dlt.read("omop_provider").select("provider_id").distinct()
    valid_care_sites = dlt.read("omop_care_site").select("care_site_id").distinct()
    
    # Validate provider references
    visit_df_with_valid_provider = (base_visits
        .join(valid_providers, "provider_id", "left_anti")
        .withColumn("provider_id", lit(None))
        .unionByName(
            base_visits.join(valid_providers, "provider_id", "inner")
        ))
    
    # Validate care site references
    visit_df_with_valid_refs = (visit_df_with_valid_provider
        .join(valid_care_sites, "care_site_id", "left_anti")
        .withColumn("care_site_id", lit(None))
        .unionByName(
            visit_df_with_valid_provider
            .join(valid_care_sites, "care_site_id", "inner")
        ))
    
    # Add preceding visit reference
    visit_window = (Window.partitionBy("person_id")
                         .orderBy("visit_start_datetime")
                         .rowsBetween(-1, -1))
    
    return (visit_df_with_valid_refs
        .withColumn("preceding_visit_occurrence_id", 
                   lag("visit_occurrence_id", 1).over(visit_window)))

@dlt.table(
    name="qual_omop_visit",
    comment="Quality metrics for visit occurrence table"
)
def qual_omop_visit():
    """Tracks quality metrics for visit occurrences"""
    visit_data = dlt.read("omop_visit_occurrence")
    
    return (visit_data.agg(
        count("*").alias("total_visits"),
        count("provider_id").alias("visits_with_provider"),
        count("care_site_id").alias("visits_with_care_site"),
        count(when(col("visit_concept_id") == 0, 1))
            .alias("unmapped_visit_types"),
        count(when(col("preceding_visit_occurrence_id").isNotNull(), 1))
            .alias("visits_with_preceding")
    ))

@dlt.table(
    name="summ_omop_visit_length_metrics",
    comment="Visit length of stay analysis"
)
def analyze_visit_lengths():
    """Analyzes visit lengths by type"""
    visit_data = dlt.read("omop_visit_occurrence")
    
    return (visit_data
        .withColumn("length_of_stay_days",
                   datediff(col("visit_end_date"), col("visit_start_date")))
        .groupBy("visit_concept_id", "visit_source_value")
        .agg(
            count("*").alias("visit_count"),
            avg("length_of_stay_days").alias("avg_length_of_stay"),
            min("length_of_stay_days").alias("min_length_of_stay"),
            max("length_of_stay_days").alias("max_length_of_stay"),
            expr("percentile_approx(length_of_stay_days, 0.5)")
                .alias("median_length_of_stay")
        ))

In [0]:


condition_schema = StructType([
    StructField("condition_occurrence_id", LongType(), False,
                metadata={"comment": "A unique identifier for each condition occurrence event."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the person who is experiencing the condition."}),
    StructField("condition_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a standard condition concept identifier in the Vocabulary."}),
    StructField("condition_start_date", DateType(), True,
                metadata={"comment": "The date when the instance of the condition is recorded."}),
    StructField("condition_start_datetime", TimestampType(), True,
                metadata={"comment": "The date and time when the instance of the condition is recorded."}),
    StructField("condition_end_date", DateType(), True,
                metadata={"comment": "The date when the instance of the condition is considered to have ended."}),
    StructField("condition_end_datetime", TimestampType(), True,
                metadata={"comment": "The date and time when the instance of the condition is considered to have ended."}),
    StructField("condition_type_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the source data from which the condition was recorded."}),
    StructField("condition_status_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the status of the condition."}),
    StructField("stop_reason", StringType(), True,
                metadata={"comment": "The reason that the condition was no longer present."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider who was responsible for determining the condition."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit in the VISIT_OCCURRENCE table during which the condition was determined."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit detail record during which the condition was determined."}),
    StructField("condition_source_value", StringType(), True,
                metadata={"comment": "The source value for the condition as it appears in the source data."}),
    StructField("condition_source_concept_id", LongType(), True,
                metadata={"comment": "A foreign key to a condition concept that refers to the code used in the source."}),
    StructField("condition_status_source_value", StringType(), True,
                metadata={"comment": "The source value for the condition status as it appears in the source data."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_condition_rules = {
    "valid_condition_id": "condition_occurrence_id IS NOT NULL",
    "valid_person": "person_id IS NOT NULL",
    "valid_concept": "condition_concept_id IS NOT NULL",
    "valid_start_date": "condition_start_date IS NOT NULL",
    "valid_type_concept": "condition_type_concept_id IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_condition_rules = {
    "valid_concept_value": "condition_concept_id > 0",
    "valid_type_concept_value": "condition_type_concept_id > 0",
    "valid_status_concept": "condition_status_concept_id IS NULL OR condition_status_concept_id >= 0",
    "valid_date_range": "condition_end_date IS NULL OR condition_end_date >= condition_start_date",
    "valid_stop_reason": "stop_reason IS NULL OR LENGTH(stop_reason) <= 20"
}

@dlt.table(
    name="valid_condition_concepts",
    comment="Valid condition concepts from OMOP vocabulary",
    temporary=True
)
@dlt.expect_or_fail("valid_domain", "domain_id = 'Condition'")
def get_valid_condition_concepts():
    """Gets valid condition concepts from the OMOP vocabulary"""
    return (spark.table("3_lookup.omop.concept")
            .filter((col("invalid_reason").isNull()) & 
                   (col("domain_id") == "Condition"))
            .select("concept_id")
            .distinct())

@dlt.table(
    name="condition_diagnosis_base",
    comment="Base condition records from diagnosis data",
    temporary=True
)
def create_condition_diagnosis():
    """Creates base condition records from diagnosis data"""
    diagnosis = spark.table("4_prod.bronze.map_diagnosis")
    valid_persons = dlt.read("omop_person").select("person_id").distinct()
    valid_concepts = dlt.read("valid_condition_concepts")
    
    return (diagnosis
        .join(valid_persons, "person_id", "inner")
        .join(valid_concepts, 
              col("omop_concept_id") == col("concept_id"),
              "inner")
        .select(
            col("person_id").cast("bigint"),
            # Change this line to cast to integer instead of bigint
            col("omop_concept_id").alias("condition_concept_id").cast("integer"),
            col("diag_dt_tm").alias("condition_start_datetime").cast("timestamp"),
            lit(32817).cast("integer").alias("condition_type_concept_id"),
            when(col("confirmation_status_desc") == "Confirmed", 2)
            .otherwise(0).cast("integer").alias("condition_status_concept_id"),
            lit(None).cast("string").alias("stop_reason"),
            when(col("diag_prsnl_id") == 0, None)
            .otherwise(col("diag_prsnl_id")).alias("provider_id").cast("bigint"),
            col("encntr_id").alias("visit_occurrence_id").cast("bigint"),
            lit(None).cast("bigint").alias("visit_detail_id"),
            substring(col("source_string"), 1, 50).alias("condition_source_value"),
            when(col("source_vocabulary_cd") == "ICD10", 
                 col("source_identifier")).otherwise(lit(None))
            .alias("condition_source_concept_id").cast("bigint"),
            substring(col("confirmation_status_desc"), 1, 50).alias("condition_status_source_value")
        ))

@dlt.table(
    name="condition_problem_base",
    comment="Base condition records from problem data",
    temporary=True
)
def create_condition_problem():
    """Creates base condition records from problem data"""
    problem = spark.table("4_prod.bronze.map_problem")
    valid_persons = dlt.read("omop_person").select("person_id").distinct()
    valid_concepts = dlt.read("valid_condition_concepts")
    
    return (problem
        .filter((col("OMOP_CONCEPT_DOMAIN") == "Condition") & 
                (col("OMOP_CONCEPT_ID").isNotNull()) & 
                (col("CALC_ENCNTR").isNotNull()))
        .join(valid_persons, "person_id", "inner")
        .join(valid_concepts, 
              col("OMOP_CONCEPT_ID") == col("concept_id"),
              "inner")
        .select(
            col("person_id").cast("bigint"),
            # Change this line to cast to integer instead of bigint
            col("OMOP_CONCEPT_ID").alias("condition_concept_id").cast("integer"),
            col("CALC_DT_TM").alias("condition_start_datetime").cast("timestamp"),
            lit(32817).cast("integer").alias("condition_type_concept_id"),
            lit(0).cast("integer").alias("condition_status_concept_id"),
            lit(None).cast("string").alias("stop_reason"),
            when(col("ACTIVE_STATUS_PRSNL_ID").isin([0, 1]), None)
            .otherwise(col("ACTIVE_STATUS_PRSNL_ID")).alias("provider_id").cast("bigint"),
            col("CALC_ENCNTR").alias("visit_occurrence_id").cast("bigint"),
            lit(None).cast("bigint").alias("visit_detail_id"),
            lit(None).cast("string").alias("condition_source_value"),
            lit(None).cast("bigint").alias("condition_source_concept_id"),
            lit(None).cast("string").alias("condition_status_source_value")
        ))

@dlt.table(
    name="omop_condition_occurrence",
    comment="OMOP CDM Condition Occurrence table - Contains records of Events suggesting the presence of a disease or medical condition",
    schema=condition_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_condition_rules)
@dlt.expect_all(advisory_condition_rules)
def create_omop_condition_occurrence():
    """
    Creates the final condition occurrence table with proper validation
    and referential integrity checks
    """
    # Combine diagnosis and problem conditions
    combined_conditions = (dlt.read("condition_diagnosis_base")
                         .unionAll(dlt.read("condition_problem_base"))
                         .distinct())
    
    # Add row_number for condition_occurrence_id
    window_spec = Window.orderBy("person_id", "condition_start_datetime")
    condition_df_with_id = (combined_conditions
        .withColumn("condition_occurrence_id", 
                   row_number().over(window_spec).cast("bigint")))
    
    # Get reference data for validation
    valid_visits = dlt.read("omop_visit_occurrence").select(
        "visit_occurrence_id", 
        "visit_start_datetime", 
        "visit_end_datetime"
    )
    valid_providers = dlt.read("omop_provider").select("provider_id").distinct()
    
    # Join with visit data and adjust condition start times
    condition_df_with_visits = (condition_df_with_id
        .join(valid_visits, "visit_occurrence_id", "left")
        .withColumn(
            "condition_start_datetime",
            when((col("visit_occurrence_id").isNotNull()) & 
                 (col("condition_start_datetime") > col("visit_end_datetime")),
                 col("visit_end_datetime"))
            .when((col("visit_occurrence_id").isNotNull()) & 
                 (col("condition_start_datetime") < col("visit_start_datetime")),
                 col("visit_start_datetime"))
            .otherwise(col("condition_start_datetime")))
        .withColumn("condition_start_date", 
                   col("condition_start_datetime").cast("date"))
        .withColumn("condition_end_date", lit(None).cast("date"))
        .withColumn("condition_end_datetime", lit(None).cast("timestamp")))
    
    
    condition_df = (condition_df_with_visits.join(broadcast(valid_providers), "provider_id", "left_anti")
        .withColumn("provider_id", lit(None))
        .unionByName(
            condition_df_with_visits
            .join(valid_providers, "provider_id", "inner")
        )
        .withColumn("condition_source_value", 
               substring(col("condition_source_value"), 1, 50))
        .withColumn("condition_status_source_value", 
               substring(col("condition_status_source_value"), 1, 50))
        .drop("visit_start_datetime", "visit_end_datetime"))


    # Validate provider references
    return condition_df

@dlt.table(
    name="qual_omop_condition",
    comment="Quality metrics for condition occurrence table"
)
def qual_omop_condition():
    """Tracks quality metrics for conditions"""
    condition_data = dlt.read("omop_condition_occurrence")
    
    return (condition_data.agg(
        count("*").alias("total_conditions"),
        count("provider_id").alias("conditions_with_provider"),
        count("visit_occurrence_id").alias("conditions_with_visit"),
        count(when(col("condition_status_concept_id") == 2, 1))
            .alias("confirmed_conditions"),
        count(when(col("condition_source_concept_id").isNotNull(), 1))
            .alias("mapped_source_concepts")
    ))

@dlt.table(
    name="summ_omop_condition_by_visit",
    comment="Analyzes conditions per visit"
)
def analyze_conditions_per_visit():
    """Analyzes conditions recorded per visit"""
    condition_data = dlt.read("omop_condition_occurrence")
    
    return (condition_data
        .filter(col("visit_occurrence_id").isNotNull())
        .groupBy("visit_occurrence_id")
        .agg(
            count("*").alias("condition_count"),
            count_distinct("condition_concept_id").alias("unique_condition_count"),
            collect_set("condition_concept_id").alias("condition_concepts")
        ))

In [0]:
drug_schema = StructType([
    StructField("drug_exposure_id", LongType(), False,
                metadata={"comment": "A unique identifier for each drug exposure event."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the person who is subjected to the drug."}),
    StructField("drug_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a standard drug concept identifier in the Vocabulary."}),
    StructField("drug_exposure_start_date", DateType(), True,
                metadata={"comment": "The start date for the current instance of drug exposure."}),
    StructField("drug_exposure_start_datetime", TimestampType(), True,
                metadata={"comment": "The start date and time for the current instance of drug exposure."}),
    StructField("drug_exposure_end_date", DateType(), True,
                metadata={"comment": "The end date for the current instance of drug exposure."}),
    StructField("drug_exposure_end_datetime", TimestampType(), True,
                metadata={"comment": "The end date and time for the current instance of drug exposure."}),
    StructField("verbatim_end_date", DateType(), True,
                metadata={"comment": "The end date of the drug exposure as it appears in the source data."}),
    StructField("drug_type_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the type of drug exposure."}),
    StructField("stop_reason", StringType(), True,
                metadata={"comment": "The reason the drug exposure was stopped."}),
    StructField("refills", IntegerType(), True,
                metadata={"comment": "The number of refills after the initial prescription."}),
    StructField("quantity", FloatType(), True,
                metadata={"comment": "The quantity of drug as recorded in the source data."}),
    StructField("days_supply", IntegerType(), True,
                metadata={"comment": "The number of days of supply of the medication."}),
    StructField("sig", StringType(), True,
                metadata={"comment": "The directions (signatur) on the drug prescription as recorded in the source."}),
    StructField("route_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a predefined concept in the Standardized Vocabularies reflecting the route of administration."}),
    StructField("lot_number", StringType(), True,
                metadata={"comment": "The identifier to determine where the product originated."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider in the provider table who prescribed the drug."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit in the visit table during which the drug exposure initiated."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit detail record during which the drug exposure initiated."}),
    StructField("drug_source_value", StringType(), True,
                metadata={"comment": "The source code for the drug as it appears in the source data."}),
    StructField("drug_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a drug concept that refers to the code used in the source."}),
    StructField("route_source_value", StringType(), True,
                metadata={"comment": "The source code for the route as it appears in the source data."}),
    StructField("dose_unit_source_value", StringType(), True,
                metadata={"comment": "The information about the dose unit as recorded in the source data."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_drug_rules = {
    "valid_drug_id": "drug_exposure_id IS NOT NULL",
    "valid_person": "person_id IS NOT NULL",
    "valid_concept": "drug_concept_id IS NOT NULL",
    "valid_start_date": "drug_exposure_start_date IS NOT NULL",
    "valid_end_date": "drug_exposure_end_date IS NOT NULL", 
    "valid_type_concept": "drug_type_concept_id IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_drug_rules = {
    "valid_concept_value": "drug_concept_id > 0",
    "valid_type_concept_value": "drug_type_concept_id > 0",
    "valid_route": "route_concept_id IS NULL OR route_concept_id >= 0",
    "valid_dates": "drug_exposure_end_date >= drug_exposure_start_date",
    "valid_quantity": "quantity IS NULL OR quantity > 0",
    "valid_days_supply": "days_supply IS NULL OR days_supply > 0",
    "valid_stop_reason": "stop_reason IS NULL OR LENGTH(stop_reason) <= 20",
    "valid_source_values": "drug_source_value IS NOT NULL"
}

@dlt.table(
    name="valid_drug_concepts",
    comment="Valid drug concepts from OMOP vocabulary",
    temporary=True
)
@dlt.expect_or_fail("valid_domain", "domain_id = 'Drug'")
def get_valid_drug_concepts():
    """Gets valid drug concepts from the OMOP vocabulary"""
    return (spark.table("3_lookup.omop.concept")
            .filter((col("invalid_reason").isNull()) & 
                   (col("domain_id") == "Drug"))
            .select("concept_id")
            .distinct())

@dlt.table(
    name="route_concept_maps",
    comment="Route concept mappings",
    temporary=True
)
def get_route_maps():
    """Gets route concept mappings"""
    return (spark.table("3_lookup.omop.barts_new_maps")
            .filter(col("OMOPField") == "route_concept_id")
            .select(
                col("SourceValue").alias("route_source"),
                col("OmopConceptId").alias("route_omop_id")
            ))

@dlt.table(
    name="base_drug_exposure",
    comment="Base drug exposure records before validation",
    temporary=True
)
def create_base_drug_exposure():
    """Creates base drug exposure records from medication administration data"""
    # Load source data with initial filtering
    medications = (spark.table("4_prod.bronze.map_med_admin")
                  .filter(col("event_type_display") == "Administered")
                  .filter(col("omop_concept_id").isNotNull()))
    
    # Get reference data
    valid_persons = dlt.read("omop_person").select("person_id").distinct()
    valid_drugs = dlt.read("valid_drug_concepts")
    route_maps = dlt.read("route_concept_maps")
    
    # Transform to OMOP format with initial validation
    return (medications
        .join(valid_persons, "person_id", "inner")
        .join(valid_drugs, 
              col("omop_concept_id") == col("concept_id"),
              "inner")
        .join(route_maps, 
              col("ADMIN_ROUTE_DISPLAY") == col("route_source"), 
              "left")
        .select(
            col("person_id").cast("bigint"),
            col("omop_concept_id").alias("drug_concept_id").cast("integer"),
            col("admin_start_dt_tm")
                .alias("drug_exposure_start_datetime").cast("timestamp"),
            col("admin_end_dt_tm")
                .alias("drug_exposure_end_datetime").cast("timestamp"),
            lit(32817).cast("integer").alias("drug_type_concept_id"),
            lit(None).cast("string").alias("stop_reason"),
            lit(None).cast("integer").alias("refills"),
            lit(None).cast("float").alias("quantity"),
            lit(None).cast("integer").alias("days_supply"),
            lit(None).cast("string").alias("sig"),
            coalesce(col("route_omop_id"), lit(0))
                .alias("route_concept_id").cast("integer"),
            lit(None).cast("string").alias("lot_number"),
            when(col("prsnl_id") == 0, None)
            .otherwise(col("prsnl_id")).alias("provider_id").cast("bigint"),
            col("encntr_id").alias("visit_occurrence_id").cast("bigint"),
            lit(None).cast("bigint").alias("visit_detail_id"),
            col("ORDER_MNEMONIC").alias("drug_source_value"),
            lit(0).cast("integer").alias("drug_source_concept_id"),
            col("admin_route_display").alias("route_source_value"),
            col("initial_dosage_unit_display").alias("dose_unit_source_value")
        ))

@dlt.table(
    name="omop_drug_exposure",
    comment="OMOP CDM Drug Exposure table - Contains records about the exposure to a Drug through prescriptions or administration",
    schema=drug_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_drug_rules)
@dlt.expect_all(advisory_drug_rules)
def create_omop_drug_exposure():
    """
    Creates the final drug exposure table with proper validation
    and referential integrity checks
    """
    # Get base records
    drug_df = dlt.read("base_drug_exposure")
    
    # Add row_number for drug_exposure_id
    window_spec = Window.orderBy("person_id", "drug_exposure_start_datetime")
    drug_df_with_id = (drug_df
        .withColumn("drug_exposure_id", 
                   row_number().over(window_spec).cast("bigint")))
    
    # Get reference data for validation
    valid_visits = dlt.read("omop_visit_occurrence").select(
        "visit_occurrence_id", 
        "visit_start_datetime", 
        "visit_end_datetime"
    )
    valid_providers = dlt.read("omop_provider").select("provider_id").distinct()
    
    # Join with visit data and adjust drug exposure times
    drug_df_with_visits = (drug_df_with_id
        .join(valid_visits, "visit_occurrence_id", "left")
        .withColumn(
            "drug_exposure_start_datetime",
            when((col("visit_occurrence_id").isNotNull()) & 
                 (col("drug_exposure_start_datetime") > col("visit_end_datetime")),
                 col("visit_end_datetime"))
            .when((col("visit_occurrence_id").isNotNull()) & 
                 (col("drug_exposure_start_datetime") < col("visit_start_datetime")),
                 col("visit_start_datetime"))
            .otherwise(col("drug_exposure_start_datetime")))
        .withColumn(
            "drug_exposure_end_datetime",
            when((col("visit_occurrence_id").isNotNull()) & 
                 (col("drug_exposure_end_datetime") > col("visit_end_datetime")),
                 col("visit_end_datetime"))
            .when((col("visit_occurrence_id").isNotNull()) & 
                 (col("drug_exposure_end_datetime") < col("visit_start_datetime")),
                 col("visit_start_datetime"))
            .otherwise(col("drug_exposure_end_datetime")))
        .withColumn("drug_exposure_start_date", 
                   col("drug_exposure_start_datetime").cast("date"))
        .withColumn("drug_exposure_end_date", 
                   col("drug_exposure_end_datetime").cast("date")))
    
    # Validate provider references
    return (drug_df_with_visits
        .join(broadcast(valid_providers), "provider_id", "left_anti")
        .withColumn("provider_id", lit(None))
        .unionByName(
            drug_df_with_visits
            .join(valid_providers, "provider_id", "inner")
        )
        .drop("visit_start_datetime", "visit_end_datetime"))

@dlt.table(
    name="qual_omop_drug_exposure",
    comment="Quality metrics for drug exposure table"
)
def qual_omop_drug_exposure():
    """Tracks quality metrics for drug exposures"""
    drug_data = dlt.read("omop_drug_exposure")
    
    return (drug_data.agg(
        count("*").alias("total_exposures"),
        count("provider_id").alias("exposures_with_provider"),
        count("visit_occurrence_id").alias("exposures_with_visit"),
        count(when(col("route_concept_id") > 0, 1))
            .alias("mapped_routes"),
        avg(when(col("drug_exposure_end_datetime").isNotNull(), 
                datediff(col("drug_exposure_end_date"), 
                        col("drug_exposure_start_date"))))
            .alias("avg_exposure_days"),
        count(when(col("drug_exposure_end_datetime").isNull(), 1))
            .alias("missing_end_dates")
    ))

@dlt.table(
    name="summ_omop_drug_routes",
    comment="Summary of drug administration routes"
)
def analyze_drug_routes():
    """Analyzes drug administration routes"""
    drug_data = dlt.read("omop_drug_exposure")
    
    return (drug_data
        .groupBy("route_concept_id", "route_source_value")
        .agg(
            count("*").alias("exposure_count"),
            count_distinct("drug_concept_id").alias("unique_drugs"),
            count_distinct("person_id").alias("unique_patients")
        )
        .orderBy(desc("exposure_count")))

@dlt.table(
    name="summ_omop_drug_frequency",
    comment="Analysis of drug administration frequency"
)
def analyze_drug_frequency():
    """Analyzes frequency of drug administrations per visit"""
    drug_data = dlt.read("omop_drug_exposure")
    
    return (drug_data
        .filter(col("visit_occurrence_id").isNotNull())
        .groupBy("visit_occurrence_id", "drug_concept_id")
        .agg(
            count("*").alias("administration_count"),
            min("drug_exposure_start_datetime").alias("first_admin"),
            max("drug_exposure_end_datetime").alias("last_admin")
        ))

In [0]:

procedure_schema = StructType([
    StructField("procedure_occurrence_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Procedure Occurrence event."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person who is subjected to the Procedure."}),
    StructField("procedure_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a standard procedure Concept identifier in the Vocabulary."}),
    StructField("procedure_date", DateType(), False,
                metadata={"comment": "The date on which the Procedure was performed."}),
    StructField("procedure_datetime", TimestampType(), True,
                metadata={"comment": "The date and time on which the Procedure was performed."}),
    StructField("procedure_end_date", DateType(), True,
                metadata={"comment": "The end date on which the Procedure was performed."}),
    StructField("procedure_end_datetime", TimestampType(), True,
                metadata={"comment": "The end date and time on which the Procedure was performed."}),
    StructField("procedure_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined Concept identifier in the Standardized Vocabularies reflecting the type of source data from which the procedure record is derived."}),
    StructField("modifier_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Standard Concept identifier for a modifier to the Procedure."}),
    StructField("quantity", IntegerType(), True,
                metadata={"comment": "The quantity of procedures ordered or administered."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the Provider in the PROVIDER table who was responsible for carrying out the procedure."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the Visit in the VISIT_OCCURRENCE table during which the Procedure was carried out."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the Visit Detail in the VISIT_DETAIL table during which the Procedure was carried out."}),
    StructField("procedure_source_value", StringType(), True,
                metadata={"comment": "The procedure as it appears in the source data."}),
    StructField("procedure_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Procedure Concept that refers to the code used in the source."}),
    StructField("modifier_source_value", StringType(), True,
                metadata={"comment": "The source code for the modifier as it appears in the source data."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_procedure_rules = {
    "valid_procedure_id": "procedure_occurrence_id IS NOT NULL",
    "valid_person": "person_id IS NOT NULL",
    "valid_procedure": "procedure_concept_id IS NOT NULL",
    "valid_date": "procedure_date IS NOT NULL",
    "valid_type_concept": "procedure_type_concept_id IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_procedure_rules = {
    "valid_concept_value": "procedure_concept_id > 0",
    "valid_type_concept_value": "procedure_type_concept_id > 0",
    "valid_modifier": "modifier_concept_id IS NULL OR modifier_concept_id >= 0",
    "valid_quantity": "quantity IS NULL OR quantity > 0",
    "valid_dates": "procedure_end_date IS NULL OR procedure_end_date >= procedure_date",
    "valid_source_values": """
        procedure_source_value IS NULL OR LENGTH(procedure_source_value) <= 50
    """
}

@dlt.table(
    name="valid_procedure_concepts",
    comment="Valid procedure concepts from OMOP vocabulary",
    temporary=True
)
@dlt.expect_or_fail("valid_domain", "domain_id = 'Procedure'")
def get_valid_procedure_concepts():
    """Gets valid procedure concepts from the OMOP vocabulary"""
    return (spark.table("3_lookup.omop.concept")
            .filter((col("invalid_reason").isNull()) & 
                   (col("domain_id") == "Procedure"))
            .select("concept_id")
            .distinct())

@dlt.table(
    name="base_procedure_occurrence",
    comment="Base procedure occurrence records before validation",
    temporary=True
)
def create_base_procedure_occurrence():
    """Creates base procedure records with initial validation"""
    
    # Load source procedure data
    procedures = spark.table("4_prod.bronze.map_procedure")
    
    # Get reference data
    valid_persons = dlt.read("omop_person").select("person_id").distinct()
    valid_procedures = dlt.read("valid_procedure_concepts")
    
    # Transform to OMOP format with initial validation
    return (procedures
        .join(valid_persons, "person_id", "inner")
        .filter(col("omop_concept_id").isNotNull())
        .join(valid_procedures, 
              col("omop_concept_id") == col("concept_id"),
              "inner")
        .select(
            col("person_id").cast("bigint"),
            col("omop_concept_id").alias("procedure_concept_id").cast("integer"),
            col("proc_dt_tm").alias("procedure_datetime").cast("timestamp"),
            lit(None).cast("timestamp").alias("procedure_end_datetime"),
            lit(32817).cast("integer").alias("procedure_type_concept_id"),
            lit(0).cast("integer").alias("modifier_concept_id"),
            col("proc_minutes").alias("quantity").cast("integer"),
            when(col("active_status_prsnl_id") == 0, None)
            .otherwise(col("active_status_prsnl_id"))
            .alias("provider_id").cast("bigint"),
            col("encntr_id").alias("visit_occurrence_id").cast("bigint"),
            lit(None).cast("bigint").alias("visit_detail_id"),
            substring(col("source_string"), 1, 50)
            .alias("procedure_source_value"),
            lit(0).cast("integer").alias("procedure_source_concept_id"),
            substring(lit(None).cast("string"), 1, 50)
            .alias("modifier_source_value")
        ))

@dlt.table(
    name="omop_procedure_occurrence",
    comment="OMOP CDM Procedure Occurrence table - Contains records of activities or processes ordered by, or carried out by, a healthcare provider on the patient with a diagnostic or therapeutic purpose",
    schema=procedure_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_procedure_rules)
@dlt.expect_all(advisory_procedure_rules)
def create_omop_procedure_occurrence():
    """
    Creates the final procedure occurrence table with proper validation
    and referential integrity checks
    """
    # Get base records
    procedure_df = dlt.read("base_procedure_occurrence")
    
    # Add row_number for procedure_occurrence_id
    window_spec = Window.orderBy(
        "person_id", 
        "procedure_datetime",
        "procedure_concept_id",
        "provider_id",
        "visit_occurrence_id"
    )
    procedure_df_with_id = (procedure_df
        .withColumn("procedure_occurrence_id", 
                   row_number().over(window_spec).cast("bigint")))
    
    # Get reference data for validation
    valid_visits = dlt.read("omop_visit_occurrence").select(
        "visit_occurrence_id", 
        "visit_start_datetime", 
        "visit_end_datetime"
    )
    valid_providers = dlt.read("omop_provider").select("provider_id").distinct()
    
    # Join with visit data and adjust procedure times
    procedure_df_with_visits = (procedure_df_with_id
        .join(valid_visits, "visit_occurrence_id", "left")
        .withColumn(
            "procedure_datetime",
            when((col("visit_occurrence_id").isNotNull()) & 
                 (col("procedure_datetime") > col("visit_end_datetime")),
                 col("visit_end_datetime"))
            .when((col("visit_occurrence_id").isNotNull()) & 
                 (col("procedure_datetime") < col("visit_start_datetime")),
                 col("visit_start_datetime"))
            .otherwise(col("procedure_datetime")))
        .withColumn("procedure_date", 
                   col("procedure_datetime").cast("date"))
        .withColumn("procedure_end_date", 
                   lit(None).cast("date")))
    
    # Validate provider references
    return (procedure_df_with_visits
        .join(broadcast(valid_providers), "provider_id", "left_anti")
        .withColumn("provider_id", lit(None))
        .unionByName(
            procedure_df_with_visits
            .join(valid_providers, "provider_id", "inner")
        )
        .drop("visit_start_datetime", "visit_end_datetime"))

@dlt.table(
    name="qual_omop_procedure",
    comment="Quality metrics for procedure occurrence table"
)
def qual_omop_procedure():
    """Tracks quality metrics for procedures"""
    procedure_data = dlt.read("omop_procedure_occurrence")
    
    return (procedure_data.agg(
        count("*").alias("total_procedures"),
        count("provider_id").alias("procedures_with_provider"),
        count("visit_occurrence_id").alias("procedures_with_visit"),
        count(when(col("quantity").isNotNull(), 1))
            .alias("procedures_with_duration"),
        avg("quantity").alias("avg_procedure_minutes"),
        count(when(col("procedure_source_concept_id") > 0, 1))
            .alias("mapped_source_concepts")
    ))

@dlt.table(
    name="summ_omop_procedure_by_visit",
    comment="Analysis of procedures per visit"
)
def analyze_procedures_per_visit():
    """Analyzes procedures performed per visit"""
    procedure_data = dlt.read("omop_procedure_occurrence")
    
    return (procedure_data
        .filter(col("visit_occurrence_id").isNotNull())
        .groupBy("visit_occurrence_id")
        .agg(
            count("*").alias("procedure_count"),
            count_distinct("procedure_concept_id")
                .alias("unique_procedure_count"),
            sum(when(col("quantity").isNotNull(), col("quantity"))
                .otherwise(0)).alias("total_procedure_minutes")
        ))

@dlt.table(
    name="summ_omop_procedure_frequency",
    comment="Analysis of most common procedures"
)
def analyze_procedure_frequency():
    """Analyzes most frequently performed procedures"""
    procedure_data = dlt.read("omop_procedure_occurrence")
    
    return (procedure_data
        .groupBy("procedure_concept_id", "procedure_source_value")
        .agg(
            count("*").alias("occurrence_count"),
            count_distinct("person_id").alias("unique_patients"),
            count_distinct("visit_occurrence_id").alias("unique_visits"),
            avg("quantity").alias("avg_duration_minutes")
        )
        .orderBy(desc("occurrence_count")))

In [0]:


# Data quality rules for device exposure
device_schema = StructType([
    StructField("device_exposure_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Device exposure event."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person who is subjected to the Device."}),
    StructField("device_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a Standard Device Concept identifier in the Vocabulary."}),
    StructField("device_exposure_start_date", DateType(), False,
                metadata={"comment": "The start date for the Device exposure."}),
    StructField("device_exposure_start_datetime", TimestampType(), True,
                metadata={"comment": "The start date and time for the Device exposure."}),
    StructField("device_exposure_end_date", DateType(), True,
                metadata={"comment": "The end date for the Device exposure."}),
    StructField("device_exposure_end_datetime", TimestampType(), True,
                metadata={"comment": "The end date and time for the Device exposure."}),
    StructField("device_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined Concept identifier in the Standardized Vocabularies reflecting the type of Device exposure."}),
    StructField("unique_device_id", StringType(), True,
                metadata={"comment": "The Unique Device Identification (UDI-DI) number for devices regulated by the FDA."}),
    StructField("production_id", StringType(), True,
                metadata={"comment": "The Production Identifier (UDI-PI) portion of the Unique Device Identification."}),
    StructField("quantity", IntegerType(), True,
                metadata={"comment": "The number of individual Devices used."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the Provider in the PROVIDER table who initiated the Device exposure."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the Visit in the VISIT_OCCURRENCE table during which the Device exposure initiated."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the Visit Detail in the VISIT_DETAIL table during which the Device exposure initiated."}),
    StructField("device_source_value", StringType(), True,
                metadata={"comment": "The source code for the Device as it appears in the source data."}),
    StructField("device_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Device Concept that refers to the code used in the source."}),
    StructField("unit_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a predefined Concept in the Standardized Vocabularies reflecting the unit the Device was administered."}),
    StructField("unit_source_value", StringType(), True,
                metadata={"comment": "The source code for the unit as it appears in the source data."}),
    StructField("unit_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Unit Concept that refers to the code used in the source."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_device_rules = {
    "valid_device_id": "device_exposure_id IS NOT NULL",
    "valid_person": "person_id IS NOT NULL",
    "valid_device": "device_concept_id IS NOT NULL",
    "valid_start_date": "device_exposure_start_date IS NOT NULL",
    "valid_type_concept": "device_type_concept_id IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_device_rules = {
    "valid_concept_value": "device_concept_id > 0",
    "valid_type_concept_value": "device_type_concept_id > 0",
    "valid_quantity": "quantity IS NULL OR quantity > 0",
    "valid_unit": "unit_concept_id IS NULL OR unit_concept_id >= 0",
    "valid_dates": "device_exposure_end_date IS NULL OR device_exposure_end_date >= device_exposure_start_date",
    "valid_device_id_advis": "unique_device_id IS NULL OR LENGTH(unique_device_id) <= 255",
    "valid_production_id": "production_id IS NULL OR LENGTH(production_id) <= 255"
}


@dlt.table(
    name="valid_device_concepts",
    comment="Valid device concepts from OMOP vocabulary",
    temporary=True
)
@dlt.expect_or_fail("valid_domain", "domain_id = 'Device'")
def get_valid_device_concepts():
    """Gets valid device concepts from the OMOP vocabulary"""
    return (spark.table("3_lookup.omop.concept")
            .filter((col("invalid_reason").isNull()) & 
                   (col("domain_id") == "Device"))
            .select("concept_id")
            .distinct())

@dlt.table(
    name="valid_clinical_events",
    comment="Valid clinical events for device exposure",
    temporary=True
)
def get_valid_clinical_events():
    """Gets valid clinical events based on status and validity dates"""
    current_date_val = current_date()
    
    return (spark.table("4_prod.raw.mill_clinical_event")
            .filter((col("VALID_UNTIL_DT_TM") > current_date_val)) 
            .filter((col("RECORD_STATUS_CD") == 188))
            .select(
                "PERSON_ID", "EVENT_CD", "EVENT_TITLE_TEXT", 
                "CLINSIG_UPDT_DT_TM", "PERFORMED_PRSNL_ID", "ENCNTR_ID"
            ))

@dlt.table(
    name="device_code_maps",
    comment="Device mappings for code-based matches",
    temporary=True
)
def get_device_code_maps():
    """Gets device mappings for event code matches"""
    return (spark.table("3_lookup.omop.barts_new_maps")
            .filter((col("OMOPField") == "device_concept_id") & 
                   (col("SourceField") == "EVENT_CD"))
            .select(
                col("SourceValue").alias("event_cd_source"),
                col("OmopConceptId").alias("device_concept_id"),
                col("EVENT_CD").alias("map_event_cd")
            ))

@dlt.table(
    name="device_text_maps",
    comment="Device mappings for text-based matches",
    temporary=True
)
def get_device_text_maps():
    """Gets device mappings for event text matches"""
    return (spark.table("3_lookup.omop.barts_new_maps")
            .filter((col("OMOPField") == "device_concept_id") & 
                   (col("SourceField") == "EVENT_RESULT_TXT"))
            .select(
                col("SourceValue").alias("result_txt_source"),
                col("OmopConceptId").alias("device_concept_id"),
                col("EVENT_CD").alias("map_event_cd")
            ))

@dlt.table(
    name="base_device_exposure",
    comment="Combined device exposures from code and text matches",
    temporary=True
)
def create_base_device_exposure():
    """Creates base device exposure records from clinical events"""
    # Get source and reference data
    clinical_events = dlt.read("valid_clinical_events").alias("ce")
    valid_persons = dlt.read("omop_person").select(
        col("person_id").alias("valid_person_id")
    ).distinct()
    valid_devices = dlt.read("valid_device_concepts")
    code_maps = dlt.read("device_code_maps").alias("cm")
    text_maps = dlt.read("device_text_maps").alias("tm")
    
    # Create code-based matches
    code_matches = (clinical_events
        .join(code_maps, 
              (col("ce.EVENT_CD") == col("cm.event_cd_source")) &
              (col("cm.map_event_cd").isNull() | 
               (col("cm.map_event_cd") == col("ce.EVENT_CD"))),
              "inner")
        .select(
            col("ce.PERSON_ID"),
            col("cm.device_concept_id"),
            col("ce.CLINSIG_UPDT_DT_TM"),
            col("ce.PERFORMED_PRSNL_ID"),
            col("ce.ENCNTR_ID"),
            col("ce.EVENT_CD")
        ))
    
    # Create text-based matches
    text_matches = (clinical_events
        .join(text_maps,
              (col("ce.EVENT_TITLE_TEXT") == col("tm.result_txt_source")) &
              (col("tm.map_event_cd").isNull() | 
               (col("tm.map_event_cd") == col("ce.EVENT_CD"))),
              "inner")
        .select(
            col("ce.PERSON_ID"),
            col("tm.device_concept_id"),
            col("ce.CLINSIG_UPDT_DT_TM"),
            col("ce.PERFORMED_PRSNL_ID"),
            col("ce.ENCNTR_ID"),
            col("ce.EVENT_CD")
        ))
    
    # Combine matches and transform
    combined_matches = code_matches.unionAll(text_matches).alias("cm")
    
    return (combined_matches
        .join(valid_persons, 
              col("cm.PERSON_ID") == col("valid_person_id"),
              "inner")
        .join(valid_devices, 
              col("device_concept_id") == col("concept_id"),
              "inner")
        .select(
            col("cm.PERSON_ID").cast("bigint").alias("person_id"),
            col("cm.device_concept_id").cast("integer"),
            col("cm.CLINSIG_UPDT_DT_TM")
                .alias("device_exposure_start_datetime").cast("timestamp"),
            lit(None).cast("timestamp").alias("device_exposure_end_datetime"),
            lit(32817).cast("integer").alias("device_type_concept_id"),
            lit(None).cast("string").alias("unique_device_id"),
            lit(None).cast("string").alias("production_id"),
            lit(1).cast("integer").alias("quantity"),
            when(col("cm.PERFORMED_PRSNL_ID") == 0, None)
            .otherwise(col("cm.PERFORMED_PRSNL_ID"))
            .alias("provider_id").cast("bigint"),
            col("cm.ENCNTR_ID").alias("visit_occurrence_id").cast("bigint"),
            lit(None).cast("bigint").alias("visit_detail_id"),
            col("cm.EVENT_CD").cast("string").alias("device_source_value"),
            lit(0).cast("integer").alias("device_source_concept_id"),
            lit(None).cast("integer").alias("unit_concept_id"),
            lit(None).cast("string").alias("unit_source_value"),
            lit(0).cast("integer").alias("unit_source_concept_id")
        ))

@dlt.table(
    name="omop_device_exposure",
    comment="OMOP CDM Device Exposure table - Contains records about exposure to a foreign physical object or instrument used for diagnostic or therapeutic purposes",
    schema=device_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_device_rules)
@dlt.expect_all(advisory_device_rules)
def create_omop_device_exposure():
    """
    Creates the final device exposure table with proper validation
    and referential integrity checks
    """
    # Get base records
    device_df = dlt.read("base_device_exposure")
    
    # Add row_number for device_exposure_id
    window_spec = Window.orderBy("person_id", "device_exposure_start_datetime")
    device_df_with_id = (device_df
        .withColumn("device_exposure_id", 
                   row_number().over(window_spec).cast("bigint")))
    
    # Get reference data for validation
    valid_visits = dlt.read("omop_visit_occurrence").select(
        "visit_occurrence_id", 
        "visit_start_datetime", 
        "visit_end_datetime"
    )
    valid_providers = dlt.read("omop_provider").select("provider_id").distinct()
    
    # Join with visit data and adjust device exposure times
    device_df_with_visits = (device_df_with_id
        .join(valid_visits, "visit_occurrence_id", "left")
        .withColumn(
            "device_exposure_start_datetime",
            when((col("visit_occurrence_id").isNotNull()) & 
                 (col("device_exposure_start_datetime") > col("visit_end_datetime")),
                 col("visit_end_datetime"))
            .when((col("visit_occurrence_id").isNotNull()) & 
                 (col("device_exposure_start_datetime") < col("visit_start_datetime")),
                 col("visit_start_datetime"))
            .otherwise(col("device_exposure_start_datetime")))
        .withColumn("device_exposure_start_date", 
                   col("device_exposure_start_datetime").cast("date"))
        .withColumn("device_exposure_end_date", 
                   col("device_exposure_end_datetime").cast("date")))
    
    # Validate provider references
    return (device_df_with_visits
        .join(broadcast(valid_providers), "provider_id", "left_anti")
        .withColumn("provider_id", lit(None))
        .unionByName(
            device_df_with_visits
            .join(valid_providers, "provider_id", "inner")
        )
        .drop("visit_start_datetime", "visit_end_datetime"))

@dlt.table(
    name="qual_omop_device_exposure",
    comment="Quality metrics for device exposure table"
)
def qual_omop_device_exposure():
    """Tracks quality metrics for device exposures"""
    device_data = dlt.read("omop_device_exposure")
    
    return (device_data.agg(
        count("*").alias("total_exposures"),
        count("provider_id").alias("exposures_with_provider"),
        count("visit_occurrence_id").alias("exposures_with_visit"),
        count("unique_device_id").alias("exposures_with_unique_id"),
        count_distinct("device_concept_id").alias("unique_devices"),
        count_distinct("device_source_value").alias("unique_source_values")
    ))

@dlt.table(
    name="summ_omop_device_by_visit",
    comment="Analysis of devices per visit"
)
def analyze_devices_per_visit():
    """Analyzes device usage patterns per visit"""
    device_data = dlt.read("omop_device_exposure")
    
    return (device_data
        .filter(col("visit_occurrence_id").isNotNull())
        .groupBy("visit_occurrence_id")
        .agg(
            count("*").alias("device_count"),
            count_distinct("device_concept_id").alias("unique_device_count"),
            collect_set("device_concept_id").alias("device_concepts")
        ))

@dlt.table(
    name="summ_omop_device_frequency",
    comment="Analysis of device usage frequency"
)
def analyze_device_frequency():
    """Analyzes frequency of device usage"""
    device_data = dlt.read("omop_device_exposure")
    
    return (device_data
        .groupBy("device_concept_id", "device_source_value")
        .agg(
            count("*").alias("exposure_count"),
            count_distinct("person_id").alias("unique_patients"),
            count_distinct("visit_occurrence_id").alias("unique_visits"),
            count_distinct("provider_id").alias("unique_providers")
        )
        .orderBy(desc("exposure_count")))

In [0]:


# Data quality rules for measurements
measurement_schema = StructType([
    StructField("measurement_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Measurement."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person about whom the measurement was recorded."}),
    StructField("measurement_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the standard measurement concept identifier in the Vocabulary."}),
    StructField("measurement_date", DateType(), False,
                metadata={"comment": "The date of the measurement."}),
    StructField("measurement_datetime", TimestampType(), True,
                metadata={"comment": "The date and time of the measurement."}),
    StructField("measurement_time", StringType(), True,
                metadata={"comment": "The time of the measurement (in the event that MEASUREMENT_DATETIME is not well defined)."}),
    StructField("measurement_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the type of the measurement."}),
    StructField("operator_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a standard concept identifier for the mathematical operator applied to the value."}),
    StructField("value_as_number", FloatType(), True,
                metadata={"comment": "The measurement result stored as a number. This is applicable to measurements where the result is expressed as a numeric value."}),
    StructField("value_as_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a standard concept identifier for a categorical result."}),
    StructField("unit_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a standard concept identifier for the unit used in the measurement."}),
    StructField("range_low", FloatType(), True,
                metadata={"comment": "The lower limit of the normal range of the measurement."}),
    StructField("range_high", FloatType(), True,
                metadata={"comment": "The upper limit of the normal range of the measurement."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider in the PROVIDER table who was responsible for taking the measurement."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit in the VISIT_OCCURRENCE table during which the measurement was taken."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit detail in the VISIT_DETAIL table during which the measurement was taken."}),
    StructField("measurement_source_value", StringType(), True,
                metadata={"comment": "The measurement name as it appears in the source data."}),
    StructField("measurement_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."}),
    StructField("unit_source_value", StringType(), True,
                metadata={"comment": "The source code for the unit as it appears in the source data."}),
    StructField("unit_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the unit code used in the source."}),
    StructField("value_source_value", StringType(), True,
                metadata={"comment": "The source value associated with the structured value stored as numeric or concept."}),
    StructField("measurement_event_id", LongType(), True,
                metadata={"comment": "A foreign key to the MEASUREMENT_EVENT table."}),
    StructField("meas_event_field_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the field in the MEASUREMENT_EVENT table."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_measurement_rules = {
    "valid_measurement_id": "measurement_id IS NOT NULL",
    "valid_person": "person_id IS NOT NULL",
    "valid_concept": "measurement_concept_id IS NOT NULL",
    "valid_date": "measurement_date IS NOT NULL",
    "valid_type_concept": "measurement_type_concept_id IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_measurement_rules = {
    "valid_concept_value": "measurement_concept_id > 0",
    "valid_type_concept_value": "measurement_type_concept_id > 0",
    "valid_operator": "operator_concept_id IS NULL OR operator_concept_id >= 0",
    "valid_value": """
        (value_as_number IS NOT NULL AND value_as_concept_id IS NULL) OR
        (value_as_number IS NULL AND value_as_concept_id IS NOT NULL) OR
        (value_as_number IS NULL AND value_as_concept_id IS NULL)
    """,
    "valid_unit": "unit_concept_id IS NULL OR unit_concept_id >= 0",
    "valid_range": "range_high IS NULL OR range_low IS NULL OR range_high >= range_low"
}

# First, create cached reference tables
@dlt.table(
    name="valid_persons_ref",
    comment="Cached valid person references",
    temporary=True
)
def get_valid_persons():
    return (dlt.read("omop_person")
            .select(col("person_id").alias("valid_person_id"))
            .distinct()
            .persist())

@dlt.table(
    name="valid_measurement_concepts_ref",
    comment="Cached valid measurement concepts",
    temporary=True
)
@dlt.expect_or_fail("valid_domain", "domain_id = 'Measurement'")
def get_valid_measurement_concepts():
    return (spark.table("3_lookup.omop.concept")
            .filter((col("invalid_reason").isNull()) & 
                   (col("domain_id") == "Measurement"))
            .select("concept_id")
            .distinct()
            .persist())

@dlt.table(
    name="valid_visits_ref",
    comment="Cached valid visit references",
    temporary=True
)
def get_valid_visits():
    return (dlt.read("omop_visit_occurrence")
            .select("visit_occurrence_id", 
                   "visit_start_datetime", 
                   "visit_end_datetime")
            .persist())

@dlt.table(
    name="valid_providers_ref",
    comment="Cached valid provider references",
    temporary=True
)
def get_valid_providers():
    return (dlt.read("omop_provider")
            .select("provider_id")
            .distinct()
            .persist())

# Helper functions
def get_measurement_date(performed_dt, clinsig_dt):
    return when(
        (performed_dt.isNotNull()) & 
        (clinsig_dt.isNotNull()) &
        (abs(months_between(performed_dt, clinsig_dt)) > 6),
        clinsig_dt
    ).otherwise(coalesce(performed_dt, clinsig_dt))

def process_common_columns(df, include_numeric=False):
    base_columns = [
        col("person_id").cast("bigint"),
        col("OMOP_MANUAL_CONCEPT").alias("measurement_concept_id").cast("integer"),
        get_measurement_date(
            col("PERFORMED_DT_TM"), 
            col("CLINSIG_UPDT_DT_TM")
        ).alias("measurement_datetime").cast("timestamp"),
        lit(32817).cast("integer").alias("measurement_type_concept_id"),
        when(col("PERFORMED_PRSNL_ID") == 0, None)
            .otherwise(col("PERFORMED_PRSNL_ID"))
            .alias("provider_id").cast("bigint"),
        col("ENCNTR_ID").alias("visit_occurrence_id").cast("bigint"),
        col("EVENT_CD_DISPLAY").alias("measurement_source_value"),
        col("EVENT_ID").cast("bigint").alias("measurement_event_id")
    ]
    
    if include_numeric:
        numeric_columns = [
            lit(None).cast("integer").alias("operator_concept_id"),
            col("NUMERIC_RESULT").alias("value_as_number").cast("float"),
            lit(None).cast("integer").alias("value_as_concept_id"),
            col("OMOP_MANUAL_UNITS").alias("unit_concept_id").cast("integer"),
            col("NORMAL_LOW").alias("range_low").cast("float"),
            col("NORMAL_HIGH").alias("range_high").cast("float"),
            col("UNIT_OF_MEASURE_DISPLAY").alias("unit_source_value"),
            col("NUMERIC_RESULT").cast("string").alias("value_source_value")
        ]
        base_columns.extend(numeric_columns)
    else:
        coded_columns = [
            lit(None).cast("integer").alias("operator_concept_id"),
            lit(None).cast("float").alias("value_as_number"),
            col("OMOP_MANUAL_VALUE_CONCEPT").alias("value_as_concept_id").cast("integer"),
            lit(None).cast("integer").alias("unit_concept_id"),
            lit(None).cast("float").alias("range_low"),
            lit(None).cast("float").alias("range_high"),
            lit(None).cast("string").alias("unit_source_value"),
            lit(None).cast("string").alias("value_source_value")
        ]
        base_columns.extend(coded_columns)
    
    return df.select(*base_columns)

@dlt.table(
    name="combined_source_measurements",
    comment="Combined measurements from all sources",
    temporary=True
)
def create_combined_measurements():
    # Get reference data
    valid_persons = dlt.read("valid_persons_ref")
    valid_measurements = dlt.read("valid_measurement_concepts_ref")
    
    # Process numeric events
    numeric_events = (spark.table("4_prod.bronze.map_numeric_events")
        .join(broadcast(valid_persons), 
              col("person_id") == col("valid_person_id"))
        .filter(col("OMOP_MANUAL_CONCEPT_DOMAIN") == "Measurement")
        .join(broadcast(valid_measurements),
              col("OMOP_MANUAL_CONCEPT") == col("concept_id")))
    
    numeric_processed = process_common_columns(numeric_events, include_numeric=True)
    
    # Process coded-type events
    def process_coded_source(table_name):
        return (spark.table(f"4_prod.bronze.map_{table_name}")
            .join(broadcast(valid_persons), 
                  col("person_id") == col("valid_person_id"))
            .filter(col("OMOP_MANUAL_CONCEPT_DOMAIN") == "Measurement")
            .join(broadcast(valid_measurements),
                  col("OMOP_MANUAL_CONCEPT") == col("concept_id")))
    
    coded_sources = [
        process_coded_source("coded_events"),
        process_coded_source("text_events"),
        process_coded_source("nomen_events")
    ]
    
    coded_processed = [process_common_columns(df) for df in coded_sources]
    
    # Combine all sources
    return (numeric_processed.unionAll(coded_processed[0])
            .unionAll(coded_processed[1])
            .unionAll(coded_processed[2])
            .repartition(200, "person_id", "measurement_datetime")
            .persist())

@dlt.table(
    name="omop_measurement",
    comment="OMOP CDM Measurement table",
    schema=measurement_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_measurement_rules)
@dlt.expect_all(advisory_measurement_rules)
def create_omop_measurement():
    # Get base measurements
    measurements = dlt.read("combined_source_measurements")
    
    # Deduplicate the source measurements first
    measurements_deduped = measurements.dropDuplicates([
        "person_id", "measurement_concept_id", "measurement_datetime", 
        "visit_occurrence_id", "provider_id"
    ])
    
    # Get reference data
    valid_visits = dlt.read("valid_visits_ref")
    valid_providers = dlt.read("valid_providers_ref")
    
    # Join with visit data and adjust measurement times
    measurements_with_visits = (measurements_deduped
        .join(broadcast(valid_visits), "visit_occurrence_id", "left")
        .withColumn(
            "measurement_datetime",
            when((col("visit_occurrence_id").isNotNull()) & 
                 (col("measurement_datetime") > col("visit_end_datetime")),
                 col("visit_end_datetime"))
            .when((col("visit_occurrence_id").isNotNull()) & 
                 (col("measurement_datetime") < col("visit_start_datetime")),
                 col("visit_start_datetime"))
            .otherwise(col("measurement_datetime")))
        .withColumn("measurement_date", 
                   col("measurement_datetime").cast("date")))
    
    # Use a global ordering for measurement_id
    window_spec = Window.orderBy(
        "person_id", 
        "measurement_datetime",
        "measurement_concept_id", 
        "visit_occurrence_id"
    )
    
    # Add measurement_id
    measurements_with_id = (measurements_with_visits
        .withColumn("measurement_id", 
                   row_number().over(window_spec).cast("bigint")))
    
    # Validate provider references with a single LEFT JOIN
    measurements_validated = measurements_with_id.alias("m")
    
    result = (measurements_validated
        .join(broadcast(valid_providers).alias("vp"), 
              measurements_validated.provider_id == valid_providers.provider_id, 
              "left")
        .select(
            col("m.measurement_id"),
            col("m.person_id"),
            col("m.measurement_concept_id"),
            col("m.measurement_date"),
            col("m.measurement_datetime"),
            lit(None).cast("string").alias("measurement_time"),
            col("m.measurement_type_concept_id"),
            col("m.operator_concept_id"),
            col("m.value_as_number"),
            col("m.value_as_concept_id"),
            col("m.unit_concept_id"),
            col("m.range_low"),
            col("m.range_high"),
            # Set provider_id to NULL if not found in valid_providers
            when(col("vp.provider_id").isNotNull(), col("m.provider_id"))
            .otherwise(lit(None)).alias("provider_id"),
            col("m.visit_occurrence_id"),
            lit(None).cast("bigint").alias("visit_detail_id"),
            col("m.measurement_source_value"),
            lit(0).cast("integer").alias("measurement_source_concept_id"),
            col("m.unit_source_value"),
            lit(0).cast("integer").alias("unit_source_concept_id"),
            col("m.value_source_value"),
            col("m.measurement_event_id"),
            lit(None).cast("integer").alias("meas_event_field_concept_id")
        )
        .drop("visit_start_datetime", "visit_end_datetime"))
    
    # Final deduplication to ensure no duplicates
    return result.dropDuplicates(["measurement_id"])



@dlt.table(
    name="summ_omop_measurement_by_type",
    comment="Analysis of measurements by concept"
)
def analyze_measurement_types():
    """Analyzes measurement patterns by concept"""
    measurement_data = dlt.read("omop_measurement")
    
    return (measurement_data
        .groupBy("measurement_concept_id", "measurement_source_value")
        .agg(
            count("*").alias("measurement_count"),
            count_distinct("person_id").alias("unique_patients"),
            count_distinct("visit_occurrence_id").alias("unique_visits"),
            mean("value_as_number").alias("avg_numeric_value"),
            stddev("value_as_number").alias("std_numeric_value"),
            collect_set("value_as_concept_id").alias("observed_values")
        )
        .orderBy(desc("measurement_count")))

@dlt.table(
    name="summ_omop_measurement_ranges",
    comment="Analysis of measurement ranges and outliers"
)
def analyze_measurement_ranges():
    """Analyzes measurements with ranges defined"""
    measurement_data = dlt.read("omop_measurement")
    
    return (measurement_data
        .filter(col("range_low").isNotNull() | 
                col("range_high").isNotNull())
        .groupBy("measurement_concept_id", "measurement_source_value")
        .agg(
            count("*").alias("measurements_with_ranges"),
            min("range_low").alias("min_range_low"),
            max("range_high").alias("max_range_high"),
            count(when(
                col("value_as_number").isNotNull() &
                ((col("value_as_number") < col("range_low")) |
                 (col("value_as_number") > col("range_high"))),
                1
            )).alias("out_of_range_count")
        ))

In [0]:


# Data quality rules for observations
observation_schema = StructType([
    StructField("observation_id", LongType(), False,
                metadata={"comment": "A unique identifier for each observation."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person about whom the observation was recorded."}),
    StructField("observation_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the standard observation concept identifier in the Vocabulary."}),
    StructField("observation_date", DateType(), False,
                metadata={"comment": "The date of the observation."}),
    StructField("observation_datetime", TimestampType(), True,
                metadata={"comment": "The date and time of the observation."}),
    StructField("observation_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the type of the observation."}),
    StructField("value_as_number", FloatType(), True,
                metadata={"comment": "The observation result stored as a number. This is applicable to observations where the result is expressed as a numeric value."}),
    StructField("value_as_string", StringType(), True,
                metadata={"comment": "The observation result stored as a string."}),
    StructField("value_as_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to an observation result stored as a Concept ID."}),
    StructField("qualifier_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a standard concept identifier for a qualifier."}),
    StructField("unit_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a standard concept identifier for the unit."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider in the PROVIDER table who was responsible for making the observation."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit in the VISIT_OCCURRENCE table during which the observation was recorded."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit detail in the VISIT_DETAIL table during which the observation was recorded."}),
    StructField("observation_source_value", StringType(), True,
                metadata={"comment": "The observation code as it appears in the source data."}),
    StructField("observation_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."}),
    StructField("unit_source_value", StringType(), True,
                metadata={"comment": "The source code for the unit as it appears in the source data."}),
    StructField("qualifier_source_value", StringType(), True,
                metadata={"comment": "The source value associated with a qualifier to characterize the observation."}),
    StructField("value_source_value", StringType(), True,
                metadata={"comment": "The source value associated with the structured value stored as numeric, string, or concept."}),
    StructField("observation_event_id", LongType(), True,
                metadata={"comment": "A foreign key to the event that caused this observation to be made."}),
    StructField("obs_event_field_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting how the event field is used."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_observation_rules = {
    "valid_observation_id": "observation_id IS NOT NULL",
    "valid_person": "person_id IS NOT NULL",
    "valid_concept": "observation_concept_id IS NOT NULL",
    "valid_date": "observation_date IS NOT NULL",
    "valid_type_concept": "observation_type_concept_id IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_observation_rules = {
    "valid_concept_value": "observation_concept_id > 0",
    "valid_type_concept_value": "observation_type_concept_id > 0",
    "valid_qualifier": "qualifier_concept_id IS NULL OR qualifier_concept_id >= 0",
    "valid_unit": "unit_concept_id IS NULL OR unit_concept_id >= 0",
    "valid_value": """
        (value_as_number IS NOT NULL) OR 
        (value_as_string IS NOT NULL) OR 
        (value_as_concept_id IS NOT NULL) OR
        (value_as_number IS NULL AND 
         value_as_string IS NULL AND 
         value_as_concept_id IS NULL)
    """,
    "valid_source_values": """
        observation_source_value IS NULL OR LENGTH(observation_source_value) <= 50
    """
}

@dlt.table(
    name="valid_observation_concepts",
    comment="Valid observation concepts from OMOP vocabulary",
    temporary=True
)
@dlt.expect_or_fail("valid_domain", "domain_id = 'Observation'")
def get_valid_observation_concepts():
    """Gets valid observation concepts from the OMOP vocabulary"""
    return (spark.table("3_lookup.omop.concept")
            .filter((col("invalid_reason").isNull()) & 
                   (col("domain_id") == "Observation"))
            .select("concept_id")
            .distinct())

@dlt.table(
    name="event_observations",
    comment="Observations from clinical events",
    temporary=True
)
def create_event_observations():
    """Creates observation records from clinical events"""
    
    def process_event_data(events_df, alias_prefix):
        """Common processing for different event types"""
        df = events_df.alias(f"{alias_prefix}_e")
        
        # Get validation data
        valid_persons = dlt.read("omop_person").select(
            col("person_id").alias("valid_person_id")
        ).distinct()
        valid_observations = dlt.read("valid_observation_concepts")
        
        # Define date selection logic
        date_selection = when(
            (col("PERFORMED_DT_TM").isNotNull()) & 
            (col("CLINSIG_UPDT_DT_TM").isNotNull()) &
            (abs(months_between(col("PERFORMED_DT_TM"), 
                              col("CLINSIG_UPDT_DT_TM"))) > 6),
            col("CLINSIG_UPDT_DT_TM")
        ).otherwise(coalesce(col("PERFORMED_DT_TM"), col("CLINSIG_UPDT_DT_TM")))
        
        return (df
            .join(valid_persons, 
                  col(f"{alias_prefix}_e.person_id") == col("valid_person_id"),
                  "inner")
            .filter(col("OMOP_MANUAL_CONCEPT_DOMAIN") == "Observation")
            .join(valid_observations,
                col("OMOP_MANUAL_CONCEPT") == col("concept_id"),
                "inner")
            .select(
                col(f"{alias_prefix}_e.person_id").cast("bigint"),
                col("OMOP_MANUAL_CONCEPT").alias("observation_concept_id")
                    .cast("integer"),
                date_selection.alias("observation_datetime").cast("timestamp"),
                lit(32817).cast("integer").alias("observation_type_concept_id"),
                lit(None).cast("float").alias("value_as_number"),
                lit(None).cast("string").alias("value_as_string"),
                col("OMOP_MANUAL_VALUE_CONCEPT").alias("value_as_concept_id")
                    .cast("integer"),
                lit(None).cast("integer").alias("qualifier_concept_id"),
                lit(None).cast("integer").alias("unit_concept_id"),
                when(col("PERFORMED_PRSNL_ID") == 0, None)
                .otherwise(col("PERFORMED_PRSNL_ID"))
                .alias("provider_id").cast("bigint"),
                col("ENCNTR_ID").alias("visit_occurrence_id").cast("bigint"),
                lit(None).cast("bigint").alias("visit_detail_id"),
                col("EVENT_CD_DISPLAY").alias("observation_source_value"),
                col("EVENT_CD").alias("observation_source_concept_id")
                    .cast("integer"),
                lit(None).cast("string").alias("unit_source_value"),
                lit(None).cast("string").alias("qualifier_source_value"),
                lit(None).cast("string").alias("value_source_value"),
                col("EVENT_ID").cast("bigint").alias("observation_event_id"),
                lit(None).cast("integer").alias("obs_event_field_concept_id")
            ))
    
    # Process each type of event
    coded_observations = process_event_data(
        spark.table("4_prod.bronze.map_coded_events"), "coded")
    text_observations = process_event_data(
        spark.table("4_prod.bronze.map_text_events"), "text")
    nomen_observations = process_event_data(
        spark.table("4_prod.bronze.map_nomen_events"), "nomen")
    
    # Combine all event observations
    return (coded_observations
            .unionAll(text_observations)
            .unionAll(nomen_observations)
            .distinct())

@dlt.table(
    name="problem_observations",
    comment="Observations from problems",
    temporary=True
)
def create_problem_observations():
    """Creates observation records from problem data"""
    # Get source and validation data
    problems = spark.table("4_prod.bronze.map_problem")
    valid_persons = dlt.read("omop_person").select(
        col("person_id").alias("valid_person_id")
    ).distinct()
    valid_observations = dlt.read("valid_observation_concepts")
    
    return (problems.alias("p")
        .filter((col("OMOP_CONCEPT_DOMAIN") == "Observation") & 
                (col("OMOP_CONCEPT_ID").isNotNull()) & 
                (col("CALC_ENCNTR").isNotNull()))
        .join(valid_persons, 
              col("p.person_id") == col("valid_person_id"),
              "inner")
        .join(valid_observations,
            col("OMOP_CONCEPT_ID") == col("concept_id"),
            "inner")
        .select(
            col("p.person_id").cast("bigint"),
            col("OMOP_CONCEPT_ID").alias("observation_concept_id").cast("integer"),
            col("CALC_DT_TM").alias("observation_datetime").cast("timestamp"),
            lit(32817).cast("integer").alias("observation_type_concept_id"),
            lit(None).cast("float").alias("value_as_number"),
            lit(None).cast("string").alias("value_as_string"),
            lit(None).cast("integer").alias("value_as_concept_id"),
            lit(None).cast("integer").alias("qualifier_concept_id"),
            lit(None).cast("integer").alias("unit_concept_id"),
            when(col("ACTIVE_STATUS_PRSNL_ID").isin([0, 1]), None)
            .otherwise(col("ACTIVE_STATUS_PRSNL_ID"))
            .alias("provider_id").cast("bigint"),
            col("CALC_ENCNTR").alias("visit_occurrence_id").cast("bigint"),
            lit(None).cast("integer").alias("visit_detail_id"),
            lit(None).cast("string").alias("observation_source_value"),
            lit(None).cast("integer").alias("observation_source_concept_id"),
            lit(None).cast("string").alias("unit_source_value"),
            lit(None).cast("string").alias("qualifier_source_value"),
            lit(None).cast("string").alias("value_source_value"),
            lit(None).cast("bigint").alias("observation_event_id"),
            lit(None).cast("integer").alias("obs_event_field_concept_id")
        ))

@dlt.table(
    name="omop_observation",
    comment="OMOP CDM Observation table - Contains clinical facts about a Person obtained in the context of examination, questioning or a procedure",
    schema=observation_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_observation_rules)
@dlt.expect_all(advisory_observation_rules)
def create_omop_observation():
    """Creates the final observation table with proper validation and referential integrity checks"""
    # Combine all observation sources
    combined_observations = (dlt.read("event_observations")
                           .unionAll(dlt.read("problem_observations"))
                           .distinct())
    
    # Get reference data for validation
    valid_visits = dlt.read("omop_visit_occurrence").select(
        "visit_occurrence_id", 
        "visit_start_datetime", 
        "visit_end_datetime"
    )
    valid_providers = dlt.read("omop_provider").select("provider_id").distinct()
    
    # Join with visit data and adjust observation times
    observations_with_visits = (combined_observations
        .join(valid_visits, "visit_occurrence_id", "left")
        .withColumn(
            "observation_datetime",
            when((col("visit_occurrence_id").isNotNull()) & 
                 (col("observation_datetime") > col("visit_end_datetime")),
                 col("visit_end_datetime"))
            .when((col("visit_occurrence_id").isNotNull()) & 
                 (col("observation_datetime") < col("visit_start_datetime")),
                 col("visit_start_datetime"))
            .otherwise(col("observation_datetime")))
        .withColumn("observation_date", 
                   col("observation_datetime").cast("date")))
    
    # Use a more granular window spec for observation_id generation
    window_spec = Window.orderBy(
        "person_id", 
        "observation_datetime", 
        "observation_concept_id",
        "visit_occurrence_id",
        "provider_id"
    )
    
    # Deduplicate before assigning IDs
    observations_deduplicated = observations_with_visits.dropDuplicates([
        "person_id", "observation_datetime", "observation_concept_id", 
        "visit_occurrence_id", "value_as_concept_id"
    ])
    
    observations_with_id = (observations_deduplicated
        .withColumn("observation_id", 
                   row_number().over(window_spec).cast("bigint")))
    
    # Simplify provider validation - use a single left join instead
    observations_validated = observations_with_id.alias("o")
    
    result = (observations_validated
        .join(broadcast(valid_providers).alias("vp"), 
              observations_validated.provider_id == valid_providers.provider_id, 
              "left")
        .select(
            col("o.observation_id"),
            col("o.person_id"), 
            col("o.observation_concept_id"),
            col("o.observation_date"),
            col("o.observation_datetime"),
            col("o.observation_type_concept_id"),
            col("o.value_as_number"),
            col("o.value_as_string"),
            col("o.value_as_concept_id"),
            col("o.qualifier_concept_id"),
            col("o.unit_concept_id"),
            # Set provider_id to NULL if not found in valid_providers
            when(col("vp.provider_id").isNotNull(), col("o.provider_id"))
            .otherwise(lit(None)).alias("provider_id"),
            col("o.visit_occurrence_id"),
            col("o.visit_detail_id"),
            col("o.observation_source_value"),
            col("o.observation_source_concept_id"),
            col("o.unit_source_value"),
            col("o.qualifier_source_value"),
            col("o.value_source_value"),
            col("o.observation_event_id"),
            col("o.obs_event_field_concept_id")
        )
        .drop("visit_start_datetime", "visit_end_datetime"))
    
    # Final deduplication to ensure no duplicates
    return result.dropDuplicates(["observation_id"])

@dlt.table(
    name="qual_omop_observation",
    comment="Quality metrics for observation table"
)
def qual_omop_observation():
    """Tracks quality metrics for observations"""
    observation_data = dlt.read("omop_observation")
    
    return (observation_data.agg(
        count("*").alias("total_observations"),
        count("provider_id").alias("observations_with_provider"),
        count("visit_occurrence_id").alias("observations_with_visit"),
        count("value_as_number").alias("numeric_observations"),
        count("value_as_string").alias("string_observations"),
        count("value_as_concept_id").alias("coded_observations"),
        count("qualifier_concept_id").alias("qualified_observations"),
        count("observation_source_value").alias("observations_with_source")
    ))

@dlt.table(
    name="summ_omop_observation_by_concept",
    comment="Analysis of observations by concept"
)
def analyze_observation_concepts():
    """Analyzes observation patterns by concept"""
    observation_data = dlt.read("omop_observation")
    
    return (observation_data
        .groupBy("observation_concept_id", "observation_source_value")
        .agg(
            count("*").alias("observation_count"),
            count_distinct("person_id").alias("unique_patients"),
            count_distinct("visit_occurrence_id").alias("unique_visits"),
            count("value_as_number").alias("numeric_values"),
            count("value_as_string").alias("string_values"),
            count("value_as_concept_id").alias("coded_values")
        )
        .orderBy(desc("observation_count")))

@dlt.table(
    name="summ_omop_observation_timeline",
    comment="Analysis of observation timing patterns"
)
def analyze_observation_timeline():
    """Analyzes timing patterns of observations"""
    observation_data = dlt.read("omop_observation")
    
    return (observation_data
        .filter(col("visit_occurrence_id").isNotNull())
        .groupBy("visit_occurrence_id")
        .agg(
            count("*").alias("observation_count"),
            count_distinct("observation_concept_id")
                .alias("unique_observation_count"),
            min("observation_datetime").alias("first_observation"),
            max("observation_datetime").alias("last_observation")
        ))

In [0]:


death_schema = StructType([
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the deceased Person."}),
    StructField("death_date", DateType(), False,
                metadata={"comment": "The date the person was deceased."}),
    StructField("death_datetime", TimestampType(), True,
                metadata={"comment": "The date and time the person was deceased."}),
    StructField("death_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting how the death was represented in the source data."}),
    StructField("cause_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the cause of death"}),
    StructField("cause_source_value", StringType(), True,
                metadata={"comment": "The source code for the cause of death as it appears in the source data."}),
    StructField("cause_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_death_rules = {
    "valid_person": "person_id IS NOT NULL",
    "valid_death_date": "death_date IS NOT NULL",
    "valid_type_concept": "death_type_concept_id IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_death_rules = {
    "valid_cause_concept": "cause_concept_id IS NULL OR cause_concept_id >= 0",
    "valid_source_concept": "cause_source_concept_id IS NULL OR cause_source_concept_id >= 0",
    "valid_dates": "death_datetime IS NULL OR CAST(death_date AS DATE) = CAST(death_datetime AS DATE)",
    "valid_source_length": "cause_source_value IS NULL OR LENGTH(cause_source_value) <= 50"
}

@dlt.table(
    name="omop_death",
    comment="OMOP CDM Death table - Contains the clinical event for how and when a Person dies",
    schema=death_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_death_rules)
@dlt.expect_all(advisory_death_rules)
def create_omop_death():
    """
    Creates the OMOP Death table from source death data.
    Implements data quality checks and handles date constraints.
    """
    # Load source death data
    death_data = spark.table("4_prod.bronze.map_death").alias("d")
    
    # Get valid person_ids
    valid_persons = dlt.read("omop_person").select(
        col("person_id").alias("valid_person_id")
    ).distinct()
    
    # Transform to OMOP format with validation
    omop_death = (death_data
        .join(valid_persons, 
              col("d.PERSON_ID") == col("valid_person_id"),
              "inner")
        .select(
            col("d.PERSON_ID").cast("bigint").alias("person_id"),
            # Use DECEASED_DT_TM if available, otherwise use CALC_DEATH_DATE
            coalesce(
                col("DECEASED_DT_TM"),
                col("CALC_DEATH_DATE")
            ).cast("date").alias("death_date"),
            
            # Keep the full datetime if available
            col("CALC_DEATH_DATE").cast("timestamp")
                .alias("death_datetime"),
            
            lit(32817).cast("integer")
                .alias("death_type_concept_id"),  # EHR
            
            # No direct cause of death in source
            lit(0).cast("integer").alias("cause_concept_id"),
            
            # Store the source information
            concat_ws(" - ", 
                     col("DECEASED_SOURCE_DESC"),
                     col("DECEASED_METHOD_DESC")
            ).alias("cause_source_value"),
            
            lit(0).cast("integer").alias("cause_source_concept_id")
        ))
    
    # Apply additional date validation
    return (omop_death
        .where(coalesce(
            col("DECEASED_DT_TM"),
            col("CALC_DEATH_DATE")
        ).isNotNull())
        # Ensure death_date doesn't exceed current date
        .withColumn("death_date",
            when(col("death_date") > current_date(), current_date())
            .otherwise(col("death_date")))
        # Align death_datetime with death_date if needed
        .withColumn("death_datetime",
            when(col("death_datetime") > current_timestamp(), 
                 to_timestamp(col("death_date")))
            .otherwise(col("death_datetime"))))

@dlt.table(
    name="qual_omop_death",
    comment="Quality metrics for death table"
)
def qual_omop_death():
    """Tracks quality metrics for death records"""
    death_data = dlt.read("omop_death")
    
    return (death_data.agg(
        count("*").alias("total_deaths"),
        count("death_datetime").alias("deaths_with_precise_time"),
        count(when(col("cause_concept_id") > 0, 1))
            .alias("deaths_with_cause"),
        count("cause_source_value").alias("deaths_with_source_info"),
        min("death_date").alias("earliest_death"),
        max("death_date").alias("latest_death")
    ))

@dlt.table(
    name="summ_omop_death_by_year",
    comment="Death count analysis by year"
)
def analyze_deaths_by_year():
    """Analyzes death patterns by year"""
    death_data = dlt.read("omop_death")
    
    return (death_data
        .withColumn("death_year", year("death_date"))
        .groupBy("death_year")
        .agg(
            count("*").alias("death_count"),
            count(when(col("death_datetime").isNotNull(), 1))
                .alias("timestamped_deaths"),
            count(when(col("cause_concept_id") > 0, 1))
                .alias("deaths_with_cause"),
            collect_set("cause_source_value").alias("recorded_causes")
        )
        .orderBy("death_year"))

In [0]:


drug_era_schema = StructType([
    StructField("drug_era_id", IntegerType(), False,
                metadata={"comment": "A unique identifier for each drug era."}),
    StructField("person_id", IntegerType(), False,
                metadata={"comment": "A foreign key identifier to the person who is subjected to the drug during the drug era."}),
    StructField("drug_concept_id", IntegerType(), False, 
                metadata={"comment": "A foreign key that refers to a standard concept identifier in the Vocabulary for the drug concept."}),
    StructField("drug_era_start_date", DateType(), False,
                metadata={"comment": "The start date for the drug era constructed from the individual instances of drug exposures. It is the start date of the very first chronologically recorded instance of utilization of a drug."}),
    StructField("drug_era_end_date", DateType(), False,
                metadata={"comment": "The end date for the drug era constructed from the individual instance of drug exposures. It is the end date of the final continuously recorded instance of utilization of a drug."}),  
    StructField("drug_exposure_count", IntegerType(), True,
                metadata={"comment": "The number of individual drug exposure occurrences used to construct the drug era."}),
    StructField("gap_days", IntegerType(), True, 
                metadata={"comment": "The number of days that separates two drugs that are adjacent to each other, if there is a gap of more than 30 days between two drug eras, then they are considered two separate eras."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_drug_era_rules = {
    "valid_drug_era_id": "drug_era_id IS NOT NULL",
    "valid_person": "person_id IS NOT NULL",  
    "valid_drug": "drug_concept_id IS NOT NULL",
    "valid_start_date": "drug_era_start_date IS NOT NULL",
    "valid_end_date": "drug_era_end_date IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_drug_era_rules = {
    "valid_concept_value": "drug_concept_id > 0",
    "valid_dates": "drug_era_end_date >= drug_era_start_date",
    "valid_exposure_count": "drug_exposure_count IS NULL OR drug_exposure_count > 0",
    "valid_gap_days": "gap_days IS NULL OR gap_days >= 0"
}

@dlt.table(
    name="valid_ingredient_concepts",
    comment="Valid ingredient concepts from OMOP vocabulary",
    temporary=True
)
@dlt.expect_or_fail("valid_ingredients", "concept_class_id = 'Ingredient'")
def get_valid_ingredient_concepts():
    """Gets valid ingredient concepts from the OMOP vocabulary"""
    return (spark.table("3_lookup.omop.concept")
            .filter(col("invalid_reason").isNull())
            .filter(col("concept_class_id") == "Ingredient")
            .select("concept_id")
            .distinct())

@dlt.table(
    name="ingredient_relationships",
    comment="Ingredient level concepts and their relationships",
    temporary=True
)
def get_ingredient_relationships():
    """Gets ingredient level concepts and their descendants"""
    valid_ingredients = dlt.read("valid_ingredient_concepts")
    
    return (spark.table("3_lookup.omop.concept_ancestor")
            .join(valid_ingredients,
                  col("ancestor_concept_id") == col("concept_id"))
            .select("ancestor_concept_id", "descendant_concept_id"))

@dlt.table(
    name="drug_exposures_with_ingredients",
    comment="Drug exposures mapped to ingredient level",
    temporary=True
)
def create_drug_exposures_with_ingredients():
    """Maps drug exposures to their ingredient level concepts"""
    # Get valid drug exposures
    drug_exposures = (dlt.read("omop_drug_exposure")
                     .filter(col("drug_concept_id").isNotNull() &
                            (col("drug_concept_id") > 0)))
    
    # Get valid persons
    valid_persons = dlt.read("omop_person").select(
        col("person_id").alias("valid_person_id")
    ).distinct()
    
    # Join with ingredients
    ingredient_concepts = dlt.read("ingredient_relationships")
    valid_ingredients = dlt.read("valid_ingredient_concepts")
    
    return (drug_exposures.alias("de")
        .join(valid_persons,
              col("de.person_id") == col("valid_person_id"),
              "inner")
        .join(ingredient_concepts,
              col("de.drug_concept_id") == col("descendant_concept_id"),
              "left")
        .select(
            col("de.person_id"),
            coalesce(col("ancestor_concept_id"), col("de.drug_concept_id"))
                .alias("drug_concept_id"),
            col("de.drug_exposure_start_date")
        )
        .join(valid_ingredients,
              col("drug_concept_id") == col("concept_id"))
        .select(
            "person_id",
            "drug_concept_id",
            "drug_exposure_start_date"
        ))

@dlt.table(
    name="drug_exposure_periods",
    comment="Drug exposure periods with gap analysis",
    temporary=True
)
def create_drug_exposure_periods():
    """
    Creates drug exposure periods by analyzing gaps between exposures
    """
    # Get exposures with ingredients
    exposures = dlt.read("drug_exposures_with_ingredients")
    
    # Window for ordering exposures by person and drug
    window_spec = Window.partitionBy("person_id", "drug_concept_id") \
                       .orderBy("drug_exposure_start_date")
    
    # Calculate gaps between exposures
    exposures_with_gaps = (exposures
        .withColumn(
            "next_start_date", 
            lead("drug_exposure_start_date").over(window_spec))
        .withColumn(
            "gap_days",
            when(col("next_start_date").isNotNull(),
                 datediff(col("next_start_date"), 
                         col("drug_exposure_start_date")))
            .otherwise(0)))
    
    # Determine era groups based on gaps
    return (exposures_with_gaps
        .withColumn(
            "era_group",
            sum(
                when(col("gap_days") > 30, 1)
                .otherwise(0)
            ).over(window_spec)))

@dlt.table(
    name="omop_drug_era",
    comment="OMOP CDM Drug Era table - Contains records of the span of time when the Person is assumed to be exposed to a particular active ingredient",
    schema=drug_era_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_drug_era_rules)
@dlt.expect_all(advisory_drug_era_rules)
def create_omop_drug_era():
    """
    Creates the final drug era table by combining exposures into continuous periods
    """
    # Get drug exposure periods
    drug_periods = dlt.read("drug_exposure_periods")
    
    # Calculate era dates and metrics
    drug_eras = (drug_periods
        .groupBy("person_id", "drug_concept_id", "era_group")
        .agg(
            min("drug_exposure_start_date").alias("drug_era_start_date"),
            max("drug_exposure_start_date").alias("last_exposure_date"),
            count("*").alias("drug_exposure_count"),
            sum(when(col("gap_days") <= 30, col("gap_days"))
                .otherwise(0)).alias("gap_days")
        )
        .select(
            "person_id",
            "drug_concept_id",
            "drug_era_start_date",
            date_add(col("last_exposure_date"), 30)
                .alias("drug_era_end_date"),
            "drug_exposure_count",
            "gap_days"
        ))
    
    # Add unique drug_era_id
    window_spec = Window.orderBy(
        "person_id", 
        "drug_concept_id", 
        "drug_era_start_date"
    )
    
    return (drug_eras
        .withColumn("drug_era_id", 
                   row_number().over(window_spec))
        .select(
            col("drug_era_id").cast("integer"),
            col("person_id").cast("integer"),
            col("drug_concept_id").cast("integer"),
            col("drug_era_start_date").cast("date"),
            col("drug_era_end_date").cast("date"),
            col("drug_exposure_count").cast("integer"),
            col("gap_days").cast("integer")
        ))

@dlt.table(
    name="qual_omop_drug_era",
    comment="Quality metrics for drug era table"
)
def qual_omop_drug_era():
    """Tracks quality metrics for drug eras"""
    drug_era_data = dlt.read("omop_drug_era")
    
    return (drug_era_data.agg(
        count("*").alias("total_eras"),
        count_distinct("person_id").alias("unique_patients"),
        count_distinct("drug_concept_id").alias("unique_drugs"),
        avg("drug_exposure_count").alias("avg_exposures_per_era"),
        avg("gap_days").alias("avg_gap_days"),
        avg(datediff(col("drug_era_end_date"), 
                    col("drug_era_start_date")))
            .alias("avg_era_length_days")
    ))

@dlt.table(
    name="summ_omop_drug_era_patterns",
    comment="Analysis of drug era patterns"
)
def analyze_drug_era_patterns():
    """Analyzes patterns in drug eras"""
    drug_era_data = dlt.read("omop_drug_era")
    
    # Window for counting eras per person-drug combination
    person_drug_window = Window.partitionBy("person_id", "drug_concept_id")
    
    return (drug_era_data
        .withColumn("era_count", 
                   count("*").over(person_drug_window))
        .withColumn("total_exposure_days",
                   sum(datediff(col("drug_era_end_date"), 
                              col("drug_era_start_date")))
                   .over(person_drug_window))
        .groupBy("drug_concept_id")
        .agg(
            count("*").alias("total_eras"),
            count_distinct("person_id").alias("unique_patients"),
            avg("era_count").alias("avg_eras_per_patient"),
            avg("total_exposure_days").alias("avg_total_exposure_days"),
            avg(col("drug_exposure_count")).alias("avg_exposures_per_era")
        )
        .orderBy(desc("unique_patients")))

In [0]:

# Define the dose_era schema
dose_era_schema = StructType([
    StructField("dose_era_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Dose Era."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person who is subjected to the drug during the drug era."}),
    StructField("drug_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a Standard Concept identifier for the active Ingredient Concept."}),
    StructField("unit_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a Standard Concept identifier for the unit concept."}),
    StructField("dose_value", FloatType(), False,
                metadata={"comment": "The numeric value of the daily dose."}),
    StructField("dose_era_start_date", DateType(), False,
                metadata={"comment": "The start date for the drug era constructed from the individual instances of drug exposures."}),
    StructField("dose_era_end_date", DateType(), False,
                metadata={"comment": "The end date for the drug era constructed from the individual instance of drug exposures."})
])

# Mandatory rules for dose_era validation
mandatory_dose_era_rules = {
    "valid_dose_era_id": "dose_era_id IS NOT NULL",
    "valid_person": "person_id IS NOT NULL",
    "valid_drug_concept": "drug_concept_id IS NOT NULL",
    "valid_unit_concept": "unit_concept_id IS NOT NULL",
    "valid_dose_value": "dose_value IS NOT NULL",
    "valid_start_date": "dose_era_start_date IS NOT NULL",
    "valid_end_date": "dose_era_end_date IS NOT NULL" 
}

# Advisory data quality rules
advisory_dose_era_rules = {
    "valid_dose_value_positive": "dose_value > 0",
    "valid_dates": "dose_era_end_date >= dose_era_start_date",
    "valid_concept_values": "drug_concept_id > 0 AND unit_concept_id > 0"
}

@dlt.table(
    name="base_dose_data",
    comment="Extracted dosage information from medication administration",
    temporary=True
)
def create_base_dose_data():
    """
    Extracts dose information from medication administration data
    """
    # Constants for OMOP concept IDs
    MILLIGRAM_CONCEPT_ID = 8576 
    MILLILITER_CONCEPT_ID = 8587
    
    # Load the medication administration data
    med_admin_df = spark.table("4_prod.bronze.map_med_admin").filter(col("event_type_display") == "Administered")
    
    # Extract and standardize dosage information
    return (med_admin_df
           .select(
               col("PERSON_ID").cast("long").alias("person_id"),
               col("OMOP_CONCEPT_ID").cast("int").alias("drug_concept_id"),
               to_date(col("ADMIN_START_DT_TM")).alias("start_date"),
               to_date(col("ADMIN_END_DT_TM")).alias("end_date"),
               when(col("DOSE_IN_MG").isNotNull(), col("DOSE_IN_MG"))
               .when(col("DOSE_IN_ML").isNotNull(), col("DOSE_IN_ML"))
               .otherwise(None).alias("dose_value"),
               when(col("DOSE_IN_MG").isNotNull(), lit(MILLIGRAM_CONCEPT_ID))
               .when(col("DOSE_IN_ML").isNotNull(), lit(MILLILITER_CONCEPT_ID))
               .otherwise(None).alias("unit_concept_id")
           )
           .filter(col("dose_value").isNotNull())
           .filter(col("drug_concept_id").isNotNull())
           .filter(col("start_date").isNotNull())
           .filter(col("end_date").isNotNull())
           )

@dlt.table(
    name="ingredient_dose_data",
    comment="Dose information at the ingredient level",
    temporary=True
)
def get_ingredient_dose_data():
    """
    Maps drug concepts to their ingredient concepts
    """
    # Get base dose data and valid ingredients
    dose_data = dlt.read("base_dose_data")
    valid_ingredients = dlt.read("valid_ingredient_concepts")
    
    # Join to get ingredient-level dose data
    return (dose_data
           .join(valid_ingredients, 
                 dose_data["drug_concept_id"] == valid_ingredients["concept_id"], 
                 "inner")
           .drop("concept_id"))

@dlt.table(
    name="daily_dose_data",
    comment="Daily dose calculations",
    temporary=True
)
def calculate_daily_dose():
    """
    Calculates the daily dose for each medication administration
    """
    ingredient_df = dlt.read("ingredient_dose_data")
    
    return (ingredient_df
           .withColumn("days_exposure", 
                     when(datediff(col("end_date"), col("start_date")) <= 0, lit(1))
                     .otherwise(datediff(col("end_date"), col("start_date"))))
           .withColumn("daily_dose", col("dose_value") / col("days_exposure"))
           .drop("days_exposure")
           )

@dlt.table(
    name="omop_dose_era",
    comment="OMOP CDM Dose Era table - Contains records of constant dose exposure to a specific ingredient",
    schema=dose_era_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_dose_era_rules)
@dlt.expect_all(advisory_dose_era_rules)
def create_omop_dose_era():
    """
    Creates the final dose era table by identifying periods of continuous
    exposure to the same ingredient at the same daily dose
    """
    # Get daily dose data
    daily_dose_df = dlt.read("daily_dose_data")
    
    # Define parameters for era construction
    GAP_THRESHOLD = 30  # days
    partition_cols = ["person_id", "drug_concept_id", "daily_dose", "unit_concept_id"]
    window_spec = Window.partitionBy(*partition_cols).orderBy("start_date")
    
    # Calculate gaps between drug administrations
    with_gaps = (daily_dose_df
                .withColumn("prev_end_date", lag("end_date", 1).over(window_spec))
                .withColumn("gap", 
                           when(col("prev_end_date").isNotNull(), 
                               datediff(col("start_date"), col("prev_end_date")))
                           .otherwise(lit(0)))
                )
    
    # Flag start of new eras when gap exceeds threshold
    era_flagged = (with_gaps
                  .withColumn("era_start", 
                             when(col("prev_end_date").isNull(), lit(1))
                             .when(col("gap") > GAP_THRESHOLD, lit(1))
                             .otherwise(lit(0)))
                  )
    
    # Assign era numbers
    window_era = Window.partitionBy(*partition_cols).orderBy("start_date").rowsBetween(Window.unboundedPreceding, Window.currentRow)
    era_numbered = (era_flagged
                   .withColumn("era_num", sql_sum("era_start").over(window_era))
                   )
    
    # Get boundaries for each era
    era_boundaries = (era_numbered
                     .groupBy(*partition_cols, "era_num")
                     .agg(
                         sql_min("start_date").alias("dose_era_start_date"),
                         sql_max("end_date").alias("dose_era_end_date"),
                         sql_sum("dose_value").alias("total_dose"),
                         count("*").alias("admin_count")
                     )
                     )
    
    # Set the final dose_value as the daily_dose
    era_with_dose = (era_boundaries
                    .withColumn("dose_value", col("daily_dose"))
                    .drop("daily_dose", "era_num", "total_dose", "admin_count")
                    )
    
    # Generate dose_era_id and final table
    window_id = Window.orderBy("person_id", "drug_concept_id", "dose_era_start_date")
    return (era_with_dose
           .withColumn("dose_era_id", row_number().over(window_id).cast(LongType()))
           .select(
               col("dose_era_id").cast(LongType()),
               col("person_id").cast(LongType()),
               col("drug_concept_id").cast(IntegerType()),
               col("unit_concept_id").cast(IntegerType()),
               col("dose_value").cast(FloatType()),
               col("dose_era_start_date").cast(DateType()),
               col("dose_era_end_date").cast(DateType())
           )
           .na.fill(0, ["dose_value", "unit_concept_id", "drug_concept_id"])
           )

@dlt.table(
    name="qual_omop_dose_era",
    comment="Quality metrics for dose era table"
)
def qual_omop_dose_era():
    """Tracks quality metrics for dose eras"""
    dose_era_data = dlt.read("omop_dose_era")
    
    return (dose_era_data.agg(
        count("*").alias("total_dose_eras"),
        countDistinct("person_id").alias("unique_patients"),
        countDistinct("drug_concept_id").alias("unique_ingredients"),
        avg(datediff(col("dose_era_end_date"), col("dose_era_start_date")))
            .alias("avg_era_days"),
        max(datediff(col("dose_era_end_date"), col("dose_era_start_date")))
            .alias("max_era_days")
    ))

@dlt.table(
    name="summ_omop_dose_units",
    comment="Summary of dose units used"
)
def analyze_dose_units():
    """Analyzes distribution of dose units"""
    dose_era_data = dlt.read("omop_dose_era")
    
    return (dose_era_data
        .groupBy("unit_concept_id")
        .agg(
            count("*").alias("era_count"),
            min("dose_value").alias("min_dose"),
            max("dose_value").alias("max_dose"),
            avg("dose_value").alias("avg_dose"),
            countDistinct("drug_concept_id").alias("unique_ingredients"),
            countDistinct("person_id").alias("unique_patients")
        )
        .orderBy(desc("era_count")))

In [0]:


# Data quality rules for condition eras
condition_era_schema = StructType([
    StructField("condition_era_id", IntegerType(), False,
                metadata={"comment": "A unique identifier for each Condition Era."}),
    StructField("person_id", IntegerType(), False,
                metadata={"comment": "A foreign key identifier to the Person who is experiencing the condition during the condition era."}),
    StructField("condition_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a Standard Condition Concept identifier in the Standardized Vocabularies."}),
    StructField("condition_era_start_date", DateType(), False,
                metadata={"comment": "The start date for the Condition Era constructed from the individual instances of Condition Occurrences. It is the start date of the very first chronologically recorded instance of the Condition."}),
    StructField("condition_era_end_date", DateType(), False,
                metadata={"comment": "The end date for the Condition Era constructed from the individual instances of Condition Occurrences. It is the end date of the final continuously recorded instance of the Condition."}),
    StructField("condition_occurrence_count", IntegerType(), True,
                metadata={"comment": "The number of individual Condition Occurrences used to construct the Condition Era."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_condition_era_rules = {
    "valid_condition_era_id": "condition_era_id IS NOT NULL",
    "valid_person": "person_id IS NOT NULL",
    "valid_condition": "condition_concept_id IS NOT NULL",
    "valid_start_date": "condition_era_start_date IS NOT NULL",
    "valid_end_date": "condition_era_end_date IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_condition_era_rules = {
    "valid_concept_value": "condition_concept_id > 0",
    "valid_dates": "condition_era_end_date >= condition_era_start_date",
    "valid_count": "condition_occurrence_count IS NULL OR condition_occurrence_count > 0"
}

@dlt.table(
    name="condition_occurrence_ordered",
    comment="Condition occurrences with gap analysis",
    temporary=True
)
def create_condition_occurrence_ordered():
    """
    Creates ordered condition occurrences with gap analysis
    """
    # Get valid condition occurrences
    condition_occurrences = (dlt.read("omop_condition_occurrence")
        .filter(col("condition_concept_id").isNotNull() & 
                (col("condition_concept_id") > 0)))  
    
    # Get valid person_ids
    valid_persons = dlt.read("omop_person").select(
        col("person_id").alias("valid_person_id")
    ).distinct()
    
    # Window for ordering conditions by person and concept
    window_spec = Window.partitionBy("person_id", "condition_concept_id") \
                       .orderBy("condition_start_date")
    
    # Calculate gaps between occurrences
    return (condition_occurrences.alias("co")
        .join(valid_persons,
              col("co.person_id") == col("valid_person_id"),
              "inner")
        .select(
            col("co.person_id"),
            col("condition_concept_id"),
            col("condition_start_date")
        )
        .withColumn(
            "next_start_date", 
            lead("condition_start_date").over(window_spec))
        .withColumn(
            "gap_days",
            when(col("next_start_date").isNotNull(),
                 datediff(col("next_start_date"), 
                         col("condition_start_date")))
            .otherwise(0))
        .withColumn(
            "era_group",
            sum(
                when(col("gap_days") > 30, 1)
                .otherwise(0)
            ).over(window_spec)))

@dlt.table(
    name="omop_condition_era",
    comment="OMOP CDM Condition Era table - Contains records that represent spans of time when a Person is assumed to have a given condition",
    schema=condition_era_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_condition_era_rules)
@dlt.expect_all(advisory_condition_era_rules)
def create_omop_condition_era():
    """
    Creates the final condition era table by combining occurrences into 
    continuous periods
    """
    # Get condition occurrence periods
    condition_periods = dlt.read("condition_occurrence_ordered")
    
    # Calculate era dates and occurrence counts
    condition_eras = (condition_periods
        .groupBy("person_id", "condition_concept_id", "era_group")
        .agg(
            min("condition_start_date").alias("condition_era_start_date"),
            max("condition_start_date").alias("condition_era_end_date"),
            count("*").alias("condition_occurrence_count")
        ))
    
    # Add unique condition_era_id
    window_spec = Window.orderBy(
        "person_id", 
        "condition_concept_id", 
        "condition_era_start_date"
    )
    
    return (condition_eras
        .withColumn("condition_era_id", 
                   row_number().over(window_spec))
        .select(
            col("condition_era_id").cast("integer"),
            col("person_id").cast("integer"),
            col("condition_concept_id").cast("integer"),
            col("condition_era_start_date").cast("date"),
            col("condition_era_end_date").cast("date"),
            col("condition_occurrence_count").cast("integer")
        ))

@dlt.table(
    name="qual_omop_condition_era",
    comment="Quality metrics for condition era table"
)
def qual_omop_condition_era():
    """Tracks quality metrics for condition eras"""
    condition_era_data = dlt.read("omop_condition_era")
    
    return (condition_era_data.agg(
        count("*").alias("total_eras"),
        count_distinct("person_id").alias("unique_patients"),
        count_distinct("condition_concept_id").alias("unique_conditions"),
        avg("condition_occurrence_count").alias("avg_occurrences_per_era"),
        avg(datediff(col("condition_era_end_date"), 
                    col("condition_era_start_date")))
            .alias("avg_era_length_days")
    ))

@dlt.table(
    name="summ_omop_condition_era_patterns",
    comment="Analysis of condition era patterns"
)
def analyze_condition_era_patterns():
    """Analyzes patterns in condition eras"""
    condition_era_data = dlt.read("omop_condition_era")
    
    # Window for counting eras per person-condition combination
    person_condition_window = Window.partitionBy(
        "person_id", 
        "condition_concept_id"
    )
    
    return (condition_era_data
        .withColumn("era_count", 
                   count("*").over(person_condition_window))
        .withColumn("total_days",
                   sum(datediff(col("condition_era_end_date"), 
                              col("condition_era_start_date")))
                   .over(person_condition_window))
        .groupBy("condition_concept_id")
        .agg(
            count("*").alias("total_eras"),
            count_distinct("person_id").alias("unique_patients"),
            avg("era_count").alias("avg_eras_per_patient"),
            avg("total_days").alias("avg_total_days"),
            avg(col("condition_occurrence_count"))
                .alias("avg_occurrences_per_era")
        )
        .orderBy(desc("unique_patients")))

In [0]:

observation_period_schema = StructType([
    StructField("observation_period_id", IntegerType(), False, 
                metadata={"comment": "A unique identifier for each observation period."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person for whom the observation period is defined."}),
    StructField("observation_period_start_date", DateType(), False,
                metadata={"comment": "The start date of the observation period for which data are available from the data source."}),
    StructField("observation_period_end_date", DateType(), False,
                metadata={"comment": "The end date of the observation period for which data are available from the data source."}),
    StructField("period_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the source of the observation period information."})
])

# Mandatory rules - these must be met or the record is dropped
mandatory_observation_period_rules = {
    "valid_observation_id": "observation_period_id IS NOT NULL",
    "valid_person": "person_id IS NOT NULL",
    "valid_start_date": "observation_period_start_date IS NOT NULL",
    "valid_end_date": "observation_period_end_date IS NOT NULL",
    "valid_type_concept": "period_type_concept_id IS NOT NULL"
}

# Advisory data quality rules - these are tracked but don't cause record drops
advisory_observation_period_rules = {
    "valid_dates": "observation_period_end_date >= observation_period_start_date",
    "valid_type_concept_value": "period_type_concept_id > 0"
}

@dlt.table(
    name="condition_dates",
    comment="Condition event dates",
    temporary=True
)
def get_condition_dates():
    """Gets condition dates for observation periods"""
    conditions = spark.table("4_prod.bronze.map_diagnosis")
    valid_persons = dlt.read("omop_person")
    
    return (conditions.alias("c")
        .join(valid_persons.alias("p"), "person_id")
        .select(
            "person_id",
            col("diag_dt_tm").alias("start_date"),
            col("diag_dt_tm").alias("end_date")
        ))

@dlt.table(
    name="drug_dates",
    comment="Drug administration dates",
    temporary=True
)
def get_drug_dates():
    """Gets drug administration dates for observation periods"""
    drugs = spark.table("4_prod.bronze.map_med_admin")
    valid_persons = dlt.read("omop_person")
    
    return (drugs.alias("d")
        .join(valid_persons.alias("p"), "person_id")
        .select(
            "person_id",
            col("admin_start_dt_tm").alias("start_date"),
            col("admin_end_dt_tm").alias("end_date")
        ))

@dlt.table(
    name="visit_dates",
    comment="Visit dates",
    temporary=True
)
def get_visit_dates():
    """Gets visit dates for observation periods"""
    visits = spark.table("4_prod.bronze.map_encounter")
    valid_persons = dlt.read("omop_person")
    
    return (visits.alias("v")
        .join(valid_persons.alias("p"), "person_id")
        .select(
            "person_id",
            col("arrive_dt_tm").alias("start_date"),
            col("depart_dt_tm").alias("end_date")
        ))

@dlt.table(
    name="combined_observation_dates",
    comment="Combined clinical event dates",
    temporary=True
)
def create_combined_observation_dates():
    """Combines all clinical event dates"""
    condition_dates = dlt.read("condition_dates")
    drug_dates = dlt.read("drug_dates")
    visit_dates = dlt.read("visit_dates")
    
    return (condition_dates
        .unionAll(drug_dates)
        .unionAll(visit_dates))

@dlt.table(
    name="omop_observation_period",
    comment="OMOP CDM Observation Period table - Contains records which define spans of time during which clinical events are recorded for a Person",
    schema=observation_period_schema,
    table_properties={"quality": "gold"}
)
@dlt.expect_all_or_drop(mandatory_observation_period_rules)
@dlt.expect_all(advisory_observation_period_rules)
def create_omop_observation_period():
    """
    Creates the observation period table identifying continuous
    periods of clinical activity for each person
    """
    # Get all observation dates
    observation_dates = dlt.read("combined_observation_dates")
    
    # Get birth dates and death dates
    birth_dates = (dlt.read("omop_person")
        .select(
            "person_id",
            date_format(
                concat(
                    col("year_of_birth").cast("string"),
                    lit("-01-01")
                ),
                "yyyy-MM-dd"
            ).alias("birth_date")
        ))
    
    death_dates = (dlt.read("omop_death")
        .select("person_id", "death_date"))
    
    # Calculate observation periods
    observation_periods = (observation_dates
        .groupBy("person_id")
        .agg(
            min("start_date").alias("observation_period_start_date"),
            max("end_date").alias("observation_period_end_date")
        ))
    
    # Apply birth date and death date constraints
    observation_periods_constrained = (observation_periods
        .join(birth_dates, "person_id")
        .join(death_dates, "person_id", "left")
        .select(
            "person_id",
            # First handle birth date constraint
            when(col("observation_period_start_date") < col("birth_date"),
                 col("birth_date"))
            .otherwise(col("observation_period_start_date"))
            .alias("observation_period_start_date"),
            # Then handle death date/current date constraint
            when(col("death_date").isNotNull(),
                 least(col("observation_period_end_date"), col("death_date")))
            .otherwise(least(col("observation_period_end_date"), current_date()))
            .alias("observation_period_end_date")
        ))
    
    # Additional pass to fix any remaining inconsistencies
    observation_periods_fixed = (observation_periods_constrained
        .select(
            "person_id",
            # If end_date is before start_date, move start_date to end_date
            when(col("observation_period_end_date") < 
                 col("observation_period_start_date"),
                 col("observation_period_end_date"))
            .otherwise(col("observation_period_start_date"))
            .alias("observation_period_start_date"),
            # Keep the end_date as is since it's already been constrained
            col("observation_period_end_date")
        ))
    
    # Add observation_period_id
    window_spec = Window.orderBy("person_id")
    
    return (observation_periods_fixed
        .withColumn("observation_period_id", 
                   row_number().over(window_spec))
        .select(
            col("observation_period_id").cast("integer"),
            col("person_id").cast("bigint"),
            col("observation_period_start_date").cast("date"),
            col("observation_period_end_date").cast("date"),
            lit(32817).cast("integer").alias("period_type_concept_id")  # EHR
        ))

@dlt.table(
    name="qual_omop_observation_period",
    comment="Quality metrics for observation periods"
)
def qual_omop_observation_period():
    """Tracks quality metrics for observation periods"""
    period_data = dlt.read("omop_observation_period")
    
    return (period_data.agg(
        count("*").alias("total_periods"),
        count_distinct("person_id").alias("unique_patients"),
        min("observation_period_start_date").alias("earliest_observation"),
        max("observation_period_end_date").alias("latest_observation"),
        avg(datediff(col("observation_period_end_date"),
                    col("observation_period_start_date")))
            .alias("avg_observation_days")
    ))

@dlt.table(
    name="summ_omop_observation_length",
    comment="Analysis of observation period lengths"
)
def analyze_observation_lengths():
    """Analyzes observation period lengths"""
    period_data = dlt.read("omop_observation_period")
    
    return (period_data
        .withColumn(
            "observation_days",
            datediff(col("observation_period_end_date"),
                    col("observation_period_start_date")))
        .withColumn(
            "observation_years",
            round(col("observation_days") / 365.25, 2))
        .withColumn(
            "length_category",
            when(col("observation_days") <= 30, "≤ 30 days")
            .when(col("observation_days") <= 180, "1-6 months")
            .when(col("observation_days") <= 365, "6-12 months")
            .when(col("observation_days") <= 730, "1-2 years")
            .otherwise("2+ years"))
        .groupBy("length_category")
        .agg(
            count("*").alias("patient_count"),
            round(avg("observation_days"), 1).alias("avg_days"),
            round(avg("observation_years"), 1).alias("avg_years")
        )
        .orderBy("length_category"))