Skip to content

Commit

Permalink
refactor!: db sessions are now explicitly passed as arguments
Browse files Browse the repository at this point in the history
This helps clarify exactly which functions require the database to be initialized. Also, it helps avoid any potential future issues with relying entirely on a global/thread-local session factory.

The following functions/hooks are affected by this change:
* `add.add_item` - new `session` parameter
* `cli.Hooks.add_command` hook - the sub-command functions are now passed a `session` parameter
* `config.Hooks.register_sa_event_listeners` hook - `session` parameter removed
* `duplicate.resolve_dup_items` - new `session` parameter
* `duplicate.Hooks.resolve_dup_items` hook - new `session` parameter
* `duplicate.resolve_duplicates` - new `session` parameter
* `duplicate.get_duplicates` - new `session` parameter
* `library.lib_item.Hooks.edit_changed_items` hook - new `session` parameter
* `library.lib_item.Hooks.edit_new_items` hook - new `session` parameter
* `library.lib_item.Hooks.process_removed_items` hook - new `session` parameter
* `library.lib_item.Hooks.process_changed_items` hook - new `session` parameter
* `library.lib_item.Hooks.process_new_items` hook - new `session` parameter
* `query.query` - new `session` parameter
* `remove.remove_item` - new `session` parameter
* `util.cli.query.cli_query` - new `session` parameter

`config.MoeSession` has been replaced with `config.moe_sessionmaker`. Sessions should no longer be created by importing MoeSession, and instead should use a session parameter that is created at the top-level of an application. Refer to the `config.py` docstring as well as `cli.py` on how this effects instantiating a session.
  • Loading branch information
jtpavlock committed Dec 20, 2022
1 parent 9cc69db commit 228a017
Show file tree
Hide file tree
Showing 32 changed files with 350 additions and 284 deletions.
18 changes: 11 additions & 7 deletions moe/add/add_cli.py
Expand Up @@ -5,6 +5,8 @@
from pathlib import Path
from typing import Optional, cast

from sqlalchemy.orm.session import Session

import moe
import moe.add
import moe.cli
Expand Down Expand Up @@ -58,12 +60,13 @@ def _skip_import(
raise SkipAdd()


def _parse_args(args: argparse.Namespace):
def _parse_args(session: Session, args: argparse.Namespace):
"""Parses the given commandline arguments.
Tracks can be added as files or albums as directories.
Args:
session: Library db session.
args: Commandline arguments to parse.
Raises:
Expand All @@ -73,7 +76,7 @@ def _parse_args(args: argparse.Namespace):

album: Optional[Album] = None
if args.album_query:
albums = cast(Album, cli_query(args.album_query, "album"))
albums = cast(Album, cli_query(session, args.album_query, "album"))

if len(albums) > 1:
log.error("Query returned more than one album.")
Expand All @@ -84,7 +87,7 @@ def _parse_args(args: argparse.Namespace):
error_count = 0
for path in paths:
try:
_add_path(path, album)
_add_path(session, path, album)
except (AddError, AlbumError) as err:
log.error(err)
error_count += 1
Expand All @@ -95,10 +98,11 @@ def _parse_args(args: argparse.Namespace):
raise SystemExit(1)


def _add_path(path: Path, album: Optional[Album]):
def _add_path(session: Session, path: Path, album: Optional[Album]):
"""Adds an item to the library from a given path.
Args:
session: Library db session.
path: Path to add. Either a directory for an Album or a file for a Track.
album: If ``path`` is a file, add it to ``album`` if given. Note, this
argument is required if adding an Extra.
Expand All @@ -110,15 +114,15 @@ def _add_path(path: Path, album: Optional[Album]):
"""
if path.is_file():
try:
moe.add.add_item(Track.from_file(path, album=album))
moe.add.add_item(session, Track.from_file(path, album=album))
except TrackError:
if not album:
raise AddError(
f"An album query is required to add an extra. [{path=!r}]"
) from None

moe.add.add_item(Extra(album, path))
moe.add.add_item(session, Extra(album, path))
elif path.is_dir():
moe.add.add_item(Album.from_dir(path))
moe.add.add_item(session, Album.from_dir(path))
else:
raise AddError(f"Path not found. [{path=}]")
5 changes: 3 additions & 2 deletions moe/add/add_core.py
Expand Up @@ -6,6 +6,7 @@
import logging

import pluggy
from sqlalchemy.orm.session import Session

import moe
from moe import config
Expand Down Expand Up @@ -61,17 +62,17 @@ class AddAbortError(Exception):
"""Add process has been aborted by the user."""


def add_item(item: LibItem):
def add_item(session: Session, item: LibItem):
"""Adds a LibItem to the library.
Args:
session: Library db session.
item: Item to be added.
Raises:
AddError: Unable to add the item to the library.
"""
log.debug(f"Adding item to the library. [{item=!r}]")
session = config.MoeSession()

config.CONFIG.pm.hook.pre_add(item=item)
session.add(item)
Expand Down
11 changes: 5 additions & 6 deletions moe/cli.py
Expand Up @@ -16,7 +16,7 @@

import moe
from moe import config
from moe.config import Config, ConfigValidationError, MoeSession
from moe.config import Config, ConfigValidationError, moe_sessionmaker

__all__ = ["console"]

Expand Down Expand Up @@ -50,7 +50,7 @@ def add_command(cmd_parsers: argparse._SubParsersAction):
The function will be called like::
func(
config: moe.Config, # user config
session: sqlalchemy.orm.session.Session, # library db session
args: argparse.Namespace, # parsed commandline arguments
)
Expand Down Expand Up @@ -108,12 +108,11 @@ def _parse_args(args: list[str]):
_set_log_lvl(parsed_args)

# call the sub-command's handler within a single session
cli_session = MoeSession()
with cli_session.begin():
with moe_sessionmaker.begin() as session:
try:
parsed_args.func(args=parsed_args)
parsed_args.func(session=session, args=parsed_args)
except SystemExit:
cli_session.commit()
session.commit()
raise


Expand Down
28 changes: 20 additions & 8 deletions moe/config.py
Expand Up @@ -7,6 +7,19 @@
from moe import config
print(config.CONFIG.settings.library_path)
Any application requiring use of the database should initiate a single sqlalchemy
'session'. This session should use ``moe_sessionmaker`` to instantiate a session to
connect to the database::
with moe_sessionmaker.begin() as session:
# do work
See Also:
* `The sqlalchemy Session docs
<https://docs.sqlalchemy.org/en/20/orm/session_basics.html#session-basics>`
* ``moe/cli.py`` for an example on how the CLI handles creating the configuration
and database connection via the session.
"""

import importlib
Expand All @@ -30,8 +43,7 @@
import alembic.config
import moe

session_factory = sqlalchemy.orm.sessionmaker(autoflush=False)
MoeSession = sqlalchemy.orm.scoped_session(session_factory)
moe_sessionmaker = sqlalchemy.orm.sessionmaker(autoflush=False)

__all__ = ["CONFIG", "Config", "ConfigValidationError", "ExtraPlugin"]

Expand Down Expand Up @@ -145,11 +157,11 @@ def plugin_registration():

@staticmethod
@moe.hookspec
def register_sa_event_listeners(session: sqlalchemy.orm.Session):
def register_sa_event_listeners():
"""Registers new sqlalchemy event listeners.
Args:
session: Session to attach the listener to.
These listeners will automatically apply to all sessions globally if the
`Session` class is passed as the listener target as shown in the example.
Important:
This hooks is for Moe internal use only and should not be used by plugins.
Expand All @@ -158,7 +170,7 @@ def register_sa_event_listeners(session: sqlalchemy.orm.Session):
.. code:: python
sqlalchemy.event.listen(
session,
Session,
"before_flush",
_my_func,
)
Expand Down Expand Up @@ -269,7 +281,7 @@ def _init_db(self, create_tables: bool = True):
if not self.engine:
self.engine = sqlalchemy.create_engine("sqlite:///" + str(db_path))

session_factory.configure(bind=self.engine)
moe_sessionmaker.configure(bind=self.engine)

# create and update database tables
if create_tables:
Expand All @@ -282,7 +294,7 @@ def _init_db(self, create_tables: bool = True):
alembic_cfg.attributes["connection"] = connection
alembic.command.upgrade(alembic_cfg, "head")

self.pm.hook.register_sa_event_listeners(config=self, session=MoeSession())
self.pm.hook.register_sa_event_listeners()

# create regular expression function for sqlite queries
@sqlalchemy.event.listens_for(self.engine, "begin")
Expand Down
23 changes: 12 additions & 11 deletions moe/duplicate/dup_cli.py
Expand Up @@ -9,6 +9,7 @@
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from sqlalchemy.orm.session import Session

import moe
import moe.cli
Expand All @@ -24,12 +25,12 @@


@moe.hookimpl(trylast=True)
def resolve_dup_items(item_a: LibItem, item_b: LibItem):
def resolve_dup_items(session: Session, item_a: LibItem, item_b: LibItem):
"""Resolve any library duplicate conflicts using a user prompt."""
console.print(_fmt_item_vs(item_a, item_b))

# Each PromptChoice `func` should have the following signature:
# func(item_a, item_b) # noqa: E800
# func(session, item_a, item_b) # noqa: E800
prompt_choices = [
PromptChoice(title="Keep item A", shortcut_key="a", func=_keep_a),
PromptChoice(title="Keep item B", shortcut_key="b", func=_keep_b),
Expand All @@ -48,37 +49,37 @@ def resolve_dup_items(item_a: LibItem, item_b: LibItem):
prompt_choices,
"Duplicate items found in the library, how would you like to resolve it?",
)
prompt_choice.func(item_a, item_b)
prompt_choice.func(session, item_a, item_b)


def _keep_a(item_a: LibItem, item_b: LibItem):
def _keep_a(session: Session, item_a: LibItem, item_b: LibItem):
"""Keeps `item_a`, removing `item_b` from the library."""
log.debug("Keeping item A.")

remove_item(item_b)
remove_item(session, item_b)


def _keep_b(item_a: LibItem, item_b: LibItem):
def _keep_b(session: Session, item_a: LibItem, item_b: LibItem):
"""Keeps `item_a`, removing `item_b` from the library."""
log.debug("Keeping item B.")

remove_item(item_a)
remove_item(session, item_a)


def _merge(item_a: LibItem, item_b: LibItem):
def _merge(session: Session, item_a: LibItem, item_b: LibItem):
"""Merges `item_a` into `item_b` without overwriting any conflicts."""
log.debug("Merging A -> B without overwriting any conflicts.")

item_b.merge(item_a)
remove_item(item_a)
remove_item(session, item_a)


def _overwrite(item_a: LibItem, item_b: LibItem):
def _overwrite(session: Session, item_a: LibItem, item_b: LibItem):
"""Merges `item_a` into `item_b`, overwriting any conflicts."""
log.debug("Merging A -> B, overwriting B on conflict.")

item_b.merge(item_a, overwrite=True)
remove_item(item_b)
remove_item(session, item_b)


def _fmt_item_vs(item_a: LibItem, item_b: LibItem) -> Columns:
Expand Down
35 changes: 20 additions & 15 deletions moe/duplicate/dup_core.py
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import sqlalchemy
from sqlalchemy.orm.session import Session

import moe
from moe import config
Expand All @@ -24,7 +25,7 @@ class Hooks:

@staticmethod
@moe.hookspec
def resolve_dup_items(item_a: LibItem, item_b: LibItem):
def resolve_dup_items(session: Session, item_a: LibItem, item_b: LibItem):
"""Resolve two duplicate items.
A resolution should come in one of two forms:
Expand All @@ -49,6 +50,7 @@ def resolve_dup_items(item_a: LibItem, item_b: LibItem):
addition to what's offered by default.
Args:
session: Library db session.
item_a: First item.
item_b: Second item.
Expand All @@ -60,30 +62,30 @@ def resolve_dup_items(item_a: LibItem, item_b: LibItem):


@moe.hookimpl(hookwrapper=True)
def edit_changed_items(items: list[LibItem]):
def edit_changed_items(session: Session, items: list[LibItem]):
"""Check for and resolve duplicates when items are edited."""
yield # run all `edit_changed_items` hook implementations
albums = [item for item in items if isinstance(item, Album)] # resolve albums first
tracks = [item for item in items if isinstance(item, Track)]
extras = [item for item in items if isinstance(item, Extra)]
resolve_duplicates(albums) # type: ignore
resolve_duplicates(tracks) # type: ignore
resolve_duplicates(extras) # type: ignore
resolve_duplicates(session, albums) # type: ignore
resolve_duplicates(session, tracks) # type: ignore
resolve_duplicates(session, extras) # type: ignore


@moe.hookimpl(hookwrapper=True)
def edit_new_items(items: list[LibItem]):
def edit_new_items(session: Session, items: list[LibItem]):
"""Check for and resolve duplicates when items are added to the library."""
yield # run all `edit_new_items` hook implementations
albums = [item for item in items if isinstance(item, Album)] # resolve albums first
tracks = [item for item in items if isinstance(item, Track)]
extras = [item for item in items if isinstance(item, Extra)]
resolve_duplicates(albums) # type: ignore
resolve_duplicates(tracks) # type: ignore
resolve_duplicates(extras) # type: ignore
resolve_duplicates(session, albums) # type: ignore
resolve_duplicates(session, tracks) # type: ignore
resolve_duplicates(session, extras) # type: ignore


def resolve_duplicates(items: list[LibItem]):
def resolve_duplicates(session: Session, items: list[LibItem]):
"""Search for and resolve any duplicates of items in ``items``."""
log.debug(f"Checking for duplicate items. [{items=!r}]")

Expand All @@ -92,8 +94,8 @@ def resolve_duplicates(items: list[LibItem]):
if _is_removed(item):
continue

dup_items = get_duplicates(item, items)
dup_items += get_duplicates(item)
dup_items = get_duplicates(session, item, items)
dup_items += get_duplicates(session, item)

for dup_item in dup_items:
if dup_item in resolved_items or _is_removed(item):
Expand All @@ -102,7 +104,9 @@ def resolve_duplicates(items: list[LibItem]):
log.debug(
f"Resolving duplicate items. [item_a={item!r}, item_b={dup_item!r}]"
)
config.CONFIG.pm.hook.resolve_dup_items(item_a=item, item_b=dup_item)
config.CONFIG.pm.hook.resolve_dup_items(
session=session, item_a=item, item_b=dup_item
)
if (
not item.is_unique(dup_item)
and not _is_removed(item)
Expand Down Expand Up @@ -133,11 +137,12 @@ def _is_removed(item):


def get_duplicates(
item: LibItem, others: Optional[list[LibItem]] = None
session: Session, item: LibItem, others: Optional[list[LibItem]] = None
) -> list[LibItem]:
"""Returns items considered duplicates of ``item``.
Args:
session: Library db session.
item: Library item to get duplicates of.
others: Items to compare against. If not given, will query the database
and compare against all items in the library.
Expand All @@ -148,7 +153,7 @@ def get_duplicates(
"""
dup_items = []
if not others:
others = query("*", type(item).__name__.lower())
others = query(session, "*", type(item).__name__.lower())

for other in others:
if (
Expand Down

0 comments on commit 228a017

Please sign in to comment.