Skip to content

Commit

Permalink
fix: consolidate dependency injection
Browse files Browse the repository at this point in the history
  • Loading branch information
huenique committed Nov 8, 2021
1 parent ba83c19 commit adeda5f
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 101 deletions.
34 changes: 29 additions & 5 deletions dayong/abc.py
Expand Up @@ -10,6 +10,7 @@
from typing import Any

import tanjun
from sqlmodel import SQLModel
from sqlmodel.engine.result import ScalarResult

from dayong.core.configs import DayongConfig
Expand All @@ -25,7 +26,7 @@ async def connect(
"""Create a database connection.
Args:
config (DayongConfig, optional): [description]. Defaults to
config (DayongConfig, optional): . Defaults to
tanjun.injected(type=DayongConfig).
"""

Expand All @@ -34,29 +35,52 @@ async def create_table(self) -> None:
"""Create physical tables for all the table models stored in `Any.metadata`."""

@abstractmethod
async def add_row(self, table_model: Any) -> None:
async def add_row(self, table_model: SQLModel) -> None:
"""Add a row to the message table.
Args:
table_model (Any): A subclass of SQLModel
"""

@abstractmethod
async def remove_row(self, table_model: Any) -> None:
async def remove_row(self, table_model: SQLModel, attribute: str) -> None:
"""Remove a row from the message table.
Args:
table_model (Any): A subclass of SQLModel
attribute (str): A Table model attribute.
"""

@abstractmethod
async def get_row(self, table_model: Any) -> ScalarResult[Any]:
async def get_row(self, table_model: SQLModel, attribute: str) -> ScalarResult[Any]:
"""Get row from the message table.
Args:
table_model (Any): A subclass of SQLModel.
attribute (str): A Table model attribute.
Returns:
ScalarResult: A `ScalarResult` which contains a scalar value or sequence of
ScalarResult[Any]: A `ScalarResult` which contains a scalar value or sequence of
scalar values.
"""

@abstractmethod
async def get_all_row(self, table_model: type[SQLModel]) -> ScalarResult[Any]:
"""Fetch all records in a database table.
Args:
table_model (type[SQLModel]): Type of the class which corresponds to a
database table.
Returns:
ScalarResult[Any]: A `ScalarResult` which contains a scalar value or sequence of
scalar values.
"""

@abstractmethod
async def update_row(self, table_model: SQLModel, attribute: str) -> None:
"""Update a database record/row.
Args:
table_model (Any): A subclass of SQLModel.
attribute (str): A Table model attribute.
"""
4 changes: 2 additions & 2 deletions dayong/bot.py
Expand Up @@ -13,7 +13,7 @@
from dayong.abc import Database
from dayong.core.configs import DayongConfig, DayongDynamicLoader
from dayong.core.settings import BASE_DIR
from dayong.operations import MessageDB
from dayong.operations import DatabaseImpl


def run() -> None:
Expand All @@ -24,7 +24,7 @@ def run() -> None:
banner="dayong",
intents=hikari.Intents.ALL,
)
database = MessageDB()
database = DatabaseImpl()
(
tanjun.Client.from_gateway_bot(
bot, declare_global_commands=hikari.Snowflake(loaded_config.guild_id)
Expand Down
6 changes: 3 additions & 3 deletions dayong/components/event_component.py
Expand Up @@ -49,9 +49,9 @@ async def greet_new_member(
Args:
event (hikari.MemberCreateEvent): Instance of `hikari.MemberCreateEvent`. This
is a registered type dependency and is injected by the client.
config (DayongConfig, optional): An instance of `dayong.core.configs.DayongConfig`.
This is registered type dependency and is injected by the client. Defaults
to tanjun.injected(type=DayongConfig).
config (DayongConfig, optional): An instance of
`dayong.core.configs.DayongConfig`. This is registered type dependency and
is injected by the client. Defaults to tanjun.injected(type=DayongConfig).
"""
embeddings = config.embeddings["new_member_greetings"]
channels = await event.app.rest.fetch_guild_channels(event.guild_id)
Expand Down
32 changes: 16 additions & 16 deletions dayong/components/task_component.py
Expand Up @@ -8,9 +8,9 @@
import tanjun
from sqlalchemy.exc import NoResultFound

from dayong.abc import Database
from dayong.core.settings import CONTENT_PROVIDER
from dayong.models import ScheduledTask
from dayong.operations import ScheduledTaskDB
from dayong.tasks.manager import TaskManagerMemory

component = tanjun.Component()
Expand All @@ -20,13 +20,13 @@
RESPONSE_MESSG = {False: "Sorry, I got nothing for today 😔"}


async def start_task(context: tanjun.abc.Context, source: str, db: ScheduledTaskDB):
async def start_task(context: tanjun.abc.Context, source: str, db: Database):
"""Start a scheduled task.
Args:
context (tanjun.abc.Context): Slash command specific context.
source (str): Alias of the third-party content provider.
db (ScheduledTaskDB): An instance of `dayong.operations.ScheduledTaskDB`.
db (Database): An instance of `dayong.operations.Database`.
Raises:
NotImplementedError: Raised if alias does not exist.
Expand All @@ -46,23 +46,23 @@ async def start_task(context: tanjun.abc.Context, source: str, db: ScheduledTask
)

try:
result = await db.get_row(task_model)
result = await db.get_row(task_model, "task_name")
if bool(result.one().run) is False:
raise PermissionError
else:
await db.update_row(task_model)
await db.update_row(task_model, "task_name")
return
except NoResultFound:
await db.add_row(task_model)


async def stop_task(context: tanjun.abc.Context, source: str, db: ScheduledTaskDB):
async def stop_task(context: tanjun.abc.Context, source: str, db: Database):
"""Stop a scheduled task.
Args:
context (tanjun.abc.Context): Slash command specific context.
source (str): Alias of the third-party content provider.
db (ScheduledTaskDB): An instance of `dayong.operations.ScheduledTaskDB`.
db (Database): An instance of `dayong.operations.Database`.
Raises:
ValueError: Raised if context failed to get the name of its channel.
Expand All @@ -72,14 +72,14 @@ async def stop_task(context: tanjun.abc.Context, source: str, db: ScheduledTaskD
if channel is None:
raise ValueError

await db.remove_row(
ScheduledTask(
channel_name=channel.name if channel.name else "",
task_name=source,
run=False,
)
task_model = ScheduledTask(
channel_name=channel.name if channel.name else "",
task_name=source,
run=False,
)

await db.remove_row(task_model, "task_name")


@component.with_command
@tanjun.with_author_permission_check(128)
Expand All @@ -92,7 +92,7 @@ async def share_content(
ctx: tanjun.abc.SlashContext,
source: str,
action: str,
db: ScheduledTaskDB = tanjun.injected(type=ScheduledTaskDB),
db: Database = tanjun.injected(type=Database),
) -> None:
"""Fetch content on email subscription, from a service, or API.
Expand All @@ -103,8 +103,8 @@ async def share_content(
ctx (tanjun.abc.Context): Interface of a context.
source (str): Alias of the third-party content provider.
action (str): Start or stop the content retrival task.
db (ScheduledTaskDB): An instance of `dayong.operations.ScheduledTaskDB`.
Defaults to tanjun.injected(type=ScheduledTaskDB).
db (Database): An instance of `dayong.operations.Database`.
Defaults to tanjun.injected(type=Database).
"""
action = action.lower()

Expand Down
104 changes: 33 additions & 71 deletions dayong/operations.py
Expand Up @@ -2,10 +2,10 @@
dayong.operations
~~~~~~~~~~~~~~~~~
This module contains data model operations which include retrieval and update commands.
Data model operations which include retrieval and update commands.
"""
import asyncio
from typing import Any, Type
from typing import Any

import tanjun
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
Expand All @@ -15,25 +15,33 @@

from dayong.abc import Database
from dayong.core.configs import DayongConfig, DayongDynamicLoader
from dayong.models import Message, ScheduledTask


class MessageDB(Database):
class DatabaseImpl(Database):
"""Implementaion of a database connection for transacting and interacting with
message tables —those that derive from message table models.
database tables —those that derive from SQLModel.
"""

_conn: AsyncEngine

async def connect(
self, config: DayongConfig = tanjun.injected(type=DayongConfig)
) -> None:
"""Create a database connection.
@staticmethod
async def update(instance: Any, update: Any) -> Any:
"""Overwrite value of class attribute.
Args:
config (DayongConfig, optional): A config interface. Defaults to
tanjun.injected(type=DayongConfig).
instance (Any): A Class instance.
update (Any): A dictionary containing the attributes to be overwritten.
Returns:
Any: A class instance with updated attribute values.
"""
for key, value in update.items():
setattr(instance, key, value)
return instance

async def connect(
self, config: DayongConfig = tanjun.injected(type=DayongConfig)
) -> None:
loop = asyncio.get_running_loop()
self._conn = await loop.run_in_executor(
None,
Expand All @@ -47,102 +55,56 @@ async def create_table(self) -> None:
async with self._conn.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

async def add_row(self, table_model: Message) -> None:
async def add_row(self, table_model: SQLModel) -> None:
async with AsyncSession(self._conn) as session:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, session.add, table_model)
await session.commit()

async def remove_row(self, table_model: Message) -> None:
async def remove_row(self, table_model: SQLModel, attribute: str) -> None:
model = type(table_model)
async with AsyncSession(self._conn) as session:
# Temp ignore incompatible type passed to `exec()`. See:
# https://github.com/tiangolo/sqlmodel/issues/54
# https://github.com/tiangolo/sqlmodel/pull/58
row: ScalarResult[Any] = await session.exec(
select(model).where(model.id == table_model.id) # type: ignore
select(model).where(
getattr(model, attribute) == getattr(table_model, attribute)
) # type: ignore
)
await session.delete(row)
await session.delete(row.one())
await session.commit()

async def get_row(self, table_model: Message) -> ScalarResult[Any]:
async def get_row(self, table_model: SQLModel, attribute: str) -> ScalarResult[Any]:
model = type(table_model)
async with AsyncSession(self._conn) as session:
# Temp ignore incompatible type passed to `exec()`. See:
# https://github.com/tiangolo/sqlmodel/issues/54
# https://github.com/tiangolo/sqlmodel/pull/58
row: ScalarResult[Any] = await session.exec(
select(model).where(model.id == table_model.id) # type: ignore
)
return row


class ScheduledTaskDB(Database):
"""Implements a database connection for managing scheduled tasks."""

_conn: AsyncEngine

async def connect(
self, config: DayongConfig = tanjun.injected(type=DayongConfig)
) -> None:
loop = asyncio.get_running_loop()
self._conn = await loop.run_in_executor(
None,
create_async_engine,
config.database_uri
if config.database_uri
else DayongDynamicLoader().load().database_uri,
)

async def create_table(self) -> None:
async with self._conn.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

async def add_row(self, table_model: ScheduledTask) -> None:
async with AsyncSession(self._conn) as session:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, session.add, table_model)
await session.commit()

async def remove_row(self, table_model: ScheduledTask) -> None:
model = type(table_model)
async with AsyncSession(self._conn) as session:
row: ScalarResult[Any] = await session.exec(
select(model).where(
model.channel_name == table_model.channel_name
getattr(model, attribute) == getattr(table_model, attribute)
) # type: ignore
)
await session.delete(row.one())
await session.commit()

async def get_row(self, table_model: ScheduledTask) -> ScalarResult[Any]:
model = type(table_model)
async with AsyncSession(self._conn) as session:
row: ScalarResult[Any] = await session.exec(
select(model).where( # type: ignore
model.task_name == table_model.task_name
)
)
return row

async def get_all_row(self, table_model: Type[ScheduledTask]) -> ScalarResult[Any]:
async def get_all_row(self, table_model: type[SQLModel]) -> ScalarResult[Any]:
async with AsyncSession(self._conn) as session:
return await session.exec(select(table_model)) # type: ignore

async def update_row(self, table_model: ScheduledTask) -> None:
async def update_row(self, table_model: SQLModel, attribute: str) -> None:
loop = asyncio.get_running_loop()
model = type(table_model)
table = table_model.__dict__

async with AsyncSession(self._conn) as session:
row: ScalarResult[Any] = await session.exec(
select(model).where(
model.channel_name == table_model.channel_name
getattr(model, attribute) == getattr(table_model, attribute)
) # type: ignore
)
task = row.one()
if table_model.task_name:
task.name = table_model.task_name
if table_model.run:
task.run = table_model.run
task = await self.update(task, table)
await loop.run_in_executor(None, session.add, task)
await session.commit()
await session.refresh(task)

0 comments on commit adeda5f

Please sign in to comment.