In [6]:
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Annotated, Optional, Sequence, Union, Literal, Type, Iterable
from pydantic import BaseModel, Field, create_model
from constants.my_enums import *

# ----------------------------
# Utilities
# ----------------------------

def dedup_preserve(seq: Iterable[str]) -> list[str]:
    return list(dict.fromkeys(x for x in seq if isinstance(x, str) and x))

def Lit(values: Sequence[str]):
    """Your rule: single-line Literal[*values] (Python 3.11+)."""
    return Literal[*tuple(values)]

# ----------------------------
# NER ingestion (assumes canonical class names already)
# ----------------------------

def extract_labels(ner_output: dict) -> dict[str, list[str]]:
    """
    Assume keys are canonical class names used by relationship specs.
    Accept values as list of dicts with 'label' or raw strings.
    Produces {ClassName: [labels...]} with de-duplication.
    """
    by_class: dict[str, list[str]] = {}
    for cls_name, items in ner_output.items():
        if not isinstance(items, list):
            continue
        labels: list[str] = []
        for it in items:
            if isinstance(it, dict) and isinstance(it.get("label"), str):
                labels.append(it["label"])
            elif isinstance(it, str):
                labels.append(it)
        labels = dedup_preserve(labels)
        if labels:
            by_class[cls_name] = labels
    return by_class

In [7]:
ner_example_output = {'Gene': [{'label': 'CDKN2A'}], 'SmallMolecule': [{'label': 'Hexachlorobenzene metabolite'}]}
print(extract_labels(ner_example_output))

{'Gene': ['CDKN2A'], 'SmallMolecule': ['Hexachlorobenzene metabolite']}


In [23]:

# ----------------------------
# Relationship spec DSL
# ----------------------------

@dataclass(frozen=True)
class FixedChoiceField:
    name: str
    choices: Sequence[str]
    # description: str = "" # since fixed for every qualifier/field, we'll add to system prompt
    optional: bool = True

@dataclass(frozen=True)
class FreeTextField:
    name: str
    # description: str = "" # since fixed for every qualifier/field, we'll add to system prompt
    optional: bool = True
    typ: type = str

@dataclass(frozen=True)
class DynamicEntityField:
    """
    Field whose allowed strings come from the union of the listed NER classes.
    If the union is empty, this field is OMITTED from the model entirely.
    """
    name: str
    classes: Sequence[str]
    # description: str = "" # since fixed for every qualifier/field, we'll add to system prompt
    optional: bool = True  # if present, should it be Optional[...]?

@dataclass(frozen=True)
class RelationshipSpec:
    name: str
    subject_classes: Sequence[str]              
    object_classes: Sequence[str]               
    predicate_choices: Sequence[str]
    fixed_fields: Sequence[FixedChoiceField | FreeTextField] = field(default_factory=list)
    dynamic_fields: Sequence[DynamicEntityField] = field(default_factory=list)

In [24]:
CHEMICAL_AFFECTS_GENE = RelationshipSpec(
    name="ChemicalAffectsGene",
    subject_classes=["SmallMolecule"],
    object_classes=["Gene", "Protein", "RnaTranscript"], # "Protein", "RnaTranscript"
    predicate_choices=["affects", "causes"],
    fixed_fields=[
        FixedChoiceField("subject_form_or_variant", CHEMICAL_OR_GENE_OR_GENE_PRODUCT_FORM_OR_VARIANT_ENUM, optional=True),
        FixedChoiceField("subject_part", GENE_OR_GENE_PRODUCT_OR_CHEMICAL_PART_QUALIFIER_ENUM, optional=True),
        FixedChoiceField("subject_derivative", CHEMICAL_ENTITY_DERIVATIVE_ENUM, optional=True),
        FixedChoiceField("subject_aspect", GENE_OR_GENE_PRODUCT_OR_CHEMICAL_ENTITY_ASPECT_ENUM, optional=True),
        FixedChoiceField("subject_direction", DIRECTION_QUALIFIER_ENUM, optional=True),
        FixedChoiceField("object_form_or_variant", CHEMICAL_OR_GENE_OR_GENE_PRODUCT_FORM_OR_VARIANT_ENUM, optional=True),
        FixedChoiceField("object_part", GENE_OR_GENE_PRODUCT_OR_CHEMICAL_PART_QUALIFIER_ENUM, optional=True),
        FixedChoiceField("object_aspect", GENE_OR_GENE_PRODUCT_OR_CHEMICAL_ENTITY_ASPECT_ENUM, optional=True),
        FixedChoiceField("object_direction", DIRECTION_QUALIFIER_ENUM, optional=True),
        FixedChoiceField("causal_mechanism", CAUSAL_MECHANISM_QUALIFIER_ENUM, optional=True),
    ],
    dynamic_fields=[
        DynamicEntityField("subject_context", classes=["CellType", "CellularComponent", "TissueOrOrgan"], optional=True),
        DynamicEntityField("object_context", classes=["CellType", "CellularComponent", "TissueOrOrgan"], optional=True),
        DynamicEntityField("anatomical_context", classes=["CellType", "CellularComponent", "TissueOrOrgan"], optional=True),
        DynamicEntityField("species_context", classes=["CellularOrganism"], optional=True),
    ],
)

# Included just to demonstrate filtering (will be excluded if NER lacks TissueOrOrgan)
CHEMICAL_AFFECTS_TISSUE = RelationshipSpec(
    name="ChemicalAffectsTissue",
    subject_classes=["SmallMolecule"],
    object_classes=["TissueOrOrgan"],
    predicate_choices=["affects"],
    fixed_fields=[
        FixedChoiceField("subject_direction", DIRECTION_QUALIFIER_ENUM, optional=True),
        FixedChoiceField("object_direction", DIRECTION_QUALIFIER_ENUM, optional=True),
    ],
    dynamic_fields=[
        DynamicEntityField("subject_context_qualifier", classes=["CellType", "CellularComponent", "TissueOrOrgan"], optional=True),
    ]
)

DEFAULT_SPECS: list[RelationshipSpec] = [
    CHEMICAL_AFFECTS_GENE,
    CHEMICAL_AFFECTS_TISSUE,
]

In [41]:
# ----------------------------
# Model builder
# ----------------------------

def _optional(typ):
    from typing import Optional as _Opt
    return _Opt[typ]

def _union_labels(classes: Sequence[str], labels_by_class: dict[str, list[str]]) -> list[str]:
    vals: list[str] = []
    for cls in classes:
        vals.extend(labels_by_class.get(cls, []))
    return dedup_preserve(vals)

def _add_fixed_fields(fields: dict, spec: RelationshipSpec):
    for f in spec.fixed_fields:
        if isinstance(f, FixedChoiceField):
            t = Lit(f.choices)
            if f.optional:
                t = _optional(t)
                fields[f.name] = (t, Field(default=None))
            else:
                fields[f.name] = (t, Field(...))
        elif isinstance(f, FreeTextField):
            t = f.typ
            if f.optional:
                t = _optional(t)
                fields[f.name] = (t, Field(default=None))
            else:
                fields[f.name] = (t, Field(...))
        else:
            raise TypeError(f"Unsupported fixed field spec: {f}")

def _add_dynamic_fields(fields: dict, spec: RelationshipSpec, labels_by_class: dict[str, list[str]]):
    for d in spec.dynamic_fields:
        opts = _union_labels(d.classes, labels_by_class)
        if not opts:
            # omit field entirely
            continue
        t = Lit(opts)
        if d.optional:
            t = _optional(t)
            fields[d.name] = (t, Field(default=None))
        else:
            fields[d.name] = (t, Field(...))

def build_relationship_models(
    ner_output: dict,
    specs: Sequence[RelationshipSpec] = DEFAULT_SPECS,
) -> tuple[dict[str, Type[BaseModel]], type, Type[BaseModel]]:
    """
    Assumes ner_output keys are canonical class names.
    Returns:
      - models_by_name: {rel_name: PydanticModel}
      - RelationshipUnion: Annotated[Union[...], Field(discriminator="rel_type")]
      - RelationshipsContainer: BaseModel with `relationships: list[RelationshipUnion]`
    """
    labels_by_class = extract_labels(ner_output)
    models_by_name: dict[str, Type[BaseModel]] = {}

    for spec in specs:
        subj_opts = _union_labels(spec.subject_classes, labels_by_class)
        obj_opts  = _union_labels(spec.object_classes, labels_by_class)

        # Filter: skip relationships if subject or object candidates are missing
        if not subj_opts or not obj_opts:
            continue

        fields: dict[str, tuple[type, Field]] = {}

        # Discriminator (fixed literal)
        fields["rel_type"] = (Lit([spec.name]), Field(default=spec.name))

        # Subject / predicate / object (core)
        fields["subject_label"] = (Lit(subj_opts), Field(...))
        fields["predicate"] = (Lit(spec.predicate_choices), Field(...))
        fields["object_label"] = (Lit(obj_opts), Field(...))

        # Add fixed + dynamic extras
        _add_fixed_fields(fields, spec)
        _add_dynamic_fields(fields, spec, labels_by_class)

        # Create model
        model = create_model(
            spec.name,
            **fields,
            __base__=BaseModel,
        )
        models_by_name[spec.name] = model

    if not models_by_name:
        raise ValueError("No relationships are applicable for this NER output.")

    # Discriminated union for structured output
    UnionType = Annotated[Union[tuple(models_by_name.values())], Field(discriminator="rel_type")]

    class RelationshipsContainer(BaseModel):
        relationships: list[UnionType]

    return models_by_name, UnionType, RelationshipsContainer

# ----------------------------
# Example usage
# ----------------------------
if __name__ == "__main__":
    ner = {
        "Gene": [{"label": "CDKN2A"}, {"label": "GN123"}],
        "Protein": [{"label": "TP53 protein"}],
        "RnaTranscript": [{"label": "BRCA1 mRNA"}],
        "SmallMolecule": [{"label": "Hexachlorobenzene"}],
        "TissueOrOrgan": [{"label": "liver"}, {"label": "brain"}],
        "CellType": [{"label": "neuron"}],
    }

    models_by_name, RelationshipUnion, Relationships = build_relationship_models(ner)

    ChemAffectsGene = models_by_name["ChemicalAffectsGene"]
    # ex1 = ChemAffectsGene(
    #     rel_type="ChemicalAffectsGene",
    #     subject_label="Hexachlorobenzene",
    #     predicate="affects",
    #     object_label="CDKN2A",          # from Gene
    #     object_context="neuron",        # dynamic field present
    #     anatomical_context="liver",     # dynamic field present
    # )
    # ex2 = ChemAffectsGene(
    #     rel_type="ChemicalAffectsGene",
    #     subject_label="Hexachlorobenzene",
    #     predicate="causes",
    #     object_label="TP53 protein",    # from Protein (multi-class object)
    # )

    #payload = Relationships(relationships=[ex1, ex2])
    #print(payload.model_dump())

In [42]:
Relationships.model_fields

{'relationships': FieldInfo(annotation=list[Annotated[Union[ChemicalAffectsGene, ChemicalAffectsTissue], FieldInfo(annotation=NoneType, required=True, discriminator='rel_type')]], required=True)}

In [54]:
models_by_name['ChemicalAffectsGene'].model_json_schema()

{'properties': {'rel_type': {'const': 'ChemicalAffectsGene',
   'default': 'ChemicalAffectsGene',
   'title': 'Rel Type',
   'type': 'string'},
  'subject_label': {'const': 'Hexachlorobenzene',
   'title': 'Subject Label',
   'type': 'string'},
  'predicate': {'enum': ['affects', 'causes'],
   'title': 'Predicate',
   'type': 'string'},
  'object_label': {'enum': ['CDKN2A', 'GN123', 'TP53 protein', 'BRCA1 mRNA'],
   'title': 'Object Label',
   'type': 'string'},
  'subject_form_or_variant': {'anyOf': [{'enum': ['genetic_variant_form',
      'modified_form',
      'loss_of_function_variant_form',
      'non_loss_of_function_variant_form',
      'gain_of_function_variant_form',
      'dominant_negative_variant_form',
      'polymorphic_form',
      'snp_form',
      'analog_form'],
     'type': 'string'},
    {'type': 'null'}],
   'default': None,
   'title': 'Subject Form Or Variant'},
  'subject_part': {'anyOf': [{'enum': ['3_prime_utr',
      '5_prime_utr',
      'polya_tail',
      '

{'rel_type': FieldInfo(annotation=Literal['ChemicalAffectsGene'], required=False, default='ChemicalAffectsGene'),
 'subject_label': FieldInfo(annotation=Literal['Hexachlorobenzene'], required=True),
 'predicate': FieldInfo(annotation=Literal['affects', 'causes'], required=True),
 'object_label': FieldInfo(annotation=Literal['CDKN2A', 'GN123', 'TP53 protein', 'BRCA1 mRNA'], required=True),
 'subject_form_or_variant': FieldInfo(annotation=Union[Literal['genetic_variant_form', 'modified_form', 'loss_of_function_variant_form', 'non_loss_of_function_variant_form', 'gain_of_function_variant_form', 'dominant_negative_variant_form', 'polymorphic_form', 'snp_form', 'analog_form'], NoneType], required=False, default=None),
 'subject_part': FieldInfo(annotation=Union[Literal['3_prime_utr', '5_prime_utr', 'polya_tail', 'promoter', 'enhancer', 'exon', 'intron'], NoneType], required=False, default=None),
 'subject_derivative': FieldInfo(annotation=Union[Literal['metabolite'], NoneType], required=Fal