In [15]:
! pip install snowflake-connector-python

Python(53068) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.




In [16]:
import snowflake.connector
from datetime import datetime
from huggingface_hub import InferenceClient

In [17]:
# Snowflake Configuration
SNOWFLAKE_CONFIG = {
    "user": "DOLPHIN",
    "password": "Maapaa@1603",
    "account": "URB63596",
    "warehouse": "ANIMAL_TASK_WH",
    "database": "mimic_iv_medi_assist",
    "schema": "PROD_MIMIC",
}

In [18]:
# SQL Queries
UNPROCESSED_RISK_NOTES_QUERY = """
SELECT DIS_RECORD_ID, DIS_NOTE_TEXT
FROM MIMIC_IV_MEDI_ASSIST.PROD_MIMIC.DIM_DISCHARGE
WHERE DIS_NOTE_TEXT IS NOT NULL
  AND DIS_RECORD_ID NOT IN (
      SELECT DISTINCT DIS_RECORD_ID
      FROM MIMIC_IV_MEDI_ASSIST.PROD_MIMIC.FCT_RISK_STRATIFICATION
  )
"""

FETCH_DRG_QUERY = """
SELECT DRG_RECORD_ID, DRG_CODE, DRG_DESCRIPTION, DRG_SEVERITY, DRG_MORTALITY
FROM MIMIC_IV_MEDI_ASSIST.PROD_MIMIC.DIM_DRGCODES
WHERE DRG_CODE = %s
ORDER BY DRG_SEVERITY DESC, DRG_MORTALITY DESC
LIMIT 1
"""

INSERT_RISK_STRATIFICATION_QUERY = """
INSERT INTO MIMIC_IV_MEDI_ASSIST.PROD_MIMIC.FCT_RISK_STRATIFICATION
(FRS_RISK_LEVEL, DIS_RECORD_ID, DRG_RECORD_ID)
VALUES (%s, %s, %s)
"""

In [19]:
# Helper: Connect to Snowflake
def connect_to_snowflake():
    """Establish a connection to Snowflake."""
    print("Connecting to Snowflake...")
    return snowflake.connector.connect(**SNOWFLAKE_CONFIG)

In [20]:
# Helper: Fetch DRG_RECORD_ID
def get_drg_record(cursor, drg_code):
    """
    Fetch the DRG details for a given DRG_CODE from the DIM_DRGCODES table.

    Args:
        cursor (object): Snowflake cursor object.
        drg_code (int): The DRG code to look up.

    Returns:
        dict or None: A dictionary with DRG details, or None if not found.
    """
    try:
        cursor.execute(FETCH_DRG_QUERY, (drg_code,))
        result = cursor.fetchone()
        if result:
            return {
                "DRG_RECORD_ID": result[0],
                "DRG_CODE": result[1],
                "DRG_DESCRIPTION": result[2],
                "DRG_SEVERITY": result[3],
                "DRG_MORTALITY": result[4],
            }
        return None
    except Exception as e:
        print(f"Error fetching DRG details for DRG_CODE {drg_code}: {e}")
        return None

In [21]:
# Helper: Call LLM for Risk Stratification
def call_llm_for_risk_stratification(clinical_note):
    """
    Call an LLM to generate risk stratification from a clinical note.

    Args:
        clinical_note (str): Clinical text to analyze for risk stratification.

    Returns:
        dict: A dictionary containing the risk level, risk category, and DRG code.
    """
    print("Calling LLM for risk stratification...")

    # Initialize the LLM client
    client = InferenceClient(api_key="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")

    # Define the prompt for risk stratification
    prompt_risk_stratification = f"""
        You are a medical assistant specializing in risk stratification.
        Based on the following clinical note, generate:
        - A numerical risk level (0-100), where 0 is the lowest risk and 100 is the highest.
        - A risk category (e.g., "Low", "Moderate", "High", "Critical") based on the score.
        - The DRG code most relevant to this patient's condition.

        Clinical Note:
        {clinical_note}

        Return the output in this exact JSON format:
        {{
            "RiskLevel": [Numerical Risk Level],
            "RiskCategory": "[Category]",
            "DRGCode": [Numerical DRG Code]
        }}
    """

     # LLM API Call
    messages = [{"role": "user", "content": prompt_risk_stratification}]
    completion = client.chat.completions.create(
        model="meta-llama/Llama-3.3-70B-Instruct",
        messages=messages,
        max_tokens=100  # Enough tokens for structured response
    )

    # Parse and return the response
    response = completion.choices[0].message["content"]
    try:
        return eval(response)  # Convert the JSON-like string response to a Python dictionary
    except Exception as e:
        print(f"Error parsing risk stratification response: {e}")
        return {"RiskLevel": None, "RiskCategory": "Unknown", "DRGCode": None}

In [22]:
# Main Workflow
def process_risk_stratification():
    """Process discharge notes for risk stratification and load results into Snowflake."""
    try:
        # Connect to Snowflake
        conn = connect_to_snowflake()
        cursor = conn.cursor()

        print("Fetching discharge notes for risk stratification...")
        cursor.execute(UNPROCESSED_RISK_NOTES_QUERY)
        records = cursor.fetchall()

        print(f"Fetched {len(records)} records for processing.")

        for record_id, clinical_note in records:
            print(f"Processing Record ID: {record_id}")

            # Generate risk stratification details
            risk_data = call_llm_for_risk_stratification(clinical_note)
            risk_level = risk_data.get("RiskLevel")
            drg_code = risk_data.get("DRGCode")

            # Fetch the DRG details for the generated DRG code
            drg_details = get_drg_record(cursor, drg_code)
            drg_record_id = drg_details["DRG_RECORD_ID"] if drg_details else None

            if drg_record_id:
                print(f"Generated Risk Data: Level={risk_level}, DRG_RECORD_ID={drg_record_id}")

                # Insert risk stratification results into Snowflake
                try:
                    cursor.execute(
                        INSERT_RISK_STRATIFICATION_QUERY,
                        (risk_level, record_id, drg_record_id)
                    )
                    conn.commit()
                    print(f"Inserted risk data for Record ID: {record_id}")
                except Exception as e:
                    print(f"Error inserting risk data for Record ID {record_id}: {e}")
            else:
                print(f"No DRG_RECORD_ID found for DRGCode: {drg_code}. Skipping...")

        # Close connections
        cursor.close()
        conn.close()
        print("Processing completed. Connection closed.")

    except Exception as e:
        print(f"Error processing risk stratification: {e}")


In [23]:
# Run the Workflow
if __name__ == "__main__":
    print("Starting discharge note processing for risk stratification...")
    process_risk_stratification()
    print("Discharge note processing for risk stratification finished.")

Starting discharge note processing for risk stratification...
Connecting to Snowflake...
Fetching discharge notes for risk stratification...
Fetched 331793 records for processing.
Processing Record ID: 18950
Calling LLM for risk stratification...
Error parsing risk stratification response: invalid syntax (<string>, line 1)
No DRG_RECORD_ID found for DRGCode: None. Skipping...
Processing Record ID: 22027
Calling LLM for risk stratification...
Error parsing risk stratification response: invalid syntax (<string>, line 1)
No DRG_RECORD_ID found for DRGCode: None. Skipping...
Processing Record ID: 22096
Calling LLM for risk stratification...
Error parsing risk stratification response: invalid syntax (<string>, line 1)
No DRG_RECORD_ID found for DRGCode: None. Skipping...
Processing Record ID: 20910
Calling LLM for risk stratification...
Error parsing risk stratification response: invalid syntax (<string>, line 1)
No DRG_RECORD_ID found for DRGCode: None. Skipping...
Processing Record ID: 18

KeyboardInterrupt: 