Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
git config --global url."https://${GH_READ_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/"

- name: Sync deps
run: uv sync --extra all
run: uv sync

- name: Lint
run: uv run tox -e lint
Expand Down Expand Up @@ -68,7 +68,7 @@ jobs:
git config --global url."https://${GH_READ_TOKEN}:x-oauth-basic@github.com/".insteadOf "https://github.com/"

- name: Sync deps
run: uv sync --extra all
run: uv sync

- name: Test
run: uv run tox -e test
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ dev = [

[project]
authors = [{email = "mattcoul7@gmail.com", name = "Matt Coulter"}]
dependencies = []
dependencies = [
"sqlalchemy>=2.0.0",
"fastapi>=0.115.0,<0.116",
"pydantic>=2.12.0",
"ab-database>=0.2.2",
]
description = "A template package template."
name = "ab-sqlalchemy-fastapi-http-exceptions"
readme = "README.md"
Expand Down
3 changes: 3 additions & 0 deletions src/ab_core/sqlalchemy_fastapi_http_exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .fastapi_integration import register_database_exception_handlers

__all__ = ["register_database_exception_handlers"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from .base import DialectExceptionMapper, GenericExceptionMapper
from .postgres import PostgresExceptionMapper
from .mysql import MySQLExceptionMapper
from .sqlite import SQLiteExceptionMapper
from .mssql import MSSQLExceptionMapper
from .oracle import OracleExceptionMapper
from .db2 import DB2ExceptionMapper
from .hana import HANAExceptionMapper
from .ansi import AnsiSQLStateMapper # optional, generic

_DIALECTS: tuple[DialectExceptionMapper, ...] = (
PostgresExceptionMapper(),
MySQLExceptionMapper(),
SQLiteExceptionMapper(),
MSSQLExceptionMapper(),
OracleExceptionMapper(),
DB2ExceptionMapper(),
HANAExceptionMapper(),
AnsiSQLStateMapper(),
GenericExceptionMapper(), # final fallback
)

def get_mapper_by_name(name: str | None) -> DialectExceptionMapper:
if name:
lowered = name.lower()
for mapper in _DIALECTS:
if mapper.name == lowered:
return mapper
return _DIALECTS[-1]
71 changes: 71 additions & 0 deletions src/ab_core/sqlalchemy_fastapi_http_exceptions/dialects/ansi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations
from typing import Mapping, Tuple
from sqlalchemy.exc import IntegrityError, OperationalError, DataError, ProgrammingError, DBAPIError
from .base import DialectExceptionMapper

def _sqlstate(exc: BaseException) -> str | None:
if hasattr(exc, "orig"):
orig = exc.orig # type: ignore[attr-defined]
if hasattr(orig, "sqlstate"):
s = orig.sqlstate # type: ignore[attr-defined]
if isinstance(s, str) and len(s) == 5:
return s
if hasattr(orig, "diag") and hasattr(orig.diag, "sqlstate"):
s2 = orig.diag.sqlstate # type: ignore[attr-defined]
if isinstance(s2, str) and len(s2) == 5:
return s2
return None

_SPECIFICS: dict[str, tuple[int, str]] = {
"23505": (409, "unique_constraint"),
"23503": (409, "foreign_key_constraint"),
"23502": (422, "not_null_violation"),
"23514": (422, "check_violation"),
"42501": (403, "insufficient_privilege"),
"42601": (500, "syntax_error"),
"40P01": (503, "deadlock_detected"), # PG-flavoured, but safe if seen
"40001": (409, "serialization_failure"),
}

def _class_default(sqlstate: str) -> tuple[int, str]:
c = sqlstate[:2]
if c == "23": # Integrity
return (409, "constraint_violation")
if c == "22": # Data
return (400, "invalid_data")
if c == "40": # Txn rollback
return (409, "transaction_rollback")
if c == "08": # Connection
return (503, "connection_exception")
if c == "57": # Operator intervention
return (503, "operator_intervention")
if c == "42": # Syntax/Access
return (500, "syntax_or_access_rule")
return (500, "db_error")

def _map(sqlstate: str | None, default_status: int, default_reason: str) -> tuple[int, Mapping[str, str]]:
if sqlstate is None:
return default_status, {"reason": default_reason}
if sqlstate in _SPECIFICS:
status, reason = _SPECIFICS[sqlstate]
return status, {"sqlstate": sqlstate, "reason": reason}
status, reason = _class_default(sqlstate)
return status, {"sqlstate": sqlstate, "reason": reason}

class AnsiSQLStateMapper(DialectExceptionMapper):
name = "ansi-sqlstate" # not a real SA dialect; treat as a generic mapper

def map_integrity_error(self, exc: IntegrityError):
return _map(_sqlstate(exc), 409, "constraint_violation")

def map_operational_error(self, exc: OperationalError):
return _map(_sqlstate(exc), 503, "db_unavailable")

def map_data_error(self, exc: DataError):
return _map(_sqlstate(exc), 400, "invalid_data")

def map_programming_error(self, exc: ProgrammingError):
return _map(_sqlstate(exc), 500, "db_programming_error")

def map_dbapi_error(self, exc: DBAPIError):
return _map(_sqlstate(exc), 503, "db_error")
46 changes: 46 additions & 0 deletions src/ab_core/sqlalchemy_fastapi_http_exceptions/dialects/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping

from sqlalchemy.exc import DataError, DBAPIError, IntegrityError, OperationalError, ProgrammingError


class DialectExceptionMapper(ABC):
"""Interface for dialect-specific DB -> HTTP mappings."""

name: str = "generic"

@abstractmethod
def map_integrity_error(self, exc: IntegrityError) -> tuple[int, Mapping[str, str]]: ...

@abstractmethod
def map_operational_error(self, exc: OperationalError) -> tuple[int, Mapping[str, str]]: ...

@abstractmethod
def map_data_error(self, exc: DataError) -> tuple[int, Mapping[str, str]]: ...

@abstractmethod
def map_programming_error(self, exc: ProgrammingError) -> tuple[int, Mapping[str, str]]: ...

@abstractmethod
def map_dbapi_error(self, exc: DBAPIError) -> tuple[int, Mapping[str, str]]: ...


class GenericExceptionMapper(DialectExceptionMapper):
name = "generic"

def map_integrity_error(self, _exc: IntegrityError) -> tuple[int, Mapping[str, str]]:
return 409, {"reason": "constraint_violation"}

def map_operational_error(self, _exc: OperationalError) -> tuple[int, Mapping[str, str]]:
return 503, {"reason": "db_unavailable"}

def map_data_error(self, _exc: DataError) -> tuple[int, Mapping[str, str]]:
return 400, {"reason": "invalid_data"}

def map_programming_error(self, _exc: ProgrammingError) -> tuple[int, Mapping[str, str]]:
return 500, {"reason": "db_programming_error"}

def map_dbapi_error(self, _exc: DBAPIError) -> tuple[int, Mapping[str, str]]:
return 503, {"reason": "db_error"}
75 changes: 75 additions & 0 deletions src/ab_core/sqlalchemy_fastapi_http_exceptions/dialects/db2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations
from typing import Mapping, Tuple
from sqlalchemy.exc import IntegrityError, OperationalError, DataError, ProgrammingError, DBAPIError
from .base import DialectExceptionMapper

def _extract_sqlstate(exc: BaseException) -> str | None:
if hasattr(exc, "orig"):
orig = exc.orig # type: ignore[attr-defined]
# ibm_db_sa surfaces .sqlstate or message with SQLSTATE=xxxxx
if hasattr(orig, "sqlstate"):
s = orig.sqlstate # type: ignore[attr-defined]
if isinstance(s, str) and len(s) == 5:
return s
if hasattr(orig, "args"):
args = orig.args # type: ignore[attr-defined]
if isinstance(args, (list, tuple)) and args:
text = str(args[0])
idx = text.find("SQLSTATE=")
if idx != -1 and len(text) >= idx + 13:
candidate = text[idx + 9 : idx + 14]
if len(candidate) == 5:
return candidate
return None

_DB2_INTEGRITY: dict[str, tuple[int, str]] = {
"23505": (409, "unique_constraint"),
"23503": (409, "foreign_key_constraint"),
"23502": (422, "not_null_violation"),
"23514": (422, "check_violation"),
}

_DB2_DATA: dict[str, tuple[int, str]] = {
"22001": (400, "string_data_right_truncation"),
"22003": (400, "numeric_value_out_of_range"),
"22007": (400, "invalid_datetime_format"),
"22008": (400, "datetime_field_overflow"),
}

_DB2_COMMON: dict[str, tuple[int, str]] = {
"40001": (409, "serialization_failure"),
"57033": (503, "lock_timeout"), # common DB2 lock timeout state
"08000": (503, "connection_exception"),
"08006": (503, "connection_failure"),
"42501": (403, "insufficient_privilege"),
"42601": (500, "syntax_error"),
}

def _from_sqlstate(code: str | None, table: dict[str, tuple[int, str]], default_status: int, default_reason: str) -> tuple[int, Mapping[str, str]]:
if code is not None and code in table:
status, reason = table[code]
return status, {"sqlstate": code, "reason": reason}
if code is not None:
return default_status, {"sqlstate": code, "reason": default_reason}
return default_status, {"reason": default_reason}

class DB2ExceptionMapper(DialectExceptionMapper):
name = "ibm_db_sa" # SQLAlchemy commonly exposes this dialect name; adjust if needed.

def map_integrity_error(self, exc: IntegrityError):
code = _extract_sqlstate(exc)
if code in _DB2_INTEGRITY:
return _from_sqlstate(code, _DB2_INTEGRITY, 409, "constraint_violation")
return _from_sqlstate(code, _DB2_COMMON, 409, "constraint_violation")

def map_operational_error(self, exc: OperationalError):
return _from_sqlstate(_extract_sqlstate(exc), _DB2_COMMON, 503, "db_unavailable")

def map_data_error(self, exc: DataError):
return _from_sqlstate(_extract_sqlstate(exc), _DB2_DATA, 400, "invalid_data")

def map_programming_error(self, exc: ProgrammingError):
return _from_sqlstate(_extract_sqlstate(exc), _DB2_COMMON, 500, "db_programming_error")

def map_dbapi_error(self, exc: DBAPIError):
return _from_sqlstate(_extract_sqlstate(exc), _DB2_COMMON, 503, "db_error")
64 changes: 64 additions & 0 deletions src/ab_core/sqlalchemy_fastapi_http_exceptions/dialects/hana.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations
from typing import Mapping, Tuple
from sqlalchemy.exc import IntegrityError, OperationalError, DataError, ProgrammingError, DBAPIError
from .base import DialectExceptionMapper

def _extract_code(exc: BaseException) -> str:
# hdbcli often exposes text like: "[<code>] <message>" in args[0]
if hasattr(exc, "orig"):
orig = exc.orig # type: ignore[attr-defined]
if hasattr(orig, "args"):
args = orig.args # type: ignore[attr-defined]
if isinstance(args, (list, tuple)) and args:
return str(args[0])
return ""

def _contains(s: str, needle: str) -> bool:
return needle.lower() in s.lower()

class HANAExceptionMapper(DialectExceptionMapper):
name = "hana"

def map_integrity_error(self, exc: IntegrityError):
msg = _extract_code(exc)
if _contains(msg, "unique constraint"):
return 409, {"reason": "unique_constraint"}
if _contains(msg, "foreign key"):
return 409, {"reason": "foreign_key_constraint"}
if _contains(msg, "not null"):
return 422, {"reason": "not_null_violation"}
if _contains(msg, "check constraint"):
return 422, {"reason": "check_violation"}
return 409, {"reason": "constraint_violation"}

def map_operational_error(self, exc: OperationalError):
msg = _extract_code(exc)
if _contains(msg, "lock timeout") or _contains(msg, "deadlock"):
return 503, {"reason": "lock_or_deadlock"}
if _contains(msg, "connection") or _contains(msg, "network"):
return 503, {"reason": "connection_exception"}
return 503, {"reason": "db_unavailable"}

def map_data_error(self, exc: DataError):
msg = _extract_code(exc)
if _contains(msg, "value too large") or _contains(msg, "overflow"):
return 400, {"reason": "numeric_value_out_of_range"}
if _contains(msg, "invalid date") or _contains(msg, "date/time"):
return 400, {"reason": "invalid_datetime_format"}
if _contains(msg, "too long"):
return 400, {"reason": "string_data_right_truncation"}
return 400, {"reason": "invalid_data"}

def map_programming_error(self, exc: ProgrammingError):
msg = _extract_code(exc)
if _contains(msg, "not authorized") or _contains(msg, "insufficient privilege"):
return 403, {"reason": "insufficient_privilege"}
if _contains(msg, "syntax error"):
return 500, {"reason": "syntax_error"}
return 500, {"reason": "db_programming_error"}

def map_dbapi_error(self, exc: DBAPIError):
msg = _extract_code(exc)
if _contains(msg, "lock") or _contains(msg, "timeout"):
return 503, {"reason": "db_error_lock_or_timeout"}
return 503, {"reason": "db_error"}
Loading
Loading