From 2dd08e268b4ffa9d7a5406fbbe9a96709367d39a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lo=C3=AFse=20Joffe?= Date: Thu, 22 Jan 2026 16:39:25 +0100 Subject: [PATCH 1/8] Refactor: Change the DeclarativeBase sqlalchemy style --- diracx-db/src/diracx/db/sql/auth/schema.py | 6 +++-- diracx-db/src/diracx/db/sql/dummy/schema.py | 6 +++-- diracx-db/src/diracx/db/sql/job/db.py | 24 +++++++++++-------- diracx-db/src/diracx/db/sql/job/schema.py | 6 +++-- diracx-db/src/diracx/db/sql/job_logging/db.py | 6 ++--- .../src/diracx/db/sql/job_logging/schema.py | 6 +++-- .../src/diracx/db/sql/pilot_agents/schema.py | 6 +++-- .../diracx/db/sql/sandbox_metadata/schema.py | 6 +++-- .../src/diracx/db/sql/task_queue/schema.py | 6 +++-- 9 files changed, 45 insertions(+), 27 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/auth/schema.py b/diracx-db/src/diracx/db/sql/auth/schema.py index 82fc701bd..17dd27503 100644 --- a/diracx-db/src/diracx/db/sql/auth/schema.py +++ b/diracx-db/src/diracx/db/sql/auth/schema.py @@ -8,7 +8,7 @@ String, Uuid, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from diracx.db.sql.utils import ( Column, @@ -19,7 +19,9 @@ USER_CODE_LENGTH = 8 -Base = declarative_base() + +class Base(DeclarativeBase): + pass class FlowStatus(Enum): diff --git a/diracx-db/src/diracx/db/sql/dummy/schema.py b/diracx-db/src/diracx/db/sql/dummy/schema.py index 5379de94d..33debcb89 100644 --- a/diracx-db/src/diracx/db/sql/dummy/schema.py +++ b/diracx-db/src/diracx/db/sql/dummy/schema.py @@ -3,11 +3,13 @@ from __future__ import annotations from sqlalchemy import ForeignKey, Integer, String, Uuid -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from diracx.db.sql.utils import Column, DateNowColumn -Base = declarative_base() + +class Base(DeclarativeBase): + pass class Owners(Base): diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index aea3e4d48..06df129e2 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -3,9 +3,9 @@ __all__ = ["JobDB"] from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any, Iterable, cast -from sqlalchemy import bindparam, case, delete, literal, select, update +from sqlalchemy import Table, bindparam, case, delete, literal, select, update if TYPE_CHECKING: from sqlalchemy.sql.elements import BindParameter @@ -75,7 +75,9 @@ async def search( async def create_job(self, compressed_original_jdl: str): """Insert a new job with original JDL. Returns inserted job id.""" result = await self.conn.execute( - JobJDLs.__table__.insert().values( + cast(Table, JobJDLs.__table__) + .insert() + .values( JDL="", JobRequirements="", OriginalJDL=compressed_original_jdl, @@ -91,7 +93,7 @@ async def delete_jobs(self, job_ids: list[int]): async def insert_input_data(self, lfns: dict[int, list[str]]): """Insert input data for jobs.""" await self.conn.execute( - InputData.__table__.insert(), + cast(Table, InputData.__table__).insert(), [ { "JobID": job_id, @@ -105,7 +107,7 @@ async def insert_input_data(self, lfns: dict[int, list[str]]): async def insert_job_attributes(self, jobs_to_update: dict[int, dict]): """Insert the job attributes.""" await self.conn.execute( - Jobs.__table__.insert(), + cast(Table, Jobs.__table__).insert(), [ { "JobID": job_id, @@ -118,9 +120,9 @@ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]): async def update_job_jdls(self, jdls_to_update: dict[int, str]): """Update the JDL, typically just after inserting the original JDL, or rescheduling, for example.""" await self.conn.execute( - JobJDLs.__table__.update().where( - JobJDLs.__table__.c.JobID == bindparam("b_JobID") - ), + cast(Table, JobJDLs.__table__) + .update() + .where(JobJDLs.__table__.c.JobID == bindparam("b_JobID")), [ { "b_JobID": job_id, @@ -186,7 +188,7 @@ async def get_job_jdls(self, job_ids, original: bool = False) -> dict[int, str]: async def set_job_commands(self, commands: list[tuple[int, str, str]]) -> None: """Store a command to be passed to the job together with the next heart beat.""" await self.conn.execute( - JobCommands.__table__.insert(), + cast(Table, JobCommands.__table__).insert(), [ { "JobID": job_id, @@ -261,7 +263,9 @@ async def add_heartbeat_data( } for key, value in dynamic_data.items() ] - await self.conn.execute(HeartBeatLoggingInfo.__table__.insert().values(values)) + await self.conn.execute( + cast(Table, HeartBeatLoggingInfo.__table__).insert().values(values) + ) async def get_job_commands(self, job_ids: Iterable[int]) -> list[JobCommand]: """Get a command to be passed to the job together with the next heartbeat. diff --git a/diracx-db/src/diracx/db/sql/job/schema.py b/diracx-db/src/diracx/db/sql/job/schema.py index c2a286a7c..841312a73 100644 --- a/diracx-db/src/diracx/db/sql/job/schema.py +++ b/diracx-db/src/diracx/db/sql/job/schema.py @@ -8,13 +8,15 @@ String, Text, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from diracx.db.sql.utils.types import SmarterDateTime from ..utils import Column, EnumBackedBool, NullColumn -JobDBBase = declarative_base() + +class JobDBBase(DeclarativeBase): + pass class AccountedFlagEnum(types.TypeDecorator): diff --git a/diracx-db/src/diracx/db/sql/job_logging/db.py b/diracx-db/src/diracx/db/sql/job_logging/db.py index 28b124b35..a1fcf8632 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/db.py +++ b/diracx-db/src/diracx/db/sql/job_logging/db.py @@ -2,9 +2,9 @@ from collections import defaultdict from datetime import datetime, timezone -from typing import Iterable +from typing import Iterable, cast -from sqlalchemy import delete, func, select +from sqlalchemy import Table, delete, func, select from diracx.core.models.job import JobLoggingRecord, JobStatusReturn @@ -56,7 +56,7 @@ async def insert_records( seqnums[record.job_id] = seqnums[record.job_id] + 1 await self.conn.execute( - LoggingInfo.__table__.insert(), + cast(Table, LoggingInfo.__table__).insert(), values, ) diff --git a/diracx-db/src/diracx/db/sql/job_logging/schema.py b/diracx-db/src/diracx/db/sql/job_logging/schema.py index df4ba7e8f..0a6e7363c 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/schema.py +++ b/diracx-db/src/diracx/db/sql/job_logging/schema.py @@ -3,11 +3,13 @@ from datetime import UTC, datetime from sqlalchemy import Integer, Numeric, PrimaryKeyConstraint, String, TypeDecorator -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from ..utils import Column, DateNowColumn -JobLoggingDBBase = declarative_base() + +class JobLoggingDBBase(DeclarativeBase): + pass class MagicEpochDateTime(TypeDecorator): diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py index 809862ed9..523639ac1 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -7,13 +7,15 @@ String, Text, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from diracx.db.sql.utils.types import SmarterDateTime from ..utils import Column, EnumBackedBool, NullColumn -PilotAgentsDBBase = declarative_base() + +class PilotAgentsDBBase(DeclarativeBase): + pass class PilotAgents(PilotAgentsDBBase): diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py index 4cf9a2a7d..049dd7f7c 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py @@ -9,11 +9,13 @@ String, UniqueConstraint, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from diracx.db.sql.utils import Column, DateNowColumn -Base = declarative_base() + +class Base(DeclarativeBase): + pass class SBOwners(Base): diff --git a/diracx-db/src/diracx/db/sql/task_queue/schema.py b/diracx-db/src/diracx/db/sql/task_queue/schema.py index 0a3c0f033..88ffdfdc9 100644 --- a/diracx-db/src/diracx/db/sql/task_queue/schema.py +++ b/diracx-db/src/diracx/db/sql/task_queue/schema.py @@ -9,11 +9,13 @@ Integer, String, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from ..utils import Column -TaskQueueDBBase = declarative_base() + +class TaskQueueDBBase(DeclarativeBase): + pass class TaskQueues(TaskQueueDBBase): From 8058793bbdda16dbccd4345f72cdd6405d8a2d95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lo=C3=AFse=20Joffe?= Date: Tue, 3 Feb 2026 10:07:11 +0100 Subject: [PATCH 2/8] Refactor: apply SQLAlchemy 2.0 migration for model mappings --- diracx-db/src/diracx/db/sql/auth/schema.py | 70 +++++---- diracx-db/src/diracx/db/sql/dummy/schema.py | 26 ++-- diracx-db/src/diracx/db/sql/job/db.py | 22 +-- diracx-db/src/diracx/db/sql/job/schema.py | 140 ++++++++++-------- .../src/diracx/db/sql/job_logging/schema.py | 30 ++-- .../src/diracx/db/sql/pilot_agents/schema.py | 71 +++++---- .../diracx/db/sql/sandbox_metadata/schema.py | 45 +++--- .../src/diracx/db/sql/task_queue/schema.py | 91 ++++++------ diracx-db/src/diracx/db/sql/utils/__init__.py | 24 ++- diracx-db/src/diracx/db/sql/utils/types.py | 17 +++ docs/dev/reference/coding-conventions.md | 14 +- .../src/gubbins/db/sql/jobs/schema.py | 16 +- .../src/gubbins/db/sql/lollygag/schema.py | 28 ++-- 13 files changed, 346 insertions(+), 248 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/auth/schema.py b/diracx-db/src/diracx/db/sql/auth/schema.py index 17dd27503..0c7554318 100644 --- a/diracx-db/src/diracx/db/sql/auth/schema.py +++ b/diracx-db/src/diracx/db/sql/auth/schema.py @@ -1,6 +1,8 @@ from __future__ import annotations from enum import Enum, auto +from typing import Any, Optional +from uuid import UUID from sqlalchemy import ( JSON, @@ -8,20 +10,26 @@ String, Uuid, ) -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from diracx.db.sql.utils import ( - Column, - DateNowColumn, - EnumColumn, - NullColumn, + datetime_now, + enum_column, + str128, + str255, + str1024, ) USER_CODE_LENGTH = 8 class Base(DeclarativeBase): - pass + type_annotation_map = { + str128: String(128), + str255: String(255), + str1024: String(1024), + dict[str, Any]: JSON, + } class FlowStatus(Enum): @@ -49,27 +57,35 @@ class FlowStatus(Enum): class DeviceFlows(Base): __tablename__ = "DeviceFlows" - user_code = Column("UserCode", String(USER_CODE_LENGTH), primary_key=True) - status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name) - creation_time = DateNowColumn("CreationTime") - client_id = Column("ClientID", String(255)) - scope = Column("Scope", String(1024)) - device_code = Column("DeviceCode", String(128), unique=True) # Should be a hash - id_token = NullColumn("IDToken", JSON()) + user_code: Mapped[str] = mapped_column( + "UserCode", String(USER_CODE_LENGTH), primary_key=True + ) + status: Mapped[FlowStatus] = enum_column( + "Status", FlowStatus, server_default=FlowStatus.PENDING.name + ) + creation_time: Mapped[datetime_now] = mapped_column("CreationTime") + client_id: Mapped[str255] = mapped_column("ClientID") + scope: Mapped[str1024] = mapped_column("Scope") + device_code: Mapped[str128] = mapped_column( + "DeviceCode", unique=True + ) # Should be a hash + id_token: Mapped[Optional[dict[str, Any]]] = mapped_column("IDToken") class AuthorizationFlows(Base): __tablename__ = "AuthorizationFlows" - uuid = Column("UUID", Uuid(as_uuid=False), primary_key=True) - status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name) - client_id = Column("ClientID", String(255)) - creation_time = DateNowColumn("CreationTime") - scope = Column("Scope", String(1024)) - code_challenge = Column("CodeChallenge", String(255)) - code_challenge_method = Column("CodeChallengeMethod", String(8)) - redirect_uri = Column("RedirectURI", String(255)) - code = NullColumn("Code", String(255)) # Should be a hash - id_token = NullColumn("IDToken", JSON()) + uuid: Mapped[UUID] = mapped_column("UUID", Uuid(as_uuid=False), primary_key=True) + status: Mapped[FlowStatus] = enum_column( + "Status", FlowStatus, server_default=FlowStatus.PENDING.name + ) + client_id: Mapped[str255] = mapped_column("ClientID") + creation_time: Mapped[datetime_now] = mapped_column("CreationTime") + scope: Mapped[str1024] = mapped_column("Scope") + code_challenge: Mapped[str255] = mapped_column("CodeChallenge") + code_challenge_method: Mapped[str] = mapped_column("CodeChallengeMethod", String(8)) + redirect_uri: Mapped[str255] = mapped_column("RedirectURI") + code: Mapped[Optional[str255]] = mapped_column("Code") # Should be a hash + id_token: Mapped[Optional[dict[str, Any]]] = mapped_column("IDToken") class RefreshTokenStatus(Enum): @@ -95,13 +111,13 @@ class RefreshTokens(Base): __tablename__ = "RefreshTokens" # Refresh token attributes - jti = Column("JTI", Uuid(as_uuid=False), primary_key=True) - status = EnumColumn( + jti: Mapped[UUID] = mapped_column("JTI", Uuid(as_uuid=False), primary_key=True) + status: Mapped[RefreshTokenStatus] = enum_column( "Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name ) - scope = Column("Scope", String(1024)) + scope: Mapped[str1024] = mapped_column("Scope") # User attributes bound to the refresh token - sub = Column("Sub", String(256), index=True) + sub: Mapped[str] = mapped_column("Sub", String(256), index=True) __table_args__ = (Index("index_status_sub", status, sub),) diff --git a/diracx-db/src/diracx/db/sql/dummy/schema.py b/diracx-db/src/diracx/db/sql/dummy/schema.py index 33debcb89..388534e1d 100644 --- a/diracx-db/src/diracx/db/sql/dummy/schema.py +++ b/diracx-db/src/diracx/db/sql/dummy/schema.py @@ -2,25 +2,31 @@ # in place of the SQLAlchemy one. Have a look at them from __future__ import annotations -from sqlalchemy import ForeignKey, Integer, String, Uuid -from sqlalchemy.orm import DeclarativeBase +from uuid import UUID -from diracx.db.sql.utils import Column, DateNowColumn +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +from diracx.db.sql.utils import datetime_now, str255 class Base(DeclarativeBase): - pass + type_annotation_map = { + str255: String(255), + } class Owners(Base): __tablename__ = "Owners" - owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True) - creation_time = DateNowColumn("CreationTime") - name = Column("Name", String(255)) + owner_id: Mapped[int] = mapped_column( + "OwnerID", primary_key=True, autoincrement=True + ) + creation_time: Mapped[datetime_now] = mapped_column("CreationTime") + name: Mapped[str255] = mapped_column("Name") class Cars(Base): __tablename__ = "Cars" - license_plate = Column("LicensePlate", Uuid(), primary_key=True) - model = Column("Model", String(255)) - owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id)) + license_plate: Mapped[UUID] = mapped_column("LicensePlate", primary_key=True) + model: Mapped[str255] = mapped_column("Model") + owner_id: Mapped[int] = mapped_column("OwnerID", ForeignKey(Owners.owner_id)) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 06df129e2..bb28aa5cf 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -3,9 +3,9 @@ __all__ = ["JobDB"] from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Iterable, cast +from typing import TYPE_CHECKING, Any, Iterable -from sqlalchemy import Table, bindparam, case, delete, literal, select, update +from sqlalchemy import bindparam, case, delete, insert, literal, select, update if TYPE_CHECKING: from sqlalchemy.sql.elements import BindParameter @@ -75,9 +75,7 @@ async def search( async def create_job(self, compressed_original_jdl: str): """Insert a new job with original JDL. Returns inserted job id.""" result = await self.conn.execute( - cast(Table, JobJDLs.__table__) - .insert() - .values( + insert(JobJDLs).values( JDL="", JobRequirements="", OriginalJDL=compressed_original_jdl, @@ -93,7 +91,7 @@ async def delete_jobs(self, job_ids: list[int]): async def insert_input_data(self, lfns: dict[int, list[str]]): """Insert input data for jobs.""" await self.conn.execute( - cast(Table, InputData.__table__).insert(), + insert(InputData), [ { "JobID": job_id, @@ -107,7 +105,7 @@ async def insert_input_data(self, lfns: dict[int, list[str]]): async def insert_job_attributes(self, jobs_to_update: dict[int, dict]): """Insert the job attributes.""" await self.conn.execute( - cast(Table, Jobs.__table__).insert(), + insert(Jobs), [ { "JobID": job_id, @@ -120,9 +118,7 @@ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]): async def update_job_jdls(self, jdls_to_update: dict[int, str]): """Update the JDL, typically just after inserting the original JDL, or rescheduling, for example.""" await self.conn.execute( - cast(Table, JobJDLs.__table__) - .update() - .where(JobJDLs.__table__.c.JobID == bindparam("b_JobID")), + update(JobJDLs).where(JobJDLs.__table__.c.JobID == bindparam("b_JobID")), [ { "b_JobID": job_id, @@ -188,7 +184,7 @@ async def get_job_jdls(self, job_ids, original: bool = False) -> dict[int, str]: async def set_job_commands(self, commands: list[tuple[int, str, str]]) -> None: """Store a command to be passed to the job together with the next heart beat.""" await self.conn.execute( - cast(Table, JobCommands.__table__).insert(), + insert(JobCommands), [ { "JobID": job_id, @@ -263,9 +259,7 @@ async def add_heartbeat_data( } for key, value in dynamic_data.items() ] - await self.conn.execute( - cast(Table, HeartBeatLoggingInfo.__table__).insert().values(values) - ) + await self.conn.execute(insert(HeartBeatLoggingInfo).values(values)) async def get_job_commands(self, job_ids: Iterable[int]) -> list[JobCommand]: """Get a command to be passed to the job together with the next heartbeat. diff --git a/diracx-db/src/diracx/db/sql/job/schema.py b/diracx-db/src/diracx/db/sql/job/schema.py index 841312a73..0bbfd8394 100644 --- a/diracx-db/src/diracx/db/sql/job/schema.py +++ b/diracx-db/src/diracx/db/sql/job/schema.py @@ -1,22 +1,39 @@ from __future__ import annotations +from datetime import datetime +from typing import Optional + import sqlalchemy.types as types from sqlalchemy import ( ForeignKey, Index, - Integer, String, Text, ) -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from typing_extensions import Annotated from diracx.db.sql.utils.types import SmarterDateTime -from ..utils import Column, EnumBackedBool, NullColumn +from ..utils import EnumBackedBool, str32, str64, str128, str255 + +str100 = Annotated[str, 100] +jobid_type = Annotated[ + int, + mapped_column( + "JobID", ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ), +] class JobDBBase(DeclarativeBase): - pass + type_annotation_map = { + str32: String(32), + str64: String(64), + str100: String(100), + str128: String(128), + str255: String(255), + } class AccountedFlagEnum(types.TypeDecorator): @@ -49,51 +66,56 @@ def process_result_value(self, value, dialect) -> bool | str: class Jobs(JobDBBase): __tablename__ = "Jobs" - job_id = Column( + job_id: Mapped[int] = mapped_column( "JobID", - Integer, ForeignKey("JobJDLs.JobID", ondelete="CASCADE"), primary_key=True, default=0, ) - job_type = Column("JobType", String(32), default="user") - job_group = Column("JobGroup", String(32), default="00000000") - site = Column("Site", String(100), default="ANY") - job_name = Column("JobName", String(128), default="Unknown") - owner = Column("Owner", String(64), default="Unknown") - owner_group = Column("OwnerGroup", String(128), default="Unknown") - vo = Column("VO", String(32)) - submission_time = NullColumn( + job_type: Mapped[str32] = mapped_column("JobType", default="user") + job_group: Mapped[str32] = mapped_column("JobGroup", default="00000000") + site: Mapped[str100] = mapped_column("Site", default="ANY") + job_name: Mapped[str128] = mapped_column("JobName", default="Unknown") + owner: Mapped[str64] = mapped_column("Owner", default="Unknown") + owner_group: Mapped[str128] = mapped_column("OwnerGroup", default="Unknown") + vo: Mapped[str32] = mapped_column("VO") + submission_time: Mapped[Optional[datetime]] = mapped_column( "SubmissionTime", SmarterDateTime(), ) - reschedule_time = NullColumn( + reschedule_time: Mapped[Optional[datetime]] = mapped_column( "RescheduleTime", SmarterDateTime(), ) - last_update_time = NullColumn( + last_update_time: Mapped[Optional[datetime]] = mapped_column( "LastUpdateTime", SmarterDateTime(), ) - start_exec_time = NullColumn( + start_exec_time: Mapped[Optional[datetime]] = mapped_column( "StartExecTime", SmarterDateTime(), ) - heart_beat_time = NullColumn( + heart_beat_time: Mapped[Optional[datetime]] = mapped_column( "HeartBeatTime", SmarterDateTime(), ) - end_exec_time = NullColumn( + end_exec_time: Mapped[Optional[datetime]] = mapped_column( "EndExecTime", SmarterDateTime(), ) - status = Column("Status", String(32), default="Received") - minor_status = Column("MinorStatus", String(128), default="Unknown") - application_status = Column("ApplicationStatus", String(255), default="Unknown") - user_priority = Column("UserPriority", Integer, default=0) - reschedule_counter = Column("RescheduleCounter", Integer, default=0) - verified_flag = Column("VerifiedFlag", EnumBackedBool(), default=False) - accounted_flag = Column("AccountedFlag", AccountedFlagEnum(), default=False) + status: Mapped[str32] = mapped_column("Status", default="Received") + minor_status: Mapped[str128] = mapped_column("MinorStatus", default="Unknown") + application_status: Mapped[str255] = mapped_column( + "ApplicationStatus", default="Unknown" + ) + user_priority: Mapped[int] = mapped_column("UserPriority", default=0) + reschedule_counter: Mapped[int] = mapped_column("RescheduleCounter", default=0) + verified_flag: Mapped[bool] = mapped_column( + "VerifiedFlag", EnumBackedBool(), default=False + ) + accounted_flag: Mapped[bool | str] = mapped_column( + "AccountedFlag", AccountedFlagEnum(), default=False + ) __table_args__ = ( Index("JobType", "JobType"), @@ -111,57 +133,47 @@ class Jobs(JobDBBase): class JobJDLs(JobDBBase): __tablename__ = "JobJDLs" - job_id = Column("JobID", Integer, autoincrement=True, primary_key=True) - jdl = Column("JDL", Text) - job_requirements = Column("JobRequirements", Text) - original_jdl = Column("OriginalJDL", Text) + job_id: Mapped[int] = mapped_column("JobID", autoincrement=True, primary_key=True) + jdl: Mapped[str] = mapped_column("JDL", Text) + job_requirements: Mapped[str] = mapped_column("JobRequirements", Text) + original_jdl: Mapped[str] = mapped_column("OriginalJDL", Text) class InputData(JobDBBase): __tablename__ = "InputData" - job_id = Column( - "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True - ) - lfn = Column("LFN", String(255), default="", primary_key=True) - status = Column("Status", String(32), default="AprioriGood") + job_id: Mapped[jobid_type] + lfn: Mapped[str255] = mapped_column("LFN", default="", primary_key=True) + status: Mapped[str32] = mapped_column("Status", default="AprioriGood") class JobParameters(JobDBBase): __tablename__ = "JobParameters" - job_id = Column( - "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True - ) - name = Column("Name", String(100), primary_key=True) - value = Column("Value", Text) + job_id: Mapped[jobid_type] + name: Mapped[str100] = mapped_column("Name", primary_key=True) + value: Mapped[str] = mapped_column("Value", Text) class OptimizerParameters(JobDBBase): __tablename__ = "OptimizerParameters" - job_id = Column( - "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True - ) - name = Column("Name", String(100), primary_key=True) - value = Column("Value", Text) + job_id: Mapped[jobid_type] + name: Mapped[str100] = mapped_column("Name", primary_key=True) + value: Mapped[str] = mapped_column("Value", Text) class AtticJobParameters(JobDBBase): __tablename__ = "AtticJobParameters" - job_id = Column( - "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True - ) - name = Column("Name", String(100), primary_key=True) - value = Column("Value", Text) - reschedule_cycle = Column("RescheduleCycle", Integer) + job_id: Mapped[jobid_type] + name: Mapped[str100] = mapped_column("Name", primary_key=True) + value: Mapped[str] = mapped_column("Value", Text) + reschedule_cycle: Mapped[int] = mapped_column("RescheduleCycle") class HeartBeatLoggingInfo(JobDBBase): __tablename__ = "HeartBeatLoggingInfo" - job_id = Column( - "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True - ) - name = Column("Name", String(100), primary_key=True) - value = Column("Value", Text) - heart_beat_time = Column( + job_id: Mapped[jobid_type] + name: Mapped[str100] = mapped_column("Name", primary_key=True) + value: Mapped[str] = mapped_column("Value", Text) + heart_beat_time: Mapped[datetime] = mapped_column( "HeartBeatTime", SmarterDateTime(), primary_key=True, @@ -170,18 +182,16 @@ class HeartBeatLoggingInfo(JobDBBase): class JobCommands(JobDBBase): __tablename__ = "JobCommands" - job_id = Column( - "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True - ) - command = Column("Command", String(100)) - arguments = Column("Arguments", String(100)) - status = Column("Status", String(64), default="Received") - reception_time = Column( + job_id: Mapped[jobid_type] + command: Mapped[str100] = mapped_column("Command") + arguments: Mapped[str100] = mapped_column("Arguments") + status: Mapped[str64] = mapped_column("Status", default="Received") + reception_time: Mapped[datetime] = mapped_column( "ReceptionTime", SmarterDateTime(), primary_key=True, ) - execution_time = NullColumn( + execution_time: Mapped[Optional[datetime]] = mapped_column( "ExecutionTime", SmarterDateTime(), ) diff --git a/diracx-db/src/diracx/db/sql/job_logging/schema.py b/diracx-db/src/diracx/db/sql/job_logging/schema.py index 0a6e7363c..e0b5ba41d 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/schema.py +++ b/diracx-db/src/diracx/db/sql/job_logging/schema.py @@ -2,14 +2,18 @@ from datetime import UTC, datetime -from sqlalchemy import Integer, Numeric, PrimaryKeyConstraint, String, TypeDecorator -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy import Numeric, PrimaryKeyConstraint, String, TypeDecorator +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from ..utils import Column, DateNowColumn +from ..utils import datetime_now, str32, str128, str255 class JobLoggingDBBase(DeclarativeBase): - pass + type_annotation_map = { + str32: String(32), + str128: String(128), + str255: String(255), + } class MagicEpochDateTime(TypeDecorator): @@ -58,12 +62,14 @@ def process_result_value(self, value, dialect): class LoggingInfo(JobLoggingDBBase): __tablename__ = "LoggingInfo" - job_id = Column("JobID", Integer) - seq_num = Column("SeqNum", Integer) - status = Column("Status", String(32), default="") - minor_status = Column("MinorStatus", String(128), default="") - application_status = Column("ApplicationStatus", String(255), default="") - status_time = DateNowColumn("StatusTime") - status_time_order = Column("StatusTimeOrder", MagicEpochDateTime, default=0) - source = Column("StatusSource", String(32), default="Unknown") + job_id: Mapped[int] = mapped_column("JobID") + seq_num: Mapped[int] = mapped_column("SeqNum") + status: Mapped[str32] = mapped_column("Status", default="") + minor_status: Mapped[str128] = mapped_column("MinorStatus", default="") + application_status: Mapped[str255] = mapped_column("ApplicationStatus", default="") + status_time: Mapped[datetime_now] = mapped_column("StatusTime") + status_time_order: Mapped[datetime] = mapped_column( + "StatusTimeOrder", MagicEpochDateTime(), default=0 + ) + source: Mapped[str32] = mapped_column("StatusSource", default="Unknown") __table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py index 523639ac1..b7e83567c 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -1,18 +1,23 @@ from __future__ import annotations +from datetime import datetime +from typing import Optional + from sqlalchemy import ( Double, Index, - Integer, - String, Text, ) -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from diracx.db.sql.utils import ( + EnumBackedBool, + str32, + str128, + str255, +) from diracx.db.sql.utils.types import SmarterDateTime -from ..utils import Column, EnumBackedBool, NullColumn - class PilotAgentsDBBase(DeclarativeBase): pass @@ -21,22 +26,34 @@ class PilotAgentsDBBase(DeclarativeBase): class PilotAgents(PilotAgentsDBBase): __tablename__ = "PilotAgents" - pilot_id = Column("PilotID", Integer, autoincrement=True, primary_key=True) - initial_job_id = Column("InitialJobID", Integer, default=0) - current_job_id = Column("CurrentJobID", Integer, default=0) - pilot_job_reference = Column("PilotJobReference", String(255), default="Unknown") - pilot_stamp = Column("PilotStamp", String(32), default="") - destination_site = Column("DestinationSite", String(128), default="NotAssigned") - queue = Column("Queue", String(128), default="Unknown") - grid_site = Column("GridSite", String(128), default="Unknown") - vo = Column("VO", String(128)) - grid_type = Column("GridType", String(32), default="LCG") - benchmark = Column("BenchMark", Double, default=0.0) - submission_time = NullColumn("SubmissionTime", SmarterDateTime) - last_update_time = NullColumn("LastUpdateTime", SmarterDateTime) - status = Column("Status", String(32), default="Unknown") - status_reason = Column("StatusReason", String(255), default="Unknown") - accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False) + pilot_id: Mapped[int] = mapped_column( + "PilotID", autoincrement=True, primary_key=True + ) + initial_job_id: Mapped[int] = mapped_column("InitialJobID", default=0) + current_job_id: Mapped[int] = mapped_column("CurrentJobID", default=0) + pilot_job_reference: Mapped[str255] = mapped_column( + "PilotJobReference", default="Unknown" + ) + pilot_stamp: Mapped[str32] = mapped_column("PilotStamp", default="") + destination_site: Mapped[str128] = mapped_column( + "DestinationSite", default="NotAssigned" + ) + queue: Mapped[str128] = mapped_column("Queue", default="Unknown") + grid_site: Mapped[str128] = mapped_column("GridSite", default="Unknown") + vo: Mapped[str128] = mapped_column("VO") + grid_type: Mapped[str32] = mapped_column("GridType", default="LCG") + benchmark: Mapped[float] = mapped_column("BenchMark", Double, default=0.0) + submission_time: Mapped[Optional[datetime]] = mapped_column( + "SubmissionTime", SmarterDateTime + ) + last_update_time: Mapped[Optional[datetime]] = mapped_column( + "LastUpdateTime", SmarterDateTime + ) + status: Mapped[str32] = mapped_column("Status", default="Unknown") + status_reason: Mapped[str255] = mapped_column("StatusReason", default="Unknown") + accounting_sent: Mapped[bool] = mapped_column( + "AccountingSent", EnumBackedBool(), default=False + ) __table_args__ = ( Index("PilotJobReference", "PilotJobReference"), @@ -48,9 +65,9 @@ class PilotAgents(PilotAgentsDBBase): class JobToPilotMapping(PilotAgentsDBBase): __tablename__ = "JobToPilotMapping" - pilot_id = Column("PilotID", Integer, primary_key=True) - job_id = Column("JobID", Integer, primary_key=True) - start_time = Column("StartTime", SmarterDateTime) + pilot_id: Mapped[int] = mapped_column("PilotID", primary_key=True) + job_id: Mapped[int] = mapped_column("JobID", primary_key=True) + start_time: Mapped[datetime] = mapped_column("StartTime", SmarterDateTime) __table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID")) @@ -58,6 +75,6 @@ class JobToPilotMapping(PilotAgentsDBBase): class PilotOutput(PilotAgentsDBBase): __tablename__ = "PilotOutput" - pilot_id = Column("PilotID", Integer, primary_key=True) - std_output = Column("StdOutput", Text) - std_error = Column("StdError", Text) + pilot_id: Mapped[int] = mapped_column("PilotID", primary_key=True) + std_output: Mapped[str] = mapped_column("StdOutput", Text) + std_error: Mapped[str] = mapped_column("StdError", Text) diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py index 049dd7f7c..473966ed6 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py @@ -2,28 +2,31 @@ from sqlalchemy import ( BigInteger, - Boolean, Index, - Integer, PrimaryKeyConstraint, String, UniqueConstraint, ) -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from diracx.db.sql.utils import Column, DateNowColumn +from diracx.db.sql.utils import datetime_now, str32, str64, str128, str512 class Base(DeclarativeBase): - pass + type_annotation_map = { + str32: String(32), + str64: String(64), + str128: String(128), + str512: String(512), + } class SBOwners(Base): __tablename__ = "sb_Owners" - OwnerID = Column(Integer, autoincrement=True) - Owner = Column(String(32)) - OwnerGroup = Column(String(32)) - VO = Column(String(64)) + OwnerID: Mapped[int] = mapped_column(autoincrement=True) + Owner: Mapped[str32] + OwnerGroup: Mapped[str32] + VO: Mapped[str64] __table_args__ = ( PrimaryKeyConstraint("OwnerID"), UniqueConstraint("Owner", "OwnerGroup", "VO", name="unique_owner_group_vo"), @@ -32,26 +35,26 @@ class SBOwners(Base): class SandBoxes(Base): __tablename__ = "sb_SandBoxes" - SBId = Column(Integer, autoincrement=True) - OwnerId = Column(Integer) - SEName = Column(String(64)) - SEPFN = Column(String(512)) - Bytes = Column(BigInteger) - RegistrationTime = DateNowColumn() - LastAccessTime = DateNowColumn() - Assigned = Column(Boolean, default=False) + SBId: Mapped[int] = mapped_column(autoincrement=True) + OwnerId: Mapped[int] + SEName: Mapped[str64] + SEPFN: Mapped[str512] + Bytes: Mapped[int] = mapped_column(BigInteger) + RegistrationTime: Mapped[datetime_now] + LastAccessTime: Mapped[datetime_now] + Assigned: Mapped[bool] = mapped_column(default=False) __table_args__ = ( PrimaryKeyConstraint("SBId"), - Index("OwnerId", OwnerId), + Index("OwnerId", "OwnerId"), UniqueConstraint("SEName", "SEPFN", name="Location"), ) class SBEntityMapping(Base): __tablename__ = "sb_EntityMapping" - SBId = Column(Integer) - EntityId = Column(String(128)) - Type = Column(String(64)) + SBId: Mapped[int] + EntityId: Mapped[str128] + Type: Mapped[str64] __table_args__ = ( PrimaryKeyConstraint("SBId", "EntityId", "Type"), Index("SBId", "EntityId"), diff --git a/diracx-db/src/diracx/db/sql/task_queue/schema.py b/diracx-db/src/diracx/db/sql/task_queue/schema.py index 88ffdfdc9..78e0595e1 100644 --- a/diracx-db/src/diracx/db/sql/task_queue/schema.py +++ b/diracx-db/src/diracx/db/sql/task_queue/schema.py @@ -2,51 +2,64 @@ from sqlalchemy import ( BigInteger, - Boolean, - Float, ForeignKey, Index, - Integer, String, ) -from sqlalchemy.orm import DeclarativeBase - -from ..utils import Column +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from typing_extensions import Annotated + +from diracx.db.sql.utils import ( + str32, + str64, + str128, + str255, +) class TaskQueueDBBase(DeclarativeBase): - pass + type_annotation_map = { + str32: String(32), + str64: String(64), + str128: String(128), + str255: String(255), + } + + +tqid_type = Annotated[ + int, + mapped_column( + ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ), +] +value_type = Annotated[str64, mapped_column(primary_key=True)] class TaskQueues(TaskQueueDBBase): __tablename__ = "tq_TaskQueues" - TQId = Column(Integer, primary_key=True) - Owner = Column(String(255), nullable=False) - OwnerGroup = Column(String(32), nullable=False) - VO = Column(String(32), nullable=False) - CPUTime = Column(BigInteger, nullable=False) - Priority = Column(Float, nullable=False) - Enabled = Column(Boolean, nullable=False, default=0) + TQId: Mapped[int] = mapped_column(primary_key=True) + Owner: Mapped[str255] + OwnerGroup: Mapped[str32] + VO: Mapped[str32] + CPUTime: Mapped[int] = mapped_column(BigInteger) + Priority: Mapped[float] + Enabled: Mapped[bool] = mapped_column(default=0) __table_args__ = (Index("TQOwner", "Owner", "OwnerGroup", "CPUTime"),) class JobsQueue(TaskQueueDBBase): __tablename__ = "tq_Jobs" - TQId = Column( - Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True - ) - JobId = Column(Integer, primary_key=True) - Priority = Column(Integer, nullable=False) - RealPriority = Column(Float, nullable=False) + TQId: Mapped[tqid_type] + JobId: Mapped[int] = mapped_column(primary_key=True) + Priority: Mapped[int] + RealPriority: Mapped[float] __table_args__ = (Index("TaskIndex", "TQId"),) class SitesQueue(TaskQueueDBBase): __tablename__ = "tq_TQToSites" - TQId = Column( - Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True - ) - Value = Column(String(64), primary_key=True) + TQId: Mapped[tqid_type] + Value: Mapped[value_type] __table_args__ = ( Index("SitesTaskIndex", "TQId"), Index("SitesIndex", "Value"), @@ -55,10 +68,8 @@ class SitesQueue(TaskQueueDBBase): class GridCEsQueue(TaskQueueDBBase): __tablename__ = "tq_TQToGridCEs" - TQId = Column( - Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True - ) - Value = Column(String(64), primary_key=True) + TQId: Mapped[tqid_type] + Value: Mapped[value_type] __table_args__ = ( Index("GridCEsTaskIndex", "TQId"), Index("GridCEsValueIndex", "Value"), @@ -67,10 +78,8 @@ class GridCEsQueue(TaskQueueDBBase): class BannedSitesQueue(TaskQueueDBBase): __tablename__ = "tq_TQToBannedSites" - TQId = Column( - Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True - ) - Value = Column(String(64), primary_key=True) + TQId: Mapped[tqid_type] + Value: Mapped[value_type] __table_args__ = ( Index("BannedSitesTaskIndex", "TQId"), Index("BannedSitesValueIndex", "Value"), @@ -79,10 +88,8 @@ class BannedSitesQueue(TaskQueueDBBase): class PlatformsQueue(TaskQueueDBBase): __tablename__ = "tq_TQToPlatforms" - TQId = Column( - Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True - ) - Value = Column(String(64), primary_key=True) + TQId: Mapped[tqid_type] + Value: Mapped[value_type] __table_args__ = ( Index("PlatformsTaskIndex", "TQId"), Index("PlatformsValueIndex", "Value"), @@ -91,10 +98,8 @@ class PlatformsQueue(TaskQueueDBBase): class JobTypesQueue(TaskQueueDBBase): __tablename__ = "tq_TQToJobTypes" - TQId = Column( - Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True - ) - Value = Column(String(64), primary_key=True) + TQId: Mapped[tqid_type] + Value: Mapped[value_type] __table_args__ = ( Index("JobTypesTaskIndex", "TQId"), Index("JobTypesValueIndex", "Value"), @@ -103,10 +108,8 @@ class JobTypesQueue(TaskQueueDBBase): class TagsQueue(TaskQueueDBBase): __tablename__ = "tq_TQToTags" - TQId = Column( - Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True - ) - Value = Column(String(64), primary_key=True) + TQId: Mapped[tqid_type] + Value: Mapped[value_type] __table_args__ = ( Index("TagsTaskIndex", "TQId"), Index("TagsValueIndex", "Value"), diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index 750b3de99..efdf8ef93 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -9,6 +9,7 @@ "BaseSQLDB", "EnumBackedBool", "EnumColumn", + "enum_column", "apply_search_filters", "apply_sort_constraints", "substract_date", @@ -16,6 +17,13 @@ "SQLDBUnavailableError", "uuid7_from_datetime", "uuid7_to_datetime", + "datetime_now", + "str32", + "str64", + "str128", + "str255", + "str512", + "str1024", ] from .base import ( @@ -28,4 +36,18 @@ uuid7_to_datetime, ) from .functions import hash, substract_date, utcnow -from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn +from .types import ( + Column, + DateNowColumn, + EnumBackedBool, + EnumColumn, + NullColumn, + datetime_now, + enum_column, + str32, + str64, + str128, + str255, + str512, + str1024, +) diff --git a/diracx-db/src/diracx/db/sql/utils/types.py b/diracx-db/src/diracx/db/sql/utils/types.py index 9dc5f09f7..49e9d4732 100644 --- a/diracx-db/src/diracx/db/sql/utils/types.py +++ b/diracx-db/src/diracx/db/sql/utils/types.py @@ -7,6 +7,8 @@ import sqlalchemy.types as types from sqlalchemy import Column as RawColumn from sqlalchemy import DateTime, Enum +from sqlalchemy.orm import mapped_column +from typing_extensions import Annotated from .functions import utcnow @@ -17,11 +19,26 @@ # Module-level constants for default timezone values _DEFAULT_UTC = ZoneInfo("UTC") +datetime_now = Annotated[ + datetime, mapped_column(DateTime(timezone=True), server_default=utcnow()) +] + +str32 = Annotated[str, 32] +str64 = Annotated[str, 64] +str128 = Annotated[str, 128] +str255 = Annotated[str, 255] +str512 = Annotated[str, 512] +str1024 = Annotated[str, 1024] + def EnumColumn(name, enum_type, **kwargs): # noqa: N802 return Column(name, Enum(enum_type, native_enum=False, length=16), **kwargs) +def enum_column(name, enum_type, **kwargs): + return mapped_column(name, Enum(enum_type, native_enum=False, length=16), **kwargs) + + class EnumBackedBool(types.TypeDecorator): """Maps a ``EnumBackedBool()`` column to True/False in Python.""" diff --git a/docs/dev/reference/coding-conventions.md b/docs/dev/reference/coding-conventions.md index aa5194380..9417cca5e 100644 --- a/docs/dev/reference/coding-conventions.md +++ b/docs/dev/reference/coding-conventions.md @@ -92,9 +92,11 @@ delay = datetime.datetime.now() + datetime.timedelta(hours=1) ```python class Owners(Base): __tablename__ = "Owners" - owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True) - creation_time = DateNowColumn("CreationTime") - name = Column("Name", String(255)) + owner_id: Mapped[int] = mapped_column( + "OwnerID", Integer, primary_key=True, autoincrement=True + ) + creation_time: Mapped[datetime_now] = mapped_column("CreationTime") + name: Mapped[str255] = mapped_column("Name") ``` @@ -104,9 +106,9 @@ class Owners(Base): ```python class Owners(Base): __tablename__ = "Owners" - OwnerID = Column(Integer, primary_key=True, autoincrement=True) - CreationTime = DateNowColumn() - Name = Column(String(255)) + OwnerID: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + creation_time: Mapped[datetime_now] + name: Mapped[str255] ``` diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py index eee922d4e..5bf132194 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py @@ -1,19 +1,13 @@ from diracx.db.sql.job.db import JobDBBase -from diracx.db.sql.utils import Column -from sqlalchemy import ( - ForeignKey, - Integer, - String, -) +from diracx.db.sql.job.schema import jobid_type, str255 +from sqlalchemy.orm import Mapped, mapped_column -# You need to inherit from the declarative_base of the parent DB +# You need to inherit from the DeclarativeBase of the parent DB class GubbinsInfo(JobDBBase): """An extra table with respect to Vanilla diracx JobDB""" __tablename__ = "GubbinsJobs" - job_id = Column( - "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True - ) - info = Column("Info", String(255), default="", primary_key=True) + job_id: Mapped[jobid_type] + info: Mapped[str255] = mapped_column("Info", default="", primary_key=True) diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py index 9b80e5133..223378a4d 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py @@ -1,21 +1,29 @@ # The utils class define some boilerplate types that should be used # in place of the SQLAlchemy one. Have a look at them -from diracx.db.sql.utils import Column, DateNowColumn -from sqlalchemy import ForeignKey, Integer, String, Uuid -from sqlalchemy.orm import declarative_base +from uuid import UUID -Base = declarative_base() +from diracx.db.sql.utils import datetime_now, str255 +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(DeclarativeBase): + type_annotation_map = { + str255: String(255), + } class Owners(Base): __tablename__ = "Owners" - owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True) - creation_time = DateNowColumn("CreationTime") - name = Column("Name", String(255)) + owner_id: Mapped[int] = mapped_column( + "OwnerID", primary_key=True, autoincrement=True + ) + creation_time: Mapped[datetime_now] = mapped_column("CreationTime") + name: Mapped[str255] = mapped_column("Name") class Cars(Base): __tablename__ = "Cars" - license_plate = Column("LicensePlate", Uuid(), primary_key=True) - model = Column("Model", String(255)) - owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id)) + license_plate: Mapped[UUID] = mapped_column("LicensePlate", primary_key=True) + model: Mapped[str255] = mapped_column("Model") + owner_id: Mapped[int] = mapped_column("OwnerID", ForeignKey(Owners.owner_id)) From 21d5a742de446652636059ab12edc8ab2a4dd1ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lo=C3=AFse=20Joffe?= Date: Tue, 3 Feb 2026 11:04:16 +0100 Subject: [PATCH 3/8] fix: remove deprecated Annotated usage with mapped_column --- diracx-db/src/diracx/db/sql/job/schema.py | 30 ++++++----- .../src/diracx/db/sql/task_queue/schema.py | 50 ++++++++++--------- .../src/gubbins/db/sql/jobs/schema.py | 7 ++- 3 files changed, 50 insertions(+), 37 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/job/schema.py b/diracx-db/src/diracx/db/sql/job/schema.py index 0bbfd8394..a1e3fac82 100644 --- a/diracx-db/src/diracx/db/sql/job/schema.py +++ b/diracx-db/src/diracx/db/sql/job/schema.py @@ -18,12 +18,6 @@ from ..utils import EnumBackedBool, str32, str64, str128, str255 str100 = Annotated[str, 100] -jobid_type = Annotated[ - int, - mapped_column( - "JobID", ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True - ), -] class JobDBBase(DeclarativeBase): @@ -141,28 +135,36 @@ class JobJDLs(JobDBBase): class InputData(JobDBBase): __tablename__ = "InputData" - job_id: Mapped[jobid_type] + job_id: Mapped[int] = mapped_column( + "JobID", ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) lfn: Mapped[str255] = mapped_column("LFN", default="", primary_key=True) status: Mapped[str32] = mapped_column("Status", default="AprioriGood") class JobParameters(JobDBBase): __tablename__ = "JobParameters" - job_id: Mapped[jobid_type] + job_id: Mapped[int] = mapped_column( + "JobID", ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) name: Mapped[str100] = mapped_column("Name", primary_key=True) value: Mapped[str] = mapped_column("Value", Text) class OptimizerParameters(JobDBBase): __tablename__ = "OptimizerParameters" - job_id: Mapped[jobid_type] + job_id: Mapped[int] = mapped_column( + "JobID", ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) name: Mapped[str100] = mapped_column("Name", primary_key=True) value: Mapped[str] = mapped_column("Value", Text) class AtticJobParameters(JobDBBase): __tablename__ = "AtticJobParameters" - job_id: Mapped[jobid_type] + job_id: Mapped[int] = mapped_column( + "JobID", ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) name: Mapped[str100] = mapped_column("Name", primary_key=True) value: Mapped[str] = mapped_column("Value", Text) reschedule_cycle: Mapped[int] = mapped_column("RescheduleCycle") @@ -170,7 +172,9 @@ class AtticJobParameters(JobDBBase): class HeartBeatLoggingInfo(JobDBBase): __tablename__ = "HeartBeatLoggingInfo" - job_id: Mapped[jobid_type] + job_id: Mapped[int] = mapped_column( + "JobID", ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) name: Mapped[str100] = mapped_column("Name", primary_key=True) value: Mapped[str] = mapped_column("Value", Text) heart_beat_time: Mapped[datetime] = mapped_column( @@ -182,7 +186,9 @@ class HeartBeatLoggingInfo(JobDBBase): class JobCommands(JobDBBase): __tablename__ = "JobCommands" - job_id: Mapped[jobid_type] + job_id: Mapped[int] = mapped_column( + "JobID", ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) command: Mapped[str100] = mapped_column("Command") arguments: Mapped[str100] = mapped_column("Arguments") status: Mapped[str64] = mapped_column("Status", default="Received") diff --git a/diracx-db/src/diracx/db/sql/task_queue/schema.py b/diracx-db/src/diracx/db/sql/task_queue/schema.py index 78e0595e1..342a46e45 100644 --- a/diracx-db/src/diracx/db/sql/task_queue/schema.py +++ b/diracx-db/src/diracx/db/sql/task_queue/schema.py @@ -7,7 +7,6 @@ String, ) from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from typing_extensions import Annotated from diracx.db.sql.utils import ( str32, @@ -26,15 +25,6 @@ class TaskQueueDBBase(DeclarativeBase): } -tqid_type = Annotated[ - int, - mapped_column( - ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True - ), -] -value_type = Annotated[str64, mapped_column(primary_key=True)] - - class TaskQueues(TaskQueueDBBase): __tablename__ = "tq_TaskQueues" TQId: Mapped[int] = mapped_column(primary_key=True) @@ -49,7 +39,9 @@ class TaskQueues(TaskQueueDBBase): class JobsQueue(TaskQueueDBBase): __tablename__ = "tq_Jobs" - TQId: Mapped[tqid_type] + TQId: Mapped[int] = mapped_column( + ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) JobId: Mapped[int] = mapped_column(primary_key=True) Priority: Mapped[int] RealPriority: Mapped[float] @@ -58,8 +50,10 @@ class JobsQueue(TaskQueueDBBase): class SitesQueue(TaskQueueDBBase): __tablename__ = "tq_TQToSites" - TQId: Mapped[tqid_type] - Value: Mapped[value_type] + TQId: Mapped[int] = mapped_column( + ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value: Mapped[str64] = mapped_column(primary_key=True) __table_args__ = ( Index("SitesTaskIndex", "TQId"), Index("SitesIndex", "Value"), @@ -68,8 +62,10 @@ class SitesQueue(TaskQueueDBBase): class GridCEsQueue(TaskQueueDBBase): __tablename__ = "tq_TQToGridCEs" - TQId: Mapped[tqid_type] - Value: Mapped[value_type] + TQId: Mapped[int] = mapped_column( + ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value: Mapped[str64] = mapped_column(primary_key=True) __table_args__ = ( Index("GridCEsTaskIndex", "TQId"), Index("GridCEsValueIndex", "Value"), @@ -78,8 +74,10 @@ class GridCEsQueue(TaskQueueDBBase): class BannedSitesQueue(TaskQueueDBBase): __tablename__ = "tq_TQToBannedSites" - TQId: Mapped[tqid_type] - Value: Mapped[value_type] + TQId: Mapped[int] = mapped_column( + ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value: Mapped[str64] = mapped_column(primary_key=True) __table_args__ = ( Index("BannedSitesTaskIndex", "TQId"), Index("BannedSitesValueIndex", "Value"), @@ -88,8 +86,10 @@ class BannedSitesQueue(TaskQueueDBBase): class PlatformsQueue(TaskQueueDBBase): __tablename__ = "tq_TQToPlatforms" - TQId: Mapped[tqid_type] - Value: Mapped[value_type] + TQId: Mapped[int] = mapped_column( + ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value: Mapped[str64] = mapped_column(primary_key=True) __table_args__ = ( Index("PlatformsTaskIndex", "TQId"), Index("PlatformsValueIndex", "Value"), @@ -98,8 +98,10 @@ class PlatformsQueue(TaskQueueDBBase): class JobTypesQueue(TaskQueueDBBase): __tablename__ = "tq_TQToJobTypes" - TQId: Mapped[tqid_type] - Value: Mapped[value_type] + TQId: Mapped[int] = mapped_column( + ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value: Mapped[str64] = mapped_column(primary_key=True) __table_args__ = ( Index("JobTypesTaskIndex", "TQId"), Index("JobTypesValueIndex", "Value"), @@ -108,8 +110,10 @@ class JobTypesQueue(TaskQueueDBBase): class TagsQueue(TaskQueueDBBase): __tablename__ = "tq_TQToTags" - TQId: Mapped[tqid_type] - Value: Mapped[value_type] + TQId: Mapped[int] = mapped_column( + ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value: Mapped[str64] = mapped_column(primary_key=True) __table_args__ = ( Index("TagsTaskIndex", "TQId"), Index("TagsValueIndex", "Value"), diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py index 5bf132194..8fabb17f4 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py @@ -1,5 +1,6 @@ from diracx.db.sql.job.db import JobDBBase -from diracx.db.sql.job.schema import jobid_type, str255 +from diracx.db.sql.job.schema import str255 +from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped, mapped_column @@ -9,5 +10,7 @@ class GubbinsInfo(JobDBBase): __tablename__ = "GubbinsJobs" - job_id: Mapped[jobid_type] + job_id: Mapped[int] = mapped_column( + "JobID", ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) info: Mapped[str255] = mapped_column("Info", default="", primary_key=True) From 6f73a005cb32e2b0e136e86c39b235bbdf0d0a1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lo=C3=AFse=20Joffe?= Date: Tue, 3 Feb 2026 13:21:06 +0100 Subject: [PATCH 4/8] fix: add missing changes to migration --- diracx-db/src/diracx/db/sql/job_logging/db.py | 6 +++--- diracx-db/src/diracx/db/sql/pilot_agents/schema.py | 7 ++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/job_logging/db.py b/diracx-db/src/diracx/db/sql/job_logging/db.py index a1fcf8632..8f42e157a 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/db.py +++ b/diracx-db/src/diracx/db/sql/job_logging/db.py @@ -2,9 +2,9 @@ from collections import defaultdict from datetime import datetime, timezone -from typing import Iterable, cast +from typing import Iterable -from sqlalchemy import Table, delete, func, select +from sqlalchemy import delete, func, insert, select from diracx.core.models.job import JobLoggingRecord, JobStatusReturn @@ -56,7 +56,7 @@ async def insert_records( seqnums[record.job_id] = seqnums[record.job_id] + 1 await self.conn.execute( - cast(Table, LoggingInfo.__table__).insert(), + insert(LoggingInfo), values, ) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py index b7e83567c..770b62b79 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -6,6 +6,7 @@ from sqlalchemy import ( Double, Index, + String, Text, ) from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column @@ -20,7 +21,11 @@ class PilotAgentsDBBase(DeclarativeBase): - pass + type_annotation_map = { + str32: String(32), + str128: String(128), + str255: String(255), + } class PilotAgents(PilotAgentsDBBase): From 6ec3af2d3ae366a33d2113d8cbe478151efc0522 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lo=C3=AFse=20Joffe?= Date: Thu, 5 Feb 2026 08:21:00 +0100 Subject: [PATCH 5/8] docs: fix typos and DeclarativeBase reference --- docs/dev/explanations/components/db.md | 2 +- docs/dev/reference/coding-conventions.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/dev/explanations/components/db.md b/docs/dev/explanations/components/db.md index 8f6aabdd1..05fbb57ff 100644 --- a/docs/dev/explanations/components/db.md +++ b/docs/dev/explanations/components/db.md @@ -40,7 +40,7 @@ See MySQL docs: https://dev.mysql.com/doc/refman/8.4/en/datetime.html ### On SQLAlchemy usage -The database schemas are defined using SQLAlchemy's ORM declarative models (e.g., classes inheriting from [`declarative_base`](https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.declarative_base)) +The database schemas are defined using SQLAlchemy's ORM declarative models (e.g., classes inheriting from [`DeclarativeBase`](https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.DeclarativeBase)) All database interactions (queries, inserts, updates, deletes) are performed using SQLAlchemy Core constructs like [`select()`](https://docs.sqlalchemy.org/en/20/core/selectable.html#sqlalchemy.sql.expression.select), `insert()`, `update()`, and `delete()`. There is no usage of the ORM `Session` or `sessionmaker`. Instead, operations are executed directly on a connection object (e.g., `self.conn.execute(...)`). This pattern is referred within the SQLAlchemy documentation as a ["2.0-style"](https://docs.sqlalchemy.org/en/20/tutorial/index.html#sqlalchemy-unified-tutorial) of working. This is practice for asynchronous applications and when used together with FastAPI. diff --git a/docs/dev/reference/coding-conventions.md b/docs/dev/reference/coding-conventions.md index 9417cca5e..fdfd533c0 100644 --- a/docs/dev/reference/coding-conventions.md +++ b/docs/dev/reference/coding-conventions.md @@ -107,8 +107,8 @@ class Owners(Base): class Owners(Base): __tablename__ = "Owners" OwnerID: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - creation_time: Mapped[datetime_now] - name: Mapped[str255] + CreationTime: Mapped[datetime_now] + Name: Mapped[str255] ``` From 3c4c93f80cf87373ad4f2fd7e4d71299cc6ae28a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lo=C3=AFse=20Joffe?= Date: Tue, 10 Feb 2026 10:41:52 +0100 Subject: [PATCH 6/8] chore: add deprecation warnings for Column helpers --- diracx-db/src/diracx/db/sql/utils/types.py | 38 +++++++++++++++++-- .../src/diracx/testing/mock_osdb.py | 10 +++-- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/utils/types.py b/diracx-db/src/diracx/db/sql/utils/types.py index 49e9d4732..5df81d16c 100644 --- a/diracx-db/src/diracx/db/sql/utils/types.py +++ b/diracx-db/src/diracx/db/sql/utils/types.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from datetime import datetime from functools import partial from zoneinfo import ZoneInfo @@ -12,9 +13,39 @@ from .functions import utcnow -Column: partial[RawColumn] = partial(RawColumn, nullable=False) -NullColumn: partial[RawColumn] = partial(RawColumn, nullable=True) -DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=utcnow()) + +def _deprecated(name: str, replacement: str): + warnings.warn( + f"{name} is deprecated and will be removed in a future major release. " + f"Use {replacement} instead.", + DeprecationWarning, + stacklevel=3, + ) + + +_Column: partial[RawColumn] = partial(RawColumn, nullable=False) + + +def Column(*args, **kwargs): # noqa: N802 + _deprecated("Column", "Mapped[...] + mapped_column(...)") + return _Column(*args, **kwargs) + + +_NullColumn: partial[RawColumn] = partial(RawColumn, nullable=True) + + +def NullColumn(*args, **kwargs): # noqa: N802 + _deprecated("NullColumn", "Mapped[Optional[...]] + mapped_column(...)") + return _NullColumn(*args, **kwargs) + + +_DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=utcnow()) + + +def DateNowColumn(*args, **kwargs): # noqa: N802 + _deprecated("DateNowColumn", "Mapped[datetime_now] + mapped_column(...)") + return _DateNowColumn(*args, **kwargs) + # Module-level constants for default timezone values _DEFAULT_UTC = ZoneInfo("UTC") @@ -32,6 +63,7 @@ def EnumColumn(name, enum_type, **kwargs): # noqa: N802 + _deprecated("EnumColumn", "Mapped[...] + enum_column(...)") return Column(name, Enum(enum_type, native_enum=False, length=16), **kwargs) diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index 97de0d895..a656a9e6b 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -37,9 +37,7 @@ class JobParametersDB(MockOSDBMixin, JobParametersDB): """ def __init__(self, connection_kwargs: dict[str, Any]) -> None: - from sqlalchemy import JSON, Column, Integer, MetaData, String, Table - - from diracx.db.sql.utils import DateNowColumn + from sqlalchemy import JSON, Column, DateTime, Integer, MetaData, String, Table # Dynamically create a subclass of BaseSQLDB so we get clearer errors mocked_db = type(f"Mocked{self.__class__.__name__}", (sql_utils.BaseSQLDB,), {}) @@ -53,7 +51,11 @@ def __init__(self, connection_kwargs: dict[str, Any]) -> None: for field, field_type in self.fields.items(): match field_type["type"]: case "date": - column_type = DateNowColumn + column_type = partial( + Column, + type_=DateTime(timezone=True), + server_default=sql_utils.utcnow(), + ) case "long": column_type = partial(Column, type_=Integer) case "keyword": From 69868b41545b4db1b393aefde5a2958ac7973a38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lo=C3=AFse=20Joffe?= Date: Wed, 11 Feb 2026 10:45:12 +0100 Subject: [PATCH 7/8] refactor: improve table typing --- diracx-db/src/diracx/db/sql/utils/base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py index d9845c1a2..143943f95 100644 --- a/diracx-db/src/diracx/db/sql/utils/base.py +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -12,9 +12,10 @@ from uuid import UUID as StdUUID # noqa: N811 from pydantic import TypeAdapter -from sqlalchemy import DateTime, MetaData, func, select +from sqlalchemy import DateTime, MetaData, func, inspect, select from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine +from sqlalchemy.orm import DeclarativeBase from uuid_utils import UUID, uuid7 from diracx.core.exceptions import InvalidQueryError @@ -250,7 +251,7 @@ async def ping(self): async def _search( self, - table: Any, + table: type[DeclarativeBase], parameters: list[str] | None, search: list[SearchSpec], sorts: list[SortSpec], @@ -290,12 +291,15 @@ async def _search( ] async def _summary( - self, table: Any, group_by: list[str], search: list[SearchSpec] + self, + table: type[DeclarativeBase], + group_by: list[str], + search: list[SearchSpec], ) -> list[dict[str, str | int]]: """Get a summary of the elements of a table.""" columns = _get_columns(table.__table__, group_by) - pk_columns = list(table.__table__.primary_key.columns) + pk_columns = list(inspect(table).primary_key) if not pk_columns: raise ValueError( "Model has no primary key and no count_column was provided." From a7e2adaa175af58310f832445a428b76c12dc21e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lo=C3=AFse=20Joffe?= Date: Wed, 11 Feb 2026 11:11:20 +0100 Subject: [PATCH 8/8] refactor: remove type ignore comments --- diracx-db/src/diracx/db/sql/auth/db.py | 2 +- diracx-db/src/diracx/db/sql/utils/base.py | 2 +- diracx-db/tests/utils/test_uuid.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/auth/db.py b/diracx-db/src/diracx/db/sql/auth/db.py index c6cb2653f..bcbe814dc 100644 --- a/diracx-db/src/diracx/db/sql/auth/db.py +++ b/diracx-db/src/diracx/db/sql/auth/db.py @@ -163,7 +163,7 @@ async def insert_device_flow( for _ in range(MAX_RETRY): user_code = "".join( secrets.choice(USER_CODE_ALPHABET) - for _ in range(DeviceFlows.user_code.type.length) # type: ignore + for _ in range(DeviceFlows.user_code.type.length) ) device_code = secrets.token_urlsafe() diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py index 143943f95..bfd1b89f4 100644 --- a/diracx-db/src/diracx/db/sql/utils/base.py +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -314,7 +314,7 @@ async def _summary( return [ dict(row._mapping) async for row in (await self.conn.stream(stmt)) - if row.count > 0 # type: ignore + if row._mapping["count"] > 0 ] diff --git a/diracx-db/tests/utils/test_uuid.py b/diracx-db/tests/utils/test_uuid.py index c396af35a..01d8c193f 100644 --- a/diracx-db/tests/utils/test_uuid.py +++ b/diracx-db/tests/utils/test_uuid.py @@ -131,16 +131,16 @@ def test_uuid7_to_datetime_different_input_types_same_result(self): def test_uuid7_to_datetime_invalid_input_type(self): """Test that invalid input types raise appropriate errors.""" with pytest.raises(TypeError): - uuid7_to_datetime(123) # type: ignore + uuid7_to_datetime(123) with pytest.raises(TypeError): - uuid7_to_datetime(123.45) # type: ignore + uuid7_to_datetime(123.45) with pytest.raises(TypeError): - uuid7_to_datetime([]) # type: ignore + uuid7_to_datetime([]) with pytest.raises(TypeError): - uuid7_to_datetime({}) # type: ignore + uuid7_to_datetime({}) def test_uuid7_to_datetime_invalid_uuid_string(self): """Test that invalid UUID strings raise appropriate errors."""