In [0]:
import dlt
from pyspark.sql.functions import *
from pyspark.sql.functions import max as spark_max
from pyspark.sql.window import Window
from datetime import datetime
from pyspark.sql.utils import AnalysisException
from pyspark.sql.types import *
from pyspark.sql import functions as F
from pyspark.sql.functions import sum as sql_sum, min as sql_min, max as sql_max

In [0]:
# --- Configuration ---
CONCEPT_TABLE = "3_lookup.omop.concept"
CONCEPT_RELATIONSHIP_TABLE = "3_lookup.omop.concept_relationship"
CONCEPT_ANCESTOR_TABLE = "3_lookup.omop.concept_ancestor"
BARTS_MAPS_TABLE = "3_lookup.omop.barts_new_maps"

BRONZE_DB = "4_prod.bronze"
RAW_DB = "4_prod.raw" 

In [0]:
@dlt.table(
    name="concept_relationships_maps_to",
    comment="Pre-filtered 'Maps to' relationships for standard concepts",
    temporary=True
)
def get_maps_to_relationships():
    """Creates a cached table of valid 'Maps to' relationships to standard concepts."""
    concepts = spark.table(CONCEPT_TABLE)
    relationships = spark.table(CONCEPT_RELATIONSHIP_TABLE)

    # Target concepts must be standard and valid
    standard_concepts = concepts.filter(
        (F.col("standard_concept") == 'S') &
        (F.col("invalid_reason").isNull())
    ).select(F.col("concept_id").alias("target_concept_id"), "domain_id")

    # Source concepts for the relationship must be valid
    source_concepts_in_rel = concepts.filter(F.col("invalid_reason").isNull()) \
                                     .select(F.col("concept_id").alias("source_concept_id_rel"))

    return relationships.filter(F.col("relationship_id") == 'Maps to') \
        .filter(F.col("invalid_reason").isNull()) \
        .join(source_concepts_in_rel, F.col("concept_id_1") == F.col("source_concept_id_rel")) \
        .join(standard_concepts, F.col("concept_id_2") == F.col("target_concept_id")) \
        .select(
            F.col("concept_id_1").alias("source_concept_id"), # The non-standard concept
            F.col("concept_id_2").alias("standard_concept_id"), # The standard concept it maps to
            F.col("domain_id").alias("standard_domain_id") # Domain of the standard concept
        ).distinct()

def _map_or_use_standard(df, source_omop_concept_id_col, target_domain_id):

    concepts   = spark.table(CONCEPT_TABLE)     .alias("c")
    maps_to    = dlt.read("concept_relationships_maps_to")  \
                    .filter(F.col("standard_domain_id") == target_domain_id) \
                    .alias("m")

    map_counts = maps_to.groupBy("source_concept_id")        \
                        .agg(F.count("*").alias("mapping_count"))

    df0 = df.alias("src")

    df1 = (df0.join(concepts,
                    F.col(f"src.{source_omop_concept_id_col}") == F.col("c.concept_id"))
                .filter(F.col("c.invalid_reason").isNull()))


    already_std = (df1
        .filter((F.col("c.standard_concept") == 'S') &
                (F.col("c.domain_id")       == target_domain_id))
        .select("src.*",
                F.col(f"src.{source_omop_concept_id_col}")
                    .alias("standard_concept_id"),
                F.lit(1).alias("mapping_count"))
    )


    mapped_std = (df1
        .filter(F.col("c.standard_concept") != 'S')
        .join(maps_to,
              F.col(f"src.{source_omop_concept_id_col}") == F.col("m.source_concept_id"),
              "inner")                                     # only rows that DO map
        .join(map_counts, "source_concept_id")
        .select("src.*",
                F.col("standard_concept_id"),
                F.col("mapping_count"))
    )


    unmapped_non_std = (df1
        .filter(F.col("c.standard_concept") != 'S')         # still non-standard
        .join(maps_to,
              F.col(f"src.{source_omop_concept_id_col}") == F.col("m.source_concept_id"),
              "left_anti")                                  # no map exists
        .select("src.*",
                F.col(f"src.{source_omop_concept_id_col}")
                    .alias("standard_concept_id"),          # keep original id
                F.lit(1).alias("mapping_count"))            # behaves like 1-map
    )

    return (already_std
            .unionByName(mapped_std , allowMissingColumns=True)
            .unionByName(unmapped_non_std, allowMissingColumns=True))
    

def _validate_provider(df):
    try:
        valid_prov = F.broadcast(dlt.read("omop_provider")
                               .selectExpr("provider_id as vp_id"))
    except Exception: 
         print("Warning: omop_provider table not found for validation, skipping provider validation.")
         return df.withColumn("provider_id", F.lit(None).cast("long")) 

    return (df.join(valid_prov, df.provider_id == F.col("vp_id"), "left")
              .withColumn("provider_id",
                          F.when(F.col("vp_id").isNull(), F.lit(None))
                          .otherwise(df.provider_id))
              .drop("vp_id"))



def get_max_timestamp(table_name_string):
    try:
        # Check if table exists first to avoid errors on first run
        if spark.catalog.tableExists(table_name_string):
            max_val = spark.sql(f"SELECT max(ADC_UPDT) from {table_name_string}").collect()[0][0]
            if max_val is not None:
                return max_val
        return lit(None).cast("timestamp") # Return a literal None if table doesn't exist or no max value
    except Exception as e:
        print(f"Error getting max timestamp for {table_name_string}: {e}")
        return lit(None).cast("timestamp")

In [0]:


# Define OMOP CDM schema for location table
location_schema = StructType([
    StructField("location_id", LongType(), False, 
                metadata={"comment": "A unique identifier for each geographic location."}),
    StructField("address_1", StringType(), True,
                metadata={"comment": "The first line of the address."}),
    StructField("address_2", StringType(), True,
                metadata={"comment": "The second line of the address"}),
    StructField("city", StringType(), True,
                metadata={"comment": "The city field is the text name of the city."}),
    StructField("state", StringType(), True,
                metadata={"comment": "The state field contains the state name. For addresses outside the US, this field can be used for provinces or other administrative regions."}),
    StructField("zip", StringType(), True,
                metadata={"comment": "The zip or postal code. For US addresses, valid formats are 3-digit, 5-digit or 9-digit ZIP codes. For non-US addresses, the postal code should be stored in the same field."}),
    StructField("county", StringType(), True,
                metadata={"comment": "The county, if available. The county field can also be used to store other regional information."}),
    StructField("location_source_value", StringType(), True,
                metadata={"comment": "The verbatim value for the location as it appears in the source data."}),
    StructField("country_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined Concept table for the country concept id, representing the country portion of the address."}),
    StructField("country_source_value", StringType(), True,
                metadata={"comment": "The source code for the country as it appears in the source data."}),
    StructField("latitude", FloatType(), True,
                metadata={"comment": "The latitude of the location. Must be between -90 and 90."}),
    StructField("longitude", FloatType(), True,
                metadata={"comment": "The longitude of the location. Must be between -180 and 180."}),
    StructField("LSOA", StringType(), True, metadata={"comment": "Local authority district (LSOA) code"}),
    StructField("IMD_Quintile", IntegerType(), True, metadata={"comment": "Index of Multiple Deprivation (IMD) quintile"})
])


@dlt.table(
    name="country_concepts",
    comment="Geography concepts from OMOP vocabulary"
)
def create_country_concepts():
    """Creates a table of geography concepts for country mapping"""
    return (spark.table("3_lookup.omop.concept")
           .filter(col("domain_id") == "Geography")
           .select(
               lower(col("concept_name")).alias("country_name_lower"),
               col("concept_id").alias("country_concept_id")
           ))

@dlt.table(
    name="omop_location",
    comment="OMOP CDM Location table - Represents a generic way to capture physical location or address information",
    schema=location_schema,
    table_properties={"quality": "gold"}
)

def create_omop_location():
    """
    Creates the OMOP Location table from source address data.
    Implements incremental processing and data quality checks.
    """
    # Read source addresses
    addresses = spark.table("4_prod.bronze.map_address")
    
    # Get country concepts from previous step
    country_concepts = dlt.read("country_concepts")
    
    # Join and transform data
    joined_data = (addresses
        .join(country_concepts, 
              lower(addresses.country_cd) == col("country_name_lower"), 
              "left"))

    # Add row number for deduplication
    window_spec = Window.partitionBy("ADDRESS_ID").orderBy("country_concept_id")
    joined_data_with_rn = joined_data.withColumn("row_number", 
                                                F.row_number().over(window_spec))

    # Transform to final format with all necessary type casting
    return (joined_data_with_rn
        .filter(col("row_number") == 1)
        .select(
            col("ADDRESS_ID").cast("bigint").alias("location_id"),
            
            when(col("full_street_address").contains(","),
                 substring(split(col("full_street_address"), ",").getItem(0), 1, 50))
            .otherwise(substring(col("full_street_address"), 1, 50))
            .cast("string").alias("address_1"),
            
            when(col("full_street_address").contains(","),
                 substring(split(col("full_street_address"), ",").getItem(1), 1, 50))
            .otherwise(lit(None))
            .cast("string").alias("address_2"),
            
            substring(col("CITY"), 1, 50).cast("string").alias("city"),
            lit(None).cast("string").alias("state"),
            substring(col("masked_zipcode"), 1, 9).cast("string").alias("zip"),
            lit(None).cast("string").alias("county"),
            substring(col("full_street_address"), 1, 50).cast("string")
                .alias("location_source_value"),
            coalesce(col("country_concept_id"), lit(0)).cast("integer")
                .alias("country_concept_id"),
            substring(col("country_cd"), 1, 20).cast("string")
                .alias("country_source_value"),
            lit(None).cast("float").alias("latitude"),
            lit(None).cast("float").alias("longitude"),
            col("LSOA").cast("string").alias("LSOA"),
            col("IMD_Quintile").cast("integer").alias("IMD_Quintile")
        ))


In [0]:
# Define OMOP CDM schema for care_site table
care_site_schema = StructType([
    StructField("care_site_id", LongType(), False, 
                metadata={"comment": "A unique identifier for each Care Site"}),
    StructField("care_site_name", StringType(), True,
                metadata={"comment": "The name of the care_site as it appears in the source data"}),
    StructField("place_of_service_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined Concept table for the place of service concept id"}),
    StructField("location_id", LongType(), True,
                metadata={"comment": "A foreign key to the Location table, where the detailed address information is stored"}),
    StructField("care_site_source_value", StringType(), True,
                metadata={"comment": "The identifier of the care_site as it appears in the source data"}),
    StructField("place_of_service_source_value", StringType(), True,
                metadata={"comment": "The source code for the place of service as it appears in the source data"})
])


@dlt.table(
    name="omop_care_site",
    comment="OMOP CDM Care Site table - Contains a list of institutional (physical or organizational) units where healthcare delivery is practiced (offices, wards, hospitals, clinics, etc.)",
    schema=care_site_schema,
    table_properties={"quality": "gold"}
)
def create_omop_care_site():
    """
    Creates the OMOP Care Site table from source care site data
    """
    # Load source care site data
    care_sites = spark.table("4_prod.bronze.map_care_site")
    
    # Transform to OMOP format
    return (care_sites
        .select(
            # Use care_site_cd as the primary identifier
            col("care_site_cd").cast("bigint").alias("care_site_id"),
            
            # Use care_site_name directly
            col("care_site_name").alias("care_site_name"),
            
            # Default place_of_service_concept_id to 0 since no mapping exists
            lit(0).cast("integer").alias("place_of_service_concept_id"),
            
            col("address_id").cast("bigint").alias("location_id"),
            
            # Use care_site_cd as source value
            col("care_site_cd").cast("string").alias("care_site_source_value"),
            
            # Use facility_name as place_of_service_source_value
            col("facility_name").alias("place_of_service_source_value")
        ))


In [0]:


# Define OMOP CDM schema for provider table
provider_schema = StructType([
    StructField("provider_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Provider."}),
    StructField("provider_name", StringType(), True,
                metadata={"comment": "A description of the Provider, typically the name of the physician or facility."}),
    StructField("npi", StringType(), True,
                metadata={"comment": "The National Provider Identifier (NPI) of the provider."}),
    StructField("dea", StringType(), True,
                metadata={"comment": "The Drug Enforcement Administration (DEA) number of the provider."}),
    StructField("specialty_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Standard Specialty Concept ID in the Standardized Vocabularies."}),
    StructField("care_site_id", LongType(), True,
                metadata={"comment": "A foreign key to the main Care Site where the provider is practicing."}),
    StructField("year_of_birth", IntegerType(), True,
                metadata={"comment": "The year of birth of the Provider."}),
    StructField("gender_concept_id", IntegerType(), True,
                metadata={"comment": "The gender of the Provider."}),
    StructField("provider_source_value", StringType(), True,
                metadata={"comment": "The identifier used for the Provider in the source data."}),
    StructField("specialty_source_value", StringType(), True,
                metadata={"comment": "The source code for the Provider specialty as it appears in the source data."}),
    StructField("specialty_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Concept that refers to the code used in the source."}),
    StructField("gender_source_value", StringType(), True,
                metadata={"comment": "The source value for the Provider gender."}),
    StructField("gender_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Concept that refers to the code used in the source."})
])



@dlt.table(
    name="omop_provider_base",
    comment="Initial provider table before care site validation",
    schema=provider_schema,
    temporary=True
)
def create_provider_base():
    """Creates base provider records from source medical personnel data"""
    
    # Load source medical personnel data
    medical_personnel = spark.table("4_prod.bronze.map_medical_personnel")
    
    return (medical_personnel
        .select(
            # Use PERSON_ID as provider_id
            col("PERSON_ID").cast("bigint").alias("provider_id"),
            
            # Combine position and service info for provider name
            concat_ws(" - ", 
                     col("position_name"),
                     col("MEDSERVICE")
            ).alias("provider_name"),
            
            # NPI and DEA not available in source
            lit(None).cast("string").alias("npi"),
            lit(None).cast("string").alias("dea"),
            
            # No specialty mapping available, set to 0
            lit(0).cast("integer").alias("specialty_concept_id"),
            
            # Link to care site - convert 0 to null
            when(col("primary_care_site_cd").cast("bigint") == 0, lit(None))
            .otherwise(col("primary_care_site_cd").cast("bigint"))
            .alias("care_site_id"),
            
            # Demographics not available in source
            lit(None).cast("integer").alias("year_of_birth"),
            lit(0).cast("integer").alias("gender_concept_id"),
            
            # Source values
            col("PERSON_ID").cast("string").alias("provider_source_value"),
            
            # Store specialty source values
            coalesce(
                col("MEDSERVICE"),
                col("SRVCATEGORY"),
                col("SURGSPEC")
            ).alias("specialty_source_value"),
            
            lit(0).cast("integer").alias("specialty_source_concept_id"),
            
            # Gender information not available
            lit(None).cast("string").alias("gender_source_value"),
            lit(0).cast("integer").alias("gender_source_concept_id")
        ))

@dlt.table(
    name="omop_provider",
    comment="OMOP CDM Provider table - Contains a list of uniquely identified healthcare providers",
    schema=provider_schema,
    table_properties={"quality": "gold"}
)
def create_omop_provider():
    """
    Creates the final provider table with validated care site references
    """
    providers = dlt.read("omop_provider_base")
    care_sites = dlt.read("omop_care_site") \
        .select("care_site_id") \
        .distinct()
        
    # Join with care_sites to validate care_site_id and set to null if not found
    return (providers
        .join(care_sites, providers.care_site_id == care_sites.care_site_id, "left")
        .select(
            providers["*"],
            when(care_sites.care_site_id.isNotNull(), providers.care_site_id)
            .otherwise(lit(None)).alias("valid_care_site_id")
        )
        .drop("care_site_id")
        .withColumnRenamed("valid_care_site_id", "care_site_id"))

In [0]:


person_schema = StructType([
    StructField("person_id", LongType(), False,
                metadata={"comment": "A unique identifier for each person."}),
    StructField("gender_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key that refers to a standard concept identifier in the Vocabulary for the gender of the person."}),
    StructField("year_of_birth", IntegerType(), True,
                metadata={"comment": "The year of birth of the person."}),
    StructField("month_of_birth", IntegerType(), True,
                metadata={"comment": "The month of birth of the person."}),
    StructField("day_of_birth", IntegerType(), True,
                metadata={"comment": "The day of birth of the person."}),
    StructField("birth_datetime", TimestampType(), True,
                metadata={"comment": "The date and time of birth of the person."}),
    StructField("race_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key that refers to a standard concept identifier in the Vocabulary for the race of the person."}),
    StructField("ethnicity_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key that refers to the standard concept identifier in the Vocabulary for the ethnicity of the person."}),
    StructField("location_id", LongType(), True,
                metadata={"comment": "A foreign key to the location table that indicates where the person is located."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider table that indicates the primary care provider of the person."}),
    StructField("care_site_id", LongType(), True,
                metadata={"comment": "A foreign key to the care site table that indicates the primary care site of the person."}),
    StructField("person_source_value", StringType(), True,
                metadata={"comment": "The source code for the person as it appears in the source data."}),
    StructField("gender_source_value", StringType(), True,
                metadata={"comment": "The source code for the gender of the person as it appears in the source data."}),
    StructField("gender_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."}),
    StructField("race_source_value", StringType(), True,
                metadata={"comment": "The source code for the race of the person as it appears in the source data."}),
    StructField("race_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."}),
    StructField("ethnicity_source_value", StringType(), True,
                metadata={"comment": "The source code for the ethnicity of the person as it appears in the source data."}),
    StructField("ethnicity_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."})
])

@dlt.table(
    name="valid_encounters",
    comment="Valid person IDs from encounters",
    temporary=True
)
def get_valid_encounters():
    """Get distinct valid person IDs from encounters"""
    return (spark.table("4_prod.bronze.map_encounter")
            .select("person_id")
            .distinct())


@dlt.table(
    name="gender_concept_maps",
    comment="Gender concept mappings",
    temporary=True
)
def get_gender_maps():
    """Get gender concept mappings"""
    return (spark.table("3_lookup.omop.barts_new_maps")
            .filter(col("OMOPField") == "gender_concept_id")
            .select(
                col("SourceValue").alias("gender_source"),
                col("OmopConceptId").alias("gender_omop_id")
            ))

@dlt.table(
    name="race_concept_maps",
    comment="Race concept mappings",
    temporary=True
)
def get_race_maps():
    """Get race concept mappings"""
    return (spark.table("3_lookup.omop.barts_new_maps")
            .filter(col("OMOPField") == "race_concept_id")
            .select(
                col("SourceValue").alias("race_source"),
                col("OmopConceptId").alias("race_omop_id")
            ))

@dlt.table(
    name="omop_person",
    comment="OMOP CDM Person table - Contains records that uniquely identify each person in the database",
    schema=person_schema,
    table_properties={"quality": "gold"}
)
def create_omop_person():
    """
    Creates the OMOP Person table with necessary data quality validations.
    Invalid records are dropped based on business rules.
    """
    # Load and filter source person data
    person_df = (spark.table("4_prod.bronze.map_person")
        .filter(
            (col("birth_year").isNotNull()) & 
            (col("gender_cd").isNotNull()) &
            (col("gender_cd") != "0") &
            (col("birth_year") >= 1901)
        ))
    
    # Get reference data
    valid_persons = dlt.read("valid_encounters")
    gender_maps = dlt.read("gender_concept_maps")
    race_maps = dlt.read("race_concept_maps")
    
    # Transform to OMOP format with all validations
    return (person_df
        .join(valid_persons, "person_id", "inner")
        .join(gender_maps, 
              person_df.gender_cd == gender_maps.gender_source, 
              "left")
        .join(race_maps,
              person_df.ethnicity_cd == race_maps.race_source,
              "left")
        .select(
            col("person_id").cast("bigint"),
            coalesce(col("gender_omop_id"), lit(0)).alias("gender_concept_id")
                .cast("integer"),
            col("birth_year").alias("year_of_birth").cast("integer"),
            lit(None).cast("integer").alias("month_of_birth"),
            lit(None).cast("integer").alias("day_of_birth"),
            lit(None).cast("timestamp").alias("birth_datetime"),
            coalesce(col("race_omop_id"), lit(0)).alias("race_concept_id")
                .cast("integer"),
            lit(0).cast("integer").alias("ethnicity_concept_id"),
            col("address_id").alias("location_id").cast("bigint"),
            lit(None).cast("bigint").alias("provider_id"),
            lit(None).cast("bigint").alias("care_site_id"),
            col("person_id").cast("string").alias("person_source_value"),
            col("gender_cd").cast("string").alias("gender_source_value"),
            lit(0).cast("integer").alias("gender_source_concept_id"),
            col("ethnicity_cd").cast("string").alias("race_source_value"),
            lit(0).cast("integer").alias("race_source_concept_id"),
            lit(None).cast("string").alias("ethnicity_source_value"),
            lit(0).cast("integer").alias("ethnicity_source_concept_id")
        ))


In [0]:


visit_schema = StructType([
    StructField("visit_occurrence_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Person's visit or encounter at a healthcare provider."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person who is having the visit."}),
    StructField("visit_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a visit concept identifier in the Standardized Vocabularies."}),
    StructField("visit_start_date", DateType(), True,
                metadata={"comment": "The start date of the visit."}),
    StructField("visit_start_datetime", TimestampType(), True,
                metadata={"comment": "The start date and time of the visit."}),
    StructField("visit_end_date", DateType(), True,
                metadata={"comment": "The end date of the visit."}),
    StructField("visit_end_datetime", TimestampType(), True,
                metadata={"comment": "The end date and time of the visit."}),
    StructField("visit_type_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the type of source data from which the visit record is derived."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider in the provider table who was associated with the visit."}),
    StructField("care_site_id", LongType(), True,
                metadata={"comment": "A foreign key to the care site in the care site table that was visited."}),
    StructField("visit_source_value", StringType(), True,
                metadata={"comment": "The source code for the visit as it appears in the source data."}),
    StructField("visit_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."}),
    StructField("admitted_from_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept in the Place of Service vocabulary indicating where the person was admitted from."}),
    StructField("admitted_from_source_value", StringType(), True,
                metadata={"comment": "The source code for the admitted from concept as it appears in the source data."}),
    StructField("discharged_to_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept in the Place of Service vocabulary indicating where the person was discharged to."}),
    StructField("discharged_to_source_value", StringType(), True,
                metadata={"comment": "The source code for the discharged to concept as it appears in the source data."}),
    StructField("preceding_visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit occurrence that immediately preceded this visit."})
])

@dlt.table(
    name="person_birth_ts",
    comment="Person Id  + birth timestamp (falls back to Jan-01 of YOB)",
    temporary=True
)
def person_birth_ts():
    p = dlt.read("omop_person")
    return (p
      .select(
          "person_id",
          F.coalesce(
              p.birth_datetime,
              F.to_timestamp(
                  F.concat_ws('-', p.year_of_birth, F.lit('01'), F.lit('01'))
              )
          ).alias("birth_ts")
      )
      .distinct()
    )

@dlt.table(
    name="person_death_ts",
    comment="Person Id  + death timestamp (NULL if still alive)",
    temporary=True
)
def person_death_ts():
    # Whatever logic you already have that produces DECEASED_DT_TM / CALC_DEATH_DATE
    return (spark.table("4_prod.bronze.map_death")
            .select("person_id",
                    F.coalesce("DECEASED_DT_TM", "CALC_DEATH_DATE")
                      .cast("timestamp")
                      .alias("death_ts"))
            .distinct())
    
@dlt.table(
    name="visit_type_mapping",
    comment="Visit type concept mappings",
    temporary=True
)
def get_visit_type_mapping():
    """Creates visit type concept mapping logic"""
    return (spark.createDataFrame([
        ("Inpatient Pre-Admission", 9201),
        ("Outpatient Pre-registration", 9202),
        ("Day Case", 8883),
        ("Day Case Waiting List", 8883),
        ("Outpatient Referral", 9202),
        ("Regular Day Admission", 581476),
        ("Outpatient Services", 9202),
        ("Results Only", 32036),
        ("Research", 38004259),
        ("Emergency Department", 9203),
        ("Outpatient", 9202),
        ("Inpatient Waiting List", 9201),
        ("Maternity", 8650),
        ("Day Care", 38004210),
        ("Inpatient", 9201),
        ("Newborn", 581384),
        ("Regular Night Admission", 9201)
    ], ["encntr_type_desc", "concept_id"]))

@dlt.table(
    name        = "base_visit_occurrence",
    comment     = "Initial visit records – dates clamped to birth/death",
    temporary   = True
)
def create_base_visit_occurrence():

    encounters  = spark.table("4_prod.bronze.map_encounter")
    birth_df    = F.broadcast(dlt.read("person_birth_ts"))
    death_df    = F.broadcast(dlt.read("person_death_ts"))
    visit_types = dlt.read("visit_type_mapping")

    MIN_TS  = F.to_timestamp(F.lit("1901-01-01 00:00:00"))
    NOW_TS  = F.current_timestamp()

    visits = (encounters.alias("e")
        .join(birth_df.alias("b"), "person_id")                # birth_ts
        .join(death_df.alias("d"), "person_id", "left")        # death_ts
        .join(visit_types.alias("v"),
              F.col("e.encntr_type_desc") == F.col("v.encntr_type_desc"),
              "left")


        .withColumn(
        "raw_start_ts",
        F.to_timestamp(F.coalesce("e.arrive_dt_tm", "e.depart_dt_tm"))
        ).withColumn(
        "raw_end_ts",
        F.to_timestamp(F.coalesce("e.depart_dt_tm", "e.arrive_dt_tm"))
        )


        .withColumn(
        "visit_start_ts",
        F.greatest(F.col("raw_start_ts"), F.col("b.birth_ts"))
        )


        .withColumn(
            "visit_end_ts",
        F.when(F.col("d.death_ts").isNotNull(),
           F.least(F.col("raw_end_ts"), F.col("d.death_ts")))
        .otherwise(F.col("raw_end_ts"))
        ).withColumn(
        "visit_end_ts",
        F.least(F.col("visit_end_ts"), NOW_TS)
        )


        .withColumn(
        "visit_end_ts",
        F.when(F.col("visit_end_ts") < F.col("visit_start_ts"),
           F.col("visit_start_ts"))
        .otherwise(F.col("visit_end_ts"))
        )


        .withColumn(
            "visit_start_ts",
            F.when(F.col("visit_start_ts") < MIN_TS, MIN_TS)
            .otherwise(F.col("visit_start_ts"))
        ).withColumn(
            "visit_end_ts",
            F.when(F.col("visit_end_ts") < MIN_TS, MIN_TS)
            .otherwise(F.col("visit_end_ts"))
        )

        .select(
            F.col("e.encntr_id").cast("bigint").alias("visit_occurrence_id"),
            F.col("e.person_id").cast("bigint").alias("person_id"),
            F.coalesce(F.col("v.concept_id"), F.lit(0)).cast("int")
              .alias("visit_concept_id"),

            F.col("visit_start_ts").cast("date").alias("visit_start_date"),
            F.col("visit_start_ts")          .alias("visit_start_datetime"),
            F.col("visit_end_ts").cast("date").alias("visit_end_date"),
            F.col("visit_end_ts")            .alias("visit_end_datetime"),

            F.lit(32817).cast("int")   .alias("visit_type_concept_id"),
            F.when(F.col("e.reg_prsnl_id") == 0, None)
             .otherwise(F.col("e.reg_prsnl_id"))
             .cast("bigint")           .alias("provider_id"),
            F.when(F.col("e.loc_nurse_unit_cd") == 0, None)
             .otherwise(F.col("e.loc_nurse_unit_cd"))
             .cast("bigint")           .alias("care_site_id"),
            F.col("e.encntr_type_desc").alias("visit_source_value"),
            F.lit(0).cast("int")       .alias("visit_source_concept_id"),

            F.lit(None).cast("int")    .alias("admitted_from_concept_id"),
            F.lit(None).cast("string") .alias("admitted_from_source_value"),
            F.lit(None).cast("int")    .alias("discharged_to_concept_id"),
            F.lit(None).cast("string") .alias("discharged_to_source_value")
        )
    )

    return visits

@dlt.table(
    name="omop_visit_occurrence",
    comment="OMOP CDM Visit Occurrence table - Contains records of Events where Persons engage with the healthcare system for a duration of time",
    schema=visit_schema,
    table_properties={"quality": "gold"}
)

def create_omop_visit_occurrence():
    """
    Creates the final visit occurrence table with all validations
    """
    base_visits = dlt.read("base_visit_occurrence")
    
    # Get valid references
    valid_providers = dlt.read("omop_provider").select("provider_id").distinct()
    valid_care_sites = dlt.read("omop_care_site").select("care_site_id").distinct()
    
    # Validate provider references
    visit_df_with_valid_provider = (base_visits
        .join(valid_providers, "provider_id", "left_anti")
        .withColumn("provider_id", lit(None))
        .unionByName(
            base_visits.join(valid_providers, "provider_id", "inner")
        ))
    
    # Validate care site references
    visit_df_with_valid_refs = (visit_df_with_valid_provider
        .join(valid_care_sites, "care_site_id", "left_anti")
        .withColumn("care_site_id", lit(None))
        .unionByName(
            visit_df_with_valid_provider
            .join(valid_care_sites, "care_site_id", "inner")
        ))
    
    # Add preceding visit reference
    visit_window = (Window.partitionBy("person_id")
                         .orderBy("visit_start_datetime")
                         .rowsBetween(-1, -1))
    
    return (visit_df_with_valid_refs
        .withColumn("preceding_visit_occurrence_id", 
                   lag("visit_occurrence_id", 1).over(visit_window)))



In [0]:


condition_schema = StructType([
    StructField("condition_occurrence_id", LongType(), False,
                metadata={"comment": "A unique identifier for each condition occurrence event."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the person who is experiencing the condition."}),
    StructField("condition_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a standard condition concept identifier in the Vocabulary."}),
    StructField("condition_start_date", DateType(), True,
                metadata={"comment": "The date when the instance of the condition is recorded."}),
    StructField("condition_start_datetime", TimestampType(), True,
                metadata={"comment": "The date and time when the instance of the condition is recorded."}),
    StructField("condition_end_date", DateType(), True,
                metadata={"comment": "The date when the instance of the condition is considered to have ended."}),
    StructField("condition_end_datetime", TimestampType(), True,
                metadata={"comment": "The date and time when the instance of the condition is considered to have ended."}),
    StructField("condition_type_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the source data from which the condition was recorded."}),
    StructField("condition_status_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the status of the condition."}),
    StructField("stop_reason", StringType(), True,
                metadata={"comment": "The reason that the condition was no longer present."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider who was responsible for determining the condition."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit in the VISIT_OCCURRENCE table during which the condition was determined."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit detail record during which the condition was determined."}),
    StructField("condition_source_value", StringType(), True,
                metadata={"comment": "The source value for the condition as it appears in the source data."}),
    StructField("condition_source_concept_id", LongType(), True,
                metadata={"comment": "A foreign key to a condition concept that refers to the code used in the source."}),
    StructField("condition_status_source_value", StringType(), True,
                metadata={"comment": "The source value for the condition status as it appears in the source data."})
])


def _add_condition_status(df):
    """
    Derives OMOP Condition-Status concept from
      • confirmation_status_desc (exists in BOTH tables)
      • diag_type_desc           (only in map_diagnosis)
    and puts the result in:
      - condition_status_concept_id   (INTEGER, may be null)
      - condition_status_source_value (VARCHAR(50))
    """

    candidate_cols = []
    if "confirmation_status_desc" in df.columns:
        candidate_cols.append(F.col("confirmation_status_desc"))
    if "diag_type_desc" in df.columns:      # absent in map_problem
        candidate_cols.append(F.col("diag_type_desc"))

    # add an empty string so coalesce never gets an empty arg
    candidate_cols.append(F.lit(""))


    df = df.withColumn(
            "status_norm",
            F.upper(F.coalesce(*candidate_cols))
         )

    lookup = {
        "CONFIRMED":                32893,


        "ADMISSION DIAGNOSIS":      32890,
        "ADMITTING DIAGNOSIS":      32890,
        "DISCHARGE DIAGNOSIS":      32896,


        "PRE-OP DIAGNOSIS":         32900,
        "PRE-OPERATIVE DIAGNOSIS":  32900,
        "POST-OP DIAGNOSIS":        32898,
        "POST-OPERATIVE DIAGNOSIS": 32898,


        "REFERRING DIAGNOSIS":      32904,
        "PRIMARY DIAGNOSIS":        32902,
        "PRINCIPAL DIAGNOSIS":      32902,
        "SECONDARY DIAGNOSIS":      32908,


        "DIFFERENTIAL":             32899,
        "SUSPECTED":                32899,
        "PROVISIONAL":              32899,
        "PROBABLE":                 32899,
        "PROBABLE DIAGNOSIS":       32899,
        "POTENTIAL":                32899,
        "POSSIBLE":                 32899,
    }

    status_map = F.create_map([F.lit(x) for kv in lookup.items() for x in kv])


    return (
        df
        .withColumn(
            "condition_status_concept_id",
            status_map.getItem(F.col("status_norm")).cast("integer")
        )
        .withColumn(
            "condition_status_source_value",
            F.substring(F.col("status_norm"), 1, 50)
        )
        .drop("status_norm")
    )

@dlt.table(
    name="condition_diagnosis_mapped",
    comment="Base condition records from diagnosis data, mapped/validated to standard concepts",
    temporary=True
)
def create_condition_diagnosis_mapped():
    """
    Creates base condition records from diagnosis data using the provided OMOP_CONCEPT_ID
    and maps/validates it to a standard concept. REVISED: Simplified logic.
    """
    diagnosis = spark.table(f"{BRONZE_DB}.map_diagnosis")
    valid_persons = F.broadcast(dlt.read("omop_person").select("person_id").distinct())


    diagnosis_filtered = diagnosis.filter(
        F.col("OMOP_CONCEPT_ID").isNotNull() & (F.col("OMOP_CONCEPT_ID") != 0)
    ).join(valid_persons, "person_id", "inner")


    mapped_diagnosis = _map_or_use_standard(
        diagnosis_filtered,
        source_omop_concept_id_col="OMOP_CONCEPT_ID",
        target_domain_id="Condition"
    )

    diagnosis_with_status = _add_condition_status(mapped_diagnosis)

    return diagnosis_with_status.select(
        F.col("person_id").cast("bigint"),
        F.col("standard_concept_id").alias("condition_concept_id").cast("integer"), 
        F.col("diag_dt_tm").alias("condition_start_datetime").cast("timestamp"),
        F.lit(32817).cast("integer").alias("condition_type_concept_id"), 
        F.coalesce(F.col("condition_status_concept_id"), F.lit(0))
     .cast("int").alias("condition_status_concept_id"),
        F.lit(None).cast("string").alias("stop_reason"),
        F.when(F.col("diag_prsnl_id") == 0, None)
         .otherwise(F.col("diag_prsnl_id")).alias("provider_id").cast("bigint"),
        F.col("encntr_id").alias("visit_occurrence_id").cast("bigint"),
        F.lit(None).cast("bigint").alias("visit_detail_id"),
        F.substring(F.coalesce(F.col("source_string"), F.col("source_identifier")), 1, 50).alias("condition_source_value"),
        F.col("OMOP_CONCEPT_ID").alias("condition_source_concept_id").cast("bigint"), 
        F.substring(F.col("confirmation_status_desc"), 1, 50).alias("condition_status_source_value"),
        F.col("mapping_count") 
    )


@dlt.table(
    name="condition_problem_mapped",
    comment="Base condition records from problem data, mapped/validated to standard concepts",
    temporary=True
)
def create_condition_problem_mapped():
    """
    Creates base condition records from problem data using the provided OMOP_CONCEPT_ID
    and maps/validates it to a standard concept. REVISED: Corrected column name and uses new helper.
    """
    problem = spark.table(f"{BRONZE_DB}.map_problem")
    valid_persons = F.broadcast(dlt.read("omop_person").select("person_id").distinct())


    problem_filtered = problem.filter(
        (F.col("OMOP_CONCEPT_DOMAIN") == "Condition") & 
        (F.col("OMOP_CONCEPT_ID").isNotNull()) &
        (F.col("OMOP_CONCEPT_ID") != 0) &
        (F.col("CALC_ENCNTR").isNotNull())
    ).join(valid_persons, "person_id", "inner") 

    mapped_problem = _map_or_use_standard(
        problem_filtered,
        source_omop_concept_id_col="OMOP_CONCEPT_ID",
        target_domain_id="Condition"
    )
    problem_with_status = _add_condition_status(mapped_problem)

    return problem_with_status.select(
        F.col("person_id").cast("bigint"),
        F.col("standard_concept_id").alias("condition_concept_id").cast("integer"), 
        F.col("CALC_DT_TM").alias("condition_start_datetime").cast("timestamp"),
        F.lit(32817).cast("integer").alias("condition_type_concept_id"),
        F.coalesce("condition_status_concept_id", F.lit(0)).cast("integer").alias("condition_status_concept_id"),
        F.lit(None).cast("string").alias("stop_reason"),
        F.when(F.col("ACTIVE_STATUS_PRSNL_ID").isin([0, 1]), None)
         .otherwise(F.col("ACTIVE_STATUS_PRSNL_ID")).alias("provider_id").cast("bigint"),
        F.col("CALC_ENCNTR").alias("visit_occurrence_id").cast("bigint"),
        F.lit(None).cast("bigint").alias("visit_detail_id"),
        F.substring(F.col("SOURCE_STRING"), 1, 50).alias("condition_source_value"),
        F.col("OMOP_CONCEPT_ID").alias("condition_source_concept_id").cast("bigint"), 
        F.lit(None).cast("string").alias("condition_status_source_value"),
        F.col("mapping_count") 
    )


@dlt.table(
    name="condition_combined",
    comment="Combined and deduplicated conditions from diagnosis and problem data",
    temporary=True
)
def create_combined_conditions():
    """Creates combined and deduplicated condition records from all sources."""

    diagnosis_conditions = (dlt.read("condition_diagnosis_mapped")
                          .withColumn("source_priority", F.lit(1))) 

    problem_conditions = (dlt.read("condition_problem_mapped")
                        .withColumn("source_priority", F.lit(2)))

    all_conditions = diagnosis_conditions.unionByName(problem_conditions, allowMissingColumns=True)

    dedup_window = Window.partitionBy(
                            "person_id",
                            "condition_concept_id",
                            "condition_start_datetime",
                            F.coalesce(F.col("visit_occurrence_id"), F.lit(0))
                        ).orderBy(
                            F.col("source_priority"),
                            F.when(F.col("provider_id").isNotNull(), 0).otherwise(1),
                            F.when(F.col("condition_source_value").isNotNull(), 0).otherwise(1)
                        )

    deduplicated_conditions = (all_conditions
        .withColumn("rank", F.row_number().over(dedup_window))
        .filter(F.col("rank") == 1)
        .drop("rank", "source_priority")
        )

    return deduplicated_conditions

@dlt.table(
    name="omop_condition_occurrence",
    comment="OMOP CDM Condition Occurrence table - Contains records of Events suggesting the presence of a disease or medical condition",
    schema=condition_schema,
    table_properties={"quality": "gold"}
)
def create_omop_condition_occurrence():
    """
    Creates the final condition occurrence table with proper validation
    and referential integrity checks.
    """
    condition_combined = dlt.read("condition_combined")

    valid_visits = F.broadcast(dlt.read("omop_visit_occurrence").select(
        "visit_occurrence_id", "visit_start_datetime", "visit_end_datetime"
    ))

    condition_df_with_visits = (condition_combined
        .join(valid_visits, "visit_occurrence_id", "left")
        .withColumn(
            "adj_condition_start_datetime",
            F.when(F.col("visit_occurrence_id").isNull(), F.col("condition_start_datetime"))
            .when(F.col("condition_start_datetime").isNull(), F.lit(None))
            .when(F.col("visit_end_datetime").isNull(), F.col("condition_start_datetime"))
            .when(F.col("condition_start_datetime") > F.col("visit_end_datetime"),
                 F.col("visit_end_datetime"))
            .when(F.col("visit_start_datetime").isNull(), F.col("condition_start_datetime"))
            .when(F.col("condition_start_datetime") < F.col("visit_start_datetime"),
                 F.col("visit_start_datetime"))
            .otherwise(F.col("condition_start_datetime")))
        .withColumn("condition_start_date",
                   F.col("adj_condition_start_datetime").cast("date"))
        .withColumn("condition_end_date", F.lit(None).cast("date"))
        .withColumn("condition_end_datetime", F.lit(None).cast("timestamp"))
        .drop("visit_start_datetime", "visit_end_datetime", "condition_start_datetime")
        .withColumnRenamed("adj_condition_start_datetime", "condition_start_datetime")
        .filter(F.col("condition_start_date").isNotNull())
        )

    condition_df_validated = _validate_provider(condition_df_with_visits)

    window_spec = Window.orderBy(
        "person_id", "condition_start_datetime", "condition_concept_id",
        F.coalesce(F.col("visit_occurrence_id"), F.lit(0)),
        F.coalesce(F.col("provider_id"), F.lit(0))
        )

    final_df = (condition_df_validated
        .withColumn("condition_occurrence_id",
                   F.row_number().over(window_spec).cast("bigint"))
        .select( 
            F.col("condition_occurrence_id"),
            F.col("person_id").cast("long"),
            F.col("condition_concept_id").cast("int"),
            F.col("condition_start_date").cast("date"),
            F.col("condition_start_datetime").cast("timestamp"),
            F.col("condition_end_date").cast("date"),
            F.col("condition_end_datetime").cast("timestamp"),
            F.col("condition_type_concept_id").cast("int"),
            F.col("condition_status_concept_id").cast("int"),
            F.col("stop_reason").cast("string"),
            F.col("provider_id").cast("long"),
            F.col("visit_occurrence_id").cast("long"),
            F.col("visit_detail_id").cast("long"),
            F.substring(F.col("condition_source_value"), 1, 50).alias("condition_source_value").cast("string"),
            F.col("condition_source_concept_id").cast("bigint"),
            F.substring(F.col("condition_status_source_value"), 1, 50).alias("condition_status_source_value").cast("string")
        )
    )

    return final_df.dropDuplicates(["condition_occurrence_id"]) \
                   .repartition(200, "person_id")

In [0]:
drug_schema = StructType([
    StructField("drug_exposure_id", LongType(), False,
                metadata={"comment": "A unique identifier for each drug exposure event."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the person who is subjected to the drug."}),
    StructField("drug_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a standard drug concept identifier in the Vocabulary."}),
    StructField("drug_exposure_start_date", DateType(), True,
                metadata={"comment": "The start date for the current instance of drug exposure."}),
    StructField("drug_exposure_start_datetime", TimestampType(), True,
                metadata={"comment": "The start date and time for the current instance of drug exposure."}),
    StructField("drug_exposure_end_date", DateType(), True,
                metadata={"comment": "The end date for the current instance of drug exposure."}),
    StructField("drug_exposure_end_datetime", TimestampType(), True,
                metadata={"comment": "The end date and time for the current instance of drug exposure."}),
    StructField("verbatim_end_date", DateType(), True,
                metadata={"comment": "The end date of the drug exposure as it appears in the source data."}),
    StructField("drug_type_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the type of drug exposure."}),
    StructField("stop_reason", StringType(), True,
                metadata={"comment": "The reason the drug exposure was stopped."}),
    StructField("refills", IntegerType(), True,
                metadata={"comment": "The number of refills after the initial prescription."}),
    StructField("quantity", FloatType(), True,
                metadata={"comment": "The quantity of drug as recorded in the source data."}),
    StructField("days_supply", IntegerType(), True,
                metadata={"comment": "The number of days of supply of the medication."}),
    StructField("sig", StringType(), True,
                metadata={"comment": "The directions (signatur) on the drug prescription as recorded in the source."}),
    StructField("route_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a predefined concept in the Standardized Vocabularies reflecting the route of administration."}),
    StructField("lot_number", StringType(), True,
                metadata={"comment": "The identifier to determine where the product originated."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider in the provider table who prescribed the drug."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit in the visit table during which the drug exposure initiated."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit detail record during which the drug exposure initiated."}),
    StructField("drug_source_value", StringType(), True,
                metadata={"comment": "The source code for the drug as it appears in the source data."}),
    StructField("drug_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a drug concept that refers to the code used in the source."}),
    StructField("route_source_value", StringType(), True,
                metadata={"comment": "The source code for the route as it appears in the source data."}),
    StructField("dose_unit_source_value", StringType(), True,
                metadata={"comment": "The information about the dose unit as recorded in the source data."})
])


@dlt.table(
    name="route_concept_maps",
    comment="Route concept mappings from Barts maps",
    temporary=True
)
def get_route_maps():

    return (spark.table(BARTS_MAPS_TABLE)
            .filter(F.col("OMOPField") == "route_concept_id")
            .select(
                F.col("SourceValue").alias("route_source"),
                F.col("OmopConceptId").alias("route_omop_id")
            ))

@dlt.table(
    name="base_drug_exposure_mapped",
    comment="Base drug exposure records mapped/validated to standard concepts",
    temporary=True
)
def create_base_drug_exposure_mapped():
    """
    Creates base drug exposure records from medication administration data,
    using the provided OMOP_CONCEPT_ID and maps/validates it to a standard concept.
    """
    medications = (spark.table(f"{BRONZE_DB}.map_med_admin")
                  .filter(F.col("event_type_display") == "Administered") 
                  .filter(F.col("omop_concept_id").isNotNull() & (F.col("omop_concept_id") != 0))) 

    valid_persons = F.broadcast(dlt.read("omop_person").select("person_id").distinct())
    route_maps = dlt.read("route_concept_maps") 


    meds_filtered = medications.join(valid_persons, "person_id", "inner")


    mapped_meds = _map_or_use_standard(
        meds_filtered,
        source_omop_concept_id_col="omop_concept_id",
        target_domain_id="Drug"
    )


    
    return (mapped_meds
        .join(route_maps,
              F.col("ADMIN_ROUTE_DISPLAY") == F.col("route_source"),
              "left")
        .select(
            F.col("person_id").cast("bigint"),
            F.col("standard_concept_id").alias("drug_concept_id").cast("integer"), 
            F.col("admin_start_dt_tm").alias("drug_exposure_start_datetime").cast("timestamp"),
            F.col("admin_end_dt_tm").alias("drug_exposure_end_datetime").cast("timestamp"),
            F.col("admin_end_dt_tm").cast("date").alias("verbatim_end_date"), 
            F.lit(32817).cast("integer").alias("drug_type_concept_id"),
            F.lit(None).cast("string").alias("stop_reason"),
            F.lit(None).cast("integer").alias("refills"),
            lit(None).cast("float").alias("quantity"), 
            F.lit(None).cast("integer").alias("days_supply"), 
            F.lit(None).cast("string").alias("sig"), 
            F.coalesce(F.col("route_omop_id"), F.lit(0)).alias("route_concept_id").cast("integer"), 
            F.lit(None).cast("string").alias("lot_number"), 
            F.when(F.col("prsnl_id") == 0, None).otherwise(F.col("prsnl_id")).alias("provider_id").cast("bigint"),
            F.col("encntr_id").alias("visit_occurrence_id").cast("bigint"),
            F.lit(None).cast("bigint").alias("visit_detail_id"), 
            F.col("ORDER_MNEMONIC").alias("drug_source_value"), 
            F.col("omop_concept_id").alias("drug_source_concept_id").cast("integer"), 
            F.col("admin_route_display").alias("route_source_value"), 
            F.col("initial_dosage_unit_display").alias("dose_unit_source_value"),
            F.col("mapping_count") 
        ))

@dlt.table(
    name="omop_drug_exposure",
    comment="OMOP CDM Drug Exposure table - Contains records about the exposure to a Drug through prescriptions or administration",
    schema=drug_schema,
    table_properties={"quality": "gold"}
)
def create_omop_drug_exposure():
    """
    Creates the final drug exposure table with proper validation,
    referential integrity checks, and nullification for multi-mappings.
    """
    base_drugs_mapped = dlt.read("base_drug_exposure_mapped")

    valid_visits = F.broadcast(dlt.read("omop_visit_occurrence").select(
        "visit_occurrence_id", "visit_start_datetime", "visit_end_datetime"
    ))


    drug_df_with_visits = (base_drugs_mapped
        .join(valid_visits, "visit_occurrence_id", "left")
        .withColumn(
            "adj_drug_exposure_start_datetime",
            F.when(F.col("visit_occurrence_id").isNull(), F.col("drug_exposure_start_datetime"))
            .when(F.col("drug_exposure_start_datetime").isNull(), F.lit(None))
            .when(F.col("visit_end_datetime").isNull(), F.col("drug_exposure_start_datetime"))
            .when(F.col("drug_exposure_start_datetime") > F.col("visit_end_datetime"),
                 F.col("visit_end_datetime"))
            .when(F.col("visit_start_datetime").isNull(), F.col("drug_exposure_start_datetime"))
            .when(F.col("drug_exposure_start_datetime") < F.col("visit_start_datetime"),
                 F.col("visit_start_datetime"))
            .otherwise(F.col("drug_exposure_start_datetime")))
        .withColumn(
            "adj_drug_exposure_end_datetime",
            F.when(F.col("drug_exposure_end_datetime").isNull(), F.col("adj_drug_exposure_start_datetime")) 
            .when(F.col("drug_exposure_end_datetime") < F.col("adj_drug_exposure_start_datetime"), F.col("adj_drug_exposure_start_datetime")) 
            .when(F.col("visit_occurrence_id").isNotNull() & F.col("visit_end_datetime").isNotNull() & (F.col("drug_exposure_end_datetime") > F.col("visit_end_datetime")),
                  F.col("visit_end_datetime"))
            .otherwise(F.col("drug_exposure_end_datetime")))
        .withColumn("drug_exposure_start_date",
                   F.col("adj_drug_exposure_start_datetime").cast("date"))
        .withColumn("drug_exposure_end_date",
                   F.col("adj_drug_exposure_end_datetime").cast("date"))
        .drop("visit_start_datetime", "visit_end_datetime",
              "drug_exposure_start_datetime", "drug_exposure_end_datetime")
        .withColumnRenamed("adj_drug_exposure_start_datetime", "drug_exposure_start_datetime")
        .withColumnRenamed("adj_drug_exposure_end_datetime", "drug_exposure_end_datetime")
        .filter(F.col("drug_exposure_start_date").isNotNull()) 
        )


    drug_df_validated = _validate_provider(drug_df_with_visits)

    # Nullify quantity/dose fields if mapping_count > 1
    drug_df_adjusted = drug_df_validated.withColumn(
        "quantity",
        F.when(F.col("mapping_count") > 1, F.lit(None)).otherwise(F.col("quantity"))
    ).withColumn(
        "days_supply",
        F.when(F.col("mapping_count") > 1, F.lit(None)).otherwise(F.col("days_supply"))
    ).withColumn(
        "dose_unit_source_value",
        F.when(F.col("mapping_count") > 1, F.lit(None)).otherwise(F.col("dose_unit_source_value"))
    )


    window_spec = Window.orderBy(
        "person_id", "drug_exposure_start_datetime", "drug_concept_id",
        F.coalesce(F.col("visit_occurrence_id"), F.lit(0)),
        F.coalesce(F.col("provider_id"), F.lit(0))
    )

    final_df = (drug_df_adjusted
        .withColumn("drug_exposure_id",
                   F.row_number().over(window_spec).cast("bigint"))
        .select(drug_schema.fieldNames()) 
    )

    return final_df.dropDuplicates(["drug_exposure_id"]) \
                   .repartition(200, "person_id")


In [0]:

procedure_schema = StructType([
    StructField("procedure_occurrence_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Procedure Occurrence event."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person who is subjected to the Procedure."}),
    StructField("procedure_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a standard procedure Concept identifier in the Vocabulary."}),
    StructField("procedure_date", DateType(), False,
                metadata={"comment": "The date on which the Procedure was performed."}),
    StructField("procedure_datetime", TimestampType(), True,
                metadata={"comment": "The date and time on which the Procedure was performed."}),
    StructField("procedure_end_date", DateType(), True,
                metadata={"comment": "The end date on which the Procedure was performed."}),
    StructField("procedure_end_datetime", TimestampType(), True,
                metadata={"comment": "The end date and time on which the Procedure was performed."}),
    StructField("procedure_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined Concept identifier in the Standardized Vocabularies reflecting the type of source data from which the procedure record is derived."}),
    StructField("modifier_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Standard Concept identifier for a modifier to the Procedure."}),
    StructField("quantity", IntegerType(), True,
                metadata={"comment": "The quantity of procedures ordered or administered."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the Provider in the PROVIDER table who was responsible for carrying out the procedure."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the Visit in the VISIT_OCCURRENCE table during which the Procedure was carried out."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the Visit Detail in the VISIT_DETAIL table during which the Procedure was carried out."}),
    StructField("procedure_source_value", StringType(), True,
                metadata={"comment": "The procedure as it appears in the source data."}),
    StructField("procedure_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Procedure Concept that refers to the code used in the source."}),
    StructField("modifier_source_value", StringType(), True,
                metadata={"comment": "The source code for the modifier as it appears in the source data."})
])


@dlt.table(
    name="base_procedure_occurrence_mapped",
    comment="Base procedure occurrence records mapped/validated to standard concepts",
    temporary=True
)
def create_base_procedure_occurrence_mapped():
    """
    Creates base procedure records using the provided OMOP_CONCEPT_ID
    and maps/validates it to a standard concept.
    """
    procedures = spark.table(f"{BRONZE_DB}.map_procedure") \
                      .filter(F.col("omop_concept_id").isNotNull() & (F.col("omop_concept_id") != 0))

    valid_persons = F.broadcast(dlt.read("omop_person").select("person_id").distinct())


    proc_filtered = procedures.join(valid_persons, "person_id", "inner")


    mapped_procedures = _map_or_use_standard(
        proc_filtered,
        source_omop_concept_id_col="omop_concept_id",
        target_domain_id="Procedure"
    )


    return (mapped_procedures
        .select(
            F.col("person_id").cast("bigint"),
            F.col("standard_concept_id").alias("procedure_concept_id").cast("integer"), 
            F.col("proc_dt_tm").alias("procedure_datetime").cast("timestamp"),
            F.lit(None).cast("timestamp").alias("procedure_end_datetime"), 
            F.lit(32817).cast("integer").alias("procedure_type_concept_id"), 
            F.lit(0).cast("integer").alias("modifier_concept_id"), 
            F.col("proc_minutes").alias("quantity").cast("integer"), 
            F.when(F.col("active_status_prsnl_id") == 0, None) 
                .otherwise(F.col("active_status_prsnl_id"))
                .alias("provider_id").cast("bigint"),
            F.col("encntr_id").alias("visit_occurrence_id").cast("bigint"),
            F.lit(None).cast("bigint").alias("visit_detail_id"),
            F.substring(F.col("source_string"), 1, 50) 
                .alias("procedure_source_value"),
            F.col("omop_concept_id").alias("procedure_source_concept_id").cast("integer"), 
            F.lit(None).cast("string").alias("modifier_source_value"), 
            F.col("mapping_count") 
        ))

@dlt.table(
    name="omop_procedure_occurrence",
    comment="OMOP CDM Procedure Occurrence table",
    schema=procedure_schema,
    table_properties={"quality": "gold"}
)
def create_omop_procedure_occurrence():
    """
    Creates the final procedure occurrence table with validation and nullification.
    """
    base_procedures_mapped = dlt.read("base_procedure_occurrence_mapped")

    valid_visits = F.broadcast(
               dlt.read("omop_visit_occurrence")
                  .select("visit_occurrence_id", "visit_start_datetime", "visit_end_datetime")
             )


    fixed_time = (base_procedures_mapped
        .join(valid_visits, "visit_occurrence_id", "left")
        .withColumn(
            "adj_procedure_datetime",
            F.when(F.col("visit_occurrence_id").isNull(), F.col("procedure_datetime"))
            .when(F.col("procedure_datetime").isNull(), F.lit(None))
            .when(F.col("visit_end_datetime").isNull(), F.col("procedure_datetime"))
            .when(F.col("procedure_datetime") > F.col("visit_end_datetime"),
                 F.col("visit_end_datetime"))
            .when(F.col("visit_start_datetime").isNull(), F.col("procedure_datetime"))
            .when(F.col("procedure_datetime") < F.col("visit_start_datetime"),
                 F.col("visit_start_datetime"))
            .otherwise(F.col("procedure_datetime"))
        )
        .withColumn("procedure_date", F.col("adj_procedure_datetime").cast("date"))
        .withColumn("procedure_end_date", F.lit(None).cast("date"))
        .withColumn("procedure_end_datetime", F.lit(None).cast("timestamp")) 
        .drop("visit_start_datetime", "visit_end_datetime", "procedure_datetime")
        .withColumnRenamed("adj_procedure_datetime", "procedure_datetime")
        .filter(F.col("procedure_date").isNotNull()) 
    )

    validated_provider = _validate_provider(fixed_time)

    # Nullify quantity if mapping_count > 1
    adjusted_quantity = validated_provider.withColumn(
        "quantity",
        F.when(F.col("mapping_count") > 1, F.lit(None)).otherwise(F.col("quantity"))
    )


    w_id = Window.orderBy(
        "person_id", "procedure_datetime", "procedure_concept_id",
        F.coalesce(F.col("visit_occurrence_id"), F.lit(0)),
        F.coalesce(F.col("provider_id"), F.lit(0))
        )
    with_id = adjusted_quantity.withColumn("procedure_occurrence_id",
                              F.row_number().over(w_id).cast("bigint"))

    final_df = with_id.select(procedure_schema.fieldNames()) 

    return final_df.dropDuplicates(["procedure_occurrence_id"]) \
                   .repartition(200, "person_id")


In [0]:


# Data quality rules for device exposure
device_schema = StructType([
    StructField("device_exposure_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Device exposure event."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person who is subjected to the Device."}),
    StructField("device_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a Standard Device Concept identifier in the Vocabulary."}),
    StructField("device_exposure_start_date", DateType(), False,
                metadata={"comment": "The start date for the Device exposure."}),
    StructField("device_exposure_start_datetime", TimestampType(), True,
                metadata={"comment": "The start date and time for the Device exposure."}),
    StructField("device_exposure_end_date", DateType(), True,
                metadata={"comment": "The end date for the Device exposure."}),
    StructField("device_exposure_end_datetime", TimestampType(), True,
                metadata={"comment": "The end date and time for the Device exposure."}),
    StructField("device_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined Concept identifier in the Standardized Vocabularies reflecting the type of Device exposure."}),
    StructField("unique_device_id", StringType(), True,
                metadata={"comment": "The Unique Device Identification (UDI-DI) number for devices regulated by the FDA."}),
    StructField("production_id", StringType(), True,
                metadata={"comment": "The Production Identifier (UDI-PI) portion of the Unique Device Identification."}),
    StructField("quantity", IntegerType(), True,
                metadata={"comment": "The number of individual Devices used."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the Provider in the PROVIDER table who initiated the Device exposure."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the Visit in the VISIT_OCCURRENCE table during which the Device exposure initiated."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the Visit Detail in the VISIT_DETAIL table during which the Device exposure initiated."}),
    StructField("device_source_value", StringType(), True,
                metadata={"comment": "The source code for the Device as it appears in the source data."}),
    StructField("device_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Device Concept that refers to the code used in the source."}),
    StructField("unit_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a predefined Concept in the Standardized Vocabularies reflecting the unit the Device was administered."}),
    StructField("unit_source_value", StringType(), True,
                metadata={"comment": "The source code for the unit as it appears in the source data."}),
    StructField("unit_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a Unit Concept that refers to the code used in the source."})
])



@dlt.table(
    name="valid_device_concepts",
    comment="Valid device concepts from OMOP vocabulary",
    temporary=True
)
def get_valid_device_concepts():
    """Gets valid device concepts from the OMOP vocabulary"""
    return (spark.table("3_lookup.omop.concept")
            .filter((col("invalid_reason").isNull()) & 
                   (col("domain_id") == "Device"))
            .select("concept_id")
            .distinct())

@dlt.table(
    name="valid_clinical_events",
    comment="Valid clinical events for device exposure",
    temporary=True
)
def get_valid_clinical_events():
    """Gets valid clinical events based on status and validity dates"""
    current_date_val = current_date()
    
    return (spark.table("4_prod.raw.mill_clinical_event")
            .filter((col("VALID_UNTIL_DT_TM") > current_date_val)) 
            .filter((col("RECORD_STATUS_CD") == 188))
            .select(
                "PERSON_ID", "EVENT_CD", "EVENT_TITLE_TEXT", 
                "CLINSIG_UPDT_DT_TM", "PERFORMED_PRSNL_ID", "ENCNTR_ID"
            ))

@dlt.table(
    name="device_code_maps",
    comment="Device mappings for code-based matches",
    temporary=True
)
def get_device_code_maps():
    """Gets device mappings for event code matches"""
    return (spark.table("3_lookup.omop.barts_new_maps")
            .filter((col("OMOPField") == "device_concept_id") & 
                   (col("SourceField") == "EVENT_CD"))
            .select(
                col("SourceValue").alias("event_cd_source"),
                col("OmopConceptId").alias("device_concept_id"),
                col("EVENT_CD").alias("map_event_cd")
            ))

@dlt.table(
    name="device_text_maps",
    comment="Device mappings for text-based matches",
    temporary=True
)
def get_device_text_maps():
    """Gets device mappings for event text matches"""
    return (spark.table("3_lookup.omop.barts_new_maps")
            .filter((col("OMOPField") == "device_concept_id") & 
                   (col("SourceField") == "EVENT_RESULT_TXT"))
            .select(
                col("SourceValue").alias("result_txt_source"),
                col("OmopConceptId").alias("device_concept_id"),
                col("EVENT_CD").alias("map_event_cd")
            ))

@dlt.table(
    name="base_device_exposure",
    comment="Combined device exposures from code and text matches",
    temporary=True
)
def create_base_device_exposure():

    ce   = dlt.read("valid_clinical_events").alias("ce")
    cm   = dlt.read("device_code_maps").alias("cm")
    tm   = dlt.read("device_text_maps").alias("tm")
    ppl  = dlt.read("omop_person")            \
                .select(col("person_id").alias("valid_person_id"))
    devs = dlt.read("valid_device_concepts")

    code_matches = (ce
        .join(cm,
              (col("ce.EVENT_CD") == col("cm.event_cd_source")) &
              (col("cm.map_event_cd").isNull() |
               (col("cm.map_event_cd") == col("ce.EVENT_CD"))),
              "inner")
        .select("PERSON_ID", "device_concept_id",
                "CLINSIG_UPDT_DT_TM", "PERFORMED_PRSNL_ID",
                "ENCNTR_ID", "EVENT_CD"))


    text_matches = (ce
        .join(tm,
              (col("ce.EVENT_TITLE_TEXT") == col("tm.result_txt_source")) &
              (col("tm.map_event_cd").isNull() |
               (col("tm.map_event_cd") == col("ce.EVENT_CD"))),
              "inner")
        .select("PERSON_ID", "device_concept_id",
                "CLINSIG_UPDT_DT_TM", "PERFORMED_PRSNL_ID",
                "ENCNTR_ID", "EVENT_CD"))

    combined = code_matches.unionAll(text_matches)


    filtered = (combined
        .join(ppl,   col("PERSON_ID") == col("valid_person_id"), "inner")
        .join(devs,  col("device_concept_id") == col("concept_id"), "inner"))


    base = (filtered
        .select(
            col("PERSON_ID").cast("bigint").alias("person_id"),
            col("device_concept_id").cast("integer"),
            col("CLINSIG_UPDT_DT_TM")
                .alias("device_exposure_start_datetime").cast("timestamp"),
            lit(None).cast("timestamp").alias("device_exposure_end_datetime"),
            lit(32817).cast("integer").alias("device_type_concept_id"),
            lit(None).cast("string").alias("unique_device_id"),
            lit(None).cast("string").alias("production_id"),
            lit(1).cast("integer").alias("quantity"),
            when(col("PERFORMED_PRSNL_ID") == 0, None)
              .otherwise(col("PERFORMED_PRSNL_ID"))
              .cast("bigint").alias("provider_id"),
            col("ENCNTR_ID").cast("bigint").alias("visit_occurrence_id"),
            lit(None).cast("bigint").alias("visit_detail_id"),
            col("EVENT_CD").cast("string").alias("device_source_value"),
            lit(0).cast("integer").alias("device_source_concept_id"),
            lit(None).cast("integer").alias("unit_concept_id"),
            lit(None).cast("string").alias("unit_source_value"),
            lit(0).cast("integer").alias("unit_source_concept_id")
        )
        # natural key de-duplication
        .dropDuplicates([
            "person_id", "device_concept_id",
            "device_exposure_start_datetime",
            "visit_occurrence_id", "provider_id"
        ]))

    return base

@dlt.table(
    name="omop_device_exposure",
    comment="OMOP CDM Device Exposure table - Contains records about exposure to a foreign physical object or instrument used for diagnostic or therapeutic purposes",
    schema=device_schema,
    table_properties={"quality": "gold"}
)
def create_omop_device_exposure():

    base = dlt.read("base_device_exposure")

    w_global = Window.orderBy(
        "person_id",
        "device_exposure_start_datetime",
        "device_concept_id",
        F.coalesce(col("visit_occurrence_id"), lit(0)),
        F.coalesce(col("provider_id"),         lit(0))
    )

    with_id = (base
        .withColumn("device_exposure_id",
                    row_number().over(w_global).cast("bigint")))

    visits = broadcast(
        dlt.read("omop_visit_occurrence")
            .select("visit_occurrence_id",
                    "visit_start_datetime",
                    "visit_end_datetime")
    )

    corrected = (with_id
        .join(visits, "visit_occurrence_id", "left")
        .withColumn(
            "device_exposure_start_datetime",
            when(col("visit_occurrence_id").isNotNull() &
                 (col("device_exposure_start_datetime") < col("visit_start_datetime")),
                 col("visit_start_datetime"))
            .when(col("visit_occurrence_id").isNotNull() &
                  (col("device_exposure_start_datetime") > col("visit_end_datetime")),
                  col("visit_end_datetime"))
            .otherwise(col("device_exposure_start_datetime"))
        )
        .withColumn("device_exposure_start_date",
                    col("device_exposure_start_datetime").cast("date"))
        .withColumn("device_exposure_end_date",
                    col("device_exposure_end_datetime").cast("date"))
        .drop("visit_start_datetime", "visit_end_datetime")
    )

    return (_validate_provider(corrected)
            .repartition(200, "person_id"))

In [0]:


# Data quality rules for measurements
measurement_schema = StructType([
    StructField("measurement_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Measurement."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person about whom the measurement was recorded."}),
    StructField("measurement_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the standard measurement concept identifier in the Vocabulary."}),
    StructField("measurement_date", DateType(), False,
                metadata={"comment": "The date of the measurement."}),
    StructField("measurement_datetime", TimestampType(), True,
                metadata={"comment": "The date and time of the measurement."}),
    StructField("measurement_time", StringType(), True,
                metadata={"comment": "The time of the measurement (in the event that MEASUREMENT_DATETIME is not well defined)."}),
    StructField("measurement_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the type of the measurement."}),
    StructField("operator_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a standard concept identifier for the mathematical operator applied to the value."}),
    StructField("value_as_number", FloatType(), True,
                metadata={"comment": "The measurement result stored as a number. This is applicable to measurements where the result is expressed as a numeric value."}),
    StructField("value_as_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a standard concept identifier for a categorical result."}),
    StructField("unit_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a standard concept identifier for the unit used in the measurement."}),
    StructField("range_low", FloatType(), True,
                metadata={"comment": "The lower limit of the normal range of the measurement."}),
    StructField("range_high", FloatType(), True,
                metadata={"comment": "The upper limit of the normal range of the measurement."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider in the PROVIDER table who was responsible for taking the measurement."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit in the VISIT_OCCURRENCE table during which the measurement was taken."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit detail in the VISIT_DETAIL table during which the measurement was taken."}),
    StructField("measurement_source_value", StringType(), True,
                metadata={"comment": "The measurement name as it appears in the source data."}),
    StructField("measurement_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."}),
    StructField("unit_source_value", StringType(), True,
                metadata={"comment": "The source code for the unit as it appears in the source data."}),
    StructField("unit_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the unit code used in the source."}),
    StructField("value_source_value", StringType(), True,
                metadata={"comment": "The source value associated with the structured value stored as numeric or concept."}),
    StructField("measurement_event_id", LongType(), True,
                metadata={"comment": "A foreign key to the MEASUREMENT_EVENT table."}),
    StructField("meas_event_field_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the field in the MEASUREMENT_EVENT table."})
])

@dlt.table(
    name="valid_persons_ref",
    comment="Cached valid person references",
    temporary=True
)
def get_valid_persons():

    return (dlt.read("omop_person")
            .select(F.col("person_id").alias("valid_person_id"))
            .distinct())



@dlt.table(
    name="valid_providers_ref",
    comment="Cached valid provider references",
    temporary=True
)
def get_valid_providers():
    return (dlt.read("omop_provider")
            .select("provider_id")
            .distinct())


def get_measurement_date(performed_dt, clinsig_dt):
    return F.when(
        (performed_dt.isNotNull()) &
        (clinsig_dt.isNotNull()) &
        (F.abs(F.months_between(performed_dt, clinsig_dt)) > 6),
        clinsig_dt
    ).otherwise(F.coalesce(performed_dt, clinsig_dt))



@dlt.table(
    name="combined_source_measurements_mapped",
    comment="Combined measurements from all sources, mapped to standard concepts",
    temporary=True
)
def create_combined_measurements_mapped(): 

    valid_persons = dlt.read("valid_persons_ref")

    numeric_events_raw = spark.table(f"{BRONZE_DB}.map_numeric_events") \
        .filter(F.col("OMOP_MANUAL_CONCEPT_DOMAIN") == "Measurement") \
        .filter(F.col("OMOP_MANUAL_CONCEPT").isNotNull() & (F.col("OMOP_MANUAL_CONCEPT") != 0)) \
        .join(F.broadcast(valid_persons), F.col("person_id") == F.col("valid_person_id")) 

    numeric_mapped = _map_or_use_standard(
        numeric_events_raw,
        source_omop_concept_id_col="OMOP_MANUAL_CONCEPT",
        target_domain_id="Measurement"
    )


    numeric_processed = numeric_mapped.select(
        F.col("person_id").cast("bigint"),
        F.col("standard_concept_id").alias("measurement_concept_id").cast("integer"), 
        get_measurement_date(F.col("PERFORMED_DT_TM"), F.col("CLINSIG_UPDT_DT_TM"))
            .alias("measurement_datetime").cast("timestamp"),
        F.lit(32817).cast("integer").alias("measurement_type_concept_id"),
        F.when(F.col("PERFORMED_PRSNL_ID") == 0, None)
            .otherwise(F.col("PERFORMED_PRSNL_ID")).alias("provider_id").cast("bigint"),
        F.col("ENCNTR_ID").alias("visit_occurrence_id").cast("bigint"),
        F.col("EVENT_CD_DISPLAY").alias("measurement_source_value"), 
        F.col("OMOP_MANUAL_CONCEPT").alias("measurement_source_concept_id").cast("integer"),
        F.col("EVENT_ID").cast("bigint").alias("measurement_event_id"),
        F.lit(None).cast("integer").alias("operator_concept_id"),
        F.col("NUMERIC_RESULT").alias("value_as_number").cast("float"),
        F.lit(None).cast("integer").alias("value_as_concept_id"),
        F.col("OMOP_MANUAL_UNITS").alias("unit_concept_id").cast("integer"), 
        F.col("NORMAL_LOW").alias("range_low").cast("float"),
        F.col("NORMAL_HIGH").alias("range_high").cast("float"),
        F.col("UNIT_OF_MEASURE_DISPLAY").alias("unit_source_value"),
        F.col("OMOP_MANUAL_UNITS").alias("unit_source_concept_id").cast("integer"),
        F.col("NUMERIC_RESULT").cast("string").alias("value_source_value"), 
        F.col("mapping_count") 
    )

    coded_processed_list = []
    for table_name in ["coded_events", "text_events", "nomen_events"]:

        coded_events_raw = spark.table(f"{BRONZE_DB}.map_{table_name}") \
            .filter(F.col("OMOP_MANUAL_CONCEPT_DOMAIN") == "Measurement") \
            .filter(F.col("OMOP_MANUAL_CONCEPT").isNotNull() & (F.col("OMOP_MANUAL_CONCEPT") != 0)) \
            .join(F.broadcast(valid_persons), F.col("person_id") == F.col("valid_person_id")) 

        coded_mapped = _map_or_use_standard(
            coded_events_raw,
            source_omop_concept_id_col="OMOP_MANUAL_CONCEPT",
            target_domain_id="Measurement"
        )


        value_source_val_col = F.coalesce(
            F.col("EVENT_TITLE_TEXT"), 
            F.col("EVENT_CD_DISPLAY")  
        ).cast("string")
        if table_name == "text_events" and "TEXT_RESULT" in coded_mapped.columns:
             value_source_val_col = F.col("TEXT_RESULT").cast("string")



        coded_processed = coded_mapped.select(
            F.col("person_id").cast("bigint"),
            F.col("standard_concept_id").alias("measurement_concept_id").cast("integer"), 
            get_measurement_date(F.col("PERFORMED_DT_TM"), F.col("CLINSIG_UPDT_DT_TM"))
                .alias("measurement_datetime").cast("timestamp"),
            F.lit(32817).cast("integer").alias("measurement_type_concept_id"),
            F.when(F.col("PERFORMED_PRSNL_ID") == 0, None)
                .otherwise(F.col("PERFORMED_PRSNL_ID")).alias("provider_id").cast("bigint"),
            F.col("ENCNTR_ID").alias("visit_occurrence_id").cast("bigint"),
            F.col("EVENT_CD_DISPLAY").alias("measurement_source_value"), 
            F.col("OMOP_MANUAL_CONCEPT").alias("measurement_source_concept_id").cast("integer"), 
            F.col("EVENT_ID").cast("bigint").alias("measurement_event_id"),
            F.lit(None).cast("integer").alias("operator_concept_id"),
            F.lit(None).cast("float").alias("value_as_number"),
            F.col("OMOP_MANUAL_VALUE_CONCEPT").alias("value_as_concept_id").cast("integer"), 
            F.lit(None).cast("integer").alias("unit_concept_id"),
            F.lit(None).cast("float").alias("range_low"),
            F.lit(None).cast("float").alias("range_high"),
            F.lit(None).cast("string").alias("unit_source_value"),
            F.lit(0).cast("integer").alias("unit_source_concept_id"),
            value_source_val_col.alias("value_source_value"), 
            F.col("mapping_count") 
        )
        coded_processed_list.append(coded_processed)


    all_measurements = numeric_processed
    for df in coded_processed_list:

        all_measurements = all_measurements.unionByName(df, allowMissingColumns=True)


    return all_measurements.dropDuplicates(["measurement_event_id", "measurement_concept_id"]) \
                            .repartition(200, "person_id", "measurement_datetime")


@dlt.table(
    name="omop_measurement",
    comment="OMOP CDM Measurement table",
    schema=measurement_schema,
    table_properties={"quality": "gold"}
)
def create_omop_measurement():
    measurements_mapped = dlt.read("combined_source_measurements_mapped")


    measurements_deduped = measurements_mapped.dropDuplicates([
        "person_id", "measurement_concept_id", "measurement_datetime",
        "visit_occurrence_id", "provider_id",
        "value_as_number", "value_as_concept_id"
    ])

    # Get reference data
    valid_visits = F.broadcast(dlt.read("omop_visit_occurrence").select(
        "visit_occurrence_id", "visit_start_datetime", "visit_end_datetime"
    ))
    valid_providers = dlt.read("valid_providers_ref") 


    measurements_with_visits = (measurements_deduped # Start from deduplicated data
        .join(valid_visits, "visit_occurrence_id", "left") 
        .withColumn(
            "adj_measurement_datetime",
            F.when(F.col("visit_occurrence_id").isNull(), F.col("measurement_datetime"))
            .when(F.col("measurement_datetime").isNull(), F.lit(None))
            .when(F.col("visit_end_datetime").isNull(), F.col("measurement_datetime"))
            .when(F.col("measurement_datetime") > F.col("visit_end_datetime"),
                 F.col("visit_end_datetime"))
            .when(F.col("visit_start_datetime").isNull(), F.col("measurement_datetime"))
            .when(F.col("measurement_datetime") < F.col("visit_start_datetime"),
                 F.col("visit_start_datetime"))
            .otherwise(F.col("measurement_datetime")))
        .withColumn("measurement_date",
                   F.col("adj_measurement_datetime").cast("date"))
        .withColumn("measurement_time", F.date_format(F.col("adj_measurement_datetime"), "HH:mm:ss"))
        .drop("visit_start_datetime", "visit_end_datetime", "measurement_datetime")
        .withColumnRenamed("adj_measurement_datetime", "measurement_datetime")
        .filter(F.col("measurement_date").isNotNull())
        )


    measurements_validated_provider = (measurements_with_visits.alias("m")
        .join(F.broadcast(valid_providers).alias("vp"),
              F.col("m.provider_id") == F.col("vp.provider_id"),
              "left")
        .select(
            F.col("m.*"), 
            F.when(F.col("vp.provider_id").isNotNull(), F.col("m.provider_id"))
            .otherwise(F.lit(None)).alias("validated_provider_id")
        )
        .drop("provider_id") 
        .withColumnRenamed("validated_provider_id", "provider_id") 
    )


    # Nullify numeric value fields if mapping_count > 1
    measurements_adjusted = measurements_validated_provider.withColumn(
        "value_as_number",
        F.when(F.col("mapping_count") > 1, F.lit(None)).otherwise(F.col("value_as_number"))
    ).withColumn(
        "range_low",
        F.when(F.col("mapping_count") > 1, F.lit(None)).otherwise(F.col("range_low"))
    ).withColumn(
        "range_high",
        F.when(F.col("mapping_count") > 1, F.lit(None)).otherwise(F.col("range_high"))
    )

    window_spec = Window.orderBy(
        "person_id",
        "measurement_datetime",
        "measurement_concept_id",
        F.coalesce(F.col("visit_occurrence_id"), F.lit(0)),
        F.coalesce(F.col("provider_id"), F.lit(0)),
        F.coalesce(F.col("value_as_concept_id"), F.lit(0)), 
        F.coalesce(F.col("value_as_number"), F.lit(0.0)), 
        F.col("measurement_event_id") 
    )

    # Add measurement_id
    measurements_with_id = (measurements_adjusted
        .withColumn("measurement_id",
                   F.row_number().over(window_spec).cast("bigint")))

    # Add missing columns required by schema before final select
    required_schema_fields = {f.name: f.dataType for f in measurement_schema}
    current_cols = {c: measurements_with_id.schema[c].dataType for c in measurements_with_id.columns}

    for field_name, field_type in required_schema_fields.items():
        if field_name not in current_cols:
            measurements_with_id = measurements_with_id.withColumn(field_name, F.lit(None).cast(field_type))


    # Drop mapping_count as it's not in the final schema
    result = measurements_with_id.select(measurement_schema.fieldNames())

    return result.dropDuplicates(["measurement_id"]) \
                 .repartition(200, "person_id") 

In [0]:


# Data quality rules for observations
observation_schema = StructType([
    StructField("observation_id", LongType(), False,
                metadata={"comment": "A unique identifier for each observation."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person about whom the observation was recorded."}),
    StructField("observation_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the standard observation concept identifier in the Vocabulary."}),
    StructField("observation_date", DateType(), False,
                metadata={"comment": "The date of the observation."}),
    StructField("observation_datetime", TimestampType(), True,
                metadata={"comment": "The date and time of the observation."}),
    StructField("observation_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the type of the observation."}),
    StructField("value_as_number", FloatType(), True,
                metadata={"comment": "The observation result stored as a number. This is applicable to observations where the result is expressed as a numeric value."}),
    StructField("value_as_string", StringType(), True,
                metadata={"comment": "The observation result stored as a string."}),
    StructField("value_as_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to an observation result stored as a Concept ID."}),
    StructField("qualifier_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a standard concept identifier for a qualifier."}),
    StructField("unit_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a standard concept identifier for the unit."}),
    StructField("provider_id", LongType(), True,
                metadata={"comment": "A foreign key to the provider in the PROVIDER table who was responsible for making the observation."}),
    StructField("visit_occurrence_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit in the VISIT_OCCURRENCE table during which the observation was recorded."}),
    StructField("visit_detail_id", LongType(), True,
                metadata={"comment": "A foreign key to the visit detail in the VISIT_DETAIL table during which the observation was recorded."}),
    StructField("observation_source_value", StringType(), True,
                metadata={"comment": "The observation code as it appears in the source data."}),
    StructField("observation_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."}),
    StructField("unit_source_value", StringType(), True,
                metadata={"comment": "The source code for the unit as it appears in the source data."}),
    StructField("qualifier_source_value", StringType(), True,
                metadata={"comment": "The source value associated with a qualifier to characterize the observation."}),
    StructField("value_source_value", StringType(), True,
                metadata={"comment": "The source value associated with the structured value stored as numeric, string, or concept."}),
    StructField("observation_event_id", LongType(), True,
                metadata={"comment": "A foreign key to the event that caused this observation to be made."}),
    StructField("obs_event_field_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting how the event field is used."})
])

@dlt.table(
    name="event_observations_mapped",
    comment="Observations from clinical events, mapped to standard concepts",
    temporary=True
)
def create_event_observations_mapped():
    """Creates observation records from clinical events, mapping concept ID."""

    all_event_observations = None

    for table_name in ["coded_events", "text_events", "nomen_events"]:
        print(f"Processing observation source: {table_name}")
        events_df = spark.table(f"{BRONZE_DB}.map_{table_name}")


        valid_persons = F.broadcast(dlt.read("omop_person").select("person_id").distinct())

        df_filtered = events_df \
            .filter(F.col("OMOP_MANUAL_CONCEPT_DOMAIN") == "Observation") \
            .filter(F.col("OMOP_MANUAL_CONCEPT").isNotNull() & (F.col("OMOP_MANUAL_CONCEPT") != 0)) \
            .join(valid_persons, events_df["person_id"] == valid_persons["person_id"], "inner") \
            .drop(valid_persons["person_id"]) # Avoid ambiguous person_id

  
        mapped_df = _map_or_use_standard(
            df_filtered,
            source_omop_concept_id_col="OMOP_MANUAL_CONCEPT",
            target_domain_id="Observation"
        )


        date_selection = F.when(
            (F.col("PERFORMED_DT_TM").isNotNull()) &
            (F.col("CLINSIG_UPDT_DT_TM").isNotNull()) &
            (F.abs(F.months_between(F.col("PERFORMED_DT_TM"), F.col("CLINSIG_UPDT_DT_TM"))) > 6),
            F.col("CLINSIG_UPDT_DT_TM")
        ).otherwise(F.coalesce(F.col("PERFORMED_DT_TM"), F.col("CLINSIG_UPDT_DT_TM")))


        value_string_col = F.coalesce(
            F.col("EVENT_TITLE_TEXT"),
            F.col("EVENT_CD_DISPLAY")
        ).cast("string")
        if table_name == "text_events" and "TEXT_RESULT" in mapped_df.columns:
             value_string_col = F.col("TEXT_RESULT").cast("string")



        processed_df = mapped_df.select(
            F.col("person_id").cast("bigint"),
            F.col("standard_concept_id").alias("observation_concept_id").cast("integer"), 
            date_selection.alias("observation_datetime").cast("timestamp"),
            F.lit(32817).cast("integer").alias("observation_type_concept_id"), 
            F.lit(None).cast("float").alias("value_as_number"), 
            value_string_col.alias("value_as_string"), 
            F.col("OMOP_MANUAL_VALUE_CONCEPT").alias("value_as_concept_id").cast("integer"), 
            F.lit(None).cast("integer").alias("qualifier_concept_id"), 
            F.lit(None).cast("integer").alias("unit_concept_id"), 
            F.when(F.col("PERFORMED_PRSNL_ID") == 0, None)
                .otherwise(F.col("PERFORMED_PRSNL_ID")).alias("provider_id").cast("bigint"), 
            F.col("ENCNTR_ID").alias("visit_occurrence_id").cast("bigint"), 
            F.lit(None).cast("bigint").alias("visit_detail_id"), 
            F.col("EVENT_CD_DISPLAY").alias("observation_source_value"), 
            F.col("OMOP_MANUAL_CONCEPT").alias("observation_source_concept_id").cast("integer"), 
            F.lit(None).cast("string").alias("unit_source_value"), 
            F.lit(None).cast("string").alias("qualifier_source_value"), 
            value_string_col.alias("value_source_value"), 
            F.col("EVENT_ID").cast("bigint").alias("observation_event_id"), 
            F.lit(None).cast("integer").alias("obs_event_field_concept_id"), 
            F.col("mapping_count") 
        )

        if all_event_observations is None:
            all_event_observations = processed_df
        else:

            all_event_observations = all_event_observations.unionByName(processed_df, allowMissingColumns=True)

    return all_event_observations.dropDuplicates(["observation_event_id", "observation_concept_id"])


@dlt.table(
    name="problem_observations_mapped", 
    comment="Observations from problems, mapped to standard concepts",
    temporary=True
)
def create_problem_observations_mapped(): 
    """Creates observation records from problem data, mapping concept ID."""

    problems = spark.table(f"{BRONZE_DB}.map_problem")
    valid_persons = F.broadcast(dlt.read("omop_person").select("person_id").distinct())

    problem_filtered = problems.alias("p") \
        .filter((F.col("OMOP_CONCEPT_DOMAIN") == "Observation") &
                (F.col("OMOP_CONCEPT_ID").isNotNull()) &
                (F.col("OMOP_CONCEPT_ID") != 0) &
                (F.col("CALC_ENCNTR").isNotNull())) \
        .join(valid_persons, F.col("p.person_id") == valid_persons["person_id"], "inner") \
        .drop(valid_persons["person_id"]) # Avoid ambiguous person_id
        


    mapped_problem = _map_or_use_standard(
        problem_filtered,
        source_omop_concept_id_col="OMOP_CONCEPT_ID",
        target_domain_id="Observation"
    )


    return (mapped_problem
        .select(
            F.col("person_id").cast("bigint"),
            F.col("standard_concept_id").alias("observation_concept_id").cast("integer"), 
            F.col("CALC_DT_TM").alias("observation_datetime").cast("timestamp"), 
            F.lit(32817).cast("integer").alias("observation_type_concept_id"), 
            F.lit(None).cast("float").alias("value_as_number"), 
            F.col("SOURCE_STRING").alias("value_as_string"), 
            F.lit(None).cast("integer").alias("value_as_concept_id"), 
            F.lit(None).cast("integer").alias("qualifier_concept_id"), 
            F.lit(None).cast("integer").alias("unit_concept_id"),
            F.when(F.col("ACTIVE_STATUS_PRSNL_ID").isin([0, 1]), None)
                .otherwise(F.col("ACTIVE_STATUS_PRSNL_ID")).alias("provider_id").cast("bigint"), 
            F.col("CALC_ENCNTR").alias("visit_occurrence_id").cast("bigint"), 
            F.lit(None).cast("bigint").alias("visit_detail_id"), 
            F.col("SOURCE_STRING").alias("observation_source_value"), 
            F.col("OMOP_CONCEPT_ID").alias("observation_source_concept_id").cast("integer"),
            F.lit(None).cast("string").alias("unit_source_value"), 
            F.lit(None).cast("string").alias("qualifier_source_value"),
            F.col("SOURCE_STRING").alias("value_source_value"), 
            F.lit(None).cast("bigint").alias("observation_event_id"), 
            F.lit(None).cast("integer").alias("obs_event_field_concept_id"), 
            F.col("mapping_count") 
        ))


@dlt.table(
    name="person_address_history_with_imd",
    comment="Generates historical IMD quintiles for persons based on their address start dates (BEG_EFFECTIVE_DT_TM). Uses full postcode match for LSOA.",
    temporary=True
)
def create_person_address_history_with_imd():
    """
    Processes raw address data to link persons to IMD quintiles effective at the start of an address period.
    Uses full postcode matching for LSOA to ensure higher accuracy for deprivation scores.
    Handles incremental processing based on ADC_UPDT of mill_address.
    """
    max_adc_updt_val = get_max_timestamp("4_prod.raw.mill_address")

    base_addresses = (
        spark.table("4_prod.raw.mill_address")
        .filter(
            (F.col("PARENT_ENTITY_NAME") == "PERSON") &
            (F.col("ACTIVE_IND") == 1) & # Address record itself is active
            (F.col("ADC_UPDT") > max_adc_updt_val if max_adc_updt_val is not None and not isinstance(max_adc_updt_val, type(lit(None))) else F.lit(True)) &
            (F.col("ZIPCODE").isNotNull() & (F.trim(F.col("ZIPCODE")) != "")) &
            (F.col("BEG_EFFECTIVE_DT_TM").isNotNull())
        )
        .select(
            F.col("PARENT_ENTITY_ID"),
            F.col("ZIPCODE"),
            F.col("BEG_EFFECTIVE_DT_TM"),
            F.col("ADC_UPDT") 
        )
        .withColumn("clean_zipcode", F.regexp_replace(F.col("ZIPCODE"), r'\s+', ''))
    )

    postcode_maps = (
        spark.table("3_lookup.imd.postcode_maps")
        .select(
            F.col("pcd7"), # Full 7-character postcode
            F.col("lsoa21cd")
        )
        .withColumn("clean_pcd7", F.regexp_replace(F.col("pcd7"), r'\s+', ''))
        .drop_duplicates(["clean_pcd7"]) # Ensure one LSOA per unique full postcode
    )

    address_with_lsoa = base_addresses.join(
        postcode_maps,
        base_addresses.clean_zipcode == postcode_maps.clean_pcd7,
        "inner"
    ).select(
        base_addresses.PARENT_ENTITY_ID,
        base_addresses.BEG_EFFECTIVE_DT_TM,
        postcode_maps.lsoa21cd.alias("final_lsoa21cd")
    )

    imd_table = (
        spark.table("3_lookup.imd.imd_2019")
        .filter(
            (F.col("DateCode") == 2019) &
            (F.regexp_replace(F.col("Measurement"), " ", "") == "Decile") &
            (F.col("Indices_of_Deprivation") == "a. Index of Multiple Deprivation (IMD)")
        )
        .select(
            F.col("FeatureCode").alias("lsoa_code_imd"), # LSOA code in IMD table
            F.col("Value").alias("imd_decile_value")     # IMD Decile
        )
        .drop_duplicates(["lsoa_code_imd"]) # Ensure one decile value per LSOA
    )

    address_with_imd_decile = address_with_lsoa.join(
        F.broadcast(imd_table), # Broadcast smaller IMD table
        address_with_lsoa.final_lsoa21cd == imd_table.lsoa_code_imd,
        "left" # Left join in case some LSOAs don't have an IMD decile; will filter later
    )

    # Calculate IMD Quintile and select final fields for observation creation
    person_imd_history_df = (
        address_with_imd_decile
        .withColumn(
            "IMD_Quintile_Calc", # Intermediate column name
            F.when(F.col("imd_decile_value").isNull(), F.lit(None).cast("integer"))
            .when(F.col("imd_decile_value").isin([1, 2]), 1)
            .when(F.col("imd_decile_value").isin([3, 4]), 2)
            .when(F.col("imd_decile_value").isin([5, 6]), 3)
            .when(F.col("imd_decile_value").isin([7, 8]), 4)
            .when(F.col("imd_decile_value").isin([9, 10]), 5)
            .otherwise(F.lit(None).cast("integer"))
        )
        .filter(F.col("IMD_Quintile_Calc").isNotNull()) 
        .select(
            F.col("PARENT_ENTITY_ID").alias("person_id"),
            F.col("BEG_EFFECTIVE_DT_TM").alias("observation_datetime"),
            F.col("IMD_Quintile_Calc").alias("value_as_number")
        )

        .drop_duplicates(["person_id", "observation_datetime", "value_as_number"])
    )
    return person_imd_history_df

@dlt.table(
    name="imd_quintile_observations_historical",
    comment="Transforms historical IMD quintile data into OMOP Observation format.",
    temporary=True
)
def create_imd_quintile_observations_historical():
    """
    Creates OMOP observation records for historical IMD Quintiles.
    Uses BEG_EFFECTIVE_DT_TM as the observation_datetime.
    Ensures person_id exists in the omop_person table.
    """
    person_imd_source = dlt.read("person_address_history_with_imd")

    # Ensure person_id is valid by joining with omop_person table
    valid_persons = F.broadcast(dlt.read("omop_person").select("person_id").distinct())

    observations_df = person_imd_source.join(
            valid_persons,
            person_imd_source["person_id"] == valid_persons["person_id"],
            "inner"
        ).drop(valid_persons["person_id"]) # Avoid ambiguous person_id column

    return observations_df.select(
        F.col("person_id").cast("long"),
        F.lit(35812882).cast("integer").alias("observation_concept_id"),       
        F.col("observation_datetime").cast("timestamp"),                       # From BEG_EFFECTIVE_DT_TM
        F.lit(32817).cast("integer").alias("observation_type_concept_id"),    
        F.col("value_as_number").cast("float"),                                # IMD Quintile
        F.lit(None).cast("string").alias("value_as_string"),
        F.lit(None).cast("integer").alias("value_as_concept_id"),
        F.lit(None).cast("integer").alias("qualifier_concept_id"),
        F.lit(37524288).cast("integer").alias("unit_concept_id"),             
        F.lit(None).cast("long").alias("provider_id"),
        F.lit(None).cast("long").alias("visit_occurrence_id"),                 # Not tied to a specific visit
        F.lit(None).cast("long").alias("visit_detail_id"),
        F.lit("IMD Quintile from address history").cast("string").alias("observation_source_value"),
        F.lit(35812882).cast("integer").alias("observation_source_concept_id"),# Source concept for deprivation
        F.lit("quintile").cast("string").alias("unit_source_value"),
        F.lit(None).cast("string").alias("qualifier_source_value"),
        F.col("value_as_number").cast("string").alias("value_source_value"),   # Store original quintile as string
        F.lit(None).cast("long").alias("observation_event_id"),
        F.lit(None).cast("integer").alias("obs_event_field_concept_id"),
        F.lit(1).alias("mapping_count") # Direct assignment, so mapping_count is 1
    )

def truncate(col_name: str, max_len: int = 48):
    """
    Safely truncates a string column to `max_len` characters.
    Keeps NULLs as NULL.
    """
    return F.when(F.col(col_name).isNull(), None) \
            .otherwise(F.expr(f"substring({col_name}, 1, {max_len})"))


@dlt.table(
    name="omop_observation",
    comment="OMOP CDM Observation table - Contains clinical facts about a Person obtained in the context of examination, questioning or a procedure",
    schema=observation_schema,
    table_properties={"quality": "gold"}
)
def create_omop_observation():
    """
    Creates the final observation table with validation, referential
    integrity checks and enforced 48-character limits on selected
    free-text columns.
    """

    event_obs_mapped   = dlt.read("event_observations_mapped")
    problem_obs_mapped = dlt.read("problem_observations_mapped")
    imd_quintile_obs   = dlt.read("imd_quintile_observations_historical")

    combined_observations = (
        event_obs_mapped
        .unionByName(problem_obs_mapped, allowMissingColumns=True)
        .unionByName(imd_quintile_obs, allowMissingColumns=True)
    )

    # Deduplicate
    dedup_cols = [
        "person_id", "observation_concept_id", "observation_datetime",
        "visit_occurrence_id", "provider_id", "value_as_concept_id",
        "value_as_string", "observation_source_concept_id"
    ]
    combined_deduplicated = combined_observations.dropDuplicates(dedup_cols)


    try:
        valid_visits = F.broadcast(
            dlt.read("omop_visit_occurrence").select(
                "visit_occurrence_id",
                "visit_start_datetime",
                "visit_end_datetime"
            )
        )
    except Exception:
        # Dummies when running outside the full pipeline (unit-tests, etc.)
        visit_schema = StructType([
            StructField("visit_occurrence_id", LongType()),
            StructField("visit_start_datetime", TimestampType()),
            StructField("visit_end_datetime", TimestampType())
        ])
        valid_visits = F.broadcast(spark.createDataFrame([], visit_schema))

    try:
        valid_providers = F.broadcast(
            dlt.read("omop_provider").select("provider_id").distinct()
        )
    except Exception:
        provider_schema = StructType([StructField("provider_id", LongType())])
        valid_providers = F.broadcast(spark.createDataFrame([], provider_schema))


    observations_with_visits = (
        combined_deduplicated
        .join(valid_visits, "visit_occurrence_id", "left")
        .withColumn(
            "adj_observation_datetime",
            F.when(F.col("visit_occurrence_id").isNull(), F.col("observation_datetime"))
             # if obs_dt missing keep NULL
            .when(F.col("observation_datetime").isNull(), F.lit(None))
             # clamp to visit_end
            .when(F.col("visit_end_datetime").isNull(), F.col("observation_datetime"))
            .when(F.col("observation_datetime") > F.col("visit_end_datetime"), F.col("visit_end_datetime"))
             # clamp to visit_start
            .when(F.col("visit_start_datetime").isNull(), F.col("observation_datetime"))
            .when(F.col("observation_datetime") < F.col("visit_start_datetime"), F.col("visit_start_datetime"))
            .otherwise(F.col("observation_datetime"))
        )
        .withColumn("observation_date", F.col("adj_observation_datetime").cast("date"))
        .drop("visit_start_datetime", "visit_end_datetime", "observation_datetime")
        .withColumnRenamed("adj_observation_datetime", "observation_datetime")
        .filter(F.col("observation_date").isNotNull())
    )


    observations_validated = (
        observations_with_visits.alias("o")
        .join(valid_providers.alias("vp"),
              F.col("o.provider_id") == F.col("vp.provider_id"), "left")
        .select(
            F.col("o.*"),
            F.when(F.col("vp.provider_id").isNotNull(), F.col("o.provider_id"))
             .otherwise(F.lit(None)).alias("validated_provider_id")
        )
        .drop("provider_id")
        .withColumnRenamed("validated_provider_id", "provider_id")
    )


    observations_adjusted = (
        observations_validated
        # Clear ambiguous value_as_number
        .withColumn(
            "value_as_number",
            F.when(F.col("mapping_count") > 1, F.lit(None))
             .otherwise(F.col("value_as_number"))
        )
        # Truncate the four free-text columns to ≤ 48 chars
        .withColumn("value_as_string",          truncate("value_as_string"))
        .withColumn("observation_source_value", truncate("observation_source_value"))
        .withColumn("qualifier_source_value",   truncate("qualifier_source_value"))
        .withColumn("value_source_value",       truncate("value_source_value"))
    )

    window_spec = Window.orderBy(
        "person_id", "observation_datetime", "observation_concept_id",
        F.coalesce(F.col("visit_occurrence_id"), F.lit(0)),
        F.coalesce(F.col("provider_id"), F.lit(0)),
        F.coalesce(F.col("value_as_concept_id"), F.lit(0)),
        F.coalesce(F.col("value_as_string"), F.lit("")),
        F.coalesce(F.col("observation_event_id"), F.lit(0))
    )

    observations_with_id = observations_adjusted \
        .withColumn("observation_id", F.row_number().over(window_spec).cast("bigint"))

    final_df = observations_with_id
    for field in observation_schema.fields:
        if field.name not in final_df.columns:
            final_df = final_df.withColumn(field.name, F.lit(None).cast(field.dataType))

    result = final_df.select(observation_schema.fieldNames()) \
                     .dropDuplicates(["observation_id"])


    return result.repartition(200, "person_id")

In [0]:


death_schema = StructType([
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the deceased Person."}),
    StructField("death_date", DateType(), False,
                metadata={"comment": "The date the person was deceased."}),
    StructField("death_datetime", TimestampType(), True,
                metadata={"comment": "The date and time the person was deceased."}),
    StructField("death_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting how the death was represented in the source data."}),
    StructField("cause_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the cause of death"}),
    StructField("cause_source_value", StringType(), True,
                metadata={"comment": "The source code for the cause of death as it appears in the source data."}),
    StructField("cause_source_concept_id", IntegerType(), True,
                metadata={"comment": "A foreign key to a concept that refers to the code used in the source."})
])


@dlt.table(
    name="omop_death",
    comment="OMOP CDM Death table - Contains the clinical event for how and when a Person dies",
    schema=death_schema,
    table_properties={"quality": "gold"}
)

def create_omop_death():
    """
    Creates the OMOP Death table from source death data.
    Implements data quality checks and handles date constraints.
    """
    death_data = spark.table("4_prod.bronze.map_death").alias("d")
    
    # Get valid person_ids
    valid_persons = dlt.read("omop_person").select(
        col("person_id").alias("valid_person_id")
    ).distinct()
    
    # Transform to OMOP format with validation
    omop_death = (death_data
        .join(valid_persons, 
              col("d.PERSON_ID") == col("valid_person_id"),
              "inner")
        .select(
            col("d.PERSON_ID").cast("bigint").alias("person_id"),

            coalesce(
                col("DECEASED_DT_TM"),
                col("CALC_DEATH_DATE")
            ).cast("date").alias("death_date"),
            

            col("CALC_DEATH_DATE").cast("timestamp")
                .alias("death_datetime"),
            
            lit(32817).cast("integer")
                .alias("death_type_concept_id"),  
            

            lit(0).cast("integer").alias("cause_concept_id"),
            
            # Store the source information
            concat_ws(" - ", 
                     col("DECEASED_SOURCE_DESC"),
                     col("DECEASED_METHOD_DESC")
            ).alias("cause_source_value"),
            
            lit(0).cast("integer").alias("cause_source_concept_id")
        ))
    
    # Apply additional date validation
    return (omop_death
        .where(coalesce(
            col("DECEASED_DT_TM"),
            col("CALC_DEATH_DATE")
        ).isNotNull())
        # Ensure death_date doesn't exceed current date
        .withColumn("death_date",
            when(col("death_date") > current_date(), current_date())
            .otherwise(col("death_date")))
        # Align death_datetime with death_date if needed
        .withColumn("death_datetime",
            when(col("death_datetime") > current_timestamp(), 
                 to_timestamp(col("death_date")))
            .otherwise(col("death_datetime"))))


In [0]:


drug_era_schema = StructType([
    StructField("drug_era_id", IntegerType(), False,
                metadata={"comment": "A unique identifier for each drug era."}),
    StructField("person_id", IntegerType(), False,
                metadata={"comment": "A foreign key identifier to the person who is subjected to the drug during the drug era."}),
    StructField("drug_concept_id", IntegerType(), False, 
                metadata={"comment": "A foreign key that refers to a standard concept identifier in the Vocabulary for the drug concept."}),
    StructField("drug_era_start_date", DateType(), False,
                metadata={"comment": "The start date for the drug era constructed from the individual instances of drug exposures. It is the start date of the very first chronologically recorded instance of utilization of a drug."}),
    StructField("drug_era_end_date", DateType(), False,
                metadata={"comment": "The end date for the drug era constructed from the individual instance of drug exposures. It is the end date of the final continuously recorded instance of utilization of a drug."}),  
    StructField("drug_exposure_count", IntegerType(), True,
                metadata={"comment": "The number of individual drug exposure occurrences used to construct the drug era."}),
    StructField("gap_days", IntegerType(), True, 
                metadata={"comment": "The number of days that separates two drugs that are adjacent to each other, if there is a gap of more than 30 days between two drug eras, then they are considered two separate eras."})
])

@dlt.table(
    name="valid_ingredient_concepts",
    comment="Valid ingredient concepts from OMOP vocabulary",
    temporary=True
)
def get_valid_ingredient_concepts():
    """
    Gets valid ingredient concepts from the OMOP vocabulary,
    including both "Ingredient" and "Multiple Ingredients" concept classes.
    """
    return (spark.table("3_lookup.omop.concept")
            .filter(col("invalid_reason").isNull())
            .filter(col("concept_class_id").isin(["Ingredient", "Multiple Ingredients"]))
            .select("concept_id")
            .distinct())

@dlt.table(
    name="ingredient_relationships",
    comment="Ingredient level concepts and their relationships",
    temporary=True
)
def get_ingredient_relationships():
    """Gets ingredient level concepts and their descendants"""
    valid_ingredients = dlt.read("valid_ingredient_concepts")
    
    return (spark.table("3_lookup.omop.concept_ancestor")
            .join(valid_ingredients,
                  col("ancestor_concept_id") == col("concept_id"))
            .select("ancestor_concept_id", "descendant_concept_id"))

@dlt.table(
    name="drug_exposures_with_ingredients",
    comment="Drug exposures mapped to ingredient level",
    temporary=True
)
def create_drug_exposures_with_ingredients():
    """Maps drug exposures to their ingredient level concepts"""
    # Get valid drug exposures
    drug_exposures = (dlt.read("omop_drug_exposure")
                     .filter(col("drug_concept_id").isNotNull() &
                            (col("drug_concept_id") > 0)))
    
    # Get valid persons
    valid_persons = dlt.read("omop_person").select(
        col("person_id").alias("valid_person_id")
    ).distinct()
    
    # Join with ingredients
    ingredient_concepts = dlt.read("ingredient_relationships")
    valid_ingredients = dlt.read("valid_ingredient_concepts")
    
    return (drug_exposures.alias("de")
        .join(valid_persons,
              col("de.person_id") == col("valid_person_id"),
              "inner")
        .join(ingredient_concepts,
              col("de.drug_concept_id") == col("descendant_concept_id"),
              "left")
        .select(
            col("de.person_id"),
            coalesce(col("ancestor_concept_id"), col("de.drug_concept_id"))
                .alias("drug_concept_id"),
            col("de.drug_exposure_start_date")
        )
        .join(valid_ingredients,
              col("drug_concept_id") == col("concept_id"))
        .select(
            "person_id",
            "drug_concept_id",
            "drug_exposure_start_date"
        ))

@dlt.table(
    name="drug_exposure_periods",
    comment="Drug exposure periods with gap analysis",
    temporary=True
)
def create_drug_exposure_periods():
    """
    Creates drug exposure periods by analyzing gaps between exposures
    """
    exposures = dlt.read("drug_exposures_with_ingredients")
    
    # Window for ordering exposures by person and drug
    window_spec = Window.partitionBy("person_id", "drug_concept_id") \
                       .orderBy("drug_exposure_start_date")
    
    # Calculate gaps between exposures
    exposures_with_gaps = (exposures
        .withColumn(
            "next_start_date", 
            lead("drug_exposure_start_date").over(window_spec))
        .withColumn(
            "gap_days",
            when(col("next_start_date").isNotNull(),
                 datediff(col("next_start_date"), 
                         col("drug_exposure_start_date")))
            .otherwise(0)))
    
    # Determine era groups based on gaps
    return (exposures_with_gaps
        .withColumn(
            "era_group",
            sum(
                when(col("gap_days") > 30, 1)
                .otherwise(0)
            ).over(window_spec)))

@dlt.table(
    name="omop_drug_era",
    comment="OMOP CDM Drug Era table - Contains records of the span of time when the Person is assumed to be exposed to a particular active ingredient",
    schema=drug_era_schema,
    table_properties={"quality": "gold"}
)
def create_omop_drug_era():
    """
    Creates the final drug era table by combining exposures into continuous periods
    """
    # Get drug exposure periods
    drug_periods = dlt.read("drug_exposure_periods")
    
    # Calculate era dates and metrics
    drug_eras = (drug_periods
        .groupBy("person_id", "drug_concept_id", "era_group")
        .agg(
            min("drug_exposure_start_date").alias("drug_era_start_date"),
            max("drug_exposure_start_date").alias("last_exposure_date"),
            count("*").alias("drug_exposure_count"),
            sum(when(col("gap_days") <= 30, col("gap_days"))
                .otherwise(0)).alias("gap_days")
        )
        .select(
            "person_id",
            "drug_concept_id",
            "drug_era_start_date",
            date_add(col("last_exposure_date"), 30)
                .alias("drug_era_end_date"),
            "drug_exposure_count",
            "gap_days"
        ))
    
    # Add unique drug_era_id
    window_spec = Window.orderBy(
        "person_id", 
        "drug_concept_id", 
        "drug_era_start_date"
    )
    
    return (drug_eras
        .withColumn("drug_era_id", 
                   row_number().over(window_spec))
        .select(
            col("drug_era_id").cast("integer"),
            col("person_id").cast("integer"),
            col("drug_concept_id").cast("integer"),
            col("drug_era_start_date").cast("date"),
            col("drug_era_end_date").cast("date"),
            col("drug_exposure_count").cast("integer"),
            col("gap_days").cast("integer")
        ))


In [0]:

# Define the dose_era schema
dose_era_schema = StructType([
    StructField("dose_era_id", LongType(), False,
                metadata={"comment": "A unique identifier for each Dose Era."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person who is subjected to the drug during the drug era."}),
    StructField("drug_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a Standard Concept identifier for the active Ingredient Concept."}),
    StructField("unit_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a Standard Concept identifier for the unit concept."}),
    StructField("dose_value", FloatType(), False,
                metadata={"comment": "The numeric value of the daily dose."}),
    StructField("dose_era_start_date", DateType(), False,
                metadata={"comment": "The start date for the drug era constructed from the individual instances of drug exposures."}),
    StructField("dose_era_end_date", DateType(), False,
                metadata={"comment": "The end date for the drug era constructed from the individual instance of drug exposures."})
])



@dlt.table(
    name="base_dose_data",
    comment="Extracted dosage information from medication administration",
    temporary=True
)
def create_base_dose_data():
    """
    Extracts dose information from medication administration data
    """
    # Constants for OMOP concept IDs
    MILLIGRAM_CONCEPT_ID = 8576 
    MILLILITER_CONCEPT_ID = 8587
    
    # Load the medication administration data
    med_admin_df = spark.table("4_prod.bronze.map_med_admin").filter(col("event_type_display") == "Administered")
    
    # Extract and standardize dosage information
    return (med_admin_df
           .select(
               col("PERSON_ID").cast("long").alias("person_id"),
               col("OMOP_CONCEPT_ID").cast("int").alias("drug_concept_id"),
               to_date(col("ADMIN_START_DT_TM")).alias("start_date"),
               to_date(col("ADMIN_END_DT_TM")).alias("end_date"),
               when(col("DOSE_IN_MG").isNotNull(), col("DOSE_IN_MG"))
               .when(col("DOSE_IN_ML").isNotNull(), col("DOSE_IN_ML"))
               .otherwise(None).alias("dose_value"),
               when(col("DOSE_IN_MG").isNotNull(), lit(MILLIGRAM_CONCEPT_ID))
               .when(col("DOSE_IN_ML").isNotNull(), lit(MILLILITER_CONCEPT_ID))
               .otherwise(None).alias("unit_concept_id")
           )
           .filter(col("dose_value").isNotNull())
           .filter(col("drug_concept_id").isNotNull())
           .filter(col("start_date").isNotNull())
           .filter(col("end_date").isNotNull())
           )

@dlt.table(
    name="ingredient_dose_data",
    comment="Dose information at the ingredient level",
    temporary=True
)
def get_ingredient_dose_data():
    """
    Maps drug concepts to their ingredient concepts
    """
    # Get base dose data and valid ingredients
    dose_data = dlt.read("base_dose_data")
    valid_ingredients = dlt.read("valid_ingredient_concepts")
    
    # Join to get ingredient-level dose data
    return (dose_data
           .join(valid_ingredients, 
                 dose_data["drug_concept_id"] == valid_ingredients["concept_id"], 
                 "inner")
           .drop("concept_id"))

@dlt.table(
    name="daily_dose_data",
    comment="Daily dose calculations",
    temporary=True
)
def calculate_daily_dose():
    """
    Calculates the daily dose for each medication administration
    """
    ingredient_df = dlt.read("ingredient_dose_data")
    
    return (ingredient_df
           .withColumn("days_exposure", 
                     when(datediff(col("end_date"), col("start_date")) <= 0, lit(1))
                     .otherwise(datediff(col("end_date"), col("start_date"))))
           .withColumn("daily_dose", col("dose_value") / col("days_exposure"))
           .drop("days_exposure")
           )

@dlt.table(
    name="omop_dose_era",
    comment="OMOP CDM Dose Era table - Contains records of constant dose exposure to a specific ingredient",
    schema=dose_era_schema,
    table_properties={"quality": "gold"}
)

def create_omop_dose_era():
    """
    Creates the final dose era table by identifying periods of continuous
    exposure to the same ingredient at the same daily dose
    """
    # Get daily dose data
    daily_dose_df = dlt.read("daily_dose_data")
    
    # Define parameters for era construction
    GAP_THRESHOLD = 30  # days
    partition_cols = ["person_id", "drug_concept_id", "daily_dose", "unit_concept_id"]
    window_spec = Window.partitionBy(*partition_cols).orderBy("start_date")
    
    # Calculate gaps between drug administrations
    with_gaps = (daily_dose_df
                .withColumn("prev_end_date", lag("end_date", 1).over(window_spec))
                .withColumn("gap", 
                           when(col("prev_end_date").isNotNull(), 
                               datediff(col("start_date"), col("prev_end_date")))
                           .otherwise(lit(0)))
                )
    
    # Flag start of new eras when gap exceeds threshold
    era_flagged = (with_gaps
                  .withColumn("era_start", 
                             when(col("prev_end_date").isNull(), lit(1))
                             .when(col("gap") > GAP_THRESHOLD, lit(1))
                             .otherwise(lit(0)))
                  )
    
    # Assign era numbers
    window_era = Window.partitionBy(*partition_cols).orderBy("start_date").rowsBetween(Window.unboundedPreceding, Window.currentRow)
    era_numbered = (era_flagged
                   .withColumn("era_num", sql_sum("era_start").over(window_era))
                   )
    
    # Get boundaries for each era
    era_boundaries = (era_numbered
                     .groupBy(*partition_cols, "era_num")
                     .agg(
                         sql_min("start_date").alias("dose_era_start_date"),
                         sql_max("end_date").alias("dose_era_end_date"),
                         sql_sum("dose_value").alias("total_dose"),
                         count("*").alias("admin_count")
                     )
                     )
    
    # Set the final dose_value as the daily_dose
    era_with_dose = (era_boundaries
                    .withColumn("dose_value", col("daily_dose"))
                    .drop("daily_dose", "era_num", "total_dose", "admin_count")
                    )
    
    # Generate dose_era_id and final table
    window_id = Window.orderBy("person_id", "drug_concept_id", "dose_era_start_date")
    return (era_with_dose
           .withColumn("dose_era_id", row_number().over(window_id).cast(LongType()))
           .select(
               col("dose_era_id").cast(LongType()),
               col("person_id").cast(LongType()),
               col("drug_concept_id").cast(IntegerType()),
               col("unit_concept_id").cast(IntegerType()),
               col("dose_value").cast(FloatType()),
               col("dose_era_start_date").cast(DateType()),
               col("dose_era_end_date").cast(DateType())
           )
           .na.fill(0, ["dose_value", "unit_concept_id", "drug_concept_id"])
           )


In [0]:


# Data quality rules for condition eras
condition_era_schema = StructType([
    StructField("condition_era_id", IntegerType(), False,
                metadata={"comment": "A unique identifier for each Condition Era."}),
    StructField("person_id", IntegerType(), False,
                metadata={"comment": "A foreign key identifier to the Person who is experiencing the condition during the condition era."}),
    StructField("condition_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key that refers to a Standard Condition Concept identifier in the Standardized Vocabularies."}),
    StructField("condition_era_start_date", DateType(), False,
                metadata={"comment": "The start date for the Condition Era constructed from the individual instances of Condition Occurrences. It is the start date of the very first chronologically recorded instance of the Condition."}),
    StructField("condition_era_end_date", DateType(), False,
                metadata={"comment": "The end date for the Condition Era constructed from the individual instances of Condition Occurrences. It is the end date of the final continuously recorded instance of the Condition."}),
    StructField("condition_occurrence_count", IntegerType(), True,
                metadata={"comment": "The number of individual Condition Occurrences used to construct the Condition Era."})
])



@dlt.table(
    name="condition_occurrence_ordered",
    comment="Condition occurrences with gap analysis",
    temporary=True
)
def create_condition_occurrence_ordered():
    """
    Creates ordered condition occurrences with gap analysis
    """
    # Get valid condition occurrences
    condition_occurrences = (dlt.read("omop_condition_occurrence")
        .filter(col("condition_concept_id").isNotNull() & 
                (col("condition_concept_id") > 0)))  
    
    # Get valid person_ids
    valid_persons = dlt.read("omop_person").select(
        col("person_id").alias("valid_person_id")
    ).distinct()
    
    # Window for ordering conditions by person and concept
    window_spec = Window.partitionBy("person_id", "condition_concept_id") \
                       .orderBy("condition_start_date")
    
    # Calculate gaps between occurrences
    return (condition_occurrences.alias("co")
        .join(valid_persons,
              col("co.person_id") == col("valid_person_id"),
              "inner")
        .select(
            col("co.person_id"),
            col("condition_concept_id"),
            col("condition_start_date")
        )
        .withColumn(
            "next_start_date", 
            lead("condition_start_date").over(window_spec))
        .withColumn(
            "gap_days",
            when(col("next_start_date").isNotNull(),
                 datediff(col("next_start_date"), 
                         col("condition_start_date")))
            .otherwise(0))
        .withColumn(
            "era_group",
            sum(
                when(col("gap_days") > 30, 1)
                .otherwise(0)
            ).over(window_spec)))

@dlt.table(
    name="omop_condition_era",
    comment="OMOP CDM Condition Era table - Contains records that represent spans of time when a Person is assumed to have a given condition",
    schema=condition_era_schema,
    table_properties={"quality": "gold"}
)

def create_omop_condition_era():
    """
    Creates the final condition era table by combining occurrences into 
    continuous periods
    """
    # Get condition occurrence periods
    condition_periods = dlt.read("condition_occurrence_ordered")
    
    # Calculate era dates and occurrence counts
    condition_eras = (condition_periods
        .groupBy("person_id", "condition_concept_id", "era_group")
        .agg(
            min("condition_start_date").alias("condition_era_start_date"),
            max("condition_start_date").alias("condition_era_end_date"),
            count("*").alias("condition_occurrence_count")
        ))
    
    # Add unique condition_era_id
    window_spec = Window.orderBy(
        "person_id", 
        "condition_concept_id", 
        "condition_era_start_date"
    )
    
    return (condition_eras
        .withColumn("condition_era_id", 
                   row_number().over(window_spec))
        .select(
            col("condition_era_id").cast("integer"),
            col("person_id").cast("integer"),
            col("condition_concept_id").cast("integer"),
            col("condition_era_start_date").cast("date"),
            col("condition_era_end_date").cast("date"),
            col("condition_occurrence_count").cast("integer")
        ))


In [0]:

observation_period_schema = StructType([
    StructField("observation_period_id", IntegerType(), False, 
                metadata={"comment": "A unique identifier for each observation period."}),
    StructField("person_id", LongType(), False,
                metadata={"comment": "A foreign key identifier to the Person for whom the observation period is defined."}),
    StructField("observation_period_start_date", DateType(), False,
                metadata={"comment": "The start date of the observation period for which data are available from the data source."}),
    StructField("observation_period_end_date", DateType(), False,
                metadata={"comment": "The end date of the observation period for which data are available from the data source."}),
    StructField("period_type_concept_id", IntegerType(), False,
                metadata={"comment": "A foreign key to the predefined concept identifier in the Standardized Vocabularies reflecting the source of the observation period information."})
])


@dlt.table(
    name="condition_dates",
    comment="Condition event dates",
    temporary=True
)
def get_condition_dates():
    """Gets condition dates for observation periods"""
    conditions = spark.table("4_prod.bronze.map_diagnosis")
    valid_persons = dlt.read("omop_person")
    
    return (conditions.alias("c")
        .join(valid_persons.alias("p"), "person_id")
        .select(
            "person_id",
            col("diag_dt_tm").alias("start_date"),
            col("diag_dt_tm").alias("end_date")
        ))

@dlt.table(
    name="drug_dates",
    comment="Drug administration dates",
    temporary=True
)
def get_drug_dates():
    """Gets drug administration dates for observation periods"""
    drugs = spark.table("4_prod.bronze.map_med_admin")
    valid_persons = dlt.read("omop_person")
    
    return (drugs.alias("d")
        .join(valid_persons.alias("p"), "person_id")
        .select(
            "person_id",
            col("admin_start_dt_tm").alias("start_date"),
            col("admin_end_dt_tm").alias("end_date")
        ))

@dlt.table(
    name="visit_dates",
    comment="Visit dates",
    temporary=True
)
def get_visit_dates():
    """Gets visit dates for observation periods"""
    visits = spark.table("4_prod.bronze.map_encounter")
    valid_persons = dlt.read("omop_person")
    
    return (visits.alias("v")
        .join(valid_persons.alias("p"), "person_id")
        .select(
            "person_id",
            col("arrive_dt_tm").alias("start_date"),
            col("depart_dt_tm").alias("end_date")
        ))

@dlt.table(
    name="combined_observation_dates",
    comment="Combined clinical event dates",
    temporary=True
)
def create_combined_observation_dates():
    """Combines all clinical event dates"""
    condition_dates = dlt.read("condition_dates")
    drug_dates = dlt.read("drug_dates")
    visit_dates = dlt.read("visit_dates")
    
    return (condition_dates
        .unionAll(drug_dates)
        .unionAll(visit_dates))

@dlt.table(
    name="omop_observation_period",
    comment="OMOP CDM Observation Period table - Contains records which define spans of time during which clinical events are recorded for a Person",
    schema=observation_period_schema,
    table_properties={"quality": "gold"}
)
def create_omop_observation_period():
    """
    Creates the observation period table identifying continuous
    periods of clinical activity for each person
    """
    # Get all observation dates
    cond = dlt.read("omop_condition_occurrence") \
              .select("person_id",
                      col("condition_start_date").alias("start_date"),
                      coalesce(col("condition_end_date"),
                               col("condition_start_date")).alias("end_date"))

    drug = dlt.read("omop_drug_exposure") \
              .select("person_id",
                      col("drug_exposure_start_date").alias("start_date"),
                      coalesce(col("drug_exposure_end_date"),
                               col("drug_exposure_start_date")).alias("end_date"))

    visit = dlt.read("omop_visit_occurrence") \
               .select("person_id",
                       col("visit_start_date").alias("start_date"),
                       col("visit_end_date"  ).alias("end_date"))

    event_dates = (cond.unionAll(drug).unionAll(visit)
                        .repartition(200, "person_id"))

    # --- 2. person birth & death -----------------------------------------
    birth = (dlt.read("omop_person")
             .select("person_id",
                     to_date(concat_ws("-",col("year_of_birth"),lit("01"),lit("01")))
                         .alias("birth_date")))

    death = dlt.read("omop_death").select("person_id","death_date")

    # --- 3. first / last clinical activity -------------------------------
    periods = (event_dates.groupBy("person_id")
               .agg( F.min("start_date").alias("first_evt"),
                     F.max("end_date"  ).alias("last_evt")) )

    # --- 4. apply birth‑ & death‑date rules ------------------------------
    constrained = (periods
        .join(birth,"person_id")
        .join(death,"person_id","left")
        .select(
            "person_id",
            when(col("first_evt") < col("birth_date"),
                 col("birth_date")).otherwise(col("first_evt"))
                 .alias("start_date"),
            when(col("death_date").isNotNull(),
                 least(col("last_evt"), col("death_date")))
            .otherwise(least(col("last_evt"), current_date()))
                 .alias("end_date"))
        .withColumn("start_date",
                    when(col("end_date") < col("start_date"),
                         col("end_date")).otherwise(col("start_date")))
        .repartition(200, "person_id")
    )

    # --- 5. add surrogate key -------------------------------------------
    w_id = Window.orderBy("person_id")
    return (constrained
            .withColumn("observation_period_id",
                        row_number().over(w_id).cast("int"))
            .selectExpr(
                "observation_period_id",
                "person_id",
                "start_date  as observation_period_start_date",
                "end_date    as observation_period_end_date",
                "cast(32817 as int) as period_type_concept_id"))
