Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diracx-db/src/diracx/db/sql/auth/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ async def insert_device_flow(
for _ in range(MAX_RETRY):
user_code = "".join(
secrets.choice(USER_CODE_ALPHABET)
for _ in range(DeviceFlows.user_code.type.length) # type: ignore
for _ in range(DeviceFlows.user_code.type.length)
)
device_code = secrets.token_urlsafe()

Expand Down
72 changes: 45 additions & 27 deletions diracx-db/src/diracx/db/sql/auth/schema.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
from __future__ import annotations

from enum import Enum, auto
from typing import Any, Optional
from uuid import UUID

from sqlalchemy import (
JSON,
Index,
String,
Uuid,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

from diracx.db.sql.utils import (
Column,
DateNowColumn,
EnumColumn,
NullColumn,
datetime_now,
enum_column,
str128,
str255,
str1024,
)

USER_CODE_LENGTH = 8

Base = declarative_base()

class Base(DeclarativeBase):
type_annotation_map = {
str128: String(128),
str255: String(255),
str1024: String(1024),
dict[str, Any]: JSON,
}


class FlowStatus(Enum):
Expand Down Expand Up @@ -47,27 +57,35 @@ class FlowStatus(Enum):

class DeviceFlows(Base):
__tablename__ = "DeviceFlows"
user_code = Column("UserCode", String(USER_CODE_LENGTH), primary_key=True)
status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
creation_time = DateNowColumn("CreationTime")
client_id = Column("ClientID", String(255))
scope = Column("Scope", String(1024))
device_code = Column("DeviceCode", String(128), unique=True) # Should be a hash
id_token = NullColumn("IDToken", JSON())
user_code: Mapped[str] = mapped_column(
"UserCode", String(USER_CODE_LENGTH), primary_key=True
)
status: Mapped[FlowStatus] = enum_column(
"Status", FlowStatus, server_default=FlowStatus.PENDING.name
)
creation_time: Mapped[datetime_now] = mapped_column("CreationTime")
client_id: Mapped[str255] = mapped_column("ClientID")
scope: Mapped[str1024] = mapped_column("Scope")
device_code: Mapped[str128] = mapped_column(
"DeviceCode", unique=True
) # Should be a hash
id_token: Mapped[Optional[dict[str, Any]]] = mapped_column("IDToken")


class AuthorizationFlows(Base):
__tablename__ = "AuthorizationFlows"
uuid = Column("UUID", Uuid(as_uuid=False), primary_key=True)
status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
client_id = Column("ClientID", String(255))
creation_time = DateNowColumn("CreationTime")
scope = Column("Scope", String(1024))
code_challenge = Column("CodeChallenge", String(255))
code_challenge_method = Column("CodeChallengeMethod", String(8))
redirect_uri = Column("RedirectURI", String(255))
code = NullColumn("Code", String(255)) # Should be a hash
id_token = NullColumn("IDToken", JSON())
uuid: Mapped[UUID] = mapped_column("UUID", Uuid(as_uuid=False), primary_key=True)
status: Mapped[FlowStatus] = enum_column(
"Status", FlowStatus, server_default=FlowStatus.PENDING.name
)
client_id: Mapped[str255] = mapped_column("ClientID")
creation_time: Mapped[datetime_now] = mapped_column("CreationTime")
scope: Mapped[str1024] = mapped_column("Scope")
code_challenge: Mapped[str255] = mapped_column("CodeChallenge")
code_challenge_method: Mapped[str] = mapped_column("CodeChallengeMethod", String(8))
redirect_uri: Mapped[str255] = mapped_column("RedirectURI")
code: Mapped[Optional[str255]] = mapped_column("Code") # Should be a hash
id_token: Mapped[Optional[dict[str, Any]]] = mapped_column("IDToken")


class RefreshTokenStatus(Enum):
Expand All @@ -93,13 +111,13 @@ class RefreshTokens(Base):

__tablename__ = "RefreshTokens"
# Refresh token attributes
jti = Column("JTI", Uuid(as_uuid=False), primary_key=True)
status = EnumColumn(
jti: Mapped[UUID] = mapped_column("JTI", Uuid(as_uuid=False), primary_key=True)
status: Mapped[RefreshTokenStatus] = enum_column(
"Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
)
scope = Column("Scope", String(1024))
scope: Mapped[str1024] = mapped_column("Scope")

# User attributes bound to the refresh token
sub = Column("Sub", String(256), index=True)
sub: Mapped[str] = mapped_column("Sub", String(256), index=True)

__table_args__ = (Index("index_status_sub", status, sub),)
28 changes: 18 additions & 10 deletions diracx-db/src/diracx/db/sql/dummy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,31 @@
# in place of the SQLAlchemy one. Have a look at them
from __future__ import annotations

from sqlalchemy import ForeignKey, Integer, String, Uuid
from sqlalchemy.orm import declarative_base
from uuid import UUID

from diracx.db.sql.utils import Column, DateNowColumn
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

Base = declarative_base()
from diracx.db.sql.utils import datetime_now, str255


class Base(DeclarativeBase):
type_annotation_map = {
str255: String(255),
}


class Owners(Base):
__tablename__ = "Owners"
owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True)
creation_time = DateNowColumn("CreationTime")
name = Column("Name", String(255))
owner_id: Mapped[int] = mapped_column(
"OwnerID", primary_key=True, autoincrement=True
)
creation_time: Mapped[datetime_now] = mapped_column("CreationTime")
name: Mapped[str255] = mapped_column("Name")


class Cars(Base):
__tablename__ = "Cars"
license_plate = Column("LicensePlate", Uuid(), primary_key=True)
model = Column("Model", String(255))
owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id))
license_plate: Mapped[UUID] = mapped_column("LicensePlate", primary_key=True)
model: Mapped[str255] = mapped_column("Model")
owner_id: Mapped[int] = mapped_column("OwnerID", ForeignKey(Owners.owner_id))
16 changes: 7 additions & 9 deletions diracx-db/src/diracx/db/sql/job/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Iterable

from sqlalchemy import bindparam, case, delete, literal, select, update
from sqlalchemy import bindparam, case, delete, insert, literal, select, update

if TYPE_CHECKING:
from sqlalchemy.sql.elements import BindParameter
Expand Down Expand Up @@ -75,7 +75,7 @@ async def search(
async def create_job(self, compressed_original_jdl: str):
"""Insert a new job with original JDL. Returns inserted job id."""
result = await self.conn.execute(
JobJDLs.__table__.insert().values(
insert(JobJDLs).values(
JDL="",
JobRequirements="",
OriginalJDL=compressed_original_jdl,
Expand All @@ -91,7 +91,7 @@ async def delete_jobs(self, job_ids: list[int]):
async def insert_input_data(self, lfns: dict[int, list[str]]):
"""Insert input data for jobs."""
await self.conn.execute(
InputData.__table__.insert(),
insert(InputData),
[
{
"JobID": job_id,
Expand All @@ -105,7 +105,7 @@ async def insert_input_data(self, lfns: dict[int, list[str]]):
async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
"""Insert the job attributes."""
await self.conn.execute(
Jobs.__table__.insert(),
insert(Jobs),
[
{
"JobID": job_id,
Expand All @@ -118,9 +118,7 @@ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
async def update_job_jdls(self, jdls_to_update: dict[int, str]):
"""Update the JDL, typically just after inserting the original JDL, or rescheduling, for example."""
await self.conn.execute(
JobJDLs.__table__.update().where(
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
),
update(JobJDLs).where(JobJDLs.__table__.c.JobID == bindparam("b_JobID")),
[
{
"b_JobID": job_id,
Expand Down Expand Up @@ -186,7 +184,7 @@ async def get_job_jdls(self, job_ids, original: bool = False) -> dict[int, str]:
async def set_job_commands(self, commands: list[tuple[int, str, str]]) -> None:
"""Store a command to be passed to the job together with the next heart beat."""
await self.conn.execute(
JobCommands.__table__.insert(),
insert(JobCommands),
[
{
"JobID": job_id,
Expand Down Expand Up @@ -261,7 +259,7 @@ async def add_heartbeat_data(
}
for key, value in dynamic_data.items()
]
await self.conn.execute(HeartBeatLoggingInfo.__table__.insert().values(values))
await self.conn.execute(insert(HeartBeatLoggingInfo).values(values))

async def get_job_commands(self, job_ids: Iterable[int]) -> list[JobCommand]:
"""Get a command to be passed to the job together with the next heartbeat.
Expand Down
Loading
Loading