From a65363f4797d83a997187e26b3597334b06b819c Mon Sep 17 00:00:00 2001 From: Remi Rampin Date: Sun, 19 Nov 2023 14:38:29 -0500 Subject: [PATCH] Get timezone-aware datetime objects from database --- reproserver/database.py | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/reproserver/database.py b/reproserver/database.py index b582454..4735974 100644 --- a/reproserver/database.py +++ b/reproserver/database.py @@ -2,11 +2,12 @@ from datetime import datetime, timezone import logging import os -from sqlalchemy import Column, ForeignKey, create_engine +from sqlalchemy import Column, ForeignKey, TypeDecorator, create_engine from sqlalchemy.exc import OperationalError from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker -from sqlalchemy.types import Boolean, DateTime, Integer, String, Text +from sqlalchemy.types import Boolean, DateTime as PlainDateTime, Integer, \ + String, Text import sys import time @@ -19,6 +20,26 @@ Base = declarative_base() +# https://docs.sqlalchemy.org/en/14/core/custom_types.html#store-timezone-aware-timestamps-as-timezone-naive-utc +class DateTimeUtc(TypeDecorator): + impl = PlainDateTime + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is not None: + if not value.tzinfo: + raise TypeError("tzinfo is required") + if value.utcoffset() != 0: + logger.warning("tzinfo was not UTC") + value = value.astimezone(timezone.utc).replace(tzinfo=None) + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = value.replace(tzinfo=timezone.utc) + return value + + class Experiment(Base): """Experiments available on the server. @@ -31,7 +52,7 @@ class Experiment(Base): __tablename__ = 'experiments' hash = Column(String(64), primary_key=True) - last_access = Column(DateTime, nullable=False, + last_access = Column(DateTimeUtc, nullable=False, default=lambda: datetime.now(timezone.utc)) size = Column(Integer, nullable=False) info = Column(Text, nullable=False) @@ -84,7 +105,7 @@ class Upload(Base): back_populates='uploads') submitted_ip = Column(Text, nullable=True) repository_key = Column(Text, nullable=True, index=True) - timestamp = Column(DateTime, nullable=False, + timestamp = Column(DateTimeUtc, nullable=False, default=lambda: datetime.now(timezone.utc)) @property @@ -171,10 +192,10 @@ class Run(Base): upload_id = Column(Integer, ForeignKey('uploads.id', ondelete='RESTRICT')) upload = relationship('Upload', uselist=False) - submitted = Column(DateTime, nullable=False, + submitted = Column(DateTimeUtc, nullable=False, default=lambda: datetime.now(timezone.utc)) - started = Column(DateTime, nullable=True) - done = Column(DateTime, nullable=True) + started = Column(DateTimeUtc, nullable=True) + done = Column(DateTimeUtc, nullable=True) progress_percent = Column(Integer, nullable=False, default=0) progress_text = Column(Text, nullable=False, default='') @@ -226,7 +247,7 @@ class RunLogLine(Base): id = Column(Integer, primary_key=True) run_id = Column(Integer, ForeignKey('runs.id', ondelete='CASCADE')) run = relationship('Run', uselist=False, back_populates='log') - timestamp = Column(DateTime, nullable=False, + timestamp = Column(DateTimeUtc, nullable=False, default=lambda: datetime.now(timezone.utc)) line = Column(Text, nullable=False)