diff --git a/dataframely/_base_schema.py b/dataframely/_base_schema.py index f1506fd..8c64e67 100644 --- a/dataframely/_base_schema.py +++ b/dataframely/_base_schema.py @@ -6,6 +6,7 @@ import sys import textwrap from abc import ABCMeta +from collections.abc import Mapping from copy import copy from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -116,12 +117,7 @@ def __new__( *args: Any, **kwargs: Any, ) -> SchemaMeta: - result = Metadata() - for base in bases: - result.update(mcs._get_metadata_recursively(base)) - namespace_metadata = mcs._get_metadata(namespace) - mcs._remove_overridden_columns(result, namespace, bases) - result.update(namespace_metadata) + result = mcs._collect_metadata(bases, namespace) namespace[_COLUMN_ATTR] = result.columns cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs) @@ -212,7 +208,7 @@ def __getattribute__(cls, name: str) -> Any: @staticmethod def _remove_overridden_columns( result: Metadata, - namespace: dict[str, Any], + namespace: Mapping[str, Any], bases: tuple[type[object], ...], ) -> None: """Remove inherited columns that the child namespace explicitly overrides. @@ -238,15 +234,23 @@ def _remove_overridden_columns( result.columns.pop(parent_key, None) @staticmethod - def _get_metadata_recursively(kls: type[object]) -> Metadata: + def _collect_metadata( + bases: tuple[type[object], ...], + namespace: Mapping[str, Any], + ) -> Metadata: result = Metadata() - for base in kls.__bases__: + for base in bases: result.update(SchemaMeta._get_metadata_recursively(base)) - result.update(SchemaMeta._get_metadata(kls.__dict__)) # type: ignore + SchemaMeta._remove_overridden_columns(result, namespace, bases) + result.update(SchemaMeta._get_metadata(namespace)) return result @staticmethod - def _get_metadata(source: dict[str, Any]) -> Metadata: + def _get_metadata_recursively(kls: type[object]) -> Metadata: + return SchemaMeta._collect_metadata(kls.__bases__, kls.__dict__) + + @staticmethod + def _get_metadata(source: Mapping[str, Any]) -> Metadata: result = Metadata() for attr, value in { k: v for k, v in source.items() if not k.startswith("__") diff --git a/tests/schema/test_inheritance.py b/tests/schema/test_inheritance.py index 083fe68..4625822 100644 --- a/tests/schema/test_inheritance.py +++ b/tests/schema/test_inheritance.py @@ -20,3 +20,27 @@ def test_columns() -> None: assert ParentSchema.column_names() == ["a"] assert ChildSchema.column_names() == ["a", "b"] assert GrandchildSchema.column_names() == ["a", "b", "c"] + + +class OverrideBase(dy.Schema): + amt = dy.Float64(nullable=True) + + +class OverrideChild(OverrideBase): + amt = dy.Float64(nullable=False) + + +class OverrideGrandchild(OverrideChild): + pass + + +class OverrideGreatGrandchild(OverrideGrandchild): + other = dy.Integer() + + +def test_column_override_propagates_to_grandchild() -> None: + assert OverrideBase.columns()["amt"].nullable is True + assert OverrideChild.columns()["amt"].nullable is False + assert OverrideGrandchild.columns()["amt"].nullable is False + assert OverrideGreatGrandchild.columns()["amt"].nullable is False + assert OverrideGreatGrandchild.column_names() == ["amt", "other"]