In [1]:
from __future__ import annotations

import os
from typing import List, Optional, Dict, Type, Union

from dotenv import load_dotenv
from pydantic import BaseModel, Field, create_model, field_serializer
from neo4j import GraphDatabase

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
import re

In [2]:
load_dotenv()

OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-5-mini")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j")

STRUCTURED SCHEMA (Pydantic) — only allowed entities/edges appear

In [3]:
from entities import Entity_Collector

ENTITY_MODEL_REGISTRY: Dict[str, Type[BaseModel]] = {
    cls.__name__: cls for cls in Entity_Collector()
}

len(ENTITY_MODEL_REGISTRY)

32

In [4]:
# ---------- 2) (Toy) prefilter that decides which types to include ----------
def prefilter_entity_types(all_types: List[str]) -> List[str]:
    """
    Replace with your real logic. For now, pretend we only want protein, gene, drug.
    """
    allowed = {"Gene", "Drug", "Cell", "TissueOrOrgan","MouseDevelopmentalTimepoint", "BiologicalProcess", "Disease", "CNSFunction"}
    return [t for t in all_types if t in allowed]

# ---------- 3) Dynamic Entities model factory ----------
def build_entities_model(selected_types: List[str]) -> Type[BaseModel]:
    """
    Create a Pydantic model class with only the selected types as optional List fields.
    Each field is omitted from the final JSON if None.
    """
    fields = {}
    for t in selected_types:
        item_model = ENTITY_MODEL_REGISTRY[t]
        # Optional[List[item_model]] default None
        fields[t + "s"] = (Optional[List[item_model]], Field(default=None))

    model_name = "ExtractedEntities" 
    # Create the model
    Entities = create_model(model_name, **fields, __doc__="Holds the lists of extracted entities divided by type")

    return Entities

In [5]:
# test above fucntions
all_types = list(ENTITY_MODEL_REGISTRY.keys())
selected_types = prefilter_entity_types(all_types)
EntitiesModel = build_entities_model(selected_types)

print(type(EntitiesModel))
EntitiesModel.model_fields

<class 'pydantic._internal._model_construction.ModelMetaclass'>


{'Genes': FieldInfo(annotation=Union[List[Gene], NoneType], required=False, default=None),
 'Diseases': FieldInfo(annotation=Union[List[Disease], NoneType], required=False, default=None),
 'Drugs': FieldInfo(annotation=Union[List[Drug], NoneType], required=False, default=None),
 'Cells': FieldInfo(annotation=Union[List[Cell], NoneType], required=False, default=None),
 'TissueOrOrgans': FieldInfo(annotation=Union[List[TissueOrOrgan], NoneType], required=False, default=None),
 'BiologicalProcesss': FieldInfo(annotation=Union[List[BiologicalProcess], NoneType], required=False, default=None),
 'CNSFunctions': FieldInfo(annotation=Union[List[CNSFunction], NoneType], required=False, default=None),
 'MouseDevelopmentalTimepoints': FieldInfo(annotation=Union[List[MouseDevelopmentalTimepoint], NoneType], required=False, default=None)}

In [6]:
EntitiesModel.model_json_schema() # also docstrings of nested classes are included !!!

{'$defs': {'BiologicalProcess': {'description': 'One or more causally connected executions of molecular functions.\nExamples: "cell cycle", "apoptosis", "DNA repair", "signal transduction", "angiogenesis", "autophagy", "immune response".',
   'properties': {'label': {'description': 'Surface form (name) of the BiologicalProcess as it appears in the text.',
     'title': 'Label',
     'type': 'string'}},
   'required': ['label'],
   'title': 'BiologicalProcess',
   'type': 'object'},
  'CNSFunction': {'description': 'A specific function or process carried out by the central nervous system (CNS), which includes the brain and spinal cord.\nExamples: "Vision", "motor control", "temperature regulation", "emotional regulation", "language comprehension", "hearing", "balance", "breathing control".',
   'properties': {'label': {'description': 'Surface form (name) of the CNSFunction as it appears in the text.',
     'title': 'Label',
     'type': 'string'},
    'type': {'anyOf': [{'enum': ['Senso

SYSTEM PROMPT

In [7]:
# ---------- 4) Prompt builder that reflects the selected types ----------
def build_system_prompt(selected_types: List[str]) -> str:
    type_list = ", ".join(selected_types)
    return f"""You are a state-of-the-art biological information extraction model that performs Named Entity Recognition (NER);
    
Return ONLY these entity classes: {type_list}.

INSTRUCTIONS:
- Extract entities only if they match the allowed entity class.
- Use the exact surface form from the text for labels (do not normalize or invent).
- Do not repeat/duplicate entities even if they appear multiple times in the text.
- If nothing is found for a class, it should be None (so it can be dropped from the JSON).
- For class fields like types and other attributes, use your knowledge to fill them, if too ambiguous use None.
- Keep concise entities labels, no need to specify their class in the label (e.g from "... protein pt42 ..." extract "pt42" as label not "protein pt42").

Output MUST conform to the provided schema you were given via the structured output tool.
"""


LLM CHAIN

In [8]:
# ---------- 5) Chain factory ----------
def build_chain(selected_types: List[str], OPENAI_MODEL: str = "gpt-5-mini"):

    EntitiesExtracted = build_entities_model(selected_types)
    SYSTEM_PROMPT = build_system_prompt(selected_types)

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", SYSTEM_PROMPT),
            ("user", "<Input Text Start>\n\n{input_text}\n\n<Input Text End>\n\nReturn the structured extraction."),
        ]
    )

    llm = ChatOpenAI(model=OPENAI_MODEL, temperature=0, reasoning={"effort": "low"})
    chain = prompt | llm.with_structured_output(EntitiesExtracted)
    return chain, EntitiesExtracted

In [9]:
# ---------- 6) Public API ----------
ALL_TYPES = all_types


def extract(text: str, OPENAI_MODEL: str = "gpt-5-mini") -> Dict:
    selected_types = prefilter_entity_types(ALL_TYPES)            # <- your real prefilter goes here
    chain, Entities = build_chain(selected_types, OPENAI_MODEL)
    result = chain.invoke({"input_text": text})                   # a EntitiesDynamic instance
    # Dump w/ exclude_none to drop anything not present
    data = result.model_dump(exclude_none=True)

    # Optional: remove keys that are empty lists (in case LLM returned [])
    data = {k: v for k, v in data.items() if v not in (None, [], {})}
    return data

In [12]:
sample_text = "Beginning as early as the 16-somite stage, most neuronal diversity derives from direct neurogenesis (Fig. 4d), including motor neurons, cerebellar Purkinje cells, Cajal-Retzius cells and many other subtypes (CNS neurons sub-panel of Extended Data Fig. 3). Indirect neurogenesis52 has a later start, with intermediate neuronal progenitors first detected at E10.25, later giving rise to deep-layer neurons, upper-layer neurons, subplate neurons, and cortical interneurons (Fig. 4d and Extended Data Fig. 10a,b). Although many subtypes deriving from direct neurogenesis are easily distinguished, the majority (55%) of these 2.1 million cells could initially only be coarsely annotated as glutamatergic or GABAergic (γ-aminobutyric acid-producing) neurons or dorsal or ventral spinal cord progenitors. To leverage the greater heterogeneity evident at early stages as these trajectories ‘launch’ from the patterned neuroectoderm, we re-analysed the pre-E13 subset. This facilitated much more granular annotation, while also highlighting sources of heterogeneity—for example, anterior versus posterior or inhibitory versus excitatory (Fig. 4e, Extended Data Fig. 10c,d and Supplementary Table 12)."

print("== Extracting NER/RE with schema constraints ==")
extraction = extract(sample_text)
print(extraction)

== Extracting NER/RE with schema constraints ==
{'Cells': [{'label': 'motor neurons', 'type': 'Neuron'}, {'label': 'cerebellar Purkinje cells', 'type': 'Neuron'}, {'label': 'Cajal-Retzius cells', 'type': 'Neuron'}, {'label': 'intermediate neuronal progenitors', 'type': 'ProgenitorCell'}, {'label': 'deep-layer neurons', 'type': 'Neuron'}, {'label': 'upper-layer neurons', 'type': 'Neuron'}, {'label': 'subplate neurons', 'type': 'Neuron'}, {'label': 'cortical interneurons', 'type': 'Neuron'}, {'label': 'glutamatergic', 'type': 'Neuron'}, {'label': 'GABAergic (γ-aminobutyric acid-producing) neurons', 'type': 'Neuron'}, {'label': 'dorsal or ventral spinal cord progenitors', 'type': 'ProgenitorCell'}], 'TissueOrOrgans': [{'label': 'spinal cord'}, {'label': 'neuroectoderm'}], 'BiologicalProcesss': [{'label': 'direct neurogenesis'}, {'label': 'Indirect neurogenesis'}, {'label': 'neurogenesis'}], 'CNSFunctions': [{'label': 'inhibitory'}, {'label': 'excitatory'}], 'MouseDevelopmentalTimepoints':

In [10]:
sample_text = (
        "In mouse cells, various proteins play crucial roles. "
        "For years, we considered insulin and EGFR as a major factor in causing diabetes."
        "They can be inhibited by giving Imatinib to the mouse embryo within E10.5-E12.5"
    )

print("== Extracting NER/RE with schema constraints ==")
extraction = extract(sample_text)
print(extraction)

== Extracting NER/RE with schema constraints ==
{'Proteins': [{'label': 'insulin', 'type': 'hormone'}, {'label': 'EGFR', 'type': 'Receptor'}], 'Diseases': [{'label': 'diabetes'}], 'Drugs': [{'label': 'Imatinib'}], 'MouseDevelopmentalTimepoints': [{'label': 'E10.5-E12.5', 'type': 'Embryonic', 'scale': 'E', 'start_value': 10.5, 'end_value': 12.5}]}
