# Privacy-Preserving RAG: Clinical Assistant with Tool Calling

This notebook implements a clinical assistant powered by OpenAI's GPT4.1 model that can answer patient-specific questions using structured data while preserving privacy through local de-identification.

In [None]:
import os
import pandas as pd

def generate_note_auto(table_name, patient_id):
    """
    Reads the generated *_notes.csv file for the given table and returns 
    the first clinical note found for the provided patient_id.

    Parameters:
    - table_name: str (e.g., "observations", "medications")
    - patient_id: str (the patient_id to search for)

    Returns:
    - str: the clinical note or a message if not found
    """
    notes_file = os.path.join("data", f"{table_name}_notes.csv")

    if not os.path.exists(notes_file):
        return f"Note file not found: {notes_file}"

    df = pd.read_csv(notes_file)
    
    # Normalize column names just in case
    df.columns = [col.strip().lower() for col in df.columns]

    if "patient_id" not in df.columns or "clinical_note" not in df.columns:
        return "Expected columns not found in the CSV."

    matching_notes = df[df["patient_id"] == patient_id]

    if matching_notes.empty:
        return f"No note found for patient_id: {patient_id}"

    return matching_notes.iloc[0]["clinical_note"]

In [None]:
# Install essential packages
# Run this cell first to install core dependencies

!pip install openai python-dotenv pandas numpy
!pip install transformers
!pip install peft
!pip install torch

In [None]:
# Import necessary libraries
import os
import json
from typing import Dict, List, Any, Optional, Union
from datetime import datetime, date
import pandas as pd
import numpy as np

# OpenAI imports
from openai import OpenAI

# Environment variables
from dotenv import load_dotenv
load_dotenv()

print("✅ Libraries imported successfully!")
print(f"📅 Current date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

In [None]:
# gated model login with Hugging Face CLI
# Make sure you have the Hugging Face CLI installed and authenticated
!pip install huggingface_hub
!pip install llama-cpp-python

from huggingface_hub import login
login(os.getenv("HUGGING_FACE_API_KEY"))

from llama_cpp import Llama

# Download the model to 'models/gemma-3-4b-it-qat-q4_0.gguf'
basic_deidentifier = Llama.from_pretrained(
    repo_id="google/gemma-3-4b-it-qat-q4_0-gguf",
    filename="gemma-3-4b-it-q4_0.gguf",  
    n_ctx=2048,
    n_threads=8
)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
import torch

BASE_MODEL = "google/gemma-2b"
LORA_MODEL_PATH = "./gemma-deid-lora/checkpoint-60"

# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Apply LoRA weights
gemma_finetuned_model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH)
gemma_finetuned_model.eval()

In [None]:
def _build_hipaa_prompt(user_data: str) -> str:
    """
    Build the common HIPAA compliance prompt for de-identification.
    
    Args:
        user_data: Clinical text containing PHI
    
    Returns:
        Formatted prompt for de-identification
    """
    return f"""
        You are a local language model responsible for enforcing HIPAA compliance by identifying and 
        removing all Protected Health Information (PHI) from clinical text and structured data before 
        it is passed to an external system. Your task is to remove all 18 identifiers defined under 
        HIPAA's Safe Harbor rule while preserving the clinical meaning of the data.

        Redact all identifiers like names, dates, addresses, SSNs, etc. with placeholder [REDACTED], without summarizing or altering clinical facts.

        ---
        <data_with_phi>
        {user_data}
        </data_with_phi>
        <data_hipaa_compliant>
    """

def _extract_redacted_content(raw_output: str, prompt: str) -> str:
    """
    Extract the redacted content from model output.
    
    Args:
        raw_output: Raw output from the model
        prompt: Original prompt used
    
    Returns:
        Extracted redacted content
    """
    if "<data_hipaa_compliant>" in raw_output:
        redacted_part = raw_output.split("<data_hipaa_compliant>")[-1]
        redacted_part = redacted_part.split("</data_hipaa_compliant>")[0].strip()
    else:
        # Fallback: remove the prompt from the beginning
        redacted_part = raw_output[len(prompt):].strip()
    
    return redacted_part

In [None]:
def deidentify_with_basic_gemma(user_data: str) -> str:
    """
    De-identify clinical text using the basic Gemma model via llama-cpp-python.
    
    Args:
        user_data: Clinical text containing PHI
    
    Returns:
        De-identified text with PHI redacted
    """
    prompt = _build_hipaa_prompt(user_data)
    
    try:
        response = basic_deidentifier.create_chat_completion(
            messages=[
                {"role": "system", "content": prompt},
                {"role": "user", "content": user_data}
            ],
            temperature=0.2
        )
        raw_output = response["choices"][0]["message"]["content"]
        redacted_content = _extract_redacted_content(raw_output, prompt)
        
        # Logging
        print(f"📄 Original note: {user_data}")
        print("🔒 Redaction complete (Basic Gemma). PHI has been removed from the clinical note.")
        print(f"📄 Redacted note: {redacted_content}")
        
        return redacted_content
        
    except Exception as e:
        print(f"❌ Error during basic de-identification: {e}")
        return f"[DEIDENTIFICATION_ERROR: {str(e)}]"

In [None]:
def deidentify_with_finetuned_gemma(user_data: str) -> str:
    """
    De-identify clinical text using the fine-tuned Gemma model with LoRA weights.
    
    Args:
        user_data: Clinical text containing PHI
    
    Returns:
        De-identified text with PHI redacted
    """
    prompt = _build_hipaa_prompt(user_data)
    
    try:
        device = "mps" if torch.mps.is_available() else "cpu"
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

        with torch.no_grad():
            output = gemma_finetuned_model.generate(
                input_ids,
                max_new_tokens=300,
                temperature=0.7,
                do_sample=False,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id
            )

        raw_output = tokenizer.decode(output[0], skip_special_tokens=True)
        redacted_content = _extract_redacted_content(raw_output, prompt)
        
        # Logging
        print(f"📄 Original note: {user_data}")
        print("🔒 Redaction complete (Fine-tuned Gemma). PHI has been removed from the clinical note.")
        print(f"📄 Redacted note: {redacted_content}")
        
        return redacted_content
        
    except Exception as e:
        print(f"❌ Error during fine-tuned de-identification: {e}")
        return f"[DEIDENTIFICATION_ERROR: {str(e)}]"

In [None]:
# Self-contained Clinical Assistant with all configurations built-in
class ClinicalAssistant:
    """Self-contained clinical assistant with built-in configurations and tools"""
    
    def __init__(self, deidentify_model: str = "finetuned"):
        """
        Initialize the Clinical Assistant with built-in configurations
        
        Args:
            deidentify_model: "basic" or "finetuned" for de-identification model
        """
        self.deidentify_model = deidentify_model
        
        # Initialize OpenAI configuration
        self.api_key = os.getenv("OPENAI_API_KEY")
        if not self.api_key:
            raise ValueError("OPENAI_API_KEY not found in environment variables")
        
        self.client = OpenAI(api_key=self.api_key)
        self.model = "gpt-4.1"
        self.max_completion_tokens = 5000
        self.temperature = 0.1
        
        # Tool definitions
        self.tool_definitions = [
            {
                "type": "function",
                "function": {
                    "name": "get_patient_observations",
                    "description": "Retrieve laboratory test results and clinical observations for a patient",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "patient_id": {
                                "type": "string",
                                "description": "Unique patient identifier"
                            }
                        },
                        "required": ["patient_id"]
                    }
                }
            },
            {
                "type": "function",
                "function": {
                    "name": "get_patient_conditions",
                    "description": "Retrieve patient diagnosis information, medical conditions, and medical history",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "patient_id": {
                                "type": "string",
                                "description": "Unique patient identifier"
                            }
                        },
                        "required": ["patient_id"]
                    }
                }
            },
            {
                "type": "function",
                "function": {
                    "name": "get_patient_medications",
                    "description": "Retrieve current and past medications for a patient",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "patient_id": {
                                "type": "string",
                                "description": "Unique patient identifier"
                            }
                        },
                        "required": ["patient_id"]
                    }
                }
            },
            {
                "type": "function",
                "function": {
                    "name": "get_patient_careplans",
                    "description": "Retrieve care plans and treatment plans for a patient",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "patient_id": {
                                "type": "string",
                                "description": "Unique patient identifier"
                            }
                        },
                        "required": ["patient_id"]
                    }
                }
            },
            {
                "type": "function",
                "function": {
                    "name": "get_patient_procedures",
                    "description": "Retrieve medical procedures and interventions performed on a patient",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "patient_id": {
                                "type": "string",
                                "description": "Unique patient identifier"
                            }
                        },
                        "required": ["patient_id"]
                    }
                }
            },
            {
                "type": "function",
                "function": {
                    "name": "get_patient_imaging_studies",
                    "description": "Retrieve imaging studies and radiology reports for a patient",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "patient_id": {
                                "type": "string",
                                "description": "Unique patient identifier"
                            }
                        },
                        "required": ["patient_id"]
                    }
                }
            },
            {
                "type": "function",
                "function": {
                    "name": "get_patient_immunizations",
                    "description": "Retrieve vaccination history and immunization records for a patient",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "patient_id": {
                                "type": "string",
                                "description": "Unique patient identifier"
                            }
                        },
                        "required": ["patient_id"]
                    }
                }
            },
            {
                "type": "function",
                "function": {
                    "name": "get_patient_allergies",
                    "description": "Retrieve allergy information and adverse reactions for a patient",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "patient_id": {
                                "type": "string",
                                "description": "Unique patient identifier"
                            }
                        },
                        "required": ["patient_id"]
                    }
                }
            }
        ]
        
        # Function mapping - directly use generate_note_auto with appropriate table names
        self.function_map = {
            "get_patient_observations": lambda patient_id: generate_note_auto("observations", patient_id),
            "get_patient_conditions": lambda patient_id: generate_note_auto("conditions", patient_id),
            "get_patient_medications": lambda patient_id: generate_note_auto("medications", patient_id),
            "get_patient_careplans": lambda patient_id: generate_note_auto("careplans", patient_id),
            "get_patient_procedures": lambda patient_id: generate_note_auto("procedures", patient_id),
            "get_patient_imaging_studies": lambda patient_id: generate_note_auto("imaging_studies", patient_id),
            "get_patient_immunizations": lambda patient_id: generate_note_auto("immunizations", patient_id),
            "get_patient_allergies": lambda patient_id: generate_note_auto("allergies", patient_id)
        }
        
        print(f"✅ Clinical Assistant initialized with {self.deidentify_model} de-identification model!")
    
    def _build_system_prompt(self) -> str:
        """Build the system prompt for the clinical assistant"""
        return f"""You are an AI Clinical Reasoning Assistant with expertise in internal medicine. 
        Analyze clinical data and respond to patient-specific questions using structured reasoning.

        Available Tools:
        - get_patient_observations(): Laboratory test results and clinical observations
        - get_patient_conditions(): Diagnosis history, medical conditions, and medical history
        - get_patient_medications(): Current and past medications
        - get_patient_careplans(): Care plans and treatment plans
        - get_patient_procedures(): Medical procedures and interventions
        - get_patient_imaging_studies(): Imaging studies and radiology reports
        - get_patient_immunizations(): Vaccination history and immunization records
        - get_patient_allergies(): Allergy information and adverse reactions

        Clinical Reasoning Framework:
        1. Analyze the question to determine needed information
        2. Use appropriate tools to gather relevant patient data
        3. Synthesize findings from multiple data sources
        4. Provide clear, evidence-based responses with clinical reasoning

        Rules:
        - Use ONLY data provided by tool outputs
        - Reference relative timeframes when provided (e.g., "Day 0", "Post-op Day 3")
        - Acknowledge limitations if data is insufficient
        - Suggest additional information needed when applicable
        - Consider interactions between medications, conditions, and procedures
        - Always check for allergies before recommending treatments

        Current date: {datetime.now().strftime('%Y-%m-%d')}
        """
    
    def _de_identify_data(self, data: str) -> str:
        """De-identify text data using the specified deidentify model"""
        if self.deidentify_model == "basic":
            return deidentify_with_basic_gemma(data)
        elif self.deidentify_model == "finetuned":
            return deidentify_with_finetuned_gemma(data)        
    
    def _execute_tool_call(self, function_name: str, arguments: Dict[str, Any]) -> str:
        """Execute a tool function call with de-identification"""
        if function_name not in self.function_map:
            return f"Error: Unknown function: {function_name}"
        
        print(f"Executing tool call: {function_name}")
        try:
            func = self.function_map[function_name]
            raw_result = func(**arguments)
            return self._de_identify_data(raw_result)
        except Exception as e:
            return f"Error executing {function_name}: {str(e)}"

    def process_query(self, query: str, patient_id: str = None) -> str:
        """
        Process a clinical query using OpenAI GPT with tool calling
        
        Args:
            query: The clinical question to answer
            patient_id: Optional patient ID if known
            
        Returns:
            Clinical assistant response
        """
        try:
            # Build messages
            messages = [
                {"role": "system", "content": self._build_system_prompt()},
                {"role": "user", "content": f"Clinical query: {query}"}
            ]
            
            if patient_id:
                messages[-1]["content"] += f"\nPatient ID: {patient_id}"
            
            # Get initial response from LLM
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                tools=self.tool_definitions,
                tool_choice="auto",
                max_completion_tokens=self.max_completion_tokens,
                temperature=self.temperature
            )
            
            # Handle tool calls if any
            if response.choices[0].message.tool_calls:
                current_messages = messages.copy()
                current_messages.append({
                    "role": "assistant",
                    "content": response.choices[0].message.content,
                    "tool_calls": response.choices[0].message.tool_calls
                })
                
                # Process each tool call
                for tool_call in response.choices[0].message.tool_calls:
                    function_name = tool_call.function.name
                    function_args = json.loads(tool_call.function.arguments)
                    
                    # Execute tool call with de-identification
                    tool_result = self._execute_tool_call(function_name, function_args)
                    
                    current_messages.append({
                        "role": "tool",
                        "tool_call_id": tool_call.id,
                        "content": tool_result
                    })
                
                # Get final response with tool results
                final_response = self.client.chat.completions.create(
                    model=self.model,
                    messages=current_messages,
                    tools=self.tool_definitions,
                    tool_choice="auto",
                    max_completion_tokens=self.max_completion_tokens,
                    temperature=self.temperature
                )
                return final_response.choices[0].message.content
            else:
                return response.choices[0].message.content
            
        except Exception as e:
            return f"❌ Error processing query: {str(e)}"

In [None]:
# Test Clinical Assistant with Different De-identification Models
print("🧪 Testing Clinical Assistant with Integrated De-identification")
print("=" * 60)

# Sample clinical queries to test the system with all available tools
sample_queries = [
    "Can you summarize this patient's medical conditions and current medications?",
    "Does this patient have any allergies I should be aware of before prescribing?",
    "What is this patient's vaccination status and immunization history?",
    "What imaging studies have been performed and what were the findings?",
]

# Enhanced interactive testing function that creates assistant instances
from IPython.display import display, Markdown

def test_query(query_text: str, patient_id: str = "12345678", deidentify_model: str = "finetuned"):
    """
    Test function that creates a Clinical Assistant instance with the specified model
    
    Args:
        query_text: The clinical question to ask
        patient_id: Patient identifier (default: "12345678")
        deidentify_model: "finetuned" or "basic" for de-identification model
    
    Returns:
        Assistant response
    """
    try:
        # Create assistant instance with the specified de-identification model
        assistant = ClinicalAssistant(deidentify_model=deidentify_model)
        
        model_name = "Fine-tuned Gemma" if deidentify_model == "finetuned" else "Basic Gemma"
        header_md = f"### 🩺 Clinical Query: {query_text}\n\n**Patient ID:** {patient_id}\n**De-identification Model:** {model_name}"
        
        display(Markdown(header_md))
        display(Markdown("---"))
        display(Markdown("### 🤖 Assistant Response"))

        response = assistant.process_query(query_text, patient_id)
        display(Markdown(f"```\n{response}\n```"))
        return response
        
    except Exception as e:
        error_msg = f"❌ Error: {str(e)}"
        display(Markdown(error_msg))
        return error_msg

In [None]:


# Test with both de-identification models separately
test_patient_id = "1329b83e-ea69-d184-b4af-0d2a8e07896e"
query = sample_queries[1]

print("🔍 Testing Fine-tuned Gemma model:")
test_query(query, test_patient_id, "finetuned")

print("\n" + "="*60 + "\n")

print("🔍 Testing Basic Gemma model:")
test_query(query, test_patient_id, "basic")