# GoLLIE for Medical Entity Extraction

### Import requeriments
See the requeriments.txt file in the main directory to install the required dependencies


In [2]:
import sys
sys.path.append("../") # Add the GoLLIE base directory to sys path

In [3]:
import rich 
import logging
from src.model.load_model import load_model
import black
import inspect
from jinja2 import Template as jinja2Template
import tempfile
from src.tasks.utils_typing import AnnotationList
logging.basicConfig(level=logging.INFO)
from typing import Dict, List, Type

  from .autonotebook import tqdm as notebook_tqdm


## Load GoLLIE

Load GOLLIE-7B from the huggingface-hub.
Use the AutoModelForCausalLM.from_pretrained function if you prefer it. However, creators provide a handy load_model function with many functionalities already implemented that will assist you in reproducing our results.

Please note that setting use_flash_attention=True is mandatory. Our flash attention implementation has small numerical differences compared to the attention implementation in Huggingface. Using use_flash_attention=False will result in the model producing inferior results. Flash attention requires an available CUDA GPU. Running GOLLIE pre-trained models on a CPU is not supported.

- Set force_auto_device_map=True to automatically load the model on available GPUs.
- Set quantization=4 if the model doesn't fit in your GPU memory.

In [4]:
#Use the custom load_model for loading GoLLIE
model, tokenizer = load_model(
    inference=True,
    model_weights_name_or_path="HiTZ/GoLLIE-7B",
    quantization=None,
    use_lora=False,
    force_auto_device_map=True,
    use_flash_attention=True,
    torch_dtype="bfloat16"
)

INFO:root:Loading model model from HiTZ/GoLLIE-7B
INFO:root:We will load the model using the following device map: auto and max_memory: None
INFO:root:Loading model with dtype: torch.bfloat16


ImportError: Please install Flash Attention: `pip install flash-attn --no-build-isolation`

## Define the guidelines

First, we will define the labels and guidelines for the task. We will represent them as Python classes.

The following guidelines have been defined for this example. They were not part of the pre-training dataset. Therefore, we will run GOLLIE in zero-shot settings using unseen labels.

We will use the `Generic` class, which is a versatile class that allows for the implementation of any task you want. However, since the model has never seen the Generic label during training, we will rename it to Template, which is recognized by the model (as it was used in the Tacred dataset).

We will define several classes: `Illness`, `Medication`, `PatientData`, `HospitalizationData`. Each class will have a definition and a set of slots that the model needs to fill. Each slot also requires a type definition and a short description, which can include examples. For instance, for the `Illness` class, we define three slots:

- The `mention`, which will be the name of the Ilness of the patient and should be a string.
- The `treatment` which will be a list of treatments or interventions used to manage the illness. 
- The `symptoms`, which is defined as a list of symptoms. Therefore, GoLLIE will fill this slot with a list of strings.


In [None]:
from typing import List

from src.tasks.utils_typing import dataclass
from src.tasks.utils_typing import Generic as Template
from dataclasses import dataclass, field


"""
Entity definitions
"""


@dataclass
class Medication(Template):
    """Refers to a drug or substance used to diagnose, cure, treat, or prevent disease.
    Medications can be administered in various forms and dosages and are crucial 
    in managing patient health conditions. They can be classified based on their 
    therapeutic use, mechanism of action, or chemical characteristics."""
    
    mention: str
    """
    The name of the medication.
    Such as: "Aspirina", "Ibuprofeno", "Aspirina".
    """
    dosage: str # The amount and frequency at which the medication is prescribed. Such as: "100 mg al día", "200 mg dos veces al día"
    route: str # The method of administration for the medication. Such as: "oral", "intravenoso", "tópico"
    purpose: List[str]  # List of reasons or conditions for which the medication is prescribed. Such as: ["dolor", "control de azúcar en la sangre", "inflamación"]
    

@dataclass
class Ilness(Template):
    """Refers to a health condition or disease that affects the body's normal functioning.
    Illnesses can be caused by various factors, such as infections, genetic disorders,
    lifestyle choices, or environmental factors. They can affect different body systems
    and have varying degrees of severity."""
    
    mention: str
    """
    The name of the illness or health condition.
    Such as: "diabetes", "cáncer", "hipertensión".
    """
    symptoms: List[str] # List of signs or symptoms associated with the illness. Such as: ["dolor de cabeza", "fatiga", "fiebre"]
    treatment: List[str] # List of treatments or interventions used to manage the illness. Such as: ["medicamentos", "cirugía", "terapia física"]


@dataclass
class HospitalizationData:
    """Refers to information related to a patient's hospitalization, including the
    admission date, discharge date, and reason for hospitalization. Hospitalization
    data is essential for tracking patient health status, treatment progress, and
    healthcare resource utilization."""
    
    admission_date: str #The date on which the patient was admitted to the hospital.
    discharge_date: str #The date on which the patient was discharged from the hospital.
    reason: str #the reason or cause for the patient's hospitalization.
    
    
@dataclass
class PatientData:
    """Refers to information related to a patient's medical history, including
    name, age or urgency. Patient data is essential for healthcare providers 
    to provide appropriate care and make informed decisions about patient management."""
    
    name: str #The name of the patient.
    age: int #The age of the patient.
    urgency: str #The urgency level of the patient's condition.
    
    
    
ENTITY_DEFINITIONS: List[Template] = [
    Medication,
    Ilness,
    HospitalizationData,
]

if __name__ == "__main__":
    cell_txt = In[-1] #In needs to be imported from IPython


In [None]:


@dataclass
class Launcher(Template):
    """Refers to a vehicle designed primarily to transport payloads from the Earth's 
    surface to space. Launchers can carry various payloads, including satellites, 
    crewed spacecraft, and cargo, into various orbits or even beyond Earth's orbit. 
    They are usually multi-stage vehicles that use rocket engines for propulsion."""

    mention: str  
    """
    The name of the launcher vehicle. 
    Such as: "Sturn V", "Atlas V", "Soyuz", "Ariane 5"
    """
    space_company: str # The company that operates the launcher. Such as: "Blue origin", "ESA", "Boeing", "ISRO", "Northrop Grumman", "Arianespace"
    crew: List[str] # Names of the crew members boarding the Launcher. Such as: "Neil Armstrong", "Michael Collins", "Buzz Aldrin"
    

@dataclass
class Mission(Template):
    """Any planned or accomplished journey beyond Earth's atmosphere with specific objectives, 
    either crewed or uncrewed. It includes missions to satellites, the International 
    Space Station (ISS), other celestial bodies, and deep space."""
    
    mention: str
    """
    The name of the mission. 
    Such as: "Apollo 11", "Artemis", "Mercury"
    """
    date: str # The start date of the mission
    departure: str # The place from which the vehicle will be launched. Such as: "Florida", "Houston", "French Guiana"
    destination: str # The place or planet to which the launcher will be sent. Such as "Moon", "low-orbit", "Saturn"


ENTITY_DEFINITIONS: List[Template] = [
    Launcher,
    Mission,
]
    
if __name__ == "__main__":
    cell_txt = In[-1]