-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(sql): implementing a SqlTaskRepository
- Loading branch information
1 parent
4c9a2e6
commit 3ed6d8d
Showing
7 changed files
with
159 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
|
||
from petisco import ApplicationConfigurer, Databases | ||
from petisco.extra.sqlalchemy import MySqlConnection, SqlDatabase, SqliteConnection | ||
|
||
DATABASE_NAME = "sql-tasks" | ||
ROOT_PATH = os.path.abspath(os.path.dirname(__file__)) | ||
SQL_SERVER = os.getenv("SQL_SERVER", "sqlite") | ||
|
||
|
||
class DatabasesConfigurer(ApplicationConfigurer): | ||
def execute(self, testing: bool = True) -> None: | ||
if testing or (SQL_SERVER == "sqlite"): | ||
test_db_filename = "tasks.db" | ||
connection = SqliteConnection.create("sqlite", test_db_filename) | ||
else: | ||
connection = MySqlConnection.from_environ() | ||
|
||
sql_database = SqlDatabase(name=DATABASE_NAME, connection=connection) | ||
|
||
databases = Databases() | ||
databases.add(sql_database) | ||
databases.initialize() | ||
|
||
|
||
configurers = [DatabasesConfigurer()] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from petisco import SqlBase | ||
from sqlalchemy import Column, Integer, String | ||
from sqlalchemy.orm import Mapped | ||
|
||
from app.src.task.shared.domain.task import Task | ||
|
||
|
||
class SqlTask(SqlBase[Task]): | ||
|
||
__tablename__ = "Task" | ||
|
||
id: Mapped[int] = Column(Integer, primary_key=True) | ||
|
||
aggregate_id: Mapped[str] = Column(String(36)) | ||
name: Mapped[str] = Column(String(50)) | ||
description: Mapped[str] = Column(String(200)) | ||
# created_at: Mapped[datetime] = Column(String(200))) | ||
# labels: list[str] | None = list() | ||
|
||
def to_domain(self) -> Task: | ||
return Task( | ||
name=self.name, description=self.description, aggregate_id=self.aggregate_id | ||
) | ||
|
||
@staticmethod | ||
def from_domain(task: Task) -> "SqlTask": | ||
return SqlTask( | ||
name=task.name, description=task.description, aggregate_id=task.aggregate_id | ||
) |
91 changes: 91 additions & 0 deletions
91
app/src/task/shared/infrastructure/sql/sql_task_repository.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from collections.abc import Callable | ||
from typing import ContextManager | ||
|
||
from meiga import BoolResult, Error, Failure, Result, Success, isSuccess | ||
from meiga.decorators import meiga | ||
from petisco import Databases | ||
from petisco.base.application.patterns.crud_repository import CrudRepository | ||
from petisco.base.domain.errors.defaults.already_exists import ( | ||
AggregateAlreadyExistError, | ||
) | ||
from petisco.base.domain.errors.defaults.not_found import AggregateNotFoundError | ||
from petisco.base.domain.model.uuid import Uuid | ||
from sqlalchemy import select | ||
from sqlalchemy.orm import Session | ||
|
||
from app.src.task.shared.domain.task import Task | ||
from app.src.task.shared.infrastructure.sql.sql_task import SqlTask | ||
|
||
|
||
class SqlTaskRepository(CrudRepository[Task]): | ||
session_scope: Callable[..., ContextManager[Session]] | ||
|
||
def __init__(self): | ||
self.session_scope = Databases.get_session_scope("sql-tasks") | ||
|
||
@meiga | ||
def save(self, task: Task) -> BoolResult: | ||
|
||
with self.session_scope() as session: | ||
query = select(SqlTask).where( | ||
SqlTask.aggregate_id == task.aggregate_id.value | ||
) | ||
sql_task = session.execute(query).first() | ||
|
||
if sql_task: | ||
return Failure(AggregateAlreadyExistError(task.aggregate_id)) | ||
|
||
sql_task = SqlTask.from_domain(task) | ||
session.add(sql_task) | ||
|
||
return isSuccess | ||
|
||
@meiga | ||
def retrieve(self, aggregate_id: Uuid) -> Result[Task, Error]: | ||
with self.session_scope() as session: | ||
query = select(SqlTask).where(SqlTask.aggregate_id == aggregate_id.value) | ||
sql_task = session.execute(query).first() | ||
|
||
if sql_task: | ||
return Failure(AggregateNotFoundError(aggregate_id)) | ||
task = sql_task.to_domain() | ||
|
||
return Success(task) | ||
|
||
def update(self, task: Task) -> BoolResult: | ||
with self.session_scope() as session: | ||
query = select(SqlTask).where( | ||
SqlTask.aggregate_id == task.aggregate_id.value | ||
) | ||
sql_task = session.execute(query).first() | ||
|
||
if sql_task: | ||
return Failure(AggregateNotFoundError(task.aggregate_id)) | ||
|
||
sql_task = SqlTask.from_domain(task) | ||
session.add(sql_task) | ||
|
||
return isSuccess | ||
|
||
def remove(self, aggregate_id: Uuid) -> BoolResult: | ||
with self.session_scope() as session: | ||
query = select(SqlTask).where(SqlTask.aggregate_id == aggregate_id.value) | ||
sql_task = session.execute(query).first() | ||
|
||
if sql_task: | ||
return Failure(AggregateNotFoundError(aggregate_id)) | ||
|
||
session.remove(sql_task) | ||
|
||
return isSuccess | ||
|
||
def retrieve_all(self) -> Result[list[Task], Error]: | ||
with self.session_scope() as session: | ||
query = select(SqlTask) | ||
sql_tasks = session.execute(query).all() | ||
tasks = [sql_task.to_domain().values() for sql_task in sql_tasks] | ||
|
||
return Success(tasks) | ||
|
||
def clear(self): | ||
Databases().remove("sql-tasks") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters