Skip to content

Commit

Permalink
new(project) add update functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
RedmanPlus committed Jan 28, 2023
1 parent 83ac1c2 commit 6bd0413
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 8 deletions.
4 changes: 4 additions & 0 deletions noio_db/core/abstract_syntax_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def to_dict(self) -> dict:
newargs = reformat_dict(v)
result_dict[k] = newargs[0]

if k == "set" and isinstance(v, dict):
newargs = reformat_dict(v)
result_dict[k] = newargs

if k == "insert" and isinstance(v, list):
v[2] = list_to_insert_vals(*v[2])

Expand Down
1 change: 1 addition & 0 deletions noio_db/core/sql_object_factories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
AsSQLObjectFactory,
OrSQLObjectFactory,
)
from .update_sql_object_factory import SetSQLObjectFactory, UpdateSQLObjectFactory
27 changes: 27 additions & 0 deletions noio_db/core/sql_object_factories/update_sql_object_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from noio_db.core.sql_object_factories import AbstractSQLObjectFactory
from noio_db.core.sql_objects import BaseSQLObject, SetSQLObject, UpdateSQLObject
from noio_db.utils import list_into_comma_sql_object


class UpdateSQLObjectFactory(AbstractSQLObjectFactory):
def get_object(self, *args) -> BaseSQLObject:

if len(args) > 1:
raise Exception(
f"Update query must accept only 1 parameter, got {len(args)}"
)

name = args[0]

return UpdateSQLObject(name)


class SetSQLObjectFactory(AbstractSQLObjectFactory):
def get_object(self, *args) -> BaseSQLObject:

if len(args) < 1:
raise Exception("Set query must accept at least one parameter, got 0")

params = list_into_comma_sql_object(*args)

return SetSQLObject(params)
1 change: 1 addition & 0 deletions noio_db/core/sql_objects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SelectSQLObject,
WhereSQLObject,
)
from .update import SetSQLObject, UpdateSQLObject
from .utils import (
AndSQLObject,
ArgInBracesSQLObject,
Expand Down
15 changes: 15 additions & 0 deletions noio_db/core/sql_objects/update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from string import Template

from noio_db.core.sql_objects import BaseSQLObject


class UpdateSQLObject(BaseSQLObject):

template = Template("UPDATE $what")
template_keys = ["what"]


class SetSQLObject(BaseSQLObject):

template = Template("SET $what")
template_keys = ["what"]
4 changes: 4 additions & 0 deletions noio_db/core/sql_query_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
InsertSQLObjectFactory,
OrderBySQLObjectFactory,
SelectSQLObjectFactory,
SetSQLObjectFactory,
UpdateSQLObjectFactory,
WhereSQLObjectFactory,
)

Expand All @@ -32,6 +34,8 @@ class SelectSQLQueryConstructor(AbstractSQLQueryConstructor):
"having": HavingSQLObjectFactory(),
"order_by": OrderBySQLObjectFactory(),
"insert": InsertSQLObjectFactory(),
"update": UpdateSQLObjectFactory(),
"set": SetSQLObjectFactory(),
}

def compile(self, query: dict) -> str:
Expand Down
48 changes: 41 additions & 7 deletions noio_db/models/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=E0611
from pydantic import BaseModel
from pydantic import BaseConfig, BaseModel

from noio_db.core import AST, CreateTableSQLObjectFactory, SelectSQLQueryConstructor
from noio_db.query import Query
Expand Down Expand Up @@ -93,10 +93,32 @@ def insert(self):
driver(query)


class Model(BaseModel, SelectMixin, CreateModelMixin, InsertMixin, ObjectCounter):
__from_orm__: bool = False
class UpdateMixin:
def update(self):
table_name = self.table_name
updated_fields = self.__dict__
id_field = updated_fields.pop("id")

# pylint: disable=W0212
ast = AST()
ast._update(table_name)
ast._set(**updated_fields)
ast._where(**{"id": id_field})
# pylint: enable=W0212

query = SelectSQLQueryConstructor().compile(ast.to_dict())
driver = get_current_settings(self)

driver(query)


class Model(BaseModel, SelectMixin, CreateModelMixin, InsertMixin, UpdateMixin):

id: int = None

class Config(BaseConfig):
is_from_orm: bool = False

@property
def table_name(self):
return self.__class__.__name__.lower()
Expand All @@ -105,8 +127,20 @@ def table_name(self):
def table_fields(self):
return list(self.__annotations__.keys())

@property
def is_from_orm(self):
return self.Config.is_from_orm

@is_from_orm.setter
def is_from_orm(self, value: bool):
if not isinstance(value, bool):
raise Exception()

self.Config.is_from_orm = value

def save(self):
if self.__from_orm__:
pass
else:
self.insert()

if self.Config.is_from_orm:
self.update()

self.insert()
3 changes: 2 additions & 1 deletion noio_db/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def _fill_query(self):

for i, result in enumerate(self.objects):
kwargs = zip_into_dict(model_field_names, result)
self.objects[i] = self.model_class(__from_orm__=True, **kwargs)
self.objects[i] = self.model_class(**kwargs)
self.objects[i].Config.is_from_orm = True

self.called = True

Expand Down
2 changes: 2 additions & 0 deletions noio_db/utils/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
"_group_by",
"_order_by",
"_insert",
"_update",
}

KWARG_METHOD_NAMES = {
"_where",
"_having",
"_and",
"_or",
"_set",
}

0 comments on commit 6bd0413

Please sign in to comment.