diff --git a/moe/add/add_cli.py b/moe/add/add_cli.py index 7e7d5c80..909103c9 100644 --- a/moe/add/add_cli.py +++ b/moe/add/add_cli.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Optional, cast +from sqlalchemy.orm.session import Session + import moe import moe.add import moe.cli @@ -58,12 +60,13 @@ def _skip_import( raise SkipAdd() -def _parse_args(args: argparse.Namespace): +def _parse_args(session: Session, args: argparse.Namespace): """Parses the given commandline arguments. Tracks can be added as files or albums as directories. Args: + session: Library db session. args: Commandline arguments to parse. Raises: @@ -73,7 +76,7 @@ def _parse_args(args: argparse.Namespace): album: Optional[Album] = None if args.album_query: - albums = cast(Album, cli_query(args.album_query, "album")) + albums = cast(Album, cli_query(session, args.album_query, "album")) if len(albums) > 1: log.error("Query returned more than one album.") @@ -84,7 +87,7 @@ def _parse_args(args: argparse.Namespace): error_count = 0 for path in paths: try: - _add_path(path, album) + _add_path(session, path, album) except (AddError, AlbumError) as err: log.error(err) error_count += 1 @@ -95,10 +98,11 @@ def _parse_args(args: argparse.Namespace): raise SystemExit(1) -def _add_path(path: Path, album: Optional[Album]): +def _add_path(session: Session, path: Path, album: Optional[Album]): """Adds an item to the library from a given path. Args: + session: Library db session. path: Path to add. Either a directory for an Album or a file for a Track. album: If ``path`` is a file, add it to ``album`` if given. Note, this argument is required if adding an Extra. @@ -110,15 +114,15 @@ def _add_path(path: Path, album: Optional[Album]): """ if path.is_file(): try: - moe.add.add_item(Track.from_file(path, album=album)) + moe.add.add_item(session, Track.from_file(path, album=album)) except TrackError: if not album: raise AddError( f"An album query is required to add an extra. [{path=!r}]" ) from None - moe.add.add_item(Extra(album, path)) + moe.add.add_item(session, Extra(album, path)) elif path.is_dir(): - moe.add.add_item(Album.from_dir(path)) + moe.add.add_item(session, Album.from_dir(path)) else: raise AddError(f"Path not found. [{path=}]") diff --git a/moe/add/add_core.py b/moe/add/add_core.py index c9687e7d..72ad3724 100644 --- a/moe/add/add_core.py +++ b/moe/add/add_core.py @@ -6,6 +6,7 @@ import logging import pluggy +from sqlalchemy.orm.session import Session import moe from moe import config @@ -61,17 +62,17 @@ class AddAbortError(Exception): """Add process has been aborted by the user.""" -def add_item(item: LibItem): +def add_item(session: Session, item: LibItem): """Adds a LibItem to the library. Args: + session: Library db session. item: Item to be added. Raises: AddError: Unable to add the item to the library. """ log.debug(f"Adding item to the library. [{item=!r}]") - session = config.MoeSession() config.CONFIG.pm.hook.pre_add(item=item) session.add(item) diff --git a/moe/cli.py b/moe/cli.py index 46a71812..ba829774 100755 --- a/moe/cli.py +++ b/moe/cli.py @@ -16,7 +16,7 @@ import moe from moe import config -from moe.config import Config, ConfigValidationError, MoeSession +from moe.config import Config, ConfigValidationError, moe_sessionmaker __all__ = ["console"] @@ -50,7 +50,7 @@ def add_command(cmd_parsers: argparse._SubParsersAction): The function will be called like:: func( - config: moe.Config, # user config + session: sqlalchemy.orm.session.Session, # library db session args: argparse.Namespace, # parsed commandline arguments ) @@ -108,12 +108,11 @@ def _parse_args(args: list[str]): _set_log_lvl(parsed_args) # call the sub-command's handler within a single session - cli_session = MoeSession() - with cli_session.begin(): + with moe_sessionmaker.begin() as session: try: - parsed_args.func(args=parsed_args) + parsed_args.func(session=session, args=parsed_args) except SystemExit: - cli_session.commit() + session.commit() raise diff --git a/moe/config.py b/moe/config.py index 426ad559..6863edfc 100644 --- a/moe/config.py +++ b/moe/config.py @@ -7,6 +7,19 @@ from moe import config print(config.CONFIG.settings.library_path) + +Any application requiring use of the database should initiate a single sqlalchemy +'session'. This session should use ``moe_sessionmaker`` to instantiate a session to +connect to the database:: + + with moe_sessionmaker.begin() as session: + # do work + +See Also: + * `The sqlalchemy Session docs + ` + * ``moe/cli.py`` for an example on how the CLI handles creating the configuration + and database connection via the session. """ import importlib @@ -30,8 +43,7 @@ import alembic.config import moe -session_factory = sqlalchemy.orm.sessionmaker(autoflush=False) -MoeSession = sqlalchemy.orm.scoped_session(session_factory) +moe_sessionmaker = sqlalchemy.orm.sessionmaker(autoflush=False) __all__ = ["CONFIG", "Config", "ConfigValidationError", "ExtraPlugin"] @@ -145,11 +157,11 @@ def plugin_registration(): @staticmethod @moe.hookspec - def register_sa_event_listeners(session: sqlalchemy.orm.Session): + def register_sa_event_listeners(): """Registers new sqlalchemy event listeners. - Args: - session: Session to attach the listener to. + These listeners will automatically apply to all sessions globally if the + `Session` class is passed as the listener target as shown in the example. Important: This hooks is for Moe internal use only and should not be used by plugins. @@ -158,7 +170,7 @@ def register_sa_event_listeners(session: sqlalchemy.orm.Session): .. code:: python sqlalchemy.event.listen( - session, + Session, "before_flush", _my_func, ) @@ -269,7 +281,7 @@ def _init_db(self, create_tables: bool = True): if not self.engine: self.engine = sqlalchemy.create_engine("sqlite:///" + str(db_path)) - session_factory.configure(bind=self.engine) + moe_sessionmaker.configure(bind=self.engine) # create and update database tables if create_tables: @@ -282,7 +294,7 @@ def _init_db(self, create_tables: bool = True): alembic_cfg.attributes["connection"] = connection alembic.command.upgrade(alembic_cfg, "head") - self.pm.hook.register_sa_event_listeners(config=self, session=MoeSession()) + self.pm.hook.register_sa_event_listeners() # create regular expression function for sqlite queries @sqlalchemy.event.listens_for(self.engine, "begin") diff --git a/moe/duplicate/dup_cli.py b/moe/duplicate/dup_cli.py index df1175f4..6f67023d 100644 --- a/moe/duplicate/dup_cli.py +++ b/moe/duplicate/dup_cli.py @@ -9,6 +9,7 @@ from rich.panel import Panel from rich.table import Table from rich.text import Text +from sqlalchemy.orm.session import Session import moe import moe.cli @@ -24,12 +25,12 @@ @moe.hookimpl(trylast=True) -def resolve_dup_items(item_a: LibItem, item_b: LibItem): +def resolve_dup_items(session: Session, item_a: LibItem, item_b: LibItem): """Resolve any library duplicate conflicts using a user prompt.""" console.print(_fmt_item_vs(item_a, item_b)) # Each PromptChoice `func` should have the following signature: - # func(item_a, item_b) # noqa: E800 + # func(session, item_a, item_b) # noqa: E800 prompt_choices = [ PromptChoice(title="Keep item A", shortcut_key="a", func=_keep_a), PromptChoice(title="Keep item B", shortcut_key="b", func=_keep_b), @@ -48,37 +49,37 @@ def resolve_dup_items(item_a: LibItem, item_b: LibItem): prompt_choices, "Duplicate items found in the library, how would you like to resolve it?", ) - prompt_choice.func(item_a, item_b) + prompt_choice.func(session, item_a, item_b) -def _keep_a(item_a: LibItem, item_b: LibItem): +def _keep_a(session: Session, item_a: LibItem, item_b: LibItem): """Keeps `item_a`, removing `item_b` from the library.""" log.debug("Keeping item A.") - remove_item(item_b) + remove_item(session, item_b) -def _keep_b(item_a: LibItem, item_b: LibItem): +def _keep_b(session: Session, item_a: LibItem, item_b: LibItem): """Keeps `item_a`, removing `item_b` from the library.""" log.debug("Keeping item B.") - remove_item(item_a) + remove_item(session, item_a) -def _merge(item_a: LibItem, item_b: LibItem): +def _merge(session: Session, item_a: LibItem, item_b: LibItem): """Merges `item_a` into `item_b` without overwriting any conflicts.""" log.debug("Merging A -> B without overwriting any conflicts.") item_b.merge(item_a) - remove_item(item_a) + remove_item(session, item_a) -def _overwrite(item_a: LibItem, item_b: LibItem): +def _overwrite(session: Session, item_a: LibItem, item_b: LibItem): """Merges `item_a` into `item_b`, overwriting any conflicts.""" log.debug("Merging A -> B, overwriting B on conflict.") item_b.merge(item_a, overwrite=True) - remove_item(item_b) + remove_item(session, item_b) def _fmt_item_vs(item_a: LibItem, item_b: LibItem) -> Columns: diff --git a/moe/duplicate/dup_core.py b/moe/duplicate/dup_core.py index 8f4e2493..490d42f9 100644 --- a/moe/duplicate/dup_core.py +++ b/moe/duplicate/dup_core.py @@ -4,6 +4,7 @@ from typing import Optional import sqlalchemy +from sqlalchemy.orm.session import Session import moe from moe import config @@ -24,7 +25,7 @@ class Hooks: @staticmethod @moe.hookspec - def resolve_dup_items(item_a: LibItem, item_b: LibItem): + def resolve_dup_items(session: Session, item_a: LibItem, item_b: LibItem): """Resolve two duplicate items. A resolution should come in one of two forms: @@ -49,6 +50,7 @@ def resolve_dup_items(item_a: LibItem, item_b: LibItem): addition to what's offered by default. Args: + session: Library db session. item_a: First item. item_b: Second item. @@ -60,30 +62,30 @@ def resolve_dup_items(item_a: LibItem, item_b: LibItem): @moe.hookimpl(hookwrapper=True) -def edit_changed_items(items: list[LibItem]): +def edit_changed_items(session: Session, items: list[LibItem]): """Check for and resolve duplicates when items are edited.""" yield # run all `edit_changed_items` hook implementations albums = [item for item in items if isinstance(item, Album)] # resolve albums first tracks = [item for item in items if isinstance(item, Track)] extras = [item for item in items if isinstance(item, Extra)] - resolve_duplicates(albums) # type: ignore - resolve_duplicates(tracks) # type: ignore - resolve_duplicates(extras) # type: ignore + resolve_duplicates(session, albums) # type: ignore + resolve_duplicates(session, tracks) # type: ignore + resolve_duplicates(session, extras) # type: ignore @moe.hookimpl(hookwrapper=True) -def edit_new_items(items: list[LibItem]): +def edit_new_items(session: Session, items: list[LibItem]): """Check for and resolve duplicates when items are added to the library.""" yield # run all `edit_new_items` hook implementations albums = [item for item in items if isinstance(item, Album)] # resolve albums first tracks = [item for item in items if isinstance(item, Track)] extras = [item for item in items if isinstance(item, Extra)] - resolve_duplicates(albums) # type: ignore - resolve_duplicates(tracks) # type: ignore - resolve_duplicates(extras) # type: ignore + resolve_duplicates(session, albums) # type: ignore + resolve_duplicates(session, tracks) # type: ignore + resolve_duplicates(session, extras) # type: ignore -def resolve_duplicates(items: list[LibItem]): +def resolve_duplicates(session: Session, items: list[LibItem]): """Search for and resolve any duplicates of items in ``items``.""" log.debug(f"Checking for duplicate items. [{items=!r}]") @@ -92,8 +94,8 @@ def resolve_duplicates(items: list[LibItem]): if _is_removed(item): continue - dup_items = get_duplicates(item, items) - dup_items += get_duplicates(item) + dup_items = get_duplicates(session, item, items) + dup_items += get_duplicates(session, item) for dup_item in dup_items: if dup_item in resolved_items or _is_removed(item): @@ -102,7 +104,9 @@ def resolve_duplicates(items: list[LibItem]): log.debug( f"Resolving duplicate items. [item_a={item!r}, item_b={dup_item!r}]" ) - config.CONFIG.pm.hook.resolve_dup_items(item_a=item, item_b=dup_item) + config.CONFIG.pm.hook.resolve_dup_items( + session=session, item_a=item, item_b=dup_item + ) if ( not item.is_unique(dup_item) and not _is_removed(item) @@ -133,11 +137,12 @@ def _is_removed(item): def get_duplicates( - item: LibItem, others: Optional[list[LibItem]] = None + session: Session, item: LibItem, others: Optional[list[LibItem]] = None ) -> list[LibItem]: """Returns items considered duplicates of ``item``. Args: + session: Library db session. item: Library item to get duplicates of. others: Items to compare against. If not given, will query the database and compare against all items in the library. @@ -148,7 +153,7 @@ def get_duplicates( """ dup_items = [] if not others: - others = query("*", type(item).__name__.lower()) + others = query(session, "*", type(item).__name__.lower()) for other in others: if ( diff --git a/moe/edit/edit_cli.py b/moe/edit/edit_cli.py index e2aed169..a079d890 100644 --- a/moe/edit/edit_cli.py +++ b/moe/edit/edit_cli.py @@ -3,6 +3,8 @@ import argparse import logging +from sqlalchemy.orm.session import Session + import moe import moe.cli from moe import edit @@ -37,17 +39,18 @@ def add_command(cmd_parsers: argparse._SubParsersAction): edit_parser.set_defaults(func=_parse_args) -def _parse_args(args: argparse.Namespace): # noqa: C901 +def _parse_args(session: Session, args: argparse.Namespace): """Parses the given commandline arguments. Args: + session: Library db session. args: Commandline arguments to parse. Raises: SystemExit: Invalid query, no items found to edit, or invalid field or field_value term format. """ - items = cli_query(args.query, args.query_type) + items = cli_query(session, args.query, args.query_type) error_count = 0 for term in args.fv_terms: diff --git a/moe/library/lib_item.py b/moe/library/lib_item.py index 8c1c0fff..d28b800f 100644 --- a/moe/library/lib_item.py +++ b/moe/library/lib_item.py @@ -11,6 +11,7 @@ import sqlalchemy.orm from sqlalchemy import Column, Integer from sqlalchemy.orm import declarative_base +from sqlalchemy.orm.session import Session import moe from moe import config @@ -31,10 +32,11 @@ class Hooks: @staticmethod @moe.hookspec - def edit_changed_items(items: list["LibItem"]): + def edit_changed_items(session: Session, items: list["LibItem"]): """Edit items in the library that were changed in some way. Args: + session: Library db session. items: Any changed items that existed in the library prior to the current session. @@ -45,10 +47,11 @@ def edit_changed_items(items: list["LibItem"]): @staticmethod @moe.hookspec - def edit_new_items(items: list["LibItem"]): + def edit_new_items(session: Session, items: list["LibItem"]): """Edit new items in the library. Args: + session: Library db session. items: Any items being added to the library for the first time. See Also: @@ -58,20 +61,22 @@ def edit_new_items(items: list["LibItem"]): @staticmethod @moe.hookspec - def process_removed_items(items: list["LibItem"]): + def process_removed_items(session: Session, items: list["LibItem"]): """Process items that have been removed from the library. Args: + session: Library db session. items: Any items that existed in the library prior to the current session, but have now been removed from the library. """ @staticmethod @moe.hookspec - def process_changed_items(items: list["LibItem"]): + def process_changed_items(session: Session, items: list["LibItem"]): """Process items in the library that were changed in some way. Args: + session: Library db session. items: Any changed items that existed in the library prior to the current session. @@ -85,10 +90,11 @@ def process_changed_items(items: list["LibItem"]): @staticmethod @moe.hookspec - def process_new_items(items: list["LibItem"]): + def process_new_items(session: Session, items: list["LibItem"]): """Process new items in the library. Args: + session: Library db session. items: Any items being added to the library for the first time. Important: @@ -109,15 +115,15 @@ def add_hooks(pm: pluggy.manager.PluginManager): @moe.hookimpl -def register_sa_event_listeners(session: sqlalchemy.orm.Session): +def register_sa_event_listeners(): """Registers event listeners for editing and processing new items.""" sqlalchemy.event.listen( - session, + Session, "before_flush", _edit_before_flush, ) sqlalchemy.event.listen( - session, + Session, "after_flush", _process_after_flush, ) @@ -148,7 +154,7 @@ def _edit_before_flush( changed_items.append(dirty_item) if changed_items: log.debug(f"Editing changed items. [{changed_items=!r}]") - config.CONFIG.pm.hook.edit_changed_items(items=changed_items) + config.CONFIG.pm.hook.edit_changed_items(session=session, items=changed_items) log.debug(f"Edited changed items. [{changed_items=!r}]") new_items = [] @@ -157,7 +163,7 @@ def _edit_before_flush( new_items.append(new_item) if new_items: log.debug(f"Editing new items. [{new_items=!r}]") - config.CONFIG.pm.hook.edit_new_items(items=new_items) + config.CONFIG.pm.hook.edit_new_items(session=session, items=new_items) log.debug(f"Edited new items. [{new_items=!r}]") @@ -184,7 +190,9 @@ def _process_after_flush( changed_items.append(dirty_item) if changed_items: log.debug(f"Processing changed items. [{changed_items=!r}]") - config.CONFIG.pm.hook.process_changed_items(items=changed_items) + config.CONFIG.pm.hook.process_changed_items( + session=session, items=changed_items + ) log.debug(f"Processed changed items. [{changed_items=!r}]") new_items = [] @@ -193,7 +201,7 @@ def _process_after_flush( new_items.append(new_item) if new_items: log.debug(f"Processing new items. [{new_items=!r}]") - config.CONFIG.pm.hook.process_new_items(items=new_items) + config.CONFIG.pm.hook.process_new_items(session=session, items=new_items) log.debug(f"Processed new items. [{new_items=!r}]") removed_items = [] @@ -202,7 +210,9 @@ def _process_after_flush( removed_items.append(removed_item) if removed_items: log.debug(f"Processing removed items. [{removed_items=!r}]") - config.CONFIG.pm.hook.process_removed_items(items=removed_items) + config.CONFIG.pm.hook.process_removed_items( + session=session, items=removed_items + ) log.debug(f"Processed removed items. [{removed_items=!r}]") diff --git a/moe/list.py b/moe/list.py index bddf945d..475575fc 100644 --- a/moe/list.py +++ b/moe/list.py @@ -9,6 +9,8 @@ from collections import OrderedDict from typing import Any +from sqlalchemy.orm.session import Session + import moe import moe.cli from moe import config @@ -53,16 +55,17 @@ def add_command(cmd_parsers: argparse._SubParsersAction): ls_parser.set_defaults(func=_parse_args) -def _parse_args(args: argparse.Namespace): +def _parse_args(session: Session, args: argparse.Namespace): """Parses the given commandline arguments. Args: + session: Library db session. args: Commandline arguments to parse. Raises: SystemExit: Invalid query or no items found. """ - items = cli_query(args.query, query_type=args.query_type) + items = cli_query(session, args.query, query_type=args.query_type) items.sort() if args.info: diff --git a/moe/move/move_cli.py b/moe/move/move_cli.py index 32ec0070..52b29af7 100644 --- a/moe/move/move_cli.py +++ b/moe/move/move_cli.py @@ -5,6 +5,7 @@ from typing import cast import sqlalchemy.orm +from sqlalchemy.orm.session import Session import moe from moe import move as moe_move @@ -32,18 +33,19 @@ def add_command(cmd_parsers: argparse._SubParsersAction): move_parser.set_defaults(func=_parse_args) -def _parse_args(args: argparse.Namespace): +def _parse_args(session: Session, args: argparse.Namespace): """Parses the given commandline arguments. Items will be moved according to the given user configuration. Args: + session: Library db session. args: Commandline arguments to parse. Raises: SystemExit: Invalid query or no items found to move. """ - albums = cast(list[Album], cli_query("*", query_type="album")) + albums = cast(list[Album], cli_query(session, "*", query_type="album")) if args.dry_run: dry_run_str = _dry_run(albums) diff --git a/moe/query.py b/moe/query.py index 1e0e5251..2393556e 100644 --- a/moe/query.py +++ b/moe/query.py @@ -9,8 +9,8 @@ import sqlalchemy as sa import sqlalchemy.orm import sqlalchemy.sql.elements +from sqlalchemy.orm.session import Session -from moe.config import MoeSession from moe.library import Album, Extra, LibItem, Track from moe.library.lib_item import SetType @@ -66,10 +66,11 @@ class QueryError(Exception): VALUE = "value" -def query(query_str: str, query_type: str) -> list[LibItem]: +def query(session: Session, query_str: str, query_type: str) -> list[LibItem]: """Queries the database for items matching the given query string. Args: + session: Library db session. query_str: Query string to parse. See HELP_STR for more info. query_type: Type of library item to return: either 'album', 'extra', or 'track'. @@ -83,7 +84,6 @@ def query(query_str: str, query_type: str) -> list[LibItem]: `The query docs `_ """ log.debug(f"Querying library for items. [{query_str=!r}, {query_type=!r}]") - session = MoeSession() terms = shlex.split(query_str) if not terms: diff --git a/moe/read/read_cli.py b/moe/read/read_cli.py index 0e5b50fb..c2c938b2 100644 --- a/moe/read/read_cli.py +++ b/moe/read/read_cli.py @@ -3,6 +3,8 @@ import argparse import logging +from sqlalchemy.orm.session import Session + import moe import moe.cli from moe import read, remove @@ -31,18 +33,19 @@ def add_command(cmd_parsers: argparse._SubParsersAction): read_parser.set_defaults(func=_parse_args) -def _parse_args(args: argparse.Namespace): +def _parse_args(session: Session, args: argparse.Namespace): """Parses the given commandline arguments. Tracks can be added as files or albums as directories. Args: + session: Library db session. args: Commandline arguments to parse. Raises: SystemExit: Path given does not exist. """ - items = cli_query(args.query, args.query_type) + items = cli_query(session, args.query, args.query_type) error_count = 0 for item in items: @@ -50,7 +53,7 @@ def _parse_args(args: argparse.Namespace): read.read_item(item) except FileNotFoundError: if args.remove: - remove.remove_item(item) + remove.remove_item(session, item) else: log.error(f"Could not find item's path. [{item=!r}]") error_count += 1 diff --git a/moe/remove/rm_cli.py b/moe/remove/rm_cli.py index 3e9ac54f..080266f6 100644 --- a/moe/remove/rm_cli.py +++ b/moe/remove/rm_cli.py @@ -7,6 +7,8 @@ import argparse import logging +from sqlalchemy.orm.session import Session + import moe import moe.cli from moe import remove as moe_rm @@ -30,16 +32,17 @@ def add_command(cmd_parsers: argparse._SubParsersAction): rm_parser.set_defaults(func=_parse_args) -def _parse_args(args: argparse.Namespace): +def _parse_args(session: Session, args: argparse.Namespace): """Parses the given commandline arguments. Args: + session: Library db session. args: Commandline arguments to parse. Raises: SystemExit: Invalid query given, or no items to remove. """ - items = cli_query(args.query, query_type=args.query_type) + items = cli_query(session, args.query, query_type=args.query_type) for item in items: - moe_rm.remove_item(item) + moe_rm.remove_item(session, item) diff --git a/moe/remove/rm_core.py b/moe/remove/rm_core.py index b3c03080..2ed949b1 100644 --- a/moe/remove/rm_core.py +++ b/moe/remove/rm_core.py @@ -6,7 +6,6 @@ import sqlalchemy.exc from sqlalchemy.orm.session import Session -from moe.config import MoeSession from moe.library import Extra, LibItem, Track __all__ = ["remove_item"] @@ -14,10 +13,9 @@ log = logging.getLogger("moe.remove") -def remove_item(item: LibItem): +def remove_item(session: Session, item: LibItem): """Removes an item from the library.""" log.debug(f"Removing item from the library. [{item=!r}]") - session = MoeSession() insp = sqlalchemy.inspect(item) if insp.persistent: diff --git a/moe/util/cli/query.py b/moe/util/cli/query.py index 0373e60a..57112ac4 100644 --- a/moe/util/cli/query.py +++ b/moe/util/cli/query.py @@ -3,6 +3,8 @@ import argparse import logging +from sqlalchemy.orm.session import Session + from moe.library import LibItem from moe.query import QueryError, query @@ -58,10 +60,11 @@ query_parser.set_defaults(query_type="track") -def cli_query(query_str: str, query_type: str) -> list[LibItem]: +def cli_query(session: Session, query_str: str, query_type: str) -> list[LibItem]: """Wrapper around the core query call, with some added cli error handling. Args: + session: Library db session. query_str: Query string to parse. See HELP_STR for more info. query_type: Type of library item to return: either 'album', 'extra', or 'track'. @@ -72,7 +75,7 @@ def cli_query(query_str: str, query_type: str) -> list[LibItem]: SystemExit: QueryError or no items returned from the query. """ try: - items = query(query_str, query_type) + items = query(session, query_str, query_type) except QueryError as err: log.error(err) raise SystemExit(1) from err diff --git a/tests/add/test_add_cli.py b/tests/add/test_add_cli.py index f4e13cc2..7269b4c0 100644 --- a/tests/add/test_add_cli.py +++ b/tests/add/test_add_cli.py @@ -2,7 +2,7 @@ from types import FunctionType from typing import Iterator -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest @@ -68,7 +68,7 @@ def test_add_choice(self): assert any(choice.shortcut_key == "s" for choice in prompt_choices) - def test_skip_item(self, tmp_config): + def test_skip_item(self, tmp_config, tmp_session): """We can skip adding items to the library.""" tmp_config( 'default_plugins = ["cli", "add", "import", "write"]', @@ -85,7 +85,7 @@ def test_skip_item(self, tmp_config): ): moe.cli.main(cli_args) - assert not config.MoeSession().query(Track).all() + assert not tmp_session.query(Track).all() @pytest.mark.usefixtures("_tmp_add_config") @@ -99,7 +99,7 @@ def test_track_file(self, mock_add): moe.cli.main(cli_args) - mock_add.assert_called_once_with(track) + mock_add.assert_called_once_with(ANY, track) def test_non_track_file(self, mock_add): """Raise SystemExit if bad track file given.""" @@ -129,8 +129,8 @@ def test_multiple_items(self, mock_add): moe.cli.main(cli_args) - mock_add.assert_any_call(track) - mock_add.assert_any_call(album) + mock_add.assert_any_call(ANY, track) + mock_add.assert_any_call(ANY, album) assert mock_add.call_count == 2 def test_single_error(self, tmp_path, mock_add): @@ -145,7 +145,7 @@ def test_single_error(self, tmp_path, mock_add): moe.cli.main(cli_args) assert error.value.code != 0 - mock_add.assert_called_once_with(track) + mock_add.assert_called_once_with(ANY, track) def test_extra_file(self, mock_add, mock_query): """Extra files are added as tracks.""" @@ -155,7 +155,7 @@ def test_extra_file(self, mock_add, mock_query): moe.cli.main(cli_args) - mock_add.assert_called_once_with(extra) + mock_add.assert_called_once_with(ANY, extra) def test_extra_no_album(self, mock_add): """Raise SystemExit if trying to add an extra but no query was given.""" diff --git a/tests/add/test_add_core.py b/tests/add/test_add_core.py index be4911f9..832182e6 100644 --- a/tests/add/test_add_core.py +++ b/tests/add/test_add_core.py @@ -34,7 +34,7 @@ class TestAddItem: def test_track(self, tmp_session): """We can add tracks to the library.""" track = track_factory() - moe.add.add_item(track) + moe.add.add_item(tmp_session, track) assert tmp_session.query(Track).one() == track @@ -42,7 +42,7 @@ def test_track(self, tmp_session): def test_album(self, tmp_session): """We can add albums to the library.""" album = album_factory() - moe.add.add_item(album) + moe.add.add_item(tmp_session, album) assert tmp_session.query(Album).one() == album @@ -50,7 +50,7 @@ def test_album(self, tmp_session): def test_extra(self, tmp_session): """We can add extras to the library.""" extra = extra_factory() - moe.add.add_item(extra) + moe.add.add_item(tmp_session, extra) assert tmp_session.query(Extra).one() == extra @@ -62,7 +62,7 @@ def test_hooks(self, tmp_config, tmp_session): tmp_db=True, ) - moe.add.add_item(track_factory()) + moe.add.add_item(tmp_session, track_factory()) db_track = tmp_session.query(Track).one() assert db_track.title == "pre_add" @@ -76,7 +76,7 @@ def test_duplicate_list_fields_album(self, tmp_session): track1.genre = "pop" track2.genre = "pop" - moe.add.add_item(album) + moe.add.add_item(tmp_session, album) db_tracks = tmp_session.query(Track).all() for track in db_tracks: @@ -88,8 +88,8 @@ def test_duplicate_list_field_tracks(self, tmp_session): track1 = track_factory(genres={"pop"}) track2 = track_factory(genres={"pop"}) - moe.add.add_item(track1) - moe.add.add_item(track2) + moe.add.add_item(tmp_session, track1) + moe.add.add_item(tmp_session, track2) db_tracks = tmp_session.query(Track).all() for track in db_tracks: diff --git a/tests/conftest.py b/tests/conftest.py index 7097b0db..96811787 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ import moe.write from moe import config -from moe.config import Config, ExtraPlugin, MoeSession, session_factory +from moe.config import Config, ExtraPlugin, moe_sessionmaker from moe.library import Album, Extra, Track __all__ = ["album_factory", "extra_factory", "track_factory"] @@ -72,9 +72,7 @@ def tmp_config( Args: settings: Settings string to use. This has the same format as a normal ``config.toml`` file. - init_db: Whether or not to initialize the database. - tmp_db: Whether or not to use a temporary (in-memory) database. If ``True``, - the database will be initialized regardless of ``init_db``. + tmp_db: Whether or not to use a temporary (in-memory) database. extra_plugins: Any additional plugins to enable. config_dir: Optionally specifiy a config directory to use. @@ -114,7 +112,7 @@ def _tmp_config( ) yield _tmp_config - session_factory.configure(bind=None) # reset the database in between tests + moe_sessionmaker.configure(bind=None) # reset the database in between tests @pytest.fixture @@ -128,22 +126,17 @@ def tmp_session(tmp_config) -> Iterator[sqlalchemy.orm.session.Session]: The temporary session. """ try: - MoeSession().get_bind() + moe_sessionmaker().get_bind() except sqlalchemy.exc.UnboundExecutionError: - MoeSession.remove() tmp_config("default_plugins = []", tmp_db=True) - session = MoeSession() - with session.begin(): + with moe_sessionmaker.begin() as session: yield session - MoeSession.remove() - @pytest.fixture(autouse=True) def _clean_moe(): - """Ensure we aren't sharing sessions or configs between tests.""" - MoeSession.remove() + """Ensure we aren't sharing configs between tests.""" config.CONFIG = MagicMock() diff --git a/tests/duplicate/test_dup_cli.py b/tests/duplicate/test_dup_cli.py index f3397822..46d0a690 100644 --- a/tests/duplicate/test_dup_cli.py +++ b/tests/duplicate/test_dup_cli.py @@ -1,7 +1,7 @@ """Test the duplicate plugin cli.""" import datetime -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -27,13 +27,18 @@ def test_choices_called(self): track_a = track_factory() track_b = track_factory() + mock_session = MagicMock() with patch( "moe.duplicate.dup_cli.choice_prompt", autospec=True, ) as mock_prompt_choice: - config.CONFIG.pm.hook.resolve_dup_items(item_a=track_a, item_b=track_b) + config.CONFIG.pm.hook.resolve_dup_items( + session=mock_session, item_a=track_a, item_b=track_b + ) - mock_prompt_choice.return_value.func.assert_called_once_with(track_a, track_b) + mock_prompt_choice.return_value.func.assert_called_once_with( + mock_session, track_a, track_b + ) def test_keep_a(self, tmp_session): """When keeping item a.""" @@ -42,7 +47,7 @@ def test_keep_a(self, tmp_session): tmp_session.add_all([track_a, track_b]) tmp_session.flush() - dup_cli._keep_a(track_a, track_b) + dup_cli._keep_a(tmp_session, track_a, track_b) db_track = tmp_session.query(Track).one() assert db_track.title == "a" @@ -54,7 +59,7 @@ def test_keep_b(self, tmp_session): tmp_session.add_all([track_a, track_b]) tmp_session.flush() - dup_cli._keep_b(track_a, track_b) + dup_cli._keep_b(tmp_session, track_a, track_b) db_track = tmp_session.query(Track).one() assert db_track.title == "b" @@ -67,7 +72,7 @@ def test_merge(self, tmp_session): tmp_session.add_all([track_a, track_b]) tmp_session.flush() - dup_cli._merge(track_a, track_b) + dup_cli._merge(tmp_session, track_a, track_b) db_track = tmp_session.query(Track).one() assert db_track.title == "b" @@ -81,7 +86,7 @@ def test_overwrite(self, tmp_session): tmp_session.add_all([track_a, track_b]) tmp_session.flush() - dup_cli._overwrite(track_a, track_b) + dup_cli._overwrite(tmp_session, track_a, track_b) db_track = tmp_session.query(Track).one() assert db_track.title == "a" diff --git a/tests/duplicate/test_dup_core.py b/tests/duplicate/test_dup_core.py index 25f9b68b..e19ef019 100644 --- a/tests/duplicate/test_dup_core.py +++ b/tests/duplicate/test_dup_core.py @@ -6,7 +6,7 @@ import moe from moe import remove -from moe.config import ExtraPlugin, MoeSession +from moe.config import ExtraPlugin from moe.library import Album, Extra, Track from tests.conftest import album_factory, extra_factory, track_factory @@ -16,13 +16,13 @@ class DuplicatePlugin: @staticmethod @moe.hookimpl - def resolve_dup_items(item_a, item_b): + def resolve_dup_items(session, item_a, item_b): """Resolve duplicates.""" if isinstance(item_a, (Track, Album)): if item_a.title == "remove me": - remove.remove_item(item_a) + remove.remove_item(session, item_a) if item_b.title == "remove me": - remove.remove_item(item_b) + remove.remove_item(session, item_b) if item_a.title == "change me": dest = item_a.path.parent / "new.mp3" shutil.copyfile(item_a.path, dest) @@ -45,103 +45,96 @@ def _tmp_dup_config(tmp_config): class TestResolveDupItems: """Test ``resolve_dup_items()``.""" - def test_remove_a(self): + def test_remove_a(self, tmp_session): """Remove a track.""" track_a = track_factory(exists=True, title="remove me") track_b = track_factory(exists=True, path=track_a.path) - session = MoeSession() - session.add(track_a) - session.add(track_b) - session.flush() + tmp_session.add(track_a) + tmp_session.add(track_b) + tmp_session.flush() - db_track = session.query(Track).one() + db_track = tmp_session.query(Track).one() assert db_track == track_b - def test_remove_b(self): + def test_remove_b(self, tmp_session): """Remove b track.""" track_a = track_factory(exists=True) track_b = track_factory(exists=True, path=track_a.path, title="remove me") - session = MoeSession() - session.add(track_a) - session.add(track_b) - session.flush() + tmp_session.add(track_a) + tmp_session.add(track_b) + tmp_session.flush() - db_track = session.query(Track).one() + db_track = tmp_session.query(Track).one() assert db_track == track_a - def test_rm_existing_track(self): + def test_rm_existing_track(self, tmp_session): """Remove b track.""" track_a = track_factory(exists=True, title="remove me") track_b = track_factory(exists=True, path=track_a.path) - session = MoeSession() - session.add(track_a) - session.flush() - session.add(track_b) - session.flush() + tmp_session.add(track_a) + tmp_session.flush() + tmp_session.add(track_b) + tmp_session.flush() - db_track = session.query(Track).one() + db_track = tmp_session.query(Track).one() assert db_track == track_b - def test_changing_fields(self): + def test_changing_fields(self, tmp_session): """Duplicates can be avoided by changing conflicting fields.""" track_a = track_factory(exists=True, title="change me") track_b = track_factory(exists=True, path=track_a.path) - session = MoeSession() - session.add(track_a) - session.add(track_b) - session.flush() + tmp_session.add(track_a) + tmp_session.add(track_b) + tmp_session.flush() - db_tracks = session.query(Track).all() + db_tracks = tmp_session.query(Track).all() assert track_a in db_tracks assert track_b in db_tracks - def test_change_extra(self): + def test_change_extra(self, tmp_session): """Duplicate extras can be avoided.""" extra_a = extra_factory(exists=True) extra_b = extra_factory(exists=True, path=extra_a.path) - session = MoeSession() - session.add(extra_a) - session.add(extra_b) - session.flush() + tmp_session.add(extra_a) + tmp_session.add(extra_b) + tmp_session.flush() - db_extras = session.query(Extra).all() + db_extras = tmp_session.query(Extra).all() assert extra_a in db_extras assert extra_b in db_extras - def test_remove_album(self): + def test_remove_album(self, tmp_session): """Remove an album.""" album_a = album_factory(exists=True, title="remove me") album_b = album_factory(exists=True, path=album_a.path) - session = MoeSession() - session.add(album_a) - session.add(album_b) - session.flush() + tmp_session.add(album_a) + tmp_session.add(album_b) + tmp_session.flush() - db_album = session.query(Album).one() + db_album = tmp_session.query(Album).one() assert db_album == album_b - def test_album_first(self): + def test_album_first(self, tmp_session): """Albums should be processed first as they may resolve tracks or extras too.""" album_a = album_factory(exists=True, title="remove me") album_b = album_factory(exists=True, path=album_a.path) album_b.tracks[0].path = album_a.tracks[0].path album_a.tracks[0].title = "change me" # shouldn't get changed as - session = MoeSession() - session.add(album_a.tracks[0]) - session.add(album_b.tracks[0]) - session.add(album_a) - session.add(album_b) - session.flush() + tmp_session.add(album_a.tracks[0]) + tmp_session.add(album_b.tracks[0]) + tmp_session.add(album_a) + tmp_session.add(album_b) + tmp_session.flush() - db_album = session.query(Album).one() + db_album = tmp_session.query(Album).one() assert db_album == album_b - db_tracks = session.query(Track).all() + db_tracks = tmp_session.query(Track).all() for track in db_tracks: assert track.title != "changed" diff --git a/tests/edit/test_edit_cli.py b/tests/edit/test_edit_cli.py index 23de80d5..bcd7af1f 100644 --- a/tests/edit/test_edit_cli.py +++ b/tests/edit/test_edit_cli.py @@ -2,7 +2,7 @@ from types import FunctionType from typing import Iterator -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest @@ -50,7 +50,7 @@ def test_track(self, mock_query, mock_edit): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="track") + mock_query.assert_called_once_with(ANY, "*", query_type="track") mock_edit.assert_called_once_with(track, "track_num", "3") def test_album(self, mock_query, mock_edit): @@ -61,7 +61,7 @@ def test_album(self, mock_query, mock_edit): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="album") + mock_query.assert_called_once_with(ANY, "*", query_type="album") mock_edit.assert_called_once_with(album, "title", "edit") def test_extra(self, mock_query, mock_edit): @@ -72,7 +72,7 @@ def test_extra(self, mock_query, mock_edit): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="extra") + mock_query.assert_called_once_with(ANY, "*", query_type="extra") mock_edit.assert_called_once_with(extra, "title", "edit") def test_multiple_items(self, mock_query, mock_edit): diff --git a/tests/library/test_lib_item.py b/tests/library/test_lib_item.py index 005c6821..c5a0e15f 100644 --- a/tests/library/test_lib_item.py +++ b/tests/library/test_lib_item.py @@ -1,7 +1,7 @@ """Test shared library functionality.""" import moe -from moe.config import ExtraPlugin, MoeSession +from moe.config import ExtraPlugin, moe_sessionmaker from moe.library import Album, Extra, Track from tests.conftest import album_factory, extra_factory, track_factory @@ -11,35 +11,35 @@ class LibItemPlugin: @staticmethod @moe.hookimpl - def edit_changed_items(items): + def edit_changed_items(session, items): """Edit changed items.""" for item in items: item.custom["changed"] = "edited" @staticmethod @moe.hookimpl - def edit_new_items(items): + def edit_new_items(session, items): """Edit new items.""" for item in items: item.custom["new"] = "edited" @staticmethod @moe.hookimpl - def process_changed_items(items): + def process_changed_items(session, items): """Process changed items.""" for item in items: item.custom["changed"] = "processed" @staticmethod @moe.hookimpl - def process_new_items(items): + def process_new_items(session, items): """Process new items.""" for item in items: item.custom["new"] = "processed" @staticmethod @moe.hookimpl - def process_removed_items(items): + def process_removed_items(session, items): """Process removed items.""" for item in items: item.custom["removed"] = "processed" @@ -59,7 +59,7 @@ def test_edit_changed_items(self, tmp_config): extra = extra_factory() track = track_factory() - session = MoeSession() + session = moe_sessionmaker() session.add(album) session.add(extra) session.add(track) @@ -85,7 +85,7 @@ def test_edit_new_items(self, tmp_config): extra = extra_factory() track = track_factory() - session = MoeSession() + session = moe_sessionmaker() session.add(album) session.add(extra) session.add(track) @@ -106,7 +106,7 @@ def test_process_changed_items(self, tmp_config): extra = extra_factory() track = track_factory() - session = MoeSession() + session = moe_sessionmaker() session.add(album) session.add(extra) session.add(track) @@ -137,7 +137,7 @@ def test_process_new_items(self, tmp_config): extra = extra_factory() track = track_factory() - session = MoeSession() + session = moe_sessionmaker() session.add(album) session.add(extra) session.add(track) @@ -164,7 +164,7 @@ def test_process_removed_items(self, tmp_config): extra = extra_factory(album=album) track = track_factory(album=album) - session = MoeSession() + session = moe_sessionmaker() session.add(album) session.add(extra) session.add(track) @@ -206,7 +206,7 @@ def test_db_changes(self, tmp_config): db="persists", my_list=["wow", "change me"], growing_list=["one"] ) - session = MoeSession() + session = moe_sessionmaker() session.add(track) session.commit() track.custom["db"] = "persisted" diff --git a/tests/move/test_move_cli.py b/tests/move/test_move_cli.py index c46c2e32..003514a9 100644 --- a/tests/move/test_move_cli.py +++ b/tests/move/test_move_cli.py @@ -2,7 +2,7 @@ from types import FunctionType from typing import Iterator -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest @@ -50,7 +50,7 @@ def test_dry_run(self, tmp_path, mock_query, mock_move): moe.cli.main(cli_args) mock_move.assert_not_called() - mock_query.assert_called_once_with("*", "album") + mock_query.assert_called_once_with(ANY, "*", "album") def test_move(self, mock_query, mock_move): """Test all items in the library are moved when the command is invoked.""" @@ -63,7 +63,7 @@ def test_move(self, mock_query, mock_move): for album in albums: mock_move.assert_any_call(album) assert mock_move.call_count == len(albums) - mock_query.assert_called_once_with("*", "album") + mock_query.assert_called_once_with(ANY, "*", "album") class TestPluginRegistration: diff --git a/tests/move/test_move_core.py b/tests/move/test_move_core.py index 14d28b1c..23ec9637 100644 --- a/tests/move/test_move_core.py +++ b/tests/move/test_move_core.py @@ -1,6 +1,6 @@ """Tests the core api for moving items.""" from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -476,21 +476,27 @@ class TestEditNewItems: def test_album(self, mock_copy): """Albums are copied after they are added to the library.""" album = album_factory() - config.CONFIG.pm.hook.edit_new_items(items=[album]) + mock_session = MagicMock() + + config.CONFIG.pm.hook.edit_new_items(session=mock_session, items=[album]) mock_copy.assert_called_once_with(album) def test_track(self, mock_copy): """Tracks are copied after they are added to the library.""" track = track_factory() - config.CONFIG.pm.hook.edit_new_items(items=[track]) + mock_session = MagicMock() + + config.CONFIG.pm.hook.edit_new_items(session=mock_session, items=[track]) mock_copy.assert_called_once_with(track) def test_extra(self, mock_copy): """Extras are copied after they are added to the library.""" extra = extra_factory() - config.CONFIG.pm.hook.edit_new_items(items=[extra]) + mock_session = MagicMock() + + config.CONFIG.pm.hook.edit_new_items(session=mock_session, items=[extra]) mock_copy.assert_called_once_with(extra) diff --git a/tests/read/test_read_cli.py b/tests/read/test_read_cli.py index b23e2dae..cba3e9a0 100644 --- a/tests/read/test_read_cli.py +++ b/tests/read/test_read_cli.py @@ -2,7 +2,7 @@ from types import FunctionType from typing import Iterator -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest @@ -48,7 +48,7 @@ def test_track(self, mock_query, mock_read): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="track") + mock_query.assert_called_once_with(ANY, "*", query_type="track") mock_read.assert_called_once_with(track) def test_album(self, mock_query, mock_read): @@ -59,7 +59,7 @@ def test_album(self, mock_query, mock_read): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="album") + mock_query.assert_called_once_with(ANY, "*", query_type="album") mock_read.assert_called_once_with(album) def test_multiple_items(self, mock_query, mock_read): @@ -70,7 +70,7 @@ def test_multiple_items(self, mock_query, mock_read): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="track") + mock_query.assert_called_once_with(ANY, "*", query_type="track") for track in tracks: mock_read.assert_any_call(track) assert mock_read.call_count == 2 @@ -108,7 +108,7 @@ def test_rm_item(self, mock_query, mock_read): with patch("moe.read.read_cli.remove.remove_item") as mock_rm: moe.cli.main(cli_args) - mock_rm.assert_called_once_with(track) + mock_rm.assert_called_once_with(ANY, track) class TestPluginRegistration: diff --git a/tests/remove/test_rm_cli.py b/tests/remove/test_rm_cli.py index 7f7d3717..44dbda96 100644 --- a/tests/remove/test_rm_cli.py +++ b/tests/remove/test_rm_cli.py @@ -2,7 +2,7 @@ from types import FunctionType from typing import Iterator -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest @@ -48,8 +48,8 @@ def test_track(self, mock_query, mock_rm): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="track") - mock_rm.assert_called_once_with(track) + mock_query.assert_called_once_with(ANY, "*", query_type="track") + mock_rm.assert_called_once_with(ANY, track) def test_album(self, mock_query, mock_rm): """Albums are removed from the database with valid query.""" @@ -59,8 +59,8 @@ def test_album(self, mock_query, mock_rm): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="album") - mock_rm.assert_called_once_with(album) + mock_query.assert_called_once_with(ANY, "*", query_type="album") + mock_rm.assert_called_once_with(ANY, album) def test_extra(self, mock_query, mock_rm): """Extras are removed from the database with valid query.""" @@ -70,8 +70,8 @@ def test_extra(self, mock_query, mock_rm): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="extra") - mock_rm.assert_called_once_with(extra) + mock_query.assert_called_once_with(ANY, "*", query_type="extra") + mock_rm.assert_called_once_with(ANY, extra) def test_multiple_items(self, mock_query, mock_rm): """All items returned from the query are removed.""" @@ -82,7 +82,7 @@ def test_multiple_items(self, mock_query, mock_rm): moe.cli.main(cli_args) for track in tracks: - mock_rm.assert_any_call(track) + mock_rm.assert_any_call(ANY, track) assert mock_rm.call_count == 2 diff --git a/tests/remove/test_rm_core.py b/tests/remove/test_rm_core.py index e03f7a0a..d9f873d9 100644 --- a/tests/remove/test_rm_core.py +++ b/tests/remove/test_rm_core.py @@ -7,7 +7,7 @@ import moe from moe import remove as moe_rm -from moe.config import ExtraPlugin, MoeSession +from moe.config import ExtraPlugin from moe.library import Album, Extra, Track from tests.conftest import album_factory, extra_factory, track_factory @@ -22,7 +22,7 @@ def rm_track_before_flush(session, flush_context, instances): """Remove a track while the session is already flushing.""" for item in session.new | session.dirty: if isinstance(item, Track) and item.title == "remove me": - moe_rm.remove_item(item) + moe_rm.remove_item(session, item) class RmPlugin: @@ -30,10 +30,10 @@ class RmPlugin: @staticmethod @moe.hookimpl - def register_sa_event_listeners(session): + def register_sa_event_listeners(): """Registers event listeners for editing and processing new items.""" sqlalchemy.event.listen( - session, + sqlalchemy.orm.session.Session, "before_flush", rm_track_before_flush, ) @@ -49,7 +49,7 @@ def test_track(self, tmp_session): tmp_session.add(track) tmp_session.flush() - moe_rm.remove_item(track) + moe_rm.remove_item(tmp_session, track) assert not tmp_session.query(Track).scalar() @@ -63,7 +63,7 @@ def test_album(self, tmp_session): tmp_session.add(album) tmp_session.flush() - moe_rm.remove_item(album) + moe_rm.remove_item(tmp_session, album) assert not tmp_session.query(Album).scalar() assert not tmp_session.query(Track).scalar() @@ -76,23 +76,22 @@ def test_extra(self, tmp_session): tmp_session.add(extra) tmp_session.flush() - moe_rm.remove_item(extra) + moe_rm.remove_item(tmp_session, extra) assert not tmp_session.query(Extra).scalar() @pytest.mark.usefixtures("_tmp_rm_config") - def test_pending(self): + def test_pending(self, tmp_session): """We can remove items that have not yet been flushed.""" track = track_factory() - session = MoeSession() - session.add(track) + tmp_session.add(track) - moe_rm.remove_item(track) - session.flush() + moe_rm.remove_item(tmp_session, track) + tmp_session.flush() - assert not session.query(Track).all() + assert not tmp_session.query(Track).all() - def test_in_flush(self, tmp_config): + def test_in_flush(self, tmp_config, tmp_session): """If the session is already flushing, ensure the delete happens first. This is to prevent potential duplciates from inserting into the database before @@ -103,36 +102,34 @@ def test_in_flush(self, tmp_config): extra_plugins=[ExtraPlugin(RmPlugin, "rm_test")], tmp_db=True, ) - session = MoeSession() track = track_factory() conflict_track = track_factory(path=track.path, title="remove me") - session.add(track) - session.flush() - session.add(conflict_track) - session.flush() + tmp_session.add(track) + tmp_session.flush() + tmp_session.add(conflict_track) + tmp_session.flush() - db_track = session.query(Track).one() + db_track = tmp_session.query(Track).one() assert db_track == track - def test_in_flush_rm_existing(self, tmp_config): + def test_in_flush_rm_existing(self, tmp_config, tmp_session): """Remove an already existing item while a session is flushing.""" tmp_config( "default_plugins = ['remove']", extra_plugins=[ExtraPlugin(RmPlugin, "rm_test")], tmp_db=True, ) - session = MoeSession() track = track_factory() conflict_track = track_factory(path=track.path) - session.add(track) - session.flush() + tmp_session.add(track) + tmp_session.flush() track.title = "remove me" - session.add(conflict_track) - session.flush() + tmp_session.add(conflict_track) + tmp_session.flush() - db_track = session.query(Track).one() + db_track = tmp_session.query(Track).one() assert db_track == conflict_track diff --git a/tests/test_cli.py b/tests/test_cli.py index 388dfe7b..d15a8a65 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,6 +7,7 @@ import moe import moe.cli from moe import config +from moe.config import moe_sessionmaker from moe.library import Track from tests.conftest import track_factory @@ -31,8 +32,7 @@ def test_commit_on_systemexit(tmp_config): assert error.value.code != 0 - session = config.MoeSession() - with session.begin(): + with moe_sessionmaker.begin() as session: session.query(Track).one() @@ -42,8 +42,7 @@ def test_default_config(tmp_config): tmp_config(settings="default_plugins = ['cli', 'write', 'list']", init_db=True) track = track_factory(exists=True) - session = config.MoeSession() - with session.begin(): + with moe_sessionmaker.begin() as session: session.add(track) moe.cli.main(cli_args) @@ -69,7 +68,7 @@ class CLIPlugin: def add_command(cmd_parsers): """Add a `cli` command to Moe.""" - def say_hello(args): + def say_hello(session, args): print("hello") cli_parser = cmd_parsers.add_parser("cli") diff --git a/tests/test_list.py b/tests/test_list.py index e2999b33..43e3292c 100644 --- a/tests/test_list.py +++ b/tests/test_list.py @@ -2,7 +2,7 @@ from types import FunctionType from typing import Iterator -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest @@ -41,7 +41,7 @@ def test_track(self, capsys, mock_query): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="track") + mock_query.assert_called_once_with(ANY, "*", query_type="track") assert capsys.readouterr().out.strip("\n") == str(track) def test_album(self, capsys, mock_query): @@ -52,7 +52,7 @@ def test_album(self, capsys, mock_query): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="album") + mock_query.assert_called_once_with(ANY, "*", query_type="album") assert capsys.readouterr().out.strip("\n") == str(album) def test_extra(self, capsys, mock_query): @@ -63,7 +63,7 @@ def test_extra(self, capsys, mock_query): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="extra") + mock_query.assert_called_once_with(ANY, "*", query_type="extra") assert capsys.readouterr().out.strip("\n") == str(extra) def test_multiple_items(self, capsys, mock_query): @@ -85,7 +85,7 @@ def test_paths(self, capsys, mock_query): moe.cli.main(cli_args) - mock_query.assert_called_once_with("*", query_type="track") + mock_query.assert_called_once_with(ANY, "*", query_type="track") assert capsys.readouterr().out.strip("\n") == str(track.path) def test_info_album(self, mock_query): diff --git a/tests/test_query.py b/tests/test_query.py index 4b761986..b5ff49b5 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -2,6 +2,7 @@ from datetime import date +from unittest.mock import MagicMock import pytest @@ -91,16 +92,16 @@ class TestQueries: def test_empty_query_str(self): """Empty queries strings should raise a QueryError.""" with pytest.raises(QueryError): - query("", "track") + query(MagicMock(), "", "track") def test_empty_query(self, tmp_session): """Empty queries should return an empty list.""" - assert not query("title:nope", "track") + assert not query(tmp_session, "title:nope", "track") def test_invalid_query_str(self, tmp_session): """Invalid queries should raise a QueryError.""" with pytest.raises(QueryError): - query("invalid", "track") + query(tmp_session, "invalid", "track") def test_return_type(self, tmp_session): """Queries return the appropriate type.""" @@ -108,9 +109,9 @@ def test_return_type(self, tmp_session): tmp_session.add(album) tmp_session.flush() - albums = query(f"a:title:'{album.title}'", "album") - extras = query(f"a:title:'{album.title}'", "extra") - tracks = query(f"a:title:'{album.title}'", "track") + albums = query(tmp_session, f"a:title:'{album.title}'", "album") + extras = query(tmp_session, f"a:title:'{album.title}'", "extra") + tracks = query(tmp_session, f"a:title:'{album.title}'", "track") assert albums for album in albums: @@ -128,8 +129,8 @@ def test_years(self, tmp_session): tmp_session.add(album) tmp_session.flush() - assert query(f"a:year:{album.year}", "album") - assert query(f"a:original_year:{album.original_year}", "album") + assert query(tmp_session, f"a:year:{album.year}", "album") + assert query(tmp_session, f"a:original_year:{album.original_year}", "album") def test_multiple_terms(self, tmp_session): """We should be able to query for multiple terms at once.""" @@ -137,14 +138,14 @@ def test_multiple_terms(self, tmp_session): tmp_session.add(album) tmp_session.flush() - assert query(f"a:year:{album.year} a:title:{album.title}", "album") + assert query(tmp_session, f"a:year:{album.year} a:title:{album.title}", "album") def test_regex(self, tmp_session): """Queries can use regular expression matching.""" tmp_session.add(track_factory()) tmp_session.flush() - assert query("title::.*", "track") + assert query(tmp_session, "title::.*", "track") def test_path_query(self, tmp_config, tmp_session): """We can query for paths.""" @@ -153,22 +154,22 @@ def test_path_query(self, tmp_config, tmp_session): tmp_session.add(album) tmp_session.flush() - assert query(f"'a:path:{str(album.path.resolve())}'", "album") - assert query("'a:path::.*'", "album") + assert query(tmp_session, f"'a:path:{str(album.path.resolve())}'", "album") + assert query(tmp_session, "'a:path::.*'", "album") def test_case_insensitive_value(self, tmp_session): """Query values should be case-insensitive.""" tmp_session.add(album_factory(title="TMP")) tmp_session.flush() - assert query("a:title:tmp", "album") + assert query(tmp_session, "a:title:tmp", "album") def test_regex_non_str(self, tmp_session): """Non string fields should be converted to strings for matching.""" tmp_session.add(album_factory()) tmp_session.flush() - assert query("a:year::.*", "album") + assert query(tmp_session, "a:year::.*", "album") def test_invalid_regex(self, tmp_session): """Invalid regex queries should raise a QueryError.""" @@ -176,22 +177,22 @@ def test_invalid_regex(self, tmp_session): tmp_session.flush() with pytest.raises(QueryError): - query("title::[", "album") + query(tmp_session, "title::[", "album") def test_regex_case_insensitive(self, tmp_session): """Regex queries should be case-insensitive.""" tmp_session.add(album_factory(title="TMP")) tmp_session.flush() - assert query("a:title::tmp", "album") + assert query(tmp_session, "a:title::tmp", "album") def test_like_query(self, tmp_session): """Test sql LIKE queries. '%' and '_' are wildcard characters.""" tmp_session.add(track_factory(track_num=1)) tmp_session.flush() - assert query("track_num:_", "track") - assert query("track_num:%", "track") + assert query(tmp_session, "track_num:_", "track") + assert query(tmp_session, "track_num:%", "track") def test_like_escape_query(self, tmp_session): r"""We should be able to escape the LIKE wildcard characters with '/'. @@ -203,29 +204,29 @@ def test_like_escape_query(self, tmp_session): tmp_session.add(album_factory(title="_")) tmp_session.flush() - assert len(query("a:title:/_", "album")) == 1 + assert len(query(tmp_session, "a:title:/_", "album")) == 1 def test_track_genre_query(self, tmp_session): """Querying 'genre' should use the 'genres' field.""" tmp_session.add(track_factory(genres={"hip hop", "rock"})) tmp_session.flush() - assert query("'genre::.*'", "track") - assert query("'genre:hip hop'", "track") + assert query(tmp_session, "'genre::.*'", "track") + assert query(tmp_session, "'genre:hip hop'", "track") def test_album_catalog_num_query(self, tmp_session): """Querying 'catalog_num' should use the 'catalog_nums' field.""" tmp_session.add(album_factory(catalog_nums={"1", "2"})) tmp_session.flush() - assert query("a:catalog_num:1 a:catalog_num:2", "album") + assert query(tmp_session, "a:catalog_num:1 a:catalog_num:2", "album") def test_wildcard_query(self, tmp_session): """'*' as a query should return all items.""" tmp_session.add(album_factory()) tmp_session.flush() - assert len(query("*", "album")) == 1 + assert len(query(tmp_session, "*", "album")) == 1 def test_missing_extras_tracks(self, tmp_session): """Ensure albums without extras or tracks.""" @@ -234,7 +235,7 @@ def test_missing_extras_tracks(self, tmp_session): tmp_session.add(album) tmp_session.flush() - assert len(query("*", "album")) == 1 + assert len(query(tmp_session, "*", "album")) == 1 def test_custom_fields(self, tmp_session): """We can query a custom field.""" @@ -245,7 +246,7 @@ def test_custom_fields(self, tmp_session): tmp_session.add(album) tmp_session.flush() - assert query("a:blah:album t:blah:track e:blah:extra", "album") + assert query(tmp_session, "a:blah:album t:blah:track e:blah:extra", "album") def test_custom_field_regex(self, tmp_session): """We can regex query a custom field.""" @@ -256,7 +257,7 @@ def test_custom_field_regex(self, tmp_session): tmp_session.add(album) tmp_session.flush() - assert query("a:blah::albu. t:blah::trac. e:blah::3", "album") + assert query(tmp_session, "a:blah::albu. t:blah::trac. e:blah::3", "album") def test_custom_list_field(self, tmp_session): """We can query custom list fields.""" @@ -267,6 +268,6 @@ def test_custom_list_field(self, tmp_session): tmp_session.add(album) tmp_session.flush() - assert query("a:blah:album t:blah:track e:blah:extra", "album") - assert query("a:blah:1 e:blah:2 t:blah:3", "album") - assert query("t:blah:3 t:blah:track", "album") + assert query(tmp_session, "a:blah:album t:blah:track e:blah:extra", "album") + assert query(tmp_session, "a:blah:1 e:blah:2 t:blah:3", "album") + assert query(tmp_session, "t:blah:3 t:blah:track", "album") diff --git a/tests/test_write.py b/tests/test_write.py index 896c7188..96b6bade 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -1,7 +1,7 @@ """Tests the ``write`` plugin.""" import datetime -from unittest.mock import patch +from unittest.mock import MagicMock, patch import mediafile import pytest @@ -124,27 +124,38 @@ class TestProcessNewItems: def test_process_track(self, mock_write): """Any altered Tracks have their tags written.""" track = track_factory() - config.CONFIG.pm.hook.process_new_items(items=[track]) + mock_session = MagicMock() + + config.CONFIG.pm.hook.process_new_items(session=mock_session, items=[track]) mock_write.assert_called_once_with(track) def test_process_extra(self, mock_write): """Any altered extras are ignored.""" - config.CONFIG.pm.hook.process_new_items(items=[extra_factory()]) + mock_session = MagicMock() + + config.CONFIG.pm.hook.process_new_items( + session=mock_session, items=[extra_factory()] + ) mock_write.assert_not_called() def test_process_album(self, mock_write): """Any altered albums are ignored.""" - config.CONFIG.pm.hook.process_new_items(items=[album_factory()]) + mock_session = MagicMock() + + config.CONFIG.pm.hook.process_new_items( + session=mock_session, items=[album_factory()] + ) mock_write.assert_not_called() def test_process_multiple_tracks(self, mock_write): """All altered tracks are written.""" tracks = [track_factory(), track_factory()] + mock_session = MagicMock() - config.CONFIG.pm.hook.process_new_items(items=tracks) + config.CONFIG.pm.hook.process_new_items(session=mock_session, items=tracks) for track in tracks: mock_write.assert_any_call(track) @@ -158,20 +169,28 @@ class TestProcessChangedItems: def test_process_track(self, mock_write): """Any altered Tracks have their tags written.""" track = track_factory() - config.CONFIG.pm.hook.process_changed_items(items=[track]) + mock_session = MagicMock() + + config.CONFIG.pm.hook.process_changed_items(session=mock_session, items=[track]) mock_write.assert_called_once_with(track) def test_process_extra(self, mock_write): """Any altered extras are ignored.""" - config.CONFIG.pm.hook.process_changed_items(items=[extra_factory()]) + mock_session = MagicMock() + + config.CONFIG.pm.hook.process_changed_items( + session=mock_session, items=[extra_factory()] + ) mock_write.assert_not_called() def test_process_album(self, mock_write): """Any altered albums should have their tracks written.""" album = album_factory() - config.CONFIG.pm.hook.process_changed_items(items=[album]) + mock_session = MagicMock() + + config.CONFIG.pm.hook.process_changed_items(session=mock_session, items=[album]) for track in album.tracks: mock_write.assert_any_call(track) @@ -179,8 +198,9 @@ def test_process_album(self, mock_write): def test_process_multiple_tracks(self, mock_write): """All altered tracks are written.""" tracks = [track_factory(), track_factory()] + mock_session = MagicMock() - config.CONFIG.pm.hook.process_changed_items(items=tracks) + config.CONFIG.pm.hook.process_changed_items(session=mock_session, items=tracks) for track in tracks: mock_write.assert_any_call(track) @@ -189,7 +209,10 @@ def test_process_multiple_tracks(self, mock_write): def test_dont_write_tracks_twice(self, mock_write): """Don't write a track twice if it's album is also in `items`.""" track = track_factory() + mock_session = MagicMock() - config.CONFIG.pm.hook.process_changed_items(items=[track, track.album]) + config.CONFIG.pm.hook.process_changed_items( + session=mock_session, items=[track, track.album] + ) mock_write.assert_called_once_with(track) diff --git a/tests/util/cli/test_query.py b/tests/util/cli/test_query.py index 32e20a8f..671e0da0 100644 --- a/tests/util/cli/test_query.py +++ b/tests/util/cli/test_query.py @@ -2,7 +2,7 @@ from types import FunctionType from typing import Iterator -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -29,19 +29,21 @@ class TestCLIQuery: def test_bad_query(self, mock_query): """Exit with non-zero code if bad query given.""" mock_query.side_effect = QueryError + mock_session = MagicMock() with pytest.raises(SystemExit) as error: - cli_query("bad query", "track") + cli_query(mock_session, "bad query", "track") assert error.value.code != 0 - mock_query.assert_called_once_with("bad query", "track") + mock_query.assert_called_once_with(mock_session, "bad query", "track") def test_empty_query(self, mock_query): """Exit with non-zero code if bad query given.""" mock_query.return_value = [] + mock_session = MagicMock() with pytest.raises(SystemExit) as error: - cli_query("*", "track") + cli_query(mock_session, "*", "track") assert error.value.code != 0 - mock_query.assert_called_once_with("*", "track") + mock_query.assert_called_once_with(mock_session, "*", "track")