From adeda5ff90b851bb92028eb75231fcb878909eb6 Mon Sep 17 00:00:00 2001 From: Hju Kneyck Flores Date: Mon, 8 Nov 2021 11:25:38 +0800 Subject: [PATCH] fix: consolidate dependency injection --- dayong/abc.py | 34 +++++++-- dayong/bot.py | 4 +- dayong/components/event_component.py | 6 +- dayong/components/task_component.py | 32 ++++----- dayong/operations.py | 104 +++++++++------------------ dayong/tasks/aptasks.py | 8 +-- 6 files changed, 87 insertions(+), 101 deletions(-) diff --git a/dayong/abc.py b/dayong/abc.py index d89c8c9..8fef501 100644 --- a/dayong/abc.py +++ b/dayong/abc.py @@ -10,6 +10,7 @@ from typing import Any import tanjun +from sqlmodel import SQLModel from sqlmodel.engine.result import ScalarResult from dayong.core.configs import DayongConfig @@ -25,7 +26,7 @@ async def connect( """Create a database connection. Args: - config (DayongConfig, optional): [description]. Defaults to + config (DayongConfig, optional): . Defaults to tanjun.injected(type=DayongConfig). """ @@ -34,7 +35,7 @@ async def create_table(self) -> None: """Create physical tables for all the table models stored in `Any.metadata`.""" @abstractmethod - async def add_row(self, table_model: Any) -> None: + async def add_row(self, table_model: SQLModel) -> None: """Add a row to the message table. Args: @@ -42,21 +43,44 @@ async def add_row(self, table_model: Any) -> None: """ @abstractmethod - async def remove_row(self, table_model: Any) -> None: + async def remove_row(self, table_model: SQLModel, attribute: str) -> None: """Remove a row from the message table. Args: table_model (Any): A subclass of SQLModel + attribute (str): A Table model attribute. """ @abstractmethod - async def get_row(self, table_model: Any) -> ScalarResult[Any]: + async def get_row(self, table_model: SQLModel, attribute: str) -> ScalarResult[Any]: """Get row from the message table. Args: table_model (Any): A subclass of SQLModel. + attribute (str): A Table model attribute. Returns: - ScalarResult: A `ScalarResult` which contains a scalar value or sequence of + ScalarResult[Any]: A `ScalarResult` which contains a scalar value or sequence of scalar values. """ + + @abstractmethod + async def get_all_row(self, table_model: type[SQLModel]) -> ScalarResult[Any]: + """Fetch all records in a database table. + + Args: + table_model (type[SQLModel]): Type of the class which corresponds to a + database table. + Returns: + ScalarResult[Any]: A `ScalarResult` which contains a scalar value or sequence of + scalar values. + """ + + @abstractmethod + async def update_row(self, table_model: SQLModel, attribute: str) -> None: + """Update a database record/row. + + Args: + table_model (Any): A subclass of SQLModel. + attribute (str): A Table model attribute. + """ diff --git a/dayong/bot.py b/dayong/bot.py index 4975a06..2cb96d6 100644 --- a/dayong/bot.py +++ b/dayong/bot.py @@ -13,7 +13,7 @@ from dayong.abc import Database from dayong.core.configs import DayongConfig, DayongDynamicLoader from dayong.core.settings import BASE_DIR -from dayong.operations import MessageDB +from dayong.operations import DatabaseImpl def run() -> None: @@ -24,7 +24,7 @@ def run() -> None: banner="dayong", intents=hikari.Intents.ALL, ) - database = MessageDB() + database = DatabaseImpl() ( tanjun.Client.from_gateway_bot( bot, declare_global_commands=hikari.Snowflake(loaded_config.guild_id) diff --git a/dayong/components/event_component.py b/dayong/components/event_component.py index 278ca12..7a6f759 100644 --- a/dayong/components/event_component.py +++ b/dayong/components/event_component.py @@ -49,9 +49,9 @@ async def greet_new_member( Args: event (hikari.MemberCreateEvent): Instance of `hikari.MemberCreateEvent`. This is a registered type dependency and is injected by the client. - config (DayongConfig, optional): An instance of `dayong.core.configs.DayongConfig`. - This is registered type dependency and is injected by the client. Defaults - to tanjun.injected(type=DayongConfig). + config (DayongConfig, optional): An instance of + `dayong.core.configs.DayongConfig`. This is registered type dependency and + is injected by the client. Defaults to tanjun.injected(type=DayongConfig). """ embeddings = config.embeddings["new_member_greetings"] channels = await event.app.rest.fetch_guild_channels(event.guild_id) diff --git a/dayong/components/task_component.py b/dayong/components/task_component.py index a844674..6bc8514 100644 --- a/dayong/components/task_component.py +++ b/dayong/components/task_component.py @@ -8,9 +8,9 @@ import tanjun from sqlalchemy.exc import NoResultFound +from dayong.abc import Database from dayong.core.settings import CONTENT_PROVIDER from dayong.models import ScheduledTask -from dayong.operations import ScheduledTaskDB from dayong.tasks.manager import TaskManagerMemory component = tanjun.Component() @@ -20,13 +20,13 @@ RESPONSE_MESSG = {False: "Sorry, I got nothing for today đŸ˜”"} -async def start_task(context: tanjun.abc.Context, source: str, db: ScheduledTaskDB): +async def start_task(context: tanjun.abc.Context, source: str, db: Database): """Start a scheduled task. Args: context (tanjun.abc.Context): Slash command specific context. source (str): Alias of the third-party content provider. - db (ScheduledTaskDB): An instance of `dayong.operations.ScheduledTaskDB`. + db (Database): An instance of `dayong.operations.Database`. Raises: NotImplementedError: Raised if alias does not exist. @@ -46,23 +46,23 @@ async def start_task(context: tanjun.abc.Context, source: str, db: ScheduledTask ) try: - result = await db.get_row(task_model) + result = await db.get_row(task_model, "task_name") if bool(result.one().run) is False: raise PermissionError else: - await db.update_row(task_model) + await db.update_row(task_model, "task_name") return except NoResultFound: await db.add_row(task_model) -async def stop_task(context: tanjun.abc.Context, source: str, db: ScheduledTaskDB): +async def stop_task(context: tanjun.abc.Context, source: str, db: Database): """Stop a scheduled task. Args: context (tanjun.abc.Context): Slash command specific context. source (str): Alias of the third-party content provider. - db (ScheduledTaskDB): An instance of `dayong.operations.ScheduledTaskDB`. + db (Database): An instance of `dayong.operations.Database`. Raises: ValueError: Raised if context failed to get the name of its channel. @@ -72,14 +72,14 @@ async def stop_task(context: tanjun.abc.Context, source: str, db: ScheduledTaskD if channel is None: raise ValueError - await db.remove_row( - ScheduledTask( - channel_name=channel.name if channel.name else "", - task_name=source, - run=False, - ) + task_model = ScheduledTask( + channel_name=channel.name if channel.name else "", + task_name=source, + run=False, ) + await db.remove_row(task_model, "task_name") + @component.with_command @tanjun.with_author_permission_check(128) @@ -92,7 +92,7 @@ async def share_content( ctx: tanjun.abc.SlashContext, source: str, action: str, - db: ScheduledTaskDB = tanjun.injected(type=ScheduledTaskDB), + db: Database = tanjun.injected(type=Database), ) -> None: """Fetch content on email subscription, from a service, or API. @@ -103,8 +103,8 @@ async def share_content( ctx (tanjun.abc.Context): Interface of a context. source (str): Alias of the third-party content provider. action (str): Start or stop the content retrival task. - db (ScheduledTaskDB): An instance of `dayong.operations.ScheduledTaskDB`. - Defaults to tanjun.injected(type=ScheduledTaskDB). + db (Database): An instance of `dayong.operations.Database`. + Defaults to tanjun.injected(type=Database). """ action = action.lower() diff --git a/dayong/operations.py b/dayong/operations.py index f5869a1..f0fa38e 100644 --- a/dayong/operations.py +++ b/dayong/operations.py @@ -2,10 +2,10 @@ dayong.operations ~~~~~~~~~~~~~~~~~ -This module contains data model operations which include retrieval and update commands. +Data model operations which include retrieval and update commands. """ import asyncio -from typing import Any, Type +from typing import Any import tanjun from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine @@ -15,25 +15,33 @@ from dayong.abc import Database from dayong.core.configs import DayongConfig, DayongDynamicLoader -from dayong.models import Message, ScheduledTask -class MessageDB(Database): +class DatabaseImpl(Database): """Implementaion of a database connection for transacting and interacting with - message tables —those that derive from message table models. + database tables —those that derive from SQLModel. """ _conn: AsyncEngine - async def connect( - self, config: DayongConfig = tanjun.injected(type=DayongConfig) - ) -> None: - """Create a database connection. + @staticmethod + async def update(instance: Any, update: Any) -> Any: + """Overwrite value of class attribute. Args: - config (DayongConfig, optional): A config interface. Defaults to - tanjun.injected(type=DayongConfig). + instance (Any): A Class instance. + update (Any): A dictionary containing the attributes to be overwritten. + + Returns: + Any: A class instance with updated attribute values. """ + for key, value in update.items(): + setattr(instance, key, value) + return instance + + async def connect( + self, config: DayongConfig = tanjun.injected(type=DayongConfig) + ) -> None: loop = asyncio.get_running_loop() self._conn = await loop.run_in_executor( None, @@ -47,102 +55,56 @@ async def create_table(self) -> None: async with self._conn.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) - async def add_row(self, table_model: Message) -> None: + async def add_row(self, table_model: SQLModel) -> None: async with AsyncSession(self._conn) as session: loop = asyncio.get_running_loop() await loop.run_in_executor(None, session.add, table_model) await session.commit() - async def remove_row(self, table_model: Message) -> None: + async def remove_row(self, table_model: SQLModel, attribute: str) -> None: model = type(table_model) async with AsyncSession(self._conn) as session: # Temp ignore incompatible type passed to `exec()`. See: # https://github.com/tiangolo/sqlmodel/issues/54 # https://github.com/tiangolo/sqlmodel/pull/58 row: ScalarResult[Any] = await session.exec( - select(model).where(model.id == table_model.id) # type: ignore + select(model).where( + getattr(model, attribute) == getattr(table_model, attribute) + ) # type: ignore ) - await session.delete(row) + await session.delete(row.one()) await session.commit() - async def get_row(self, table_model: Message) -> ScalarResult[Any]: + async def get_row(self, table_model: SQLModel, attribute: str) -> ScalarResult[Any]: model = type(table_model) async with AsyncSession(self._conn) as session: # Temp ignore incompatible type passed to `exec()`. See: # https://github.com/tiangolo/sqlmodel/issues/54 # https://github.com/tiangolo/sqlmodel/pull/58 - row: ScalarResult[Any] = await session.exec( - select(model).where(model.id == table_model.id) # type: ignore - ) - return row - - -class ScheduledTaskDB(Database): - """Implements a database connection for managing scheduled tasks.""" - - _conn: AsyncEngine - - async def connect( - self, config: DayongConfig = tanjun.injected(type=DayongConfig) - ) -> None: - loop = asyncio.get_running_loop() - self._conn = await loop.run_in_executor( - None, - create_async_engine, - config.database_uri - if config.database_uri - else DayongDynamicLoader().load().database_uri, - ) - - async def create_table(self) -> None: - async with self._conn.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async def add_row(self, table_model: ScheduledTask) -> None: - async with AsyncSession(self._conn) as session: - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, session.add, table_model) - await session.commit() - - async def remove_row(self, table_model: ScheduledTask) -> None: - model = type(table_model) - async with AsyncSession(self._conn) as session: row: ScalarResult[Any] = await session.exec( select(model).where( - model.channel_name == table_model.channel_name + getattr(model, attribute) == getattr(table_model, attribute) ) # type: ignore ) - await session.delete(row.one()) - await session.commit() - - async def get_row(self, table_model: ScheduledTask) -> ScalarResult[Any]: - model = type(table_model) - async with AsyncSession(self._conn) as session: - row: ScalarResult[Any] = await session.exec( - select(model).where( # type: ignore - model.task_name == table_model.task_name - ) - ) return row - async def get_all_row(self, table_model: Type[ScheduledTask]) -> ScalarResult[Any]: + async def get_all_row(self, table_model: type[SQLModel]) -> ScalarResult[Any]: async with AsyncSession(self._conn) as session: return await session.exec(select(table_model)) # type: ignore - async def update_row(self, table_model: ScheduledTask) -> None: + async def update_row(self, table_model: SQLModel, attribute: str) -> None: loop = asyncio.get_running_loop() model = type(table_model) + table = table_model.__dict__ + async with AsyncSession(self._conn) as session: row: ScalarResult[Any] = await session.exec( select(model).where( - model.channel_name == table_model.channel_name + getattr(model, attribute) == getattr(table_model, attribute) ) # type: ignore ) task = row.one() - if table_model.task_name: - task.name = table_model.task_name - if table_model.run: - task.run = table_model.run + task = await self.update(task, table) await loop.run_in_executor(None, session.add, task) await session.commit() await session.refresh(task) diff --git a/dayong/tasks/aptasks.py b/dayong/tasks/aptasks.py index ecd0752..8b0f0cf 100644 --- a/dayong/tasks/aptasks.py +++ b/dayong/tasks/aptasks.py @@ -12,7 +12,7 @@ from dayong.exts.apis import RESTClient from dayong.exts.emails import EmailClient from dayong.models import ScheduledTask -from dayong.operations import ScheduledTaskDB +from dayong.operations import DatabaseImpl CLIENT = discord.Client() @@ -21,10 +21,10 @@ async def get_scheduled(table_model): - db = ScheduledTaskDB() + db = DatabaseImpl() await db.connect(config) await db.create_table() - result = await db.get_row(table_model) + result = await db.get_row(table_model, "task_name") return result.one() @@ -60,7 +60,7 @@ async def get_devto_article(): await asyncio.sleep(60) -@sched.scheduled_job("interval", seconds=30) +@sched.scheduled_job("interval", days=1) async def get_medium_daily_digest(): try: result = await get_scheduled(ScheduledTask(channel_name="", task_name="medium"))