# Healthcare Data Processing- Medical coding and extraction of ICD-10 codes

# Import packages

In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Add a query tag to the session.
session.query_tag = {"origin":"sf_sit-is", "name":"Healthcare_Code_Extraction", "version":{"major":1, "minor":0}}

# DATA ENGINEERING
* A file format custom_PDF is created, and an external stage REPORTS_DATA is defined, pointing to an S3 bucket. This will store PDFs to be analyzed.

In [None]:
create or replace file format custom_PDF;
-- Create external stage

CREATE or replace STAGE REPORTS_DATA
    URL='s3://sfquickstarts/sfguide_llm_assisted_medical_coding_extraction_for_healthcare_in_snowflake/'
    DIRECTORY = ( ENABLE = true )
    FILE_FORMAT = custom_PDF;

-- Inspect content of stage
LS @REPORTS_DATA;

### A pdf_text_chunker extracts text from PDFs in REPORTS_DATA using PyPDF2 and langchain's RecursiveCharacterTextSplitter.
* Text is chunked into sections for easier processing, allowing overlap for context.

In [None]:

create or replace function pdf_text_chunker(file_url string)
returns table (chunk varchar)
language python
runtime_version = '3.9'
handler = 'pdf_text_chunker'
packages = ('snowflake-snowpark-python','PyPDF2', 'langchain')
as
$$
from snowflake.snowpark.types import StringType, StructField, StructType
from langchain.text_splitter import RecursiveCharacterTextSplitter
from snowflake.snowpark.files import SnowflakeFile
import PyPDF2, io
import logging
import pandas as pd
class pdf_text_chunker:
    def read_pdf(self, file_url: str) -> str:
        logger = logging.getLogger("udf_logger")
        logger.info(f"Opening file {file_url}")
        with SnowflakeFile.open(file_url, 'rb') as f:
            buffer = io.BytesIO(f.readall())
        reader = PyPDF2.PdfReader(buffer)
        text = ""
        for page in reader.pages:
            try:
                text += page.extract_text().replace('\n', ' ').replace('\0', ' ')
            except:
                text = "Unable to Extract"
                logger.warn(f"Unable to extract from file {file_url}, page {page}")
        return text
    def process(self,file_url: str):
        text = self.read_pdf(file_url)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size = 4000, #Adjust this as you see fit
            chunk_overlap  = 400, #This let's text have some form of overlap. Useful for keeping chunks contextual
            length_function = len
        )
        chunks = text_splitter.split_text(text)
        df = pd.DataFrame(chunks, columns=['chunks'])
        yield from df.itertuples(index=False, name=None)
$$;



In [None]:
select * from directory(@REPORTS_DATA);

In [None]:
create or replace TABLE DOCS_CHUNKS_TABLE (
    RELATIVE_PATH VARCHAR(16777216),
    SIZE NUMBER(38,0),
    FILE_URL VARCHAR(16777216),
    SCOPED_FILE_URL VARCHAR(16777216),
    CHUNK VARCHAR(16777216),
    CHUNK_VEC VECTOR(FLOAT, 768) ); 

In [None]:
CREATE OR REPLACE TABLE REPORTS_DATA_table (
  path string,
  fileurl string,
  SIZE NUMBER(38,0)
);

insert into REPORTS_DATA_table(path,FILEURL,size) select RELATIVE_PATH,FILE_URL,SIZE from directory(@REPORTS_DATA);


## The pdf_text_chunker function is invoke using the scopd file URL and the output is stored in docs_chunks_table with each chunk vectorized by snowflake.cortex.embed_text (768-dimensional vector).


In [None]:
INSERT INTO docs_chunks_table (relative_path,SIZE, file_url, scoped_file_url, chunk, chunk_vec)
SELECT path,
SIZE,
       fileurl,
       BUILD_SCOPED_FILE_URL(@REPORTS_DATA, path) AS scoped_file_url, 
       func.chunk AS chunk,
      snowflake.cortex.embed_text_768('e5-base-v2', func.chunk) AS chunk_vec
FROM REPORTS_DATA_table,TABLE(pdf_text_chunker(BUILD_SCOPED_FILE_URL(@REPORTS_DATA, path))) AS func;


In [None]:
SELECT * FROM docs_chunks_table;

In [None]:
select relative_path, count(*) as num_chunks
    from docs_chunks_table
    group by relative_path
    order by num_chunks desc;

In [None]:
create or replace transient table docs_and_text as
select
    relative_path ,
    listagg(chunk ,' ') as doc_text,
    null as report,
    null as specialty
from docs_chunks_table
where relative_path like '%.pdf'
group by all;

In [None]:
SELECT * FROM docs_and_text LIMIT 5;

In [None]:
select count(*) from docs_and_text;

### Extract the speciality and the report summary

In [None]:
UPDATE docs_and_text as l
SET l.report = r.report
    ,l.specialty = r.specialty
FROM (
    select relative_path,
    SNOWFLAKE.CORTEX.Complete ('mixtral-8x7b', concat(doc_text||
        ' In less than 5 words, how would you best describe the type of the document content? 
        Do not provide explanation. Remove special characters')) as report,
    SNOWFLAKE.CORTEX.EXTRACT_ANSWER (doc_text,
        'What is the medical specialty?')[0]:answer::varchar as specialty,
    from docs_and_text) as r
where l.relative_path = r.relative_path;


In [None]:
SELECT * FROM docs_and_text LIMIT 5;

# DISTILLATION FLOW - USING LLAMA3.1-405B

#### The model llama3.1-405b is used to extract ICD-10 codes from medical documents by prompting Snowflake Cortex to identify relevant codes. Outputs are stored in a table called LLAMA_OUTPUT_ICD.

In [None]:
CREATE or replace table LLAMA_OUTPUT_ICD as
select relative_path,
doc_text,
report,
specialty,
SNOWFLAKE.CORTEX.COMPLETE('llama3.1-405b', concat(doc_text||'Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a python list. Do not provide explanation')) as AI_ICD10_Code
from docs_and_text;


In [None]:
select * from LLAMA_OUTPUT_ICD limit 5;

* Carry 70:30 split for model fine tuning

In [None]:
CREATE OR REPLACE TEMPORARY TABLE temp_split_table AS
WITH numbered_rows AS (
  SELECT *,
         ROW_NUMBER() OVER (ORDER BY RANDOM()) AS row_num,
         COUNT(*) OVER() AS total_rows
  FROM LLAMA_OUTPUT_ICD
)


SELECT *,
       CASE 
         WHEN row_num < total_rows * 0.7 THEN 'train'
         WHEN row_num > total_rows * 0.7 AND row_num <= total_rows * 0.85 THEN 'val'
         ELSE 'test'
       END AS split
FROM numbered_rows;

CREATE or replace TABLE  codeextraction_training AS
SELECT relative_path,
doc_text,
report,
specialty,AI_ICD10_Code
FROM temp_split_table
WHERE split = 'train';

CREATE  or replace TABLE   codeextraction_test AS
SELECT relative_path,
doc_text,
report,
specialty,AI_ICD10_Code
FROM temp_split_table
WHERE split = 'test';

CREATE  or replace TABLE codeextraction_val AS
SELECT relative_path,
doc_text,
report,
specialty,AI_ICD10_Code
FROM temp_split_table
WHERE split = 'val';



In [None]:
select * from codeextraction_training limit 2;

# BASELINE OUTPUT FROM THE SMALLER MODEL llama3-8b (Optional)

In [None]:
create or replace table llama38b_ICDOutput as select relative_path,
doc_text,
report,
specialty,
AI_ICD10_Code,
SNOWFLAKE.CORTEX.COMPLETE('llama3-8b', concat(doc_text||'Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a python list. Do not provide explanation')) as llama38b_ICD10_Code
from LLAMA_OUTPUT_ICD;

In [None]:
SELECT * FROM llama38b_ICDOutput limit 5;

# FINE TUNE llama3-8b

Cortex Fine-tuning allows users to leverage parameter-efficient fine-tuning (PEFT) to create customized adaptors for use with pre-trained models on more specialized tasks. If you don’t want the high cost of training a large model from scratch but need better latency and results than you’re getting from prompt engineering or even retrieval augmented generation (RAG) methods, fine-tuning an existing large model is an option. Fine-tuning allows you to use examples to adjust the behavior of the model and improve the model’s knowledge of domain-specific tasks.

In [None]:
SELECT SNOWFLAKE.CORTEX.FINETUNE(
    'CREATE', 
    -- Custom model name, make sure name below is unique
    'FINETUNE_llama38b_ICDCODES',
    -- Base model name
    'llama3-8b',
    -- Training data query
    'SELECT doc_text || '' Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a python list. Do not provide explanation '' AS PROMPT, AI_ICD10_Code AS COMPLETION FROM codeextraction_training',
    -- Test data query 
    'SELECT doc_text || '' Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a python list. Do not provide explanation '' AS PROMPT, AI_ICD10_Code AS COMPLETION FROM codeextraction_val'
);


In [None]:
-- The output is the job ID of the fine-tuning job:
Select SNOWFLAKE.CORTEX.FINETUNE(
  'DESCRIBE',
'CortexFineTuningWorkflow_5f63e53e-0c49-4af7-93c5-82675b433629');--replace <> with the workflow id returned from the execution of last cell

# STOP -  PROCEED ONLY WHEN THE STATUS FIELD FOR THE JOB CHANGES TO SUCCESS IN THE PREVIOUS CELL

# INFERENCE USING FINE TUNED MODEL (AFTER ENSURING THE MOEL TUNING WAS SUCCESSFUL)

In [None]:
create or replace TABLE  llama38b_ICD_Codes as select relative_path,
doc_text,
report,
specialty,
SNOWFLAKE.CORTEX.COMPLETE('FINETUNE_llama38b_ICDCODES', concat(doc_text||'Given this medical transcript, list the unique major ICD10-CM diagnosis code in this format ONLY: X##.#. Please provide the response in the form of a python list. Do not provide explanation')) as FT_ICD10_Code
from codeextraction_training;

# Benefits of using the Fine Tuned Model:

## 1. Higher Accuracy as that of Larger Model -> HIGH ACCURACY
## 2. Smaller Model -> LOWER COST
## 3. Higher Throughput -> Smaller Model -> HIGH THROUGHPUT

## Extract codes from Fine Tuned Model

In [None]:
llama38b_ICD_Code_FT_df=session.table('llama38b_ICD_Codes').to_pandas()
llama38b_ICD_Code_FT_df.head(50)

# END OF NOTEBOOK