# Zero-shot and One-shot Learning using Large Language Models for De-Identification of Medical Records

### Summary

All medical records are now digitalized and the old records are also being converted to digital records. These medical records have sensitive personal data related to the patient and even medical professionals, which is a threat to their pirvacy. Removal or censoring of personal identifiable information embedded within the medical records is call de-identification. 

Dataset used - https://portal.dbmi.hms.harvard.edu/projects/n2c2-nlp/ ( Have to request for permission for this data, if you want to access this dataset) 

Here, we use the recent Large Lanaguage Models like GPT-3.5, GPT-4, PaLM, Llama, Bard to evaluate how they perform on de-identification task.

#### Model Accuracies for Zero Shot Learning

|  | Brief Prompt |  Detailed Prompt | Test Sample Size |
|----------|----------|----------|----------|
| GPT-3.5 | 69% | 96% | 514 |
| GPT-4 | 96% | 99% | 100 |
| PaLM | 71% | 74% | 514 |
| Bard | * | * | 100 |
| Llama-7B | ** | ** | 100 |

\* Llama-7B model produced very inconsistent results. For some, medical records it returns says that it is not allowed to de-identify medical records and sometimes it produced different text from what was given as part of the prompt 

** Same with Bard. Bard official API is not publically available. But there is python package available and the experiment was based on it. So, bard experiment is not included in this notebook

### Model Training and Accuracies for One Shot Learning

Both GPT-3.5 and PaLM LLMs were training and fine tuned as per our needs. GPT-3.5 outforms GPT-4 when fined tuned with just 30 sample data files and trained for 3 epochs. There is considerable increase in PaLM model also when fine tuned

|                                | Training Set Size | Validation Set Size | Approx Cost   |
|--------------------------------|-------------------|---------------------|---------------|
| GPT-3.5 with brief prompts     | 30                | 10                  | `$1.4`        |
| GPT-3.5 with detailed prompts  | 30                | 10                  | `$2.3`        |
| PaLM with brief prompts        | 50                | NA                  | Free          |
| PaLM with detailed prompts     | 50                | NA                  | Free          |



|                                | Brief Prompts | Detailed Prompts | Test Set Size   |
|--------------------------------|-------------------|---------------------|---------------|
| GPT-3.5      | 96%                | 98%                 | 100        |
| PaLM         | 87%                | 93%                 | 100          |

In [1]:
%%capture --no-display
!pip3 install openai
!pip3 install requests
!pip3 install tiktoken
!pip3 install bardapi
!pip3 install -U google-generativeai

In [14]:
#importing required libraries
import os
import requests
import openai
import shutil
import configparser
import pandas as pd
import xml.etree.ElementTree as ET
from concurrent.futures import ThreadPoolExecutor
from bs4 import BeautifulSoup
from statistics import mean
import google.generativeai as palm
from llama_cpp import Llama
from llama_index.llms import LlamaCPP
from llama_index.llms.base import ChatMessage

In [3]:
#read API keys from config file
config = configparser.ConfigParser()
config.read('/Users/yashwanthys/Desktop/config.ini')
openai.api_key = config['API_KEYS']['OPENAI_API_KEY']
palm.configure(api_key = config['API_KEYS']['PALM_API_KEY'])

In [4]:
# fetching all avaliable models of OpenAI 
models = openai.Model.list()
data = pd.DataFrame(models["data"])
data.head()

Unnamed: 0,id,object,created,owned_by
0,text-search-babbage-doc-001,model,1651172509,openai-dev
1,gpt-3.5-turbo-16k-0613,model,1685474247,openai
2,curie-search-query,model,1651172509,openai-dev
3,gpt-3.5-turbo-16k,model,1683758102,openai-internal
4,text-search-babbage-query-001,model,1651172509,openai-dev


### Prompt
This the prompt or the common instruction given to all the large language models. Detailed prompt is the ideal prompt in terms of accuracy and from cost perspective. Detailed prompt is designed in such a way that it is of ideal length and proper examples are included for each instruction. Brief prompt is good in terms of cost but performs poorly in terms of accuracy

In [5]:
brief_prompt = '''Task: Please anonymize the following clinical note. Replace all the Protected health information (PHI) text with the '[censored]'.'''


detailed_prompt = '''Task: Please anonymize the following clinical note.

Specific Instructions: Replace all the following Protected health information (PHI) text with the '[censored]'.

1) Censor any string or substring that has name, including patients, doctors, any acronyms, initials, and medical titles

2) Censor any string or substring that indicate profession with any mentions of job titles, like medical staff professional names, such as 'M.D.' and 'Dr.'.

3) Censor any string or substring with location, including addresses, clinic names, hospital names, and any other possible location indicators, such as '920 River Street'.

4) Censor any string or substring that indicate age, such as "Over 80 years" or "Aged 70".

5) Censor any string or substring that indicate dates, including record dates, admit dates, decharge dates etc, such as '27/09/2090' or '07/06' or '2090-08-25'

6) Censor any string or substring with contact information, including phone numbers, email, fax, URLs and IP Addresses'''

In [18]:
def parse_single_xml(file_path: str) -> str:
    tree = ET.parse(file_path)
    root = tree.getroot()
    text = root.find('TEXT').text
    return text

In [6]:
# gpt_request function makes a API call to particular gpt model and fetches the result for the prompt provided
def gpt_request(model, system_prompt, user_report, temperature):
    Chat_Completion = openai.ChatCompletion.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_report}
        ],
        temperature=temperature,
    )
    return Chat_Completion

In [23]:
# process xml files which have the medical records using LLM models
def process_single_file(filename, input_folder, output_folder, model, system_prompt, temperature):
    # Construct file paths
    file_path = os.path.join(input_folder, filename)
    censored_filename = filename.replace('.xml', '_censored.txt')
    censored_file_path = os.path.join(output_folder, censored_filename)

    try:
        # Parse the XML file and extract the text
        tree = ET.parse(file_path)
        root = tree.getroot()
        text_element = root.find('TEXT')
        if text_element is not None and text_element.text is not None:
            text_content = text_element.text.strip()
        else:
            print(f"No text found in {filename}")
            return

        # Make the API call
        if model in ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "ft:gpt-3.5-turbo-0613:personal:briefprompt-deid2:8Hr8yYi1", "ft:gpt-3.5-turbo-0613:personal:detailedpromptdeid:8I0NZ50z"]:
            gpt_reponse = gpt_request(model=model, system_prompt=system_prompt, user_report=text_content, temperature=temperature)
            censored_text = gpt_reponse.choices[0].message.content
        elif model == "PaLM":
            palm_response = palm.chat(context=system_prompt, messages=text_content, temperature=temperature)
            censored_text = palm_response.last
        else:
            raise ValueError(f"Model {model} is not supported.")

        # Write the response to a file in the output folder
        with open(censored_file_path, 'w') as censored_file:
            censored_file.write(censored_text)

        print(f"Processed {filename} and saved response to {censored_file_path}")

    except ET.ParseError as e:
        print(f"Error parsing {filename}: {str(e)}")
    except Exception as e:
        print(f"An error occurred while processing {filename}: {str(e)}")


In [8]:
#process files in parallel for gpt models
def process_xml_files(input_folder, output_folder, model, system_prompt, temperature):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    xml_files = [f for f in sorted(os.listdir(input_folder)) if f.endswith(".xml")]

    # Use ThreadPoolExecutor to process files in parallel
    with ThreadPoolExecutor(max_workers=1) as executor:
        # Submit tasks to the executor
        futures = [executor.submit(process_single_file, filename, input_folder, output_folder, model, system_prompt, temperature) for filename in xml_files]

In [9]:
# Script taken from - https://github.com/yhydhx/ChatGPT-API/blob/main/process_xml_public.py and modified
# Uncomment print statements if detailed analysis of each report is required
def words_in_string(word_list, a_string):
    return set(word_list).intersection(a_string.split())

# Calculates accuracy of the models
def check_deidentification_accuracy(rewrite_directory, original_directory):
    list_of_files_to_check = []
    list_of_anonymized_reports = []

    for filename in os.listdir(rewrite_directory):
        f = os.path.join(rewrite_directory, filename)
        if os.path.isfile(f):
            target_file = os.path.basename(os.path.normpath(f))[:-13]
            list_of_files_to_check.append(target_file)

            text_file = open(f, "r")
            data = text_file.read()
            list_of_anonymized_reports.append(data)
            text_file.close()

    list_of_accuracies = []

    for i in range(len(list_of_files_to_check)):
        names = []
        professions = []
        locations = []
        ages = []
        dates = []
        contacts = []
        ids = []

        names_count = 0
        professions_count = 0
        locations_count = 0
        ages_count = 0
        dates_count = 0
        contacts_count = 0
        ids_count = 0

        with open(os.path.join(original_directory, list_of_files_to_check[i] + ".xml")) as fp:
            soup = BeautifulSoup(fp, features="xml")

            tags = soup.find("TAGS")

            for name_tag in tags.find_all('NAME'):
                names.append(name_tag.get('text'))

            for profession_tag in tags.find_all('PROFESSION'):
                professions.append(profession_tag.get('text'))

            for location_tag in tags.find_all('LOCATION'):
                locations.append(location_tag.get('text'))

            for age_tag in tags.find_all('AGE'):
                ages.append(age_tag.get('text'))

            for date_tag in tags.find_all('DATE'):
                dates.append(date_tag.get('text'))

            for contact_tag in tags.find_all('CONTACT'):
                contacts.append(contact_tag.get('text'))

            for id_tag in tags.find_all('ID'):
                ids.append(id_tag.get('text'))

            #print("==========================")
            #print(list_of_files_to_check[i])
            a_string = list_of_anonymized_reports[i]

            for word in words_in_string(names, a_string):
                names_count += 1

            for word in words_in_string(professions, a_string):
                professions_count += 1

            for word in words_in_string(locations, a_string):
                locations_count += 1

            for word in words_in_string(ages, a_string):
                ages_count += 1

            for word in words_in_string(dates, a_string):
                dates_count += 1

            for word in words_in_string(contacts, a_string):
                contacts_count += 1

            for word in words_in_string(ids, a_string):
                ids_count += 1

            total_remaining = names_count + professions_count + locations_count + ages_count + dates_count + contacts_count + ids_count
            total_length = len(names) + len(professions) + len(locations) + len(ages) + len(dates) + len(contacts) + len(ids)

            accuracy = 1 - (total_remaining / total_length)
            list_of_accuracies.append(accuracy)
            #print("Remaining number of strings and Accuracy: ", total_remaining, accuracy)
            #print("==========================\n")

    #print(len(list_of_files_to_check))
    average_accuracy = mean(list_of_accuracies)
    print("Average accuracy =", round(average_accuracy, 3))
    return average_accuracy

### GPT-3.5-turbo model with brief prompt 

In [None]:
model = 'gpt-3.5-turbo'
temperature = 0.05
test_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed'
output_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_3.5_test_results2'
process_xml_files(test_folder_path, output_folder_path, model, brief_prompt, temperature)

In [10]:
rewrite_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_3.5_test_results2'
original_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed'
accuracy_detailed_prompt = check_deidentification_accuracy(rewrite_directory,original_directory)

Average accuracy = 0.692


### GPT-3.5-turbo with detailed prompt 

In [None]:
model = 'gpt-3.5-turbo-16k'
temperature = 0.05
test_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed'
output_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_3.5_test_result2'
process_xml_files(test_folder_path, output_folder_path, model, detailed_prompt, temperature)

In [11]:
rewrite_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_3.5_test_result2'
original_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed'
accuracy_detailed_prompt = check_deidentification_accuracy(rewrite_directory,original_directory)

Average accuracy = 0.955


### GPT-4 with brief prompt

In [None]:
model = 'gpt-4'
temperature = 0.05
test_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
output_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_4_test_results2'
process_xml_files(test_folder_path, output_folder_path, model, brief_prompt, temperature)

In [12]:
rewrite_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_4_test_results2'
original_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
accuracy_detailed_prompt = check_deidentification_accuracy(rewrite_directory,original_directory)

Average accuracy = 0.956


### GPT-4 with detailed prompt

In [None]:
model = 'gpt-4'
temperature = 0.05
test_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
output_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_4_test_result2'
process_xml_files(test_folder_path, output_folder_path, model, detailed_prompt, temperature)

In [13]:
rewrite_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_4_test_result2'
original_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
accuracy_detailed_prompt = check_deidentification_accuracy(rewrite_directory,original_directory)

Average accuracy = 0.985


### PaLM with brief prompt

In [None]:
model = "PaLM"
temperature = 0.1
test_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
output_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_palm_test_results2'
process_xml_files(test_folder_path, output_folder_path, model, brief_prompt, temperature)

In [11]:
rewrite_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_palm_test_results2'
original_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
accuracy_detailed_prompt = check_deidentification_accuracy(rewrite_directory,original_directory)

Average accuracy = 0.71


### PaLM with detailed prompt

In [None]:
model = 'PaLM'
temperature = 0.1
test_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
output_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_palm_test_result2'
process_xml_files(test_folder_path, output_folder_path, model, detailed_prompt, temperature)

In [13]:
rewrite_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_palm_test_result2'
original_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
accuracy_detailed_prompt = check_deidentification_accuracy(rewrite_directory,original_directory)

Average accuracy = 0.739


### Llama model

In [None]:
# load the locally downloaded Llama model
llama = LlamaCPP(model_path="/Users/yashwanthys/PersonalProjects/llama2/llama.cpp/models/7B/ggml-model-q4_0.bin")

In [19]:
# test with a sindle sample xml file whoch has a medical record
# Run this cell multiple times and check the results. Sometimes it does not de identify info and 
# tells that it is not allowed do this particular task or it performs very poorly
medical_report = parse_single_xml('/Users/yashwanthys/Downloads/testing-PHI-Gold-fixed/110-03.xml')
message1 = ChatMessage(role='system',content=detailed_prompt)
message2 = ChatMessage(role="user",content=medical_report)
llama.chat([message1,message2])

Llama.generate: prefix-match hit

llama_print_timings:        load time = 52729.91 ms
llama_print_timings:      sample time =   172.33 ms /   256 runs   (    0.67 ms per token,  1485.49 tokens per second)
llama_print_timings: prompt eval time =     0.00 ms /     1 tokens (    0.00 ms per token,      inf tokens per second)
llama_print_timings:        eval time = 52396.62 ms /   256 runs   (  204.67 ms per token,     4.89 tokens per second)
llama_print_timings:       total time = 53464.00 ms


ChatResponse(message=ChatMessage(role=<MessageRole.ASSISTANT: 'assistant'>, content='\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n', additional_kwargs={}), raw={'id': 'cmpl-dd6bbe46-13e6-45b7-b4d6-e11630561501', 'object': 'text_completion', 'created': 1699121346, 'model': '/Users/yashwanthys/PersonalProjects/llama2/llama.cpp/models/7B/ggml-model-q4_0.bin', 'choices': [{'text': '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n

# Fine Tuned Models

### GPT-3.5 model fine tuned with brief prompts

In [21]:
model = 'ft:gpt-3.5-turbo-0613:personal:briefprompt-deid2:8Hr8yYi1'
temperature = 0.05
test_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
output_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_fune_tine_results'
process_xml_files(test_folder_path, output_folder_path, model, brief_prompt, temperature)

Processed 110-01.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_fune_tine_results/110-01_censored.txt
Processed 110-02.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_fune_tine_results/110-02_censored.txt
Processed 110-03.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_fune_tine_results/110-03_censored.txt
Processed 110-04.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_fune_tine_results/110-04_censored.txt
Processed 111-01.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_fune_tine_results/111-01_censored.txt
Processed 111-02.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_fune_tine_results/111-02_censored.txt
Processed 111-03.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_

In [22]:
rewrite_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/breif_gpt_fune_tine_results'
original_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
accuracy_detailed_prompt = check_deidentification_accuracy(rewrite_directory,original_directory)

Average accuracy = 0.955


### Fine tuned GPT-3.5 model with detailed prompts

In [24]:
model = 'ft:gpt-3.5-turbo-0613:personal:detailedpromptdeid:8I0NZ50z'
temperature = 0.05
test_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
output_folder_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_fune_tine_results'
process_xml_files(test_folder_path, output_folder_path, model, brief_prompt, temperature)

Processed 110-01.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_fune_tine_results/110-01_censored.txt
Processed 110-02.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_fune_tine_results/110-02_censored.txt
Processed 110-03.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_fune_tine_results/110-03_censored.txt
Processed 110-04.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_fune_tine_results/110-04_censored.txt
Processed 111-01.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_fune_tine_results/111-01_censored.txt
Processed 111-02.xml and saved response to /Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_fune_tine_results/111-02_censored.txt
Processed 111-03.xml and saved response to /Users/yashwanthys/Pe

In [25]:
rewrite_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/detailed_gpt_fune_tine_results'
original_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/testing-PHI-Gold-fixed-short'
accuracy_detailed_prompt = check_deidentification_accuracy(rewrite_directory,original_directory)

Average accuracy = 0.977
