diff --git a/pysus/api/ducklake/client.py b/pysus/api/ducklake/client.py index b94ef75..8dbb252 100644 --- a/pysus/api/ducklake/client.py +++ b/pysus/api/ducklake/client.py @@ -10,7 +10,7 @@ from pysus import CACHEPATH from pysus.api.models import BaseRemoteClient, BaseRemoteFile from sqlalchemy import create_engine -from sqlalchemy.orm import joinedload, sessionmaker +from sqlalchemy.orm import contains_eager, joinedload, sessionmaker from sqlalchemy.pool import StaticPool from .catalog import CatalogDataset, CatalogFile, DatasetGroup @@ -307,18 +307,25 @@ async def query( def _query(): with self._Session() as session: - q = session.query(CatalogFile).options( - joinedload(CatalogFile.dataset), - joinedload(CatalogFile.group), - ) + q = session.query(CatalogFile) if dataset: - q = q.join(CatalogDataset).filter( - CatalogDataset.name == dataset.lower() + q = ( + q.join(CatalogFile.dataset) + .options(contains_eager(CatalogFile.dataset)) + .filter(CatalogDataset.name == dataset.lower()) ) + else: + q = q.options(joinedload(CatalogFile.dataset)) if group: - q = q.join(DatasetGroup).filter(DatasetGroup.name == group) + q = ( + q.join(CatalogFile.group) + .options(contains_eager(CatalogFile.group)) + .filter(DatasetGroup.name.ilike(group)) + ) + else: + q = q.options(joinedload(CatalogFile.group)) if state: q = q.filter(CatalogFile.state == state.upper())