In [77]:
import os
import json
import csv
import xml.etree.ElementTree as ET
import openai
import numpy as np
from collections import defaultdict
import tiktoken
import gradio as gr

In [103]:
# json_template = '{{"messages": [{{"role": "system", "content": "{system}"}}, {{"role": "user", "content": "{user}"}}, {{"role": "assistant", "content": "{assistant}"}}]}}'

# def create_csv_entry(xml_content, prompt):
#     root = ET.fromstring(xml_content)
#     text_element = root.find('TEXT')
#     if text_element is None or text_element.text is None:
#         raise ValueError("The TEXT element is missing or empty.")
#     original_text = text_element.text
#     censored_text = original_text
#     tags = sorted(root.findall('.//TAGS/*'), key=lambda x: int(x.get('start')), reverse=True)
#     for tag in tags:
#         start = int(tag.get('start'))
#         end = int(tag.get('end'))
#         censored_text = censored_text[:start] + "[censored]" + censored_text[end:]
#     return json_template.format(system=prompt, user=original_text, assistant=censored_text)
def create_csv_entry(xml_content, prompt):
    root = ET.fromstring(xml_content)
    text_element = root.find('TEXT')
    if text_element is None or text_element.text is None:
        raise ValueError("The TEXT element is missing or empty.")
    original_text = text_element.text.strip()
    censored_text = original_text
    tags = sorted(root.findall('.//TAGS/*'), key=lambda x: int(x.get('start')), reverse=True)
    for tag in tags:
        start = int(tag.get('start'))
        end = int(tag.get('end'))
        censored_text = censored_text[:start] + "[censored]" + censored_text[end:]

    # Use a dictionary to form the JSON structure.
    json_structure = {
        "messages": [
            {"role": "system", "content": prompt},
            {"role": "user", "content": original_text},
            {"role": "assistant", "content": censored_text}
        ]
    }
    
    # Convert the dictionary to a JSON string.
    json_string = json.dumps(json_structure, ensure_ascii=False, separators=(',', ':'))
    
    return json_string

In [104]:
def process_xml_directory(input_dir, output_csv_path, prompt):
    with open(output_csv_path, mode='w', newline='', encoding='utf-8') as file:
        for filename in os.listdir(input_dir):
            if filename.endswith('.xml'):
                file_path = os.path.join(input_dir, filename)
                with open(file_path, 'r', encoding='UTF-8') as xml_file:
                    xml_content = xml_file.read()
                
                try:
                    writer = csv.writer(file)
                    writer.writerow([create_csv_entry(xml_content, brief_prompt)]) 
                    print(f'Processed and wrote results for {filename}')
                except ET.ParseError as e:
                    print(f"Error parsing {filename}: {str(e)}")
                except Exception as e:
                    print(f"An error occurred with {filename}: {str(e)}")

In [105]:
brief_prompt = '''Task: Please anonymize the following clinical note. Replace all the Protected health information (PHI) text with the "[censored]".'''


detailed_prompt = '''Task: Please anonymize the following clinical note.

Specific Instructions: Replace all the following Protected health information (PHI) text with the '[censored]'.

1) Censor any string or substring that has name, including patients, doctors, any acronyms, initials, and medical titles

2) Censor any string or substring that indicate profession with any mentions of job titles, like medical staff professional names, such as 'M.D.' and 'Dr.'.

3) Censor any string or substring with location, including addresses, clinic names, hospital names, and any other possible location indicators, such as '920 River Street'.

4) Censor any string or substring that indicate age, such as "Over 80 years" or "Aged 70".

5) Censor any string or substring that indicate dates, including record dates, admit dates, decharge dates etc, such as '27/09/2090' or '07/06' or '2090-08-25'

6) Censor any string or substring with contact information, including phone numbers, email, fax, URLs and IP Addresses'''

In [106]:
# Use the function like this:
input_directory = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/training-PHI-Gold-Set1'
output_csv_file = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/fine_tuning/fine_tuning_brief_dataset.csv'

# Process the XML files and write results to a CSV file in JSON format
process_xml_directory(input_directory, output_csv_file, brief_prompt)

Processed and wrote results for 279-03.xml
Processed and wrote results for 304-03.xml
Processed and wrote results for 251-02.xml
Processed and wrote results for 351-02.xml
Processed and wrote results for 400-02.xml
Processed and wrote results for 288-05.xml
Processed and wrote results for 332-02.xml
Processed and wrote results for 367-03.xml
Processed and wrote results for 296-05.xml
Processed and wrote results for 243-02.xml
Processed and wrote results for 320-02.xml
Processed and wrote results for 220-02.xml
Processed and wrote results for 275-03.xml
Processed and wrote results for 284-05.xml
Processed and wrote results for 308-03.xml
Processed and wrote results for 308-02.xml
Processed and wrote results for 284-04.xml
Processed and wrote results for 275-02.xml
Processed and wrote results for 220-03.xml
Processed and wrote results for 320-03.xml
Processed and wrote results for 243-03.xml
Processed and wrote results for 294-01.xml
Processed and wrote results for 394-01.xml
Processed a

In [141]:
#Load CSV data in
csv_file_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/fine_tuning/TestFineTunePalmBreif.csv'
cleaned_data = []

with open(csv_file_path, 'r', encoding='utf-8-sig') as file:
    csv_reader = csv.reader(file)
    for row in csv_reader:
        for cell in row:
            try:
                # Replace square brackets and inner double quotes that are problematic
                cell = cell.replace('["', '').replace('"]', '').replace('\\"', '"')

                # Load each cell as a JSON object
                cell_json = json.loads(cell)

                # Now that the content is clean, append to cleaned_data list
                cleaned_data.append(cell_json)
            except json.JSONDecodeError as e:
                print(f"JSON decode error for cell '{cell}': {e}")

jsonl_file_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/fine_tuning/TestFineTunePalmBreif.jsonl'
# Write cleaned data to a JSONL file
with open(jsonl_file_path, 'w', encoding='utf-8') as jsonl_file:
    for item in cleaned_data:
        jsonl_file.write(json.dumps(item) + '\n')

In [134]:
#from OpenAI website to format data;  https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset

# Next, we specify the data path and open the JSONL file

data_path = '/Users/yashwanthys/PersonalProjects/ML_Proj/De-Identification/fine_tuning/TestFineTune_detailed.jsonl'

# Load dataset
with open(data_path) as f:
    dataset = [json.loads(line) for line in f]

# We can inspect the data quickly by checking the number of examples and the first item

# Initial dataset stats
print("Num examples:", len(dataset))
print("First example:")
for message in dataset[0]["messages"]:
    print(message)

# Now that we have a sense of the data, we need to go through all the different examples and check to make sure the formatting is correct and matches the Chat completions message structure

# Format error checks
format_errors = defaultdict(int)

for ex in dataset:
    if not isinstance(ex, dict):
        format_errors["data_type"] += 1
        continue

    messages = ex.get("messages", None)
    if not messages:
        format_errors["missing_messages_list"] += 1
        continue

    for message in messages:
        if "role" not in message or "content" not in message:
            format_errors["message_missing_key"] += 1

        if any(k not in ("role", "content", "name") for k in message):
            format_errors["message_unrecognized_key"] += 1

        if message.get("role", None) not in ("system", "user", "assistant"):
            format_errors["unrecognized_role"] += 1

        content = message.get("content", None)
        if not content or not isinstance(content, str):
            format_errors["missing_content"] += 1

    if not any(message.get("role", None) == "assistant" for message in messages):
        format_errors["example_missing_assistant_message"] += 1

if format_errors:
    print("Found errors:")
    for k, v in format_errors.items():
        print(f"{k}: {v}")
else:
    print("No errors found")

# Beyond the structure of the message, we also need to ensure that the length does not exceed the 4096 token limit.

# Token counting functions
encoding = tiktoken.get_encoding("cl100k_base")

# not exact!
# simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3
    return num_tokens

def num_assistant_tokens_from_messages(messages):
    num_tokens = 0
    for message in messages:
        if message["role"] == "assistant":
            num_tokens += len(encoding.encode(message["content"]))
    return num_tokens

def print_distribution(values, name):
    print(f"\n#### Distribution of {name}:")
    print(f"min / max: {min(values)}, {max(values)}")
    print(f"mean / median: {np.mean(values)}, {np.median(values)}")
    print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

# Last, we can look at the results of the different formatting operations before proceeding with creating a fine-tuning job:

# Warnings and tokens counts
n_missing_system = 0
n_missing_user = 0
n_messages = []
convo_lens = []
assistant_message_lens = []

for ex in dataset:
    messages = ex["messages"]
    if not any(message["role"] == "system" for message in messages):
        n_missing_system += 1
    if not any(message["role"] == "user" for message in messages):
        n_missing_user += 1
    n_messages.append(len(messages))
    convo_lens.append(num_tokens_from_messages(messages))
    assistant_message_lens.append(num_assistant_tokens_from_messages(messages))

print("Num examples missing system message:", n_missing_system)
print("Num examples missing user message:", n_missing_user)
print_distribution(n_messages, "num_messages_per_example")
print_distribution(convo_lens, "num_total_tokens_per_example")
print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
n_too_long = sum(l > 4096 for l in convo_lens)
print(f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning")

# Pricing and default n_epochs estimate
MAX_TOKENS_PER_EXAMPLE = 4096

MIN_TARGET_EXAMPLES = 100
MAX_TARGET_EXAMPLES = 25000
TARGET_EPOCHS = 3
MIN_EPOCHS = 1
MAX_EPOCHS = 25

n_epochs = TARGET_EPOCHS
n_train_examples = len(dataset)
if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
    n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
    n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)

n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)
print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training")
print(f"By default, you'll train for {n_epochs} epochs on this dataset")
print(f"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens")

# Calculate the estimated cost for fine-tuning
cost_per_100k_tokens = 0.80  # Cost for every 100,000 tokens
estimated_cost = ((n_epochs * n_billing_tokens_in_dataset) / 100000) * cost_per_100k_tokens
print(f"Estimated cost for fine-tuning: approximately ${estimated_cost:.2f}") #I added this for actual cost based on current pricing

Num examples: 30
First example:
{'role': 'system', 'content': "Task: Please anonymize the following clinical note.\n\nSpecific Instructions: Replace all the following Protected health information (PHI) text with the '[censored]'.\n\n1) Censor any string or substring that has name, including patients, doctors, any acr"}
{'role': 'user', 'content': "Record date: \n\n\nRecord date: 2068-12-05\n\nNarrative History\t\n\n Patient presents for uri.  Walks -in .  Overdue for follow-up by 3 months.\n\n\n\nStarted last week.   Sinus pressure, post nasal drip , headache , ears blocked.  \n\nNo fevers.\n\nSlight white nasal cdischarge.\n\n\n\nTaking advil sinus only helps transiently.\n\nNot getting better or worse.\n\nNo chest symptoms - coughing etc..\n\n\n\nGot flu shot already.\n\n\n\n\n\nProblems\n\nFH breast cancer : 37 yo s\n\nFH myocardial infarction : mother died 66 yo\n\nHypertension \n\nUterine fibroids : u/s 2062\n\nSmoking : quit 2/67 s/p MI\n\nHyperlipidemia : CRF mild chol, cigs, HT