# Healthcare Data Processing- Medical coding and extraction of ICD-10 codes

# Import packages

In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Add a query tag to the session.
session.query_tag = {"origin":"sf_sit-is", "name":"Healthcare_Code_Extraction", "version":{"major":1, "minor":0}}

# DATA ENGINEERING
* A file format custom_PDF is created, and an external stage REPORTS_DATA is defined, pointing to an S3 bucket. This will store PDFs to be analyzed.

In [None]:
create or replace file format custom_PDF;
-- Create external stage

CREATE or replace STAGE REPORTS_DATA
    URL='s3://sfquickstarts/sfguide_llm_assisted_medical_coding_extraction_for_healthcare_in_snowflake/'
    DIRECTORY = ( ENABLE = true )
    FILE_FORMAT = custom_PDF;

-- Inspect content of stage
LS @REPORTS_DATA;

### A pdf_text_chunker extracts text from PDFs in REPORTS_DATA using PyPDF2 and langchain's RecursiveCharacterTextSplitter.
* Text is chunked into sections for easier processing, allowing overlap for context.

In [None]:

create or replace function pdf_text_chunker(file_url string)
returns table (chunk varchar)
language python
runtime_version = '3.9'
handler = 'pdf_text_chunker'
packages = ('snowflake-snowpark-python','PyPDF2', 'langchain')
as
$$
from snowflake.snowpark.types import StringType, StructField, StructType
from langchain.text_splitter import RecursiveCharacterTextSplitter
from snowflake.snowpark.files import SnowflakeFile
import PyPDF2, io
import logging
import pandas as pd
class pdf_text_chunker:
    def read_pdf(self, file_url: str) -> str:
        logger = logging.getLogger("udf_logger")
        logger.info(f"Opening file {file_url}")
        with SnowflakeFile.open(file_url, 'rb') as f:
            buffer = io.BytesIO(f.readall())
        reader = PyPDF2.PdfReader(buffer)
        text = ""
        for page in reader.pages:
            try:
                text += page.extract_text().replace('\n', ' ').replace('\0', ' ')
            except:
                text = "Unable to Extract"
                logger.warn(f"Unable to extract from file {file_url}, page {page}")
        return text
    def process(self,file_url: str):
        text = self.read_pdf(file_url)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size = 4000, #Adjust this as you see fit
            chunk_overlap  = 400, #This let's text have some form of overlap. Useful for keeping chunks contextual
            length_function = len
        )
        chunks = text_splitter.split_text(text)
        df = pd.DataFrame(chunks, columns=['chunks'])
        yield from df.itertuples(index=False, name=None)
$$;



In [None]:
select * from directory(@REPORTS_DATA);

In [None]:
create or replace TABLE DOCS_CHUNKS_TABLE (
    RELATIVE_PATH VARCHAR(16777216),
    SIZE NUMBER(38,0),
    FILE_URL VARCHAR(16777216),
    SCOPED_FILE_URL VARCHAR(16777216),
    CHUNK VARCHAR(16777216),
    CHUNK_VEC VECTOR(FLOAT, 768) ); 

In [None]:
CREATE OR REPLACE TABLE REPORTS_DATA_table (
  path string,
  fileurl string,
  SIZE NUMBER(38,0)
);

insert into REPORTS_DATA_table(path,FILEURL,size) select RELATIVE_PATH,FILE_URL,SIZE from directory(@REPORTS_DATA);


## The pdf_text_chunker function is invoke using the scopd file URL and the output is stored in docs_chunks_table with each chunk vectorized by snowflake.cortex.embed_text (768-dimensional vector).


In [None]:
INSERT INTO docs_chunks_table (relative_path,SIZE, file_url, scoped_file_url, chunk, chunk_vec)
SELECT path,
SIZE,
       fileurl,
       BUILD_SCOPED_FILE_URL(@REPORTS_DATA, path) AS scoped_file_url, 
       func.chunk AS chunk,
      snowflake.cortex.embed_text_768('e5-base-v2', func.chunk) AS chunk_vec
FROM REPORTS_DATA_table,TABLE(pdf_text_chunker(BUILD_SCOPED_FILE_URL(@REPORTS_DATA, path))) AS func;


In [None]:
SELECT * FROM docs_chunks_table;

In [None]:
select relative_path, count(*) as num_chunks
    from docs_chunks_table
    group by relative_path
    order by num_chunks desc;

In [None]:
create or replace transient table docs_and_text as
select
    relative_path ,
    listagg(chunk ,' ') as doc_text,
    null as report,
    null as specialty
from docs_chunks_table
where relative_path like '%.pdf'
group by all;

In [None]:
SELECT * FROM docs_and_text LIMIT 5;

In [None]:
select count(*) from docs_and_text;

### Extract the speciality and the report summary

In [None]:
UPDATE docs_and_text as l
SET l.report = r.report
    ,l.specialty = r.specialty
FROM (
    select relative_path,
    SNOWFLAKE.CORTEX.Complete ('mixtral-8x7b', concat(doc_text||
        ' In less than 5 words, how would you best describe the type of the document content? 
        Do not provide explanation. Remove special characters')) as report,
    SNOWFLAKE.CORTEX.EXTRACT_ANSWER (doc_text,
        'What is the medical specialty?')[0]:answer::varchar as specialty,
    from docs_and_text) as r
where l.relative_path = r.relative_path;


In [None]:
SELECT * FROM docs_and_text LIMIT 5;

# DISTILLATION FLOW - USING LLAMA3.1-405B

#### The model llama3.1-405b is used to extract ICD-10 codes from medical documents by prompting Snowflake Cortex to identify relevant codes. Outputs are stored in a table called LLAMA_OUTPUT_ICD.

In [None]:
CREATE or replace table LLAMA_OUTPUT_ICD as
select relative_path,
doc_text,
report,
specialty,
SNOWFLAKE.CORTEX.COMPLETE('llama3.1-405b', concat(doc_text||'Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a python list. Do not provide explanation')) as AI_ICD10_Code
from docs_and_text;


In [None]:
select * from LLAMA_OUTPUT_ICD limit 5;

* Carry 70:30 split for model fine tuning

In [None]:
CREATE OR REPLACE TEMPORARY TABLE temp_split_table AS
WITH numbered_rows AS (
  SELECT *,
         ROW_NUMBER() OVER (ORDER BY RANDOM()) AS row_num,
         COUNT(*) OVER() AS total_rows
  FROM LLAMA_OUTPUT_ICD
)


SELECT *,
       CASE 
         WHEN row_num < total_rows * 0.7 THEN 'train'
         WHEN row_num > total_rows * 0.7 AND row_num <= total_rows * 0.85 THEN 'val'
         ELSE 'test'
       END AS split
FROM numbered_rows;

CREATE or replace TABLE  codeextraction_training AS
SELECT relative_path,
doc_text,
report,
specialty,AI_ICD10_Code
FROM temp_split_table
WHERE split = 'train';

CREATE  or replace TABLE   codeextraction_test AS
SELECT relative_path,
doc_text,
report,
specialty,AI_ICD10_Code
FROM temp_split_table
WHERE split = 'test';

CREATE  or replace TABLE codeextraction_val AS
SELECT relative_path,
doc_text,
report,
specialty,AI_ICD10_Code
FROM temp_split_table
WHERE split = 'val';



In [None]:
select * from codeextraction_training limit 2;

# BASELINE OUTPUT FROM THE SMALLER MODEL llama3-8b (Optional)

In [None]:
create or replace table llama38b_ICDOutput as select relative_path,
doc_text,
report,
specialty,
AI_ICD10_Code,
SNOWFLAKE.CORTEX.COMPLETE('llama3-8b', concat(doc_text||'Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a python list. Do not provide explanation')) as llama38b_ICD10_Code
from LLAMA_OUTPUT_ICD;

In [None]:
SELECT * FROM llama38b_ICDOutput limit 5;

# FINE TUNE llama3-8b

Cortex Fine-tuning allows users to leverage parameter-efficient fine-tuning (PEFT) to create customized adaptors for use with pre-trained models on more specialized tasks. If you don’t want the high cost of training a large model from scratch but need better latency and results than you’re getting from prompt engineering or even retrieval augmented generation (RAG) methods, fine-tuning an existing large model is an option. Fine-tuning allows you to use examples to adjust the behavior of the model and improve the model’s knowledge of domain-specific tasks.

In [None]:
SELECT SNOWFLAKE.CORTEX.FINETUNE(
    'CREATE', 
    -- Custom model name, make sure name below is unique
    'FINETUNE_llama38b_ICDCODES',
    -- Base model name
    'llama3-8b',
    -- Training data query
    'SELECT doc_text || '' Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a python list. Do not provide explanation '' AS PROMPT, AI_ICD10_Code AS COMPLETION FROM codeextraction_training',
    -- Test data query 
    'SELECT doc_text || '' Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a python list. Do not provide explanation '' AS PROMPT, AI_ICD10_Code AS COMPLETION FROM codeextraction_val'
);


In [None]:
-- The output is the job ID of the fine-tuning job:
Select SNOWFLAKE.CORTEX.FINETUNE(
  'DESCRIBE',
'CortexFineTuningWorkflow_5f63e53e-0c49-4af7-93c5-82675b433629');--replace <> with the workflow id returned from the execution of last cell

# STOP -  PROCEED ONLY WHEN THE STATUS FIELD FOR THE JOB CHANGES TO SUCCESS IN THE PREVIOUS CELL

# INFERENCE USING FINE TUNED MODEL (AFTER ENSURING THE MOEL TUNING WAS SUCCESSFUL)

In [None]:
create or replace TABLE  llama38b_ICD_Codes as select relative_path,
doc_text,
report,
specialty,
SNOWFLAKE.CORTEX.COMPLETE('FINETUNE_llama38b_ICDCODES', concat(doc_text||'Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a python list. Do not provide explanation')) as FT_ICD10_Code
from codeextraction_training;

# Benefits of using the Fine Tuned Model:

## 1. Higher Accuracy as that of Larger Model -> HIGH ACCURACY
## 2. Smaller Model -> LOWER COST
## 3. Higher Throughput -> Smaller Model -> HIGH THROUGHPUT

## Extract codes from Fine Tuned Model

In [None]:
llama38b_ICD_Code_FT_df=session.table('llama38b_ICD_Codes').to_pandas()
llama38b_ICD_Code_FT_df.head(50)

### SIMPLE INTERACTIVE APP FOR USING THE ICD CODE FOR CALCULATING PATIENT RISK SCORE

In [None]:
# Import python packages
import streamlit as st
from snowflake.snowpark.context import get_active_session
import pandas as pd
import re
import ast




st.title('❄️ Medical Coding Assistant ❄️')
st.subheader(
    """Calculate the risk score accurately by leveraging ICD10 Codes extracted by Fine Tuning a Llama3 with Cortex AI
    """
)

# Get the current credentials
session = get_active_session()



patient_data = {
    'patient_id': [1, 2, 3, 4, 5, 6, 7],
    'name': ['John Doe', 'Jane Smith', 'Alice Johnson', 'Bob Brown', 
             'Charlie Williams', 'Emily Davis', 'Frank Harris'],
    'age': [45, 62, 30, 50, 40, 55, 38],
    'icd_code': [
        'D66.0',
    'C11.9',
    'H04.89',
    'N62.9',
    'N63.9',
    'I25.5', 
    'Z95.1'
    ]
}


patient_df = pd.DataFrame(patient_data)
icd_explanations = {
    'M21.9': 'Malignant neoplasm of the upper lobe of the bronchus.',
    'M25.5': 'Other specified diseases of the kidney.',
    'C11.9': 'unspecified malignant neoplasm of the nasopharynx',
    'H04.89': 'disorders of the lacrimal system',
    'N62.9': 'Hypertrophy of breast',
    'N63.9': 'unspecified lump in the breast',
    'I25.5': 'ischemic cardiomyopathy', 
    'Z95.1': 'presence of an aortocoronary bypass graft', 
    'Z95.2': 'presence of a prosthetic heart valve', 
    'D66.0':'hereditary factor VIII deficiency'
    
}
reports_df = session.table("llama38b_ICD_Codes")
reports_df =reports_df.to_pandas()

def display_icd_code_with_explanation(icd_code):
    explanation = icd_explanations.get(icd_code, 'Explanation not available')
    return f"ICD Code: {icd_code} - {explanation}"

def calculate_risk_score(icd_code):
    risk_scores = {
        'M21.9': 8.0,
        'M25.5': 5.5,
        'C11.9': 6.8,
        'H04.89': 9.0,
        'N62.9': 8.4,
        'N63.9': 9.0,
        'I25.5': 8.6, 
        'Z95.1': 8.9, 
        'Z95.2': 9.5, 
        'D66.0': 5.0
    }
    
    # If it's a single code, return its score
    if isinstance(icd_code, str):
        return risk_scores.get(icd_code, 5.0)
    
    elif isinstance(icd_code, list):
        scores = [risk_scores.get(code, 5.0) for code in icd_code]
        return sum(scores) / len(scores)# Average score
    

patient_df['risk_score'] = patient_df['icd_code'].apply(calculate_risk_score)



# Select a patient by name
patient_name = st.selectbox("Select Patient", patient_df['name'])

# Get the selected patient’s data
selected_patient = patient_df[patient_df['name'] == patient_name].iloc[0]
patient_id = selected_patient['patient_id']
patient_icd_code = selected_patient['icd_code']
patient_risk_score = selected_patient['risk_score']
icd_info_list = []
risk_scores = []

# Check if the icd_code is a list (multiple conditions)
if isinstance(patient_icd_code, list):
    for icd in patient_icd_code:
        icd_info = display_icd_code_with_explanation(icd)
        icd_info_list.append(icd_info)
        risk_scores.append(calculate_risk_score(icd))
else:
    icd_info = display_icd_code_with_explanation(patient_icd_code)
    icd_info_list.append(icd_info)
    risk_scores.append(calculate_risk_score(patient_icd_code))

# Join ICD information and calculate average risk score
icd_info_str = ", ".join(icd_info_list)
average_risk_score = sum(risk_scores) / len(risk_scores)
st.write("---")
# Streamlit App
st.subheader(f"{icd_info_str}")

st.write(f"**Average Risk Score:** {average_risk_score:.2f}")

st.write("---")

# Associated Medical Reports section
def extract_icd_codes(icd_code_str):
    """Parse and extract ICD codes from a string."""
    icd_codes_part = icd_code_str.split("codes:")[-1].strip()
    try:
        icd_codes_list = ast.literal_eval(icd_codes_part)
        cleaned_icd_codes = [code.strip("'") for code in icd_codes_list]
        return cleaned_icd_codes
    except (SyntaxError, ValueError):
        return []

# Extract ICD codes from all reports
all_icd_codes = []
for icd_code_str in reports_df['FT_ICD10_CODE']:
    all_icd_codes.extend(extract_icd_codes(icd_code_str))

# Get unique ICD codes in reports
unique_icd_codes = list(set(all_icd_codes))
st.subheader("Associated Medical Reports")


# Check if patient_icd_code is a list and filter accordingly
if isinstance(patient_icd_code, list):
    st.write(f"Patient ICD Codes: {patient_icd_code}")
    
    # Retrieve matching codes by checking if any patient ICD code exists in the reports
    matching_codes = set(patient_icd_code) & set(unique_icd_codes)
    
    
    # Filter reports using the matching codes
    filtered_reports = reports_df[reports_df['FT_ICD10_CODE'].apply(lambda x: any(code in x for code in matching_codes))]
else:
    st.write(f"Patient ICD Code: {patient_icd_code}")
    filtered_reports = reports_df[reports_df['FT_ICD10_CODE'].str.contains(patient_icd_code)]

if not filtered_reports.empty:
    with st.expander("View Reports"):  # Add parentheses and a label for the expander
        st.write(f"Found {len(filtered_reports)} associated reports:")
        for idx, report in filtered_reports.iterrows():
            st.write(f"**Report Name:** {report['RELATIVE_PATH']}")
            st.write(f"**Report Description:** {report['REPORT']}")
            st.write(f"**Speciality:** {report['SPECIALTY']}")
            st.write(f"**Extracted Text:** {report['DOC_TEXT']}")
            st.write("---")
else:
    with st.expander("No Reports Found"):  # Add parentheses and a label for the expander
        st.write("No associated reports found for this ICD code(s).")


# END OF NOTEBOOK