## Assignment 3
Group Members: Janel Joson (jmjoson) and Yu Fang Ma (yfm)

## Final Assignment Overview: Working with Patient Records and Encounter Notes

In this final assignment, we’ll focus on patient records related to COVID-19 encounters. Our task is to analyze, process, and transform the data while applying the concepts we’ve covered throughout this course. Here's a detailed breakdown of the assignment:

What Are Encounter Notes?
An encounter note is a record that captures details about a patient’s visit with a doctor. It includes both structured and semi-structured information that is crucial for understanding the context of the visit. Here’s what an encounter note typically looks like:

```
AMBULATORY ENCOUNTER NOTE
Date of Service: March 2, 2020 15:45-16:30

DEMOGRAPHICS:
Name: Jeffrey Greenfelder
DOB: 1/16/2005
Gender: Male
Address: 428 Wiza Glen Unit 91, Springfield, Massachusetts 01104
Insurance: Guardian
MRN: 055ae6fc-7e18-4a39-8058-64082ca6d515

PERTINENT MEDICAL HISTORY:
- Obesity 

Recent Visit: Well child visit (2/23/2020)
Immunizations: Influenza vaccine (2/23/2020)

Recent Baseline (2/23/2020):
Height: 155.0 cm
Weight: 81.2 kg
BMI: 33.8 kg/m² (99.1th percentile)
BP: 123/80 mmHg
HR: 92/min
RR: 13/min

SUBJECTIVE:
Adolescent patient presents with multiple symptoms including:
- Cough
- Sore throat
- Severe fatigue
- Muscle pain
- Joint pain
- Fever
Never smoker. Symptoms began recently.

OBJECTIVE:
Vitals:
Temperature: 39.3°C (102.7°F)
Heart Rate: 131.1/min
Blood Pressure: 120/73 mmHg
Respiratory Rate: 27.6/min
O2 Saturation: 75.8% on room air
Weight: 81.2 kg

Laboratory/Testing:
Comprehensive Respiratory Panel:
- Influenza A RNA: Negative
- Influenza B RNA: Negative
- RSV RNA: Negative
- Parainfluenza virus 1,2,3 RNA: Negative
- Rhinovirus RNA: Negative
- Human metapneumovirus RNA: Negative
- Adenovirus DNA: Negative
- SARS-CoV-2 RNA: Positive

ASSESSMENT:
1. Suspected COVID-19 with severe symptoms
2. Severe hypoxemia requiring immediate intervention
3. Tachycardia (HR 131)
4. High-grade fever
5. Risk factors:
   - Obesity (BMI 33.8)
   - Adolescent age

PLAN:
1. Face mask provided for immediate oxygen support
2. Infectious disease care plan initiated
3. Close monitoring required due to:
   - Severe hypoxemia
   - Tachycardia
   - Age and obesity risk factors
4. Parent/patient education on:
   - Home isolation protocols
   - Warning signs requiring emergency care
   - Return precautions
5. Follow-up plan:
   - Daily monitoring during acute phase
   - Virtual check-ins as needed

Encounter Duration: 45 minutes
Encounter Type: Ambulatory
Provider: ID# e2c226c2-3e1e-3d0b-b997-ce9544c10528
Facility: 5103c940-0c08-392f-95cd-446e0cea042a
```


The enocuter contains

* General encounter information: 

  * When the encounter took place: Date and time of the visit.
  * Demographics: Patient’s age, gender, and unique medical record identifier.
  * Encounter details: The reason for the visit, diagnosis, and any associated costs.


* Semi-Structured Notes:

These notes mirror how doctors organize their thoughts and observations during an encounter. They generally follow a SOAP format:

* Subjective: The patient’s subjective description of their symptoms, feelings, and medical concerns.
* Objective: The doctor’s objective findings, including test results, measurements, or physical examination outcomes.
* Assessment: The doctor’s evaluation or diagnosis based on subjective and objective information.
* Plan: The proposed treatment plan, including medications, follow-ups, or other interventions.

While some encounter notes might include additional details, the majority conform to this semi-structured format, making them ideal for analysis and transformation.

* Goals for the Assignment

1. Transforming Encounter Notes:

Using an LLM to convert semi-structured encounter notes into a JSON format that organizes the information into structured fields. The JSON will include details such as demographics, encounter specifics, and the SOAP components of the note. Subsequently, you will need to transform the JSON data into a Parquet file, which is not only suitable for analysis in Spark but also ideal for storage later.
Here we will use the ML classificaition to assing the objective and assessment semi-structured fields into standardized, structured fields. The medical taxonomy for this task will be the one provided by the CDC, which defines standard codes for diagnoses, symptoms, procedures, and treatments. This step ensures the structured data aligns with domain-wide medical standards, making it interoperable and ready for deeper analysis.

The JSON format should capture the hierachies described in the structure below. 




2. Basic Analytics and Visualizations:
Using Apache Spark, perform comprehensive data analysis on the encounter data and create visualizations that reveal meaningful patterns. Your analysis must include:
- COVID-19 Case Demographics: Case breakdown by age ranges ([0-5], [6-10], [11-17], [18-30], [31-50], [51-70], [71+])
- Cumulative case count of Covid between the earliest case observed in the dataset and last case observed
- Symptoms for all COVID-19 patients versus patients that admitted into the intensive care unit due to COVID.
- Rank medications by frequency of prescription
- Analyze medication patterns across different demographic groups (e.g., top 3 per age group)
- Identify and plot co-morbidity information from the patient records (e.g., hypertension, obesity, prediabetes, etc.) provided in the dataset. 
- An independent group analysis: You need to develop and execute THREE original analyses that provide meaningful insights about COVID-19 patterns in this dataset. For each analysis:
  - Clearly state your analytical question/hypothesis
  - Justify why this analysis is valuable
  - Show your Spark code and methodology
  - Present results with appropriate visualizations


EncounterType:
    code
    description

Encounter:
    -- id --
    date
    time
    -- type: EncounterType -- 
    provider_id
    facility_id

Address:
    city
    state

Demographics:
    -- id -- 
    name
    date_of_birth
    age
    gender
    address: Address
    insurance

Condition:
    code
    description

Medication:
    code
    description

Immunization:
    code
    description
    date: date

VitalMeasurement:
    code
    value: float
    unit

BloodPressure:
    systolic: VitalMeasurement
    diastolic: VitalMeasurement

CurrentVitals:
    temperature: VitalMeasurement
    heart_rate: VitalMeasurement
    blood_pressure: BloodPressure
    respiratory_rate: VitalMeasurement
    oxygen_saturation: VitalMeasurement
    weight: VitalMeasurement

BaselineVitals:
    date: date
    height: VitalMeasurement
    weight: VitalMeasurement
    bmi: VitalMeasurement
    bmi_percentile: VitalMeasurement

Vitals:
    current: CurrentVitals
    baseline: BaselineVitals

RespiratoryTest:
    code
    result

RespiratoryPanel:
    influenza_a: RespiratoryTest
    influenza_b: RespiratoryTest
    rsv: RespiratoryTest
    parainfluenza_1: RespiratoryTest
    parainfluenza_2: RespiratoryTest
    parainfluenza_3: RespiratoryTest
    rhinovirus: RespiratoryTest
    metapneumovirus: RespiratoryTest
    adenovirus: RespiratoryTest

Covid19Test:
    code
    description
    result

Laboratory:
    covid19: Covid19Test
    respiratory_panel: RespiratoryPanel

Procedure:
    code
    description
    date: date
    reasonCode
    reasonDescription

CarePlan:
    -- id -- 
    code
    description
    start: date
    stop: date
    reasonCode
    reasonDescription

PatientRecord:
    encounter: Encounter
    demographics: Demographics
    conditions: List[Condition]
    medications: List[Medication]
    immunizations: List[Immunization]
    vitals: Vitals
    laboratory: Laboratory
    procedures: List[Procedure]


In [None]:
pip install pydantic


In [None]:
# Using provided schema
from typing import List, Optional
from pydantic import BaseModel
from datetime import date

class EncounterType(BaseModel):
    code: str
    description: str

class Encounter(BaseModel):
    date: str
    time: str
    provider_id: str
    facility_id: str

class Address(BaseModel):
    city: str
    state: str

class Demographics(BaseModel):
    name: str
    date_of_birth: str
    age: int
    gender: str
    address: Address
    insurance: str

class Condition(BaseModel):
    code: str
    description: str


class Medication(BaseModel):
    code: str
    description: str

class Immunization(BaseModel):
    code: str
    description: str
    date: str

class VitalMeasurement(BaseModel):
    code: str
    value: float
    unit: str

class BloodPressure(BaseModel):
    systolic: Optional[VitalMeasurement]
    diastolic: Optional[VitalMeasurement]

class CurrentVitals(BaseModel):
    temperature: Optional[VitalMeasurement] = None
    heart_rate: Optional[VitalMeasurement] = None
    blood_pressure: Optional[BloodPressure] = None
    respiratory_rate: Optional[VitalMeasurement] = None
    oxygen_saturation: Optional[VitalMeasurement] = None
    weight: Optional[VitalMeasurement] = None

class BaselineVitals(BaseModel):
    date: str
    height: Optional[VitalMeasurement] = None
    weight: Optional[VitalMeasurement] = None
    bmi: Optional[VitalMeasurement] = None
    bmi_percentile: Optional[VitalMeasurement] = None

class Vitals(BaseModel):
    current: Optional[CurrentVitals]
    baseline: Optional[BaselineVitals]

class RespiratoryTest(BaseModel):
    code: str
    result: str

class RespiratoryPanel(BaseModel):
    influenza_a: Optional[RespiratoryTest] = None
    influenza_b: Optional[RespiratoryTest] = None
    rsv: Optional[RespiratoryTest] = None
    parainfluenza_1: Optional[RespiratoryTest] = None
    parainfluenza_2: Optional[RespiratoryTest] = None
    parainfluenza_3: Optional[RespiratoryTest] = None
    rhinovirus: Optional[RespiratoryTest] = None
    metapneumovirus: Optional[RespiratoryTest] = None
    adenovirus: Optional[RespiratoryTest] = None

class Covid19Test(BaseModel):
    code: str
    description: str
    result: str

class Laboratory(BaseModel):
    covid19: Optional[Covid19Test] = None
    respiratory_panel: Optional[RespiratoryPanel] = None

class Procedure(BaseModel):
    code: str
    description: str
    date: str
    reasonCode: str
    reasonDescription: str

class CarePlan(BaseModel):
    code: str
    description: str
    start: str
    stop: str
    reasonCode: str
    reasonDescription: str

class PatientRecord(BaseModel):
    demographics: Demographics
    encounter: Optional[Encounter] = None
    conditions: Optional[List[Condition]] = None
    medications: Optional[List[Medication]] = None
    immunizations: Optional[List[Immunization]] = None
    vitals: Optional[Vitals] = None
    laboratory: Optional[Laboratory] = None
    procedures: Optional[List[Procedure]] = None

## Part 1: Transforming Encounter Notes
### Transform JSON data (parsed_notes.jsonl.gz) into a Parquet file to use for analysis in Spark

In [None]:
from pyspark.sql import SparkSession, Window

from pyspark.sql.types import (
    StructType, StructField, 
    StringType, IntegerType, FloatType, 
    ArrayType, MapType
)

from pyspark.sql.functions import (
  col, count, desc, explode, array_contains, 
  dense_rank, when, struct, first, max,
  collect_list, size, array_intersect)

import pyspark.sql.functions as F

import matplotlib.pyplot as plt


In [None]:
def convert_pydantic_to_spark_schema(model_class):
    """
    Dynamically convert Pydantic model to Spark StructType
    
    Args:
        model_class (Type[BaseModel]): Pydantic BaseModel class to convert
    
    Returns:
        StructType: Corresponding Spark SQL schema
    """
    def _get_field_type(field_type, is_optional=False):
        """
        Convert Python type annotations to Spark SQL types
        """
        origin = getattr(field_type, '__origin__', None)
        
        # Handle Optional types
        if origin is type(None):
            return None
        
        # List handling
        if origin is list:
            # Get the type of list elements
            element_type = field_type.__args__[0]
            return ArrayType(_get_field_type(element_type))
        
        # Optional handling
        if origin is type(None) or (hasattr(field_type, '__origin__') and 
                                    field_type.__origin__ is type(origin)):
            return None
        
        # Nested Pydantic model handling
        if hasattr(field_type, '__bases__') and \
           len(field_type.__bases__) > 0 and \
           field_type.__bases__[0].__name__ == 'BaseModel':
            return _convert_model_to_struct(field_type)
        
        # Primitive type mapping
        type_map = {
            str: StringType(),
            int: IntegerType(),
            float: FloatType(),
            date: StringType(),  # Convert date to string in Spark
        }
        
        # Direct type match
        if field_type in type_map:
            return type_map[field_type]
        
        # Default to string for unknown types
        return StringType()

    def _convert_model_to_struct(model):
        """
        Convert a Pydantic model to a Spark StructType
        """
        fields = []
        for field_name, field_type in model.__annotations__.items():
            # Check if the field is Optional
            is_optional = 'typing.Optional' in str(field_type)
            
            # Get the actual type (remove Optional wrapper)
            if is_optional and hasattr(field_type, '__args__'):
                field_type = field_type.__args__[0]
            
            # Convert type
            spark_type = _get_field_type(field_type)
            
            # Skip None types
            if spark_type is None:
                continue
            
            fields.append(StructField(field_name, spark_type, nullable=is_optional))
        
        return StructType(fields)

    return _convert_model_to_struct(model_class)

In [None]:
def get_patient_record_schema():
    return convert_pydantic_to_spark_schema(PatientRecord)

In [None]:
def convert_json_to_parquet(input_path, output_path):
    """
    Convert JSON to Parquet using dynamically generated schema
    
    Args:
        input_path (str): Path to input JSON file
        output_path (str): Path to output Parquet file
    """
    # Create Spark session
    spark = SparkSession.builder \
        .appName("PatientRecord JSON to Parquet") \
        .getOrCreate()
    
    # Get the dynamically generated schema
    patient_record_schema = get_patient_record_schema()
    
    # Read JSON with explicit schema
    df = spark.read.schema(patient_record_schema) \
        .json(input_path)
    
    # Write to Parquet
    df.write.mode("overwrite").parquet(output_path)
    
    # Optional: Show schema and first few rows for verification
    df.printSchema()
    df.show(5)
    
    return df

In [None]:
df = convert_json_to_parquet("dbfs:/FileStore/parsed_notes_jsonl.gz", "dbfs:/FileStore/notes.parquet")

## Part 2: Basic Analytics and Visualizations

- Rank medications by frequency of prescription
- Analyze medication patterns across different demographic groups (e.g., top 3 per age group)
- Identify and plot co-morbidity information from the patient records (e.g., hypertension, obesity, prediabetes, etc.) provided in the dataset.

In [None]:
# Rank medications by frequency of prescription

def analyze_medication_frequency(df):
    # Explode medications to create a row per medication
    medication_freq = df.filter(col("medications").isNotNull()) \
        .select(explode(col("medications")).alias("medication"))
    
    # Count and rank medications
    medication_ranking = medication_freq \
        .groupBy("medication.code", "medication.description") \
        .agg(count("*").alias("prescription_count")) \
        .orderBy(desc("prescription_count"))
    
    # Show top 20 medications
    print("Top 20 Prescribed Medications:")
    medication_ranking.show(20, truncate=False)

results = analyze_medication_frequency(df)


In [None]:
# Additional advanced analysis
def advanced_medication_analysis(df):
    # Medications by demographic
    medication_by_demographics = df \
        .filter(col("medications").isNotNull()) \
        .select(
            explode(col("medications")).alias("medication"),
            col("demographics.gender"),
            col("demographics.age")
        )
    
    # Medication frequency by gender
    gender_medication_freq = medication_by_demographics \
        .groupBy("medication.description", "gender") \
        .agg(count("*").alias("prescription_count")) \
        .orderBy(desc("prescription_count"))
    
    print("\nMedication Prescription by Gender:")
    gender_medication_freq.show(20, truncate=False)
    
    # Medication frequency by age range
    def categorize_age_range(age):
        return (
            when((age >= 0) & (age <= 17), "0-17")
            .when((age >= 18) & (age <= 50), "18-50")
            .when(age >= 51, "51+")
            .otherwise("Unknown")
        )
    
    age_medication_freq = medication_by_demographics \
        .withColumn("age_range", categorize_age_range(col("age"))) \
        .groupBy("medication.description", "age_range") \
        .agg(count("*").alias("prescription_count")) \
        .orderBy(desc("prescription_count"))
    
    print("\nMedication Prescription by Age Range:")
    age_medication_freq.show(20, truncate=False)
    
    return {
        "gender_medication_freq": gender_medication_freq.collect(),
        "age_medication_freq": age_medication_freq.collect()
    }

additional_analysis = advanced_medication_analysis(df)

In [None]:
# Identify and plot co-morbidity information from the patient records (e.g., hypertension, obesity, prediabetes, etc.) provided in the dataset.

# Prepare conditions data
conditions_df = df.filter(col("conditions").isNotNull())

In [None]:
# Group and count co-morbidities
comorbidity_counts = conditions_df.groupBy("conditions") \
    .agg(count("*").alias("count")) \
    .orderBy(desc("count"))

# Collect results
comorbidity_results = comorbidity_counts.collect()

In [None]:
# Explode the conditions list to flatten the DataFrame
flattened_df = comorbidity_counts.withColumn("condition", explode(col("conditions"))).select(
    col("condition.code").alias("code"),
    col("condition.description").alias("description"),
    col("count")
)

# Group by condition code and description, and sum the counts
result_df = flattened_df.groupBy("code", "description").sum("count").withColumnRenamed("sum(count)", "total_count")

# Sort the results in descending order of total_count
sorted_df = result_df.orderBy(col("total_count").desc())

# Display the sorted results
sorted_df.show(truncate=False)

In [None]:
# COVID-19 Case Demographics: Case breakdown by age ranges ([0-5], [6-10], [11-17], [18-30], [31-50], [51-70], [71+])

from pyspark.sql.functions import when, col

df = convert_json_to_parquet("dbfs:/FileStore/parsed_notes_jsonl.gz", "dbfs:/FileStore/notes.parquet")

new_df = df.select("demographics.age").withColumn("Age Ranges", when(col("age").between(0, 5), "[0-5]").when(col("age").between(6, 10), "[6-10]").when(col("age").between(11, 17), "[11-17]").when(col("age").between(18, 30), "[18-30]").when(col("age").between(31, 50), "[31-50]").when(col("age").between(51, 70), "[51-70]").when(col("age") > 70, "[71+]")).groupBy("Age Ranges").count()

final_df = new_df.withColumn("sort", when(col("Age Ranges") == "[0-5]", 1).when(col("Age Ranges") == "[6-10]", 2).when(col("Age Ranges") == "[11-17]", 3).when(col("Age Ranges") == "[18-30]", 4).when(col("Age Ranges") == "[31-50]", 5).when(col("Age Ranges") == "[51-70]", 6).when(col("Age Ranges") == "[71+]", 7)).orderBy("sort").select("Age Ranges", "count")

final_df.show()

In [None]:
# Cumulative case count of Covid between the earliest case observed in the dataset and last case observed

from pyspark.sql.functions import length

df = convert_json_to_parquet("dbfs:/FileStore/parsed_notes_jsonl.gz", "dbfs:/FileStore/notes.parquet")

new_df_2 = df.select("encounter.date").filter(col("date") !="null").sort("date")
print(new_df_2.count())

In [None]:
## This visualization helps us determine the Oxygen Saturation status for all COVID-19 patients. We create the Oxygen Saturation range based on the standard scale and then count the number of COVID-19 patients who got a value located at these levels and see whether there is a relationship between Oxygen Saturation and COVID-19 encounters. In other words, whether COVID-19 will affect a person's Oxygen Saturation.


from pyspark.sql.functions import when, col

df = convert_json_to_parquet("dbfs:/FileStore/parsed_notes_jsonl.gz", "dbfs:/FileStore/notes.parquet")

df = df.select("vitals.current.oxygen_saturation.value").withColumn("Oxygen Saturation Ranges", when(col("value").between(98, 100), "Normal").when(col("value").between(95, 97), "Insufficient").when(col("value").between(90, 94), "Decreased").when(col("value").between(80, 89), "Critical").when(col("value").between(70, 79), "Severe hypoxia").when(col("value") < 70, "Acute danger to life")).groupBy("Oxygen Saturation Ranges").count()

df = df.filter(df["Oxygen Saturation Ranges"].isNotNull())

df = df.withColumn("sort", when(col("Oxygen Saturation Ranges") == "Normal", 1).when(col("Oxygen Saturation Ranges") == "Insufficient", 2).when(col("Oxygen Saturation Ranges") == "Decreased", 3).when(col("Oxygen Saturation Ranges") == "Critical", 4).when(col("Oxygen Saturation Ranges") == "Severe hypoxia", 5).when(col("Oxygen Saturation Ranges") == "Acute danger to life", 6)).orderBy("sort").select("Oxygen Saturation Ranges", "count")

df.show()

In [None]:
## This visualization can help us determine the relationship between COVID-19 and Weight. We create the weight range based on the standard scale and then count the number of COVID-19 patients who weigh at those levels to see whether there is a relationship between weight and COVID-19 encounters.

df = convert_json_to_parquet("dbfs:/FileStore/parsed_notes_jsonl.gz", "dbfs:/FileStore/notes.parquet")

df = df.select("vitals.current.weight.value").withColumn("Common Weight Ranges (kg)", when(col("value").between(0, 2.5), "0 - 2.5").when(col("value").between(2.6, 10), "2.6 - 10").when(col("value").between(11, 20), "11 - 20").when(col("value").between(21, 40), "21 - 40").when(col("value").between(41, 60), "41 - 60").when(col("value").between(61, 80), "61 - 80").when(col("value").between(81, 100), "81 - 100").when(col("value").between(101, 120), "101 - 120").when(col("value") > 120, "121+")).groupBy("Common Weight Ranges (kg)").count()

df = df.filter(df["Common Weight Ranges (kg)"].isNotNull())

df = df.withColumn("sort", when(col("Common Weight Ranges (kg)") == "0 - 2.5", 1).when(col("Common Weight Ranges (kg)") == "2.6 - 10", 2).when(col("Common Weight Ranges (kg)") == "11 - 20", 3).when(col("Common Weight Ranges (kg)") == "21 - 40", 4).when(col("Common Weight Ranges (kg)") == "41 - 60", 5).when(col("Common Weight Ranges (kg)") == "61 - 80", 6).when(col("Common Weight Ranges (kg)") == "81 - 100", 7).when(col("Common Weight Ranges (kg)") == "101 - 120", 8).when(col("Common Weight Ranges (kg)") == "121+", 9)).orderBy("sort").select("Common Weight Ranges (kg)", "count")

display(df)