# Imports, Functions, Keys

In [1]:
from openai import OpenAI
import pandas as pd
from langchain import PromptTemplate
import random

from typing_extensions import override
from openai import AssistantEventHandler

from tqdm import tqdm

In [None]:
# Function that loads a file with .rff extension and separate the values when the | character is found
def LoadRFF_to_Pandas(file_path):
    with open(file_path, 'r') as f:
        content = f.readlines()
    content = [x.strip() for x in content] 
    content = [x.split('|') for x in content]
    df = pd.DataFrame(content)
    return df


# Define a function that stores several prompts in a list from random concepts
def ComputePrompts(df, n, prompt_template):
    prompts = []
    concepts = []
    df_sample = df.sample(n, replace=True)
    for index, row in df_sample.iterrows():
        prompts.append(prompt_template.format(CONCEPT = row['STR']))
        concepts.append(row['STR'])
    return prompts, concepts

# Define a function that stores several prompts in a list from random concepts
def ComputePrompts2(concept_list, n, prompt_template):
    prompts = []
    concepts = []
    sampled_concepts = random.choices(concept_list, k=n)  # Randomly sample n concepts from the list
    for concept in sampled_concepts:
        prompts.append(prompt_template.format(CONCEPT=concept))
        concepts.append(concept)
    return prompts, concepts


# Function that loads a file with .rff extension and separate the values when the | character is found
def LoadRFF_to_Pandas(file_path):
    with open(file_path, 'r') as f:
        content = f.readlines()
    content = [x.strip() for x in content] 
    content = [x.split('|') for x in content]
    df = pd.DataFrame(content)
    return df


def PrepareUMLS(root_path):

    concept_names_POR = LoadRFF_to_Pandas(f'{root_path}/MRCONSO.RRF')
    concept_names_POR = concept_names_POR[[0, 7, 14]]
    # Rename the columns
    concept_names_POR.columns = ['CUI', 'AUI', 'STR']

    definitions_POR = LoadRFF_to_Pandas(f'{root_path}/MRDEF.RRF')
    definitions_POR = definitions_POR[[0, 1, 5]]
    # Rename the columns
    definitions_POR.columns = ['CUI', 'AUI', 'DEF']

    # Right join to get the definitions of the diseases
    concepts_def_POR = pd.merge(concept_names_POR, definitions_POR, on = ['CUI', 'AUI'], how='left')

    # Sort the dataframe by the STR column, and for the same value sort by the DEF column
    concepts_def_POR = concepts_def_POR.sort_values(by=['CUI', 'DEF'])
    # Drop duplicates based on the STR column
    concepts_def_POR = concepts_def_POR.drop_duplicates(subset=['CUI'], keep='first')

    semantic_types_POR = LoadRFF_to_Pandas(f'{root_path}/MRSTY.RRF')
    semantic_types_POR = semantic_types_POR[[0, 3]]
    # Rename the columns
    semantic_types_POR.columns = ['CUI', 'STY']

    # Left join to get the definitions of the diseases
    concepts_def_POR = pd.merge(concepts_def_POR, semantic_types_POR, on = ['CUI'], how='left')
    display(concepts_def_POR)

    return concepts_def_POR


def CreateVectorStore(file_paths, client):
    vector_store = client.beta.vector_stores.create(name="Electronic Health Records")
    file_streams = [open(path, "rb") for path in file_paths]
    file_batch = client.beta.vector_stores.file_batches.upload_and_poll(
    vector_store_id=vector_store.id, files=file_streams
    )
    print(f"File batch status: {file_batch.status}")
    print(f"File batch counts: {file_batch.file_counts}")
    print(f"Vector store ID: {vector_store.id}")

    return vector_store


class EventHandler(AssistantEventHandler):
    def __init__(self, client):
        super().__init__()
        self.client = client

    @override
    def on_message_done(self, message) -> None:
        message_content = message.content[0].text
        annotations = message_content.annotations
        citations = []
        for index, annotation in enumerate(annotations):
            message_content.value = message_content.value.replace(
                annotation.text, f"[{index}]"
            )
            if file_citation := getattr(annotation, "file_citation", None):
                cited_file = self.client.files.retrieve(file_citation.file_id)
                citations.append(f"[{index}] {cited_file.filename}")


def GenerateSyntheticEHRs(client, user_prompts, assistant_id, vector_store_id, output_path, start_id=0, benchmark='semclinbr'):
    full_responses = []
    responses = []
    pbar = tqdm(enumerate(user_prompts), total=len(user_prompts))

    for index, prompt in pbar:
        thread = client.beta.threads.create(messages=[ { "role": "user", "content": prompt} ],
                                            tool_resources={"file_search": {"vector_store_ids": [vector_store_id]}},
                                            )
        
        pbar.set_description(f"Thread ID: {thread.id}")
        
        with client.beta.threads.runs.stream(thread_id=thread.id,
                                             assistant_id=assistant_id,
                                             event_handler=EventHandler(client)
                                             ) as stream: stream.until_done()

        # Retrieve the messages from the thread
        messages = client.beta.threads.messages.list(thread_id=thread.id)
        # Retrieve the response from the messages
        full_response = messages.data[0].content[0].text.value
        if benchmark == 'semclinbr':
            # Retrieve the response from the messages after removing irrelevant characters
            response = messages.data[0].content[0].text.value[7:-4]
        elif benchmark == 'portugueseclinicalner':
            # Retrieve the response from the messages after removing irrelevant characters
            response = messages.data[0].content[0].text.value[13:-4]
        else:
            raise ValueError('Benchmark not recognized. Please choose between "semclinbr" and "portugueseclinicalner"')

        full_responses.append(full_response)
        responses.append(response)

        if benchmark == 'semclinbr':
            # Store the response in a XML file in a given path
            with open(f"{output_path}/{1000+index+start_id}.xml", "w") as f:
                f.write(response)

        elif benchmark == 'portugueseclinicalner':
            # Store the response in a XML file in a given path
            with open(f"{output_path}/syntheticEHR_{index+start_id}_annotatedNER.conll", "w") as f:
                f.write(response)

        else:
            raise ValueError('Benchmark not recognized. Please choose between "semclinbr" and "portugueseclinicalner"')

    return full_responses, responses

In [None]:
# Set your API key here
api_key = "<your_api_key>"
# Set your API key here
project_id = "<your_project_id>"

# Prepare UMLS data

In [None]:
concepts_def_POR = PrepareUMLS(root_path="<your_umls_data_path>")

neurological_diseases = [
    "Epilepsia",
    "Esclerose Múltipla",
    "Doença de Alzheimer",
    "Doença de Parkinson",
    "Enxaqueca",
    "Acidente Vascular Cerebral (AVC)",
    "Neuropatia Diabética",
    "Esclerose Lateral Amiotrófica (ELA)",
    "Meningite",
    "Encefalite",
    "Miastenia Gravis",
    "Síndrome de Guillain-Barré",
    "Neuralgia do Trigémeo",
    "Neuropatia Periférica",
    "Síndrome das Pernas Inquietas",
    "Hidrocefalia",
    "Demência com Corpos de Lewy",
    "Paralisia de Bell",
    "Doença de Huntington",
    "Espondilose Cervical",
    "Doença de Creutzfeldt-Jakob",
    "Síndrome Pós-Poliomielite",
    "Malformação de Chiari",
    "Síndrome de Tourette",
    "Neurofibromatose",
    "Ataxia de Friedreich",
    "Neuralgia Pós-Herpética",
    "Doença de Wilson",
    "Síndrome de Arnold-Chiari",
    "Síndrome de Lennox-Gastaut"
]

# SemClinBr augmentation

## Prepare user prompts and system instructions

In [6]:
system_instruction = '''You are an advanced AI language model that generates and annotates synthetic health reports in Portuguese. Your task is to generate high-quality health reports and provide accurate Named Entity Recognition annotations based on the provided definitions and examples.

### Guidelines:
1. Language: Generate text in Portuguese.
2. Health Report Context: Ensure the text reflects typical health report scenarios. Follow the structure of the reports uploaded to the system instructions.
3. Annotations: Annotate entities accurately using the provided definitions and examples. Follow the format used in the file examples uploaded to the system instructions.

### Annotation Tag Definitions:
1. Disease or Syndrome: A condition that affects the body or mind, with specific symptoms and signs.
2. Sign or Symptom: Signs are objective evidence of disease perceptible to the examining physician, while symptoms are the patient's subjective experiences.
3. Quantitative Concept: A concept that represents measurable quantities. These could include physical quantities, such as length or mass.
4. Laboratory or Test Result: The outcome of a laboratory test or diagnostic procedure. This result provides information about the extent of disease, condition, or abnormality.
5. Diagnostic Procedure: A medical process to determine the presence, extent, or cause of a disease or condition. Diagnostic procedures include laboratory tests, biopsies, and other methods used to diagnose medical conditions.'''

In [7]:
template = """Considering the system instructions given and the files uploaded to it, generate a complete, extensive, and high-quality electronic health report, with around 40 entities for tag annotation across all annotation tags, for a patient with the following condition: "{CONCEPT}". Generate annotations according to the annotation tag definitions, while considering all the tags mentioned in it. Also, pay close attention to how the tags were used in the uploaded file examples. Use several abbreviations just like in the file examples uploaded. Generate an XML file format like the text files I uploaded. Your output/response should only have the generated file."""

prompt_template = PromptTemplate(input_variables=["CONCEPT"],
                                 template = template
                                )

In [None]:
N = 300
prompts, concepts = ComputePrompts(concepts_def_POR, N, prompt_template)

## SDG using OpenAI API

In [8]:
# Create a client
client = OpenAI(project=project_id, api_key=api_key)

### Prepare EHRs to add to the assistant

In [None]:
file_paths = ["real_ehr_example_01.txt", 
                "real_ehr_example_02.txt",
                "real_ehr_example_03.txt",
                "real_ehr_example_04.txt"
                ]

vector_store = CreateVectorStore(file_paths, client)

### Create the assistant

In [28]:
# Create an assistant
temperature = 0.5
top_p = 0.5
model = "gpt-4o-2024-08-06"

assistant = client.beta.assistants.create(name="Electronic Health Records Assistant",
                                          instructions=system_instruction,
                                          model=model,
                                          tools=[{"type": "file_search"}],
                                          #tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
                                          temperature=temperature,
                                          top_p=top_p
                                          )

### Generate GPT-4o SEHRs

In [None]:
full_responses, responses = GenerateSyntheticEHRs(client=client,
                                                  user_prompts=prompts, 
                                                  assistant_id=assistant.id, 
                                                  vector_store_id=vector_store.id,
                                                  output_path="<your_output_path>",
                                                  start_id=1
                                                  )

# PortugueseClinicalNER augmentation

## Prepare user prompts and system instructions

In [5]:
system_instruction = '''You are an advanced AI language model that generates and annotates synthetic health reports in Portuguese. Your task is to generate high-quality health reports and provide accurate Named Entity Recognition annotations based on the provided definitions and examples.

### Guidelines:
1. Language: Generate text in European Portuguese.
2. Health Report Context: Ensure the text reflects typical health report scenarios. Follow the structure of the reports uploaded to the system instructions.
3. Annotations: Annotate entities accurately using the provided definitions and examples. Follow the format used in the file examples uploaded to the system instructions.

### Annotation Tag Definitions:
1. THER: This tag identifies references to therapeutic interventions, including medications, treatments, and procedures used for patient management. 
2. C: This tag highlights descriptions of medical conditions, symptoms, or diagnoses. 
3. OBS: This tag is used for supplementary clinical observations that do not necessarily describe a medical condition but provide additional information about the patient's state.
4. DT: This tag marks temporal information, including specific dates, timeframes, or durations relevant to the patient's history or condition.
5. CH: This tag captures descriptive details that characterize conditions, symptoms, or findings.
6. R: This tag refers to the outcomes or findings from diagnostic tests, examinations, or treatments.'''

In [6]:
template = """Considering the system instructions given and the files uploaded to it, generate a complete, extensive, and high-quality electronic health report, with around 40 entities for tag annotation across all annotation tags, for a patient with the following condition: "{CONCEPT}". Generate annotations according to the annotation tag definitions, while considering all the tags mentioned in it. Also, pay close attention to how the tags were used in the uploaded file examples. Use several abbreviations just like in the file examples uploaded. Generate a CoNLL-2003 file format like the text files I uploaded. Your output/response should only have the generated file."""

prompt_template = PromptTemplate(input_variables=["CONCEPT"],
                                 template = template
                                )

In [None]:
N = 100
prompts, concepts = ComputePrompts2(neurological_diseases, N, prompt_template)

## SDG using OpenAI API

In [8]:
# Create a client
client = OpenAI(project=project_id, api_key=api_key)

### Prepare EHRs to add to the assistant

In [None]:
file_paths = ["real_ehr_example_01.txt", 
                "real_ehr_example_02.txt",
                "real_ehr_example_03.txt",
                "real_ehr_example_04.txt",
                ]

vector_store = CreateVectorStore(file_paths, client)

### Create the assistant

In [10]:
# Create an assistant
temperature = 0.5
top_p = 0.5
model = "gpt-4o-2024-08-06"

assistant = client.beta.assistants.create(name="Electronic Health Records Assistant",
                                          instructions=system_instruction,
                                          model=model,
                                          tools=[{"type": "file_search"}],
                                          #tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
                                          temperature=temperature,
                                          top_p=top_p
                                          )

### Generate GPT-4o v3 SEHRs

In [None]:
full_responses, responses = GenerateSyntheticEHRs(client=client,
                                                  user_prompts=prompts, 
                                                  assistant_id=assistant.id, 
                                                  vector_store_id=vector_store.id,
                                                  output_path="<your_output_path>",
                                                  start_id=1,
                                                  benchmark='portugueseclinicalner'
                                                  )