Skip to content

Commit

Permalink
feat(class): Add a filter param for avoid duplicate insert or update …
Browse files Browse the repository at this point in the history
…or unnecessary query execute
  • Loading branch information
JackTheMico committed Sep 2, 2022
1 parent c2b3a02 commit 7aac3b4
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 56 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@ name: Codecov

on:
workflow_dispatch:
push:
branches:
- "*"
paths-ignore:
- '**/README.md'
- '**/pyproject.toml'
pull_request:
branches:
- "*"
Expand Down
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ from ruia_peewee_async import (
after_start,
)


class DoubanItem(Item):
target_item = TextField(css_select="tr.item")
title = AttrField(css_select="a.nbg", attr="title")
Expand All @@ -53,18 +52,17 @@ class DoubanItem(Item):
async def clean_title(self, value):
return value.strip()


class DoubanSpider(Spider):
start_urls = ["https://movie.douban.com/chart"]
# aiohttp_kwargs = {"proxy": "http://127.0.0.1:7890"}

async def parse(self, response: Response):
async for item in DoubanItem.get_items(html=await response.text()):
yield RuiaPeeweeInsert(item.results) # default is MySQL
# yield RuiaPeeweeInsert(item.results, filters="url") # use url field(column) to deduplicate, avoid unnecessary insert query executed.
# yield RuiaPeeweeInsert(item.results, database=TargetDB.POSTGRES) # save to Postgresql
# yield RuiaPeeweeInsert(item.results, database=TargetDB.BOTH) # save to both MySQL and Postgresql


class DoubanUpdateSpider(Spider):
start_urls = ["https://movie.douban.com/chart"]

Expand All @@ -83,11 +81,10 @@ class DoubanUpdateSpider(Spider):
# data: A dict that's going to be updated in the database.
# query: A peewee's query or a dict to search for the target data in database.
# database: The target database type.
# filters: A str or List[str] of columns to avoid duplicate data and avoid unnecessary query execute.
# create_when_not_exists: Default is True. If True, will create a record when query can't get the record.
# not_update_when_exists: Default is True. If True and record exists, won't update data to the records.
# only: A list or tuple of fields that should be updated only.


mysql = {
"host": "127.0.0.1",
"port": 3306,
Expand Down
4 changes: 4 additions & 0 deletions examples/douban.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class DoubanSpider(Spider):
async def parse(self, response: Response):
async for item in DoubanItem.get_items(html=await response.text()):
yield RuiaPeeweeInsert(item.results) # default is MySQL
# use url field(column) to deduplicate, avoid unnecessary insert query executed.
# yield RuiaPeeweeInsert(item.results, filters="url")

# yield RuiaPeeweeInsert(item.results, database=TargetDB.POSTGRES) # save to Postgresql
# yield RuiaPeeweeInsert(item.results, database=TargetDB.BOTH) # save to both MySQL and Postgresql

Expand All @@ -49,6 +52,7 @@ async def parse(self, response: Response):
# data: A dict that's going to be updated in the database.
# query: A peewee's query or a dict to search for the target data in database.
# database: The target database type.
# filters: A str or List[str] of columns to avoid duplicate data and avoid unnecessary query execute.
# create_when_not_exists: Default is True. If True, will create a record when query can't get the record.
# not_update_when_exists: Default is True. If True and record exists, won't update data to the records.
# only: A list or tuple of fields that should be updated only.
Expand Down
146 changes: 109 additions & 37 deletions ruia_peewee_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,35 @@
from functools import wraps
from ssl import SSLContext
from types import MethodType
from typing import Dict
from typing import Dict, Callable
from typing import Optional as TOptional
from typing import Sequence, Tuple, Union

from peewee import DoesNotExist, Model, Query
from peewee_async import Manager, MySQLDatabase, PostgresqlDatabase
from peewee_async import (
AsyncQueryWrapper,
Manager,
MySQLDatabase,
PooledMySQLDatabase,
PooledPostgresqlDatabase,
PostgresqlDatabase,
)
from pymysql import OperationalError
from ruia import Spider as RuiaSpider
from schema import And, Optional, Or, Schema, SchemaError, Use


class Spider(RuiaSpider):
mysql_model: Union[Model, Dict]
mysql_model: Model
mysql_manager: Manager
postgres_model: Union[Model, Dict]
postgres_model: Model
postgres_manager: Manager
mysql_db: MySQLDatabase
postgres_db: PostgresqlDatabase
mysql_db: Union[MySQLDatabase, PooledMySQLDatabase]
postgres_db: Union[PostgresqlDatabase, PooledPostgresqlDatabase]
mysql_filters: TOptional[AsyncQueryWrapper]
postgres_filters: TOptional[AsyncQueryWrapper]
process_insert_callback_result: Callable
process_update_callback_result: Callable


class TargetDB(Enum):
Expand All @@ -35,20 +46,19 @@ def logging(func):
async def decorator(spider_ins: Spider, callback_result):
data = callback_result.data
database = callback_result.database
msg_pre = f"<RuiaPeeweeAsync: Success insert data: {data} into "
query = getattr(callback_result, "query", None)
try:
result = await func(spider_ins, callback_result)
except OperationalError as ope: # pragma: no cover
method = "insert" if not query else "update"
spider_ins.logger.error(
f"<RuiaPeeweeAsync: {database.name} insert error: {ope}>"
f"<RuiaPeeweeAsync: {database.name} {method} data: {data} error: {ope}>"
)
except SchemaError as pae:
spider_ins.logger.error(pae)
raise pae
else:
msg = "".join([msg_pre, database.name, ">"])
spider_ins.logger.info(msg)
return result
spider_ins.logger.info(result)

return decorator

Expand Down Expand Up @@ -78,8 +88,26 @@ def _check_result(data: Tuple):
result_validator = Schema(Use(_check_result))


async def filter_func(data, spider_ins, database, manager, model, filters) -> bool:
if not hasattr(spider_ins, f"{database}_filters"):
conditions = [getattr(model, fil) for fil in filters]
filter_res = await manager.execute(model.select(*conditions).distinct())
setattr(spider_ins, f"{database}_filters", filter_res)
filtered = False
filter_res = getattr(spider_ins, f"{database}_filters")
for fil in filters:
if data[fil] in [getattr(x, fil) for x in filter_res]:
filtered = True
return filtered


class RuiaPeeweeInsert:
def __init__(self, data: Dict, database: TargetDB = TargetDB.MYSQL) -> None:
def __init__(
self,
data: Dict,
database: TargetDB = TargetDB.MYSQL,
filters: TOptional[Union[Sequence[str], str]] = None,
) -> None:
"""
Args:
Expand All @@ -90,25 +118,49 @@ def __init__(self, data: Dict, database: TargetDB = TargetDB.MYSQL) -> None:

self.data = data
self.database = database
self.filters = filters

@staticmethod
@logging
async def process(spider_ins: Spider, callback_result):
needs_check = (
callback_result,
{"data": dict, "database": TargetDB},
{"data": dict, "database": TargetDB, "filters": (str, type(None), list)},
"RuiaPeeweeAsync: insert process",
)
result_validator.validate(needs_check)
data = callback_result.data
database = callback_result.database
if database == TargetDB.MYSQL:
await spider_ins.mysql_manager.create(spider_ins.mysql_model, **data)
elif database == TargetDB.POSTGRES:
await spider_ins.postgres_manager.create(spider_ins.postgres_model, **data)
filters = callback_result.filters
if database == TargetDB.BOTH:
databases = [TargetDB.MYSQL.name, TargetDB.POSTGRES.name]
else:
await spider_ins.mysql_manager.create(spider_ins.mysql_model, **data)
await spider_ins.postgres_manager.create(spider_ins.postgres_model, **data)
databases = [database.name]
msg = ""
if isinstance(filters, str):
filters = [filters]
for database in databases:
database = database.lower()
manager: Manager = getattr(spider_ins, f"{database}_manager")
model: Model = getattr(spider_ins, f"{database}_model")
if filters:
filtered = await filter_func(
data, spider_ins, database, manager, model, filters
)
if filtered:
msg += (
f"<RuiaPeeweeAsync: data: {data} was filtered by filters: {filters},"
f" won't insert into {database.upper()}>\n"
)
continue
msg += (
f"<RuiaPeeweeAsync: data: {data} wasn't filtered by filters: {filters}, "
f"success insert into {database.upper()}>\n"
)
await manager.create(model, **data)
if msg:
return msg
return f"<RuiaPeeweeAsync: Success insert {data} into database: {databases}>"


class RuiaPeeweeUpdate:
Expand All @@ -119,6 +171,7 @@ def __init__(
data: Dict,
query: Union[Query, Dict],
database: TargetDB = TargetDB.MYSQL,
filters: TOptional[Union[Sequence[str], str]] = None,
create_when_not_exists: bool = True,
not_update_when_exists: bool = True,
only: TOptional[Sequence[str]] = None,
Expand All @@ -129,6 +182,7 @@ def __init__(
data: A dict that's going to be updated in the database.
query: A peewee's query or a dict to search for the target data in database.
database: The target database type.
filters: A str or List[str] of columns to avoid duplicate data and avoid unnecessary query execute.
create_when_not_exists: Default is True. If True, will create a record when query can't get the record.
not_update_when_exists: Default is True. If True and record exists, won't update data to the records.
only: A list or tuple of fields that should be updated only.
Expand All @@ -138,6 +192,7 @@ def __init__(
self.data = data
self.query = query
self.database = database
self.filters = filters
self.create_when_not_exists = create_when_not_exists
self.not_update_when_exists = not_update_when_exists
self.only = only
Expand All @@ -147,45 +202,56 @@ async def _deal_update(
spider_ins,
data,
query,
filters,
create_when_not_exists,
not_update_when_exists,
only,
databases,
):
): # pylint: disable=too-many-locals
msg = ""
if isinstance(filters, str):
filters = [filters]
for database in databases:
database = database.lower()
manager: Manager = getattr(spider_ins, f"{database}_manager")
model: Model = getattr(spider_ins, f"{database}_model")
if filters:
filtered = await filter_func(
data, spider_ins, database, manager, model, filters
)
if filtered:
msg += f"<RuiaPeeweeAsync: data: {data} was filtered by filters: {filters}\n"
continue
msg += f"<RuiaPeeweeAsync: data: {data} wasn't filtered by filters: {filters}\n"
try:
model_ins = await manager.get(model, **query)
except DoesNotExist:
if create_when_not_exists:
await manager.create(model, **data)
spider_ins.logger.info(
f"<RuiaPeeweeAsync: data: {data} not exists in {database}, but success created>"
)
else:
spider_ins.logger.warning(
f"<RuiaPeeweeAsync: data: {data} not exists in {database}, \
won't create it because create_when_not_exists is False>"
)
msg += f"<RuiaPeeweeAsync: data: {data} not exists in {database.upper()}, but success created>\n"
msg += (
f"<RuiaPeeweeAsync: data: {data} not exists in {database.upper()}, "
"won't create it because create_when_not_exists is False>\n"
)
else:
if not_update_when_exists:
spider_ins.logger.info(
f"<RuiaPeeweeAsync: {data} won't updated in {database}>"
msg += (
f"<RuiaPeeweeAsync: Won't update {data} in {database.upper()} "
"because not_update_when_exists is True>\n"
)
continue
model_ins.__data__.update(data)
await manager.update(model_ins, only=only)
spider_ins.logger.info(
f"<RuiaPeeweeAsync: {data} was updated in {database}>"
)
if msg:
return msg
return f"<RuiaPeeweeAsync: Updated {data} in {databases}>"

@staticmethod
async def _update(
spider_ins,
data,
query,
filters,
database,
create_when_not_exists,
not_update_when_exists,
Expand All @@ -195,22 +261,25 @@ async def _update(
databases = [TargetDB.MYSQL.name, TargetDB.POSTGRES.name]
else:
databases = [database.name]
await RuiaPeeweeUpdate._deal_update(
result = await RuiaPeeweeUpdate._deal_update(
spider_ins,
data,
query,
filters,
create_when_not_exists,
not_update_when_exists,
only,
databases,
)
return result

@staticmethod
@logging
async def process(spider_ins, callback_result):
data = callback_result.data
database = callback_result.database
query = callback_result.query
filters = callback_result.filters
create_when_not_exists = callback_result.create_when_not_exists
not_update_when_exists = callback_result.not_update_when_exists
only = callback_result.only
Expand All @@ -220,22 +289,25 @@ async def process(spider_ins, callback_result):
"data": dict,
"database": TargetDB,
"query": (Query, dict),
"filters": (str, type(None), list),
"create_when_not_exists": bool,
"not_update_when_exists": bool,
"only": (list, tuple, type(None)),
},
"RuiaPeeweeAsync: update process",
)
result_validator.validate(needs_check)
await RuiaPeeweeUpdate._update(
result = await RuiaPeeweeUpdate._update(
spider_ins,
data,
query,
filters,
database,
create_when_not_exists,
not_update_when_exists,
only,
)
return result


def init_spider(*, spider_ins: Spider):
Expand Down Expand Up @@ -323,10 +395,10 @@ async def init_after_start(spider_ins):

if mysql and mysql_model:
spider_ins.mysql_config = mysql
spider_ins.mysql_model = mysql_model
# spider_ins.mysql_model = mysql_model
if postgres and postgres_model:
spider_ins.postgres_config = postgres
spider_ins.postgres_model = postgres_model
# spider_ins.postgres_model = postgres_model
init_spider(spider_ins=spider_ins)

return init_after_start
Expand Down
Loading

0 comments on commit 7aac3b4

Please sign in to comment.