In [1]:
from __future__ import annotations

import os
from typing import List, Optional, Dict, Type, Union

from dotenv import load_dotenv
from pydantic import BaseModel, Field, create_model, field_serializer
from neo4j import GraphDatabase

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
import re

In [2]:
load_dotenv()

OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-5-mini")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j")

STRUCTURED SCHEMA (Pydantic) — only allowed entities/edges appear

In [3]:
from entities import Entity_Collector

ENTITY_MODEL_REGISTRY: Dict[str, Type[BaseModel]] = {
    cls.__name__: cls for cls in Entity_Collector()
}

len(ENTITY_MODEL_REGISTRY)

32

In [4]:
# ---------- 2) (Toy) prefilter that decides which types to include ----------
def prefilter_entity_types(all_types: List[str]) -> List[str]:
    """
    Replace with your real logic. For now, pretend we only want protein, gene, drug.
    """
    allowed = {"Gene", "SmallMolecule", "CNSFunction"}
    return [t for t in all_types if t in allowed]

# ---------- 3) Dynamic Entities model factory ----------
def build_entities_model(selected_types: List[str]) -> Type[BaseModel]:
    """
    Create a Pydantic model class with only the selected types as optional List fields.
    Each field is omitted from the final JSON if None.
    """
    fields = {}
    for t in selected_types:
        item_model = ENTITY_MODEL_REGISTRY[t]
        # Optional[List[item_model]] default None
        fields[t + "s"] = (Optional[List[item_model]], Field(default=None))

    model_name = "ExtractedEntities" 
    # Create the model
    Entities = create_model(model_name, **fields, __doc__="Holds the lists of extracted entities divided by type")

    return Entities

In [5]:
# test above fucntions
all_types = list(ENTITY_MODEL_REGISTRY.keys())
selected_types = prefilter_entity_types(all_types)
EntitiesModel = build_entities_model(selected_types)

print(type(EntitiesModel))
EntitiesModel.model_fields

<class 'pydantic._internal._model_construction.ModelMetaclass'>


{'Genes': FieldInfo(annotation=Union[List[Gene], NoneType], required=False, default=None),
 'SmallMolecules': FieldInfo(annotation=Union[List[SmallMolecule], NoneType], required=False, default=None),
 'CNSFunctions': FieldInfo(annotation=Union[List[CNSFunction], NoneType], required=False, default=None)}

In [6]:
EntitiesModel.model_json_schema() # also docstrings of nested classes are included !!!

{'$defs': {'CNSFunction': {'description': 'A specific function or process carried out by the central nervous system (CNS), which includes the brain and spinal cord.\nExamples: "Vision", "motor control", "temperature regulation", "emotional regulation", "language comprehension", "hearing", "balance", "breathing control".',
   'properties': {'label': {'description': 'Surface form (name) of the CNSFunction as it appears in the text.',
     'title': 'Label',
     'type': 'string'},
    'type': {'anyOf': [{'enum': ['SensoryFunction',
        'MotorFunction',
        'RegulatoryFunction',
        'HigherCognitiveFunction'],
       'type': 'string'},
      {'type': 'null'}],
     'default': None,
     'description': 'Type of CNS function, None if generic or other.',
     'title': 'Type'}},
   'required': ['label'],
   'title': 'CNSFunction',
   'type': 'object'},
  'Gene': {'description': 'A stretch of DNA (or locus) that encodes (or is associated with) a functional product, such as an RNA or

SYSTEM PROMPT

In [7]:
# ---------- 4) Prompt builder that reflects the selected types ----------
def build_system_prompt(selected_types: List[str]) -> str:
    type_list = ", ".join(selected_types)
    return f"""You are a state-of-the-art biological information extraction model that performs Named Entity Recognition (NER);
    
Return ONLY these entity classes: {type_list}.

INSTRUCTIONS:
- Extract entities only if they match the allowed entity class.
- Use the exact surface form from the text for labels (do not normalize or invent) but generate a unique id.
- If nothing is found for a class, it should be None (so it can be dropped from the JSON).
- For class fields like types and other attributes, use your knowledge to fill them, if too ambiguous use None.
- Keep concise entities labels, no need to specify their class in the label (e.g from "... protein pt42 ..." extract "pt42" as label not "protein pt42").

Output MUST conform to the provided schema you were given via the structured output tool.
"""


LLM CHAIN

In [8]:
# ---------- 5) Chain factory ----------
def build_chain(selected_types: List[str], OPENAI_MODEL: str = "gpt-5-mini"):

    EntitiesExtracted = build_entities_model(selected_types)
    SYSTEM_PROMPT = build_system_prompt(selected_types)

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", SYSTEM_PROMPT),
            ("user", "<Input Text Start>\n\n{input_text}\n\n<Input Text End>\n\nReturn the structured extraction."),
        ]
    )

    llm = ChatOpenAI(model=OPENAI_MODEL, temperature=0, reasoning={"effort": "minimal"})
    chain = prompt | llm.with_structured_output(EntitiesExtracted)
    return chain, EntitiesExtracted

In [9]:
# ---------- 6) Public API ----------
ALL_TYPES = all_types

def extract(text: str, OPENAI_MODEL: str = "gpt-5-mini") -> Dict:
    selected_types = prefilter_entity_types(ALL_TYPES)            # <- your real prefilter goes here
    chain, Entities = build_chain(selected_types, OPENAI_MODEL)
    result = chain.invoke({"input_text": text})                   # a EntitiesDynamic instance
    # Dump w/ exclude_none to drop anything not present
    data = result.model_dump(exclude_none=True)

    # Optional: remove keys that are empty lists (in case LLM returned [])
    data = {k: v for k, v in data.items() if v not in (None, [], {})}
    return data

In [10]:
sample_text = "A Hexachlorobenzene metabolite increases methylation of a mutant form of the CDKN2A promoter in the nucleus of HeLa cells"

print("== Extracting NER/RE with schema constraints ==")
extraction = extract(sample_text)
print(extraction)

== Extracting NER/RE with schema constraints ==
{'Genes': [{'label': 'CDKN2A'}], 'SmallMolecules': [{'label': 'Hexachlorobenzene metabolite'}]}


In [11]:
# sample_text = (
#         "In mouse cells, various proteins play crucial roles. "
#         "For years, we considered insulin and EGFR as a major factor in causing diabetes."
#         "They can be inhibited by giving Imatinib to the mouse embryo within E10.5-E12.5"
#     )

# print("== Extracting NER/RE with schema constraints ==")
# extraction = extract(sample_text)
# print(extraction)

In [15]:
extraction

{'Genes': [{'label': 'CDKN2A'}],
 'SmallMolecules': [{'label': 'Hexachlorobenzene metabolite'}]}

In [13]:
from typing import Literal, Type
from pydantic import BaseModel, Field, create_model

def _literal_type(values: list[str]):

    return Literal[*values]  

_literal_type(['a', 'b', 'c', 'a', 'b', 'd'])

typing.Literal['a', 'b', 'c', 'd']