# CPT Code Extractor

In [1]:


import os
import gradio as gr
from langchain.document_loaders import PyPDFDirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.chat_models import AzureChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chains.llm import LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
from reportlab.lib.styles import getSampleStyleSheet
import logging
import openai

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
AZURE_OPENAI_API_BASE = ""
AZURE_OPENAI_API_KEY = ""
VECTORSTORE_PATH = "vectorstore"
faiss_index_created = False  # Flag to track FAISS index creation

# Initialize LLM and Embeddings
llm = AzureChatOpenAI(
    deployment_name="300-turbo",
    temperature=0,
    model_name="gpt-35-turbo-16k",
    openai_api_base=AZURE_OPENAI_API_BASE,
    openai_api_version="2023-03-15-preview",
    openai_api_type='azure',
    openai_api_key=AZURE_OPENAI_API_KEY
)

embeddings = OpenAIEmbeddings(
    deployment='350-embedding',
    openai_api_base=AZURE_OPENAI_API_BASE,
    openai_api_version="2023-05-15",
    openai_api_type='azure',
    openai_api_key=AZURE_OPENAI_API_KEY,
    chunk_size=1
)

# Cell 2: Prompt Template
prompt_template = """
Role: You are an expert in understanding medical terms and fetching the appropriate CPT codes. Your task is to understand the medical history provided below and identify the relevant CPT codes (Current Procedural Terminology codes) that describe the medical, surgical, and diagnostic services mentioned.

General Instruction: 
1. If you encounter any irrelevant text, provide a relevant query or state "Irrelevant text found".
2. Ensure that the CPT codes you select are the most specific and accurate for the procedures described.
3. Provide a brief explanation for each selected CPT code to justify its relevance.

Output Format:
CPT Codes:
[CPT code]: [CPT code description]
[CPT code]: [CPT code description]

Explanation:
[Short explanation of why each CPT code was selected]

Medical History of Patient:
{context}

Question: {question}

"""

# Cell 3: Function to Create FAISS Index
def create_faiss_index(pdf_directory):
    global faiss_index_created
    if not pdf_directory:
        logging.error("Directory path is missing")
        return "Error: Please provide a directory path."

    if not os.path.isdir(pdf_directory):
        logging.error("The provided path is not a directory")
        return "Error: The provided path is not a directory."

    pdf_files = [f for f in os.listdir(pdf_directory) if f.lower().endswith('.pdf')]
    if not pdf_files:
        logging.error("The directory does not contain any PDF files")
        return "Error: The directory does not contain any PDF files."

    try:
        logging.info("Loading PDF documents")
        loader = PyPDFDirectoryLoader(pdf_directory)
        docs = loader.load()

        logging.info("Splitting documents into chunks")
        splitter = RecursiveCharacterTextSplitter(chunk_size=3500, chunk_overlap=100)
        chunks = splitter.split_documents(docs)

        logging.info("Initializing and saving vectorstore")
        vectorstore = FAISS.from_documents(documents=chunks, embedding=embeddings)
        vectorstore.save_local(VECTORSTORE_PATH)

        faiss_index_created = True
        logging.info("FAISS index created and saved successfully")
        return "FAISS index created and saved successfully."
    except Exception as e:
        logging.error(f"An error occurred while creating FAISS index: {str(e)}")
        return f"An error occurred while creating FAISS index: {str(e)}"


# Cell 4: Function to Get Result

# Modify the get_result function to handle query chunking
def get_result(query):
    if not faiss_index_created:
        logging.error("FAISS index not created. Please create the FAISS index first.")
        return "Error: FAISS index not created. Please create the FAISS index first."
    
    if not query:
        logging.error("Query is missing")
        return "Error: Please provide a query."

    # Chunk the query if it's too long
    query_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)  # Adjust sizes as needed
    query_chunks = query_splitter.split_text(query)

    responses = []
    try:
        logging.info("Loading vectorstore")
        vectorstore = FAISS.load_local(VECTORSTORE_PATH, embeddings)

        logging.info("Initializing QA Chain Prompt Template")
        QA_CHAIN_PROMPT = PromptTemplate.from_template(prompt_template)
        llm_chain = LLMChain(llm=llm, prompt=QA_CHAIN_PROMPT, callbacks=None)

        document_prompt = PromptTemplate(
            input_variables=["page_content", "source"],
            template="Context:\ncontent:{page_content}\nsource:{source}",
        )

        combine_documents_chain = StuffDocumentsChain(
            llm_chain=llm_chain,
            document_variable_name="context",
            document_prompt=document_prompt,
            callbacks=None,
        )

        retriever = vectorstore.as_retriever(search_kwargs={"k": 4})

        for chunk in query_chunks:
            logging.info(f"Firing query for chunk: {chunk[:50]}...")  # Log the start of each chunk
            fire_query = RetrievalQA(
                combine_documents_chain=combine_documents_chain,
                callbacks=None,
                retriever=retriever
            )
            chunk_response = fire_query(chunk)['result']
            responses.append(chunk_response)

        # Aggregate responses - simple concatenation example
        final_response = "\n".join(responses)
        final_res = f'''
        ::::::: EXTRACTED CPT CODE ::::::: 
        
        {final_response}
        '''
        logging.info("Query successful")
        return final_res
    except openai.error.OpenAIError:
        logging.error("OpenAI API is currently unavailable. Please try again later.")
        return "Error: OpenAI API is currently unavailable. Please try again later."
    except Exception as e:
        logging.error(f"An error occurred: {str(e)}")
        return f"An error occurred: {str(e)}"


# Cell 5: Function to Generate PDF
def generate_pdf(medical_history, cpt_code_output):
    pdf_path = "cpt_code_output.pdf"
    doc = SimpleDocTemplate(pdf_path, pagesize=letter)
    styles = getSampleStyleSheet()# Cell 6: Define Gradio Interfaces

def setup_interface():
    with gr.Tab("Setup FAISS Index"):
        with gr.Column():
            pdf_directory_input = gr.Textbox(label="Enter PDF Directory Path")
            create_index_button = gr.Button("Create FAISS Index")
            index_creation_status = gr.Textbox(label="Status", interactive=False)

            def on_create_index_click(pdf_directory):
                return create_faiss_index(pdf_directory)

            create_index_button.click(on_create_index_click, inputs=pdf_directory_input, outputs=index_creation_status)

    return setup_interface

def query_interface():
    with gr.Tab("Query Interface"):
        with gr.Column():
            query_input = gr.Textbox(lines=10, label="Enter your query", interactive=True)  # Reduced lines to 10
            query_button = gr.Button("Get Result")
            query_result = gr.Textbox(label="Result", interactive=False)
            download_button = gr.File(label="Download PDF")

            def on_query_click(query):
                result, pdf_path = query_and_generate_pdf(query)
                return result, pdf_path

            query_button.click(on_query_click, inputs=query_input, outputs=[query_result, download_button])

    return query_interface

def query_and_generate_pdf(query):
    result = get_result(query)
    if "Error" in result:
        return result, None
    pdf_path = generate_pdf(query, result)
    return result, pdf_path


# Cell 7: Combine Interfaces into Gradio App
css = """
body {
    background-color: #e0f7fa;
    color: #006064;
}
.gradio-container {
    border-radius: 10px;
    box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
    padding: 20px;
    background-color: #ffffff;
    border: 1px solid #006064;
}
.gr-tab {
    border-radius: 10px;
    background-color: #b2ebf2;
    padding: 10px;
    margin-bottom: 10px;
}
.gr-button {
    background-color: #00796b;
    color: white;
    font-size: 16px;
    padding: 10px 20px;
    border-radius: 5px;
    margin-top: 10px;
    border: 1px solid #004d40;
}
.gr-button:hover {
    background-color: #004d40;
}
.gr-textbox {
    border: 1px solid #004d40;
    padding: 10px;
    border-radius: 5px;
    font-size: 14px;
    background-color: #ffffff;
    margin-bottom: 10px;
}
.gr-title {
    font-size: 28px;
    color: #00796b;
    text-align: center;
    margin-bottom: 20px;
    margin-top: 20px;
}
"""

iface = gr.Blocks(title="CPT CODE EXTRACTOR", css=css)
with iface:
    with gr.Column():
        gr.Markdown("<h1 class='gr-title'>CPT CODE EXTRACTOR</h1>")
    setup_interface()
    query_interface()

# Cell 8: Run the Gradio app
if __name__ == "__main__":
    iface.launch(server_port=7880)  # Specify the port
    print("Gradio interface running on http://127.0.0.1:7880")


2024-08-09 15:54:01,943 - INFO - HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
2024-08-09 15:54:02,670 - INFO - HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
2024-08-09 15:54:02,821 - INFO - HTTP Request: GET http://127.0.0.1:7880/startup-events "HTTP/1.1 200 OK"


Running on local URL:  http://127.0.0.1:7880


2024-08-09 15:54:05,555 - INFO - Found credentials in shared credentials file: ~/.aws/credentials
2024-08-09 15:54:12,326 - INFO - HTTP Request: HEAD http://127.0.0.1:7880/ "HTTP/1.1 200 OK"



To create a public link, set `share=True` in `launch()`.


Gradio interface running on http://127.0.0.1:7880


2024-08-09 15:59:21,616 - INFO - Loading PDF documents
2024-08-09 15:59:26,373 - ERROR - FAISS index not created. Please create the FAISS index first.
2024-08-09 15:59:27,152 - INFO - Splitting documents into chunks
2024-08-09 15:59:27,207 - INFO - Initializing and saving vectorstore
2024-08-09 16:00:55,810 - INFO - Loading faiss with AVX2 support.
2024-08-09 16:00:55,811 - INFO - Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
2024-08-09 16:00:55,812 - INFO - Loading faiss.
2024-08-09 16:00:55,854 - INFO - Successfully loaded faiss.
2024-08-09 16:00:56,088 - INFO - FAISS index created and saved successfully
