From 4f68487e00b81354a44b261093f20feb88984bbf Mon Sep 17 00:00:00 2001 From: Animesh404 Date: Thu, 4 Jul 2024 02:42:30 +0530 Subject: [PATCH 1/7] add joinedload to eagerly load related entities --- openadapt/db/crud.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index 96de81333..c7c7f533d 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -11,6 +11,7 @@ from loguru import logger from sqlalchemy.orm import Session as SaSession +from sqlalchemy.orm import joinedload import psutil import sqlalchemy as sa @@ -60,7 +61,7 @@ def _insert( sa.engine.Result | None: The SQLAlchemy Result object if a buffer is not provided. None if a buffer is provided. """ - db_obj = {column.name: None for column in table.__table__.columns} + db_obj = {column.name: None for column in table._table_.columns} for key in db_obj: if key in event_data: val = event_data[key] @@ -391,7 +392,13 @@ def get_action_events( list[ActionEvent]: A list of action events for the recording. """ assert recording, "Invalid recording." - action_events = _get(session, ActionEvent, recording.id) + # Using joinedload to eagerly load related entities + action_events = ( + session.query(ActionEvent) + .options(joinedload(ActionEvent.screenshot)) + .filter(ActionEvent.recording_id == recording.id) + .all() + ) action_events = filter_disabled_action_events(action_events) # filter out stop sequences listed in STOP_SEQUENCES and Ctrl + C filter_stop_sequences(action_events) @@ -530,7 +537,13 @@ def get_screenshots( Returns: list[Screenshot]: A list of screenshots for the recording. """ - screenshots = _get(session, Screenshot, recording.id) + screenshots = screenshots = ( + session.query(Screenshot) + .filter(Screenshot.recording_id == recording.id) + .order_by(Screenshot.timestamp) + .options(joinedload(Screenshot.action_event)) + .all() + ) for prev, cur in zip(screenshots, screenshots[1:]): cur.prev = prev @@ -555,7 +568,12 @@ def get_window_events( Returns: list[WindowEvent]: A list of window events for the recording. """ - return _get(session, WindowEvent, recording.id) + return ( + session.query(WindowEvent) + .options(joinedload(WindowEvent.action_events)) + .filter(WindowEvent.recording_id == recording.id) + .all() + ) def disable_action_event(session: SaSession, event_id: int) -> None: From 85d01005df02c1378a36e1fee76bc8cade22b954 Mon Sep 17 00:00:00 2001 From: Animesh404 Date: Thu, 4 Jul 2024 13:49:16 +0530 Subject: [PATCH 2/7] change get_active_window_data to return empty dict when state is none --- openadapt/db/crud.py | 15 +++++++++------ openadapt/window/__init__.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index c7c7f533d..e3fa20e1f 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -15,6 +15,7 @@ import psutil import sqlalchemy as sa + from openadapt import utils from openadapt.config import DATABASE_LOCK_FILE_PATH, config from openadapt.db.db import Session, get_read_only_session_maker @@ -61,7 +62,7 @@ def _insert( sa.engine.Result | None: The SQLAlchemy Result object if a buffer is not provided. None if a buffer is provided. """ - db_obj = {column.name: None for column in table._table_.columns} + db_obj = {column.name: None for column in table.__table__.columns} for key in db_obj: if key in event_data: val = event_data[key] @@ -392,11 +393,12 @@ def get_action_events( list[ActionEvent]: A list of action events for the recording. """ assert recording, "Invalid recording." - # Using joinedload to eagerly load related entities + assert recording, "Invalid recording." action_events = ( session.query(ActionEvent) - .options(joinedload(ActionEvent.screenshot)) .filter(ActionEvent.recording_id == recording.id) + .order_by(ActionEvent.timestamp) + .options(joinedload(ActionEvent.recording)) .all() ) action_events = filter_disabled_action_events(action_events) @@ -537,11 +539,11 @@ def get_screenshots( Returns: list[Screenshot]: A list of screenshots for the recording. """ - screenshots = screenshots = ( + screenshots = ( session.query(Screenshot) .filter(Screenshot.recording_id == recording.id) .order_by(Screenshot.timestamp) - .options(joinedload(Screenshot.action_event)) + .options(joinedload(Screenshot.recording), joinedload(Screenshot.action_event)) .all() ) @@ -570,8 +572,9 @@ def get_window_events( """ return ( session.query(WindowEvent) - .options(joinedload(WindowEvent.action_events)) .filter(WindowEvent.recording_id == recording.id) + .order_by(WindowEvent.timestamp) + .options(joinedload(WindowEvent.recording)) .all() ) diff --git a/openadapt/window/__init__.py b/openadapt/window/__init__.py index 99b42ec02..ebde66d3f 100644 --- a/openadapt/window/__init__.py +++ b/openadapt/window/__init__.py @@ -33,7 +33,7 @@ def get_active_window_data( """ state = get_active_window_state(include_window_data) if not state: - return None + return {} title = state["title"] left = state["left"] top = state["top"] From 1feea90f244b44adef3433ea90e284fca8f2e127 Mon Sep 17 00:00:00 2001 From: Animesh404 Date: Thu, 4 Jul 2024 21:11:48 +0530 Subject: [PATCH 3/7] add eager as parameter in _get --- openadapt/db/crud.py | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index e3fa20e1f..3b521d3f7 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -359,6 +359,7 @@ def _get( session: SaSession, table: BaseModelType, recording_id: int, + eager: bool = False, ) -> list[BaseModelType]: """Retrieve records from the database table based on the recording timestamp. @@ -366,11 +367,21 @@ def _get( session (sa.orm.Session): The database session. table (BaseModel): The database table to query. recording_id (int): The recording id. + eager (bool): if true, implement eagerloading. Returns: list[BaseModel]: A list of records retrieved from the database table, ordered by timestamp. """ + if eager: + return ( + session.query(table) + .filter(table.recording_id == recording_id) + .order_by(table.timestamp) + .options(joinedload(table.recording)) + .all() + ) + return ( session.query(table) .filter(table.recording_id == recording_id) @@ -393,14 +404,7 @@ def get_action_events( list[ActionEvent]: A list of action events for the recording. """ assert recording, "Invalid recording." - assert recording, "Invalid recording." - action_events = ( - session.query(ActionEvent) - .filter(ActionEvent.recording_id == recording.id) - .order_by(ActionEvent.timestamp) - .options(joinedload(ActionEvent.recording)) - .all() - ) + action_events = _get(session, ActionEvent, recording.id, eager=True) action_events = filter_disabled_action_events(action_events) # filter out stop sequences listed in STOP_SEQUENCES and Ctrl + C filter_stop_sequences(action_events) @@ -539,13 +543,7 @@ def get_screenshots( Returns: list[Screenshot]: A list of screenshots for the recording. """ - screenshots = ( - session.query(Screenshot) - .filter(Screenshot.recording_id == recording.id) - .order_by(Screenshot.timestamp) - .options(joinedload(Screenshot.recording), joinedload(Screenshot.action_event)) - .all() - ) + screenshots = _get(session, Screenshot, recording.id, eager=True) for prev, cur in zip(screenshots, screenshots[1:]): cur.prev = prev @@ -570,13 +568,7 @@ def get_window_events( Returns: list[WindowEvent]: A list of window events for the recording. """ - return ( - session.query(WindowEvent) - .filter(WindowEvent.recording_id == recording.id) - .order_by(WindowEvent.timestamp) - .options(joinedload(WindowEvent.recording)) - .all() - ) + return _get(session, WindowEvent, recording.id, eager=True) def disable_action_event(session: SaSession, event_id: int) -> None: From 3cc144e17889071e8adb3cd7b84aea298dabf2b2 Mon Sep 17 00:00:00 2001 From: Animesh404 Date: Fri, 5 Jul 2024 01:21:01 +0530 Subject: [PATCH 4/7] add a default parameter for relationships --- openadapt/db/crud.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index 3b521d3f7..7ecbff8a5 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -360,6 +360,7 @@ def _get( table: BaseModelType, recording_id: int, eager: bool = False, + relationships: list[str] = None, ) -> list[BaseModelType]: """Retrieve records from the database table based on the recording timestamp. @@ -367,28 +368,25 @@ def _get( session (sa.orm.Session): The database session. table (BaseModel): The database table to query. recording_id (int): The recording id. - eager (bool): if true, implement eagerloading. + eager (bool, optional): If true, implement eager loading. Defaults to False. + relationships (list[str], optional): List of relationships to eagerly load. Defaults to None. Returns: list[BaseModel]: A list of records retrieved from the database table, ordered by timestamp. """ - if eager: - return ( - session.query(table) - .filter(table.recording_id == recording_id) - .order_by(table.timestamp) - .options(joinedload(table.recording)) - .all() - ) - - return ( + query = ( session.query(table) .filter(table.recording_id == recording_id) .order_by(table.timestamp) - .all() ) + if eager and relationships: + for rel in relationships: + query = query.options(joinedload(getattr(table, rel))) + + return query.all() + def get_action_events( session: SaSession, @@ -404,7 +402,9 @@ def get_action_events( list[ActionEvent]: A list of action events for the recording. """ assert recording, "Invalid recording." - action_events = _get(session, ActionEvent, recording.id, eager=True) + action_events = _get( + session, ActionEvent, recording.id, eager=True, relationships=["screenshot"] + ) action_events = filter_disabled_action_events(action_events) # filter out stop sequences listed in STOP_SEQUENCES and Ctrl + C filter_stop_sequences(action_events) @@ -543,7 +543,13 @@ def get_screenshots( Returns: list[Screenshot]: A list of screenshots for the recording. """ - screenshots = _get(session, Screenshot, recording.id, eager=True) + screenshots = _get( + session, + Screenshot, + recording.id, + eager=True, + relationships=["action_event", "recording"], + ) for prev, cur in zip(screenshots, screenshots[1:]): cur.prev = prev @@ -568,7 +574,9 @@ def get_window_events( Returns: list[WindowEvent]: A list of window events for the recording. """ - return _get(session, WindowEvent, recording.id, eager=True) + return _get( + session, WindowEvent, recording.id, eager=True, relationships=["action_events"] + ) def disable_action_event(session: SaSession, event_id: int) -> None: From 7597cb1d92645a838aa458d01b1896fa8a61c030 Mon Sep 17 00:00:00 2001 From: Animesh404 Date: Fri, 5 Jul 2024 01:44:20 +0530 Subject: [PATCH 5/7] flake8 error fix --- openadapt/db/crud.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index 7ecbff8a5..4ba92e3be 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -369,7 +369,8 @@ def _get( table (BaseModel): The database table to query. recording_id (int): The recording id. eager (bool, optional): If true, implement eager loading. Defaults to False. - relationships (list[str], optional): List of relationships to eagerly load. Defaults to None. + relationships (list[str], optional): Relationships to load eagerly. + Defaults to None. Returns: list[BaseModel]: A list of records retrieved from the database table, From e975620810c48d86a9f305d9b3041185cf2ee95a Mon Sep 17 00:00:00 2001 From: Animesh404 Date: Fri, 5 Jul 2024 19:05:13 +0530 Subject: [PATCH 6/7] use direct and indirect relationships --- openadapt/db/crud.py | 54 ++++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index 4ba92e3be..04327364f 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -11,7 +11,7 @@ from loguru import logger from sqlalchemy.orm import Session as SaSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, subqueryload import psutil import sqlalchemy as sa @@ -370,24 +370,19 @@ def _get( recording_id (int): The recording id. eager (bool, optional): If true, implement eager loading. Defaults to False. relationships (list[str], optional): Relationships to load eagerly. - Defaults to None. + Defaults to None. Returns: list[BaseModel]: A list of records retrieved from the database table, ordered by timestamp. """ - query = ( + return ( session.query(table) .filter(table.recording_id == recording_id) .order_by(table.timestamp) + .all() ) - if eager and relationships: - for rel in relationships: - query = query.options(joinedload(getattr(table, rel))) - - return query.all() - def get_action_events( session: SaSession, @@ -403,8 +398,18 @@ def get_action_events( list[ActionEvent]: A list of action events for the recording. """ assert recording, "Invalid recording." - action_events = _get( - session, ActionEvent, recording.id, eager=True, relationships=["screenshot"] + action_events = ( + session.query(ActionEvent) + .filter(ActionEvent.recording_id == recording.id) + .options( + joinedload(ActionEvent.recording), + joinedload(ActionEvent.screenshot), + subqueryload(ActionEvent.window_event).joinedload( + WindowEvent.action_events + ), + ) + .order_by(ActionEvent.timestamp) + .all() ) action_events = filter_disabled_action_events(action_events) # filter out stop sequences listed in STOP_SEQUENCES and Ctrl + C @@ -544,12 +549,16 @@ def get_screenshots( Returns: list[Screenshot]: A list of screenshots for the recording. """ - screenshots = _get( - session, - Screenshot, - recording.id, - eager=True, - relationships=["action_event", "recording"], + screenshots = ( + session.query(Screenshot) + .filter(Screenshot.recording_id == recording.id) + .options( + joinedload(Screenshot.action_event).joinedload(ActionEvent.recording), + subqueryload(Screenshot.action_event).joinedload(ActionEvent.screenshot), + subqueryload(Screenshot.recording), + ) + .order_by(Screenshot.timestamp) + .all() ) for prev, cur in zip(screenshots, screenshots[1:]): @@ -575,8 +584,15 @@ def get_window_events( Returns: list[WindowEvent]: A list of window events for the recording. """ - return _get( - session, WindowEvent, recording.id, eager=True, relationships=["action_events"] + return ( + session.query(WindowEvent) + .filter(WindowEvent.recording_id == recording.id) + .options( + joinedload(WindowEvent.recording), + subqueryload(WindowEvent.action_events).joinedload(ActionEvent.screenshot), + ) + .order_by(WindowEvent.timestamp) + .all() ) From bbb7f0fd5d00347bed4b34e62e29aabfddcd675e Mon Sep 17 00:00:00 2001 From: Animesh404 Date: Sat, 6 Jul 2024 00:49:10 +0530 Subject: [PATCH 7/7] chore: refactor code --- openadapt/db/crud.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index 04327364f..1e6bc2649 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -359,8 +359,6 @@ def _get( session: SaSession, table: BaseModelType, recording_id: int, - eager: bool = False, - relationships: list[str] = None, ) -> list[BaseModelType]: """Retrieve records from the database table based on the recording timestamp. @@ -368,9 +366,6 @@ def _get( session (sa.orm.Session): The database session. table (BaseModel): The database table to query. recording_id (int): The recording id. - eager (bool, optional): If true, implement eager loading. Defaults to False. - relationships (list[str], optional): Relationships to load eagerly. - Defaults to None. Returns: list[BaseModel]: A list of records retrieved from the database table,