Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions backend/app/api/v1/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from backend.app.core.settings import get_settings
from backend.app.domain.study.chain_summary import ChainStopReason as ChainStopReason
from backend.app.domain.study.confidence import ConfidenceShape as ConfidenceShape
from backend.app.domain.study.convergence import (
ConvergenceVerdict as ConvergenceVerdict,
)
from backend.app.domain.study.convergence import (
StudyConvergenceShape as StudyConvergenceShape,
)
Expand Down Expand Up @@ -896,6 +899,28 @@ class StudySummary(BaseModel):
``bug_ceiling_badge_assumes_maximize_direction``."""
created_at: datetime
completed_at: datetime | None
trial_count: int = 0
"""Non-baseline trial-row count for this study, matching the detail
page's ``trials_summary.total`` exactly (both use
``is_baseline.is_(False)``). A ``max_trials=50`` study with a
completed baseline shows ``trial_count=50``. Computed per request via
one batched ``GROUP BY study_id`` aggregate
(``count_trials_for_studies``); see
``feat_studies_convergence_visibility`` Story 1.1 / FR-1. Default
``0`` for backward compatibility on hand-constructed instances in
tests; the live API always populates it."""
convergence_verdict: ConvergenceVerdict | None = None
"""Per-study convergence verdict literal (NOT the full
:class:`StudyConvergenceShape` — list payload only). Equal to
``StudyDetail.convergence.verdict`` for every case (in-flight /
invalid-direction / ``<5`` / ``5–49`` / ``≥50``) — see AC-2 + AC-3b
in ``feat_studies_convergence_visibility/feature_spec.md``. Computed
via :func:`backend.app.services.study_convergence.resolve_list_convergence_verdicts`
using the same gate order as ``fetch_study_convergence``
(in-flight → direction → count → classifier). ``None`` for in-flight
studies, invalid-direction completed studies, ``< 5`` complete
non-baseline trials, and the graceful-degrade exception path; never
raises."""


class StudyListResponse(BaseModel):
Expand Down
66 changes: 61 additions & 5 deletions backend/app/api/v1/studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
derive_chain_stop_reason,
select_best_link,
)
from backend.app.domain.study.convergence import ConvergenceVerdict
from backend.app.domain.study.followups import parse_followup_list
from backend.app.domain.study.search_space import (
MissingDeclaredParamError,
Expand All @@ -77,7 +78,10 @@
)
from backend.app.services import study_state
from backend.app.services.study_confidence import fetch_study_confidence
from backend.app.services.study_convergence import fetch_study_convergence
from backend.app.services.study_convergence import (
fetch_study_convergence,
resolve_list_convergence_verdicts,
)
from backend.app.services.study_preflight import MIN_OVERLAP, probe_judgment_overlap

router = APIRouter()
Expand Down Expand Up @@ -169,11 +173,30 @@ async def _detail(db: AsyncSession, row: Study) -> StudyDetail:
)


def _summary(row: Study) -> StudySummary:
def _summary(
row: Study,
*,
trial_count: int,
convergence_verdict: ConvergenceVerdict | None,
) -> StudySummary:
# ``objective`` is a non-null JSONB dict; ``direction`` arrived with
# feat_study_baseline_trial, so older rows may lack the key — default
# to "maximize" (per bug_ceiling_badge_assumes_maximize_direction).
direction = row.objective.get("direction", "maximize")
#
# Coerce ANY value outside the {"maximize", "minimize"} Literal to
# "maximize" — not only the absent-key case. Without this guard, a
# row whose persisted ``direction`` somehow drifted to a third value
# (corrupt JSONB, a future migration that re-uses the key, a manual
# SQL edit) would crash the entire studies-list response with a
# ``ValidationError`` because ``StudySummary.direction`` is typed as
# a two-value Literal. The detail-path's
# :func:`backend.app.services.study_convergence._resolve_direction`
# already handles this case — the list path was the latent gap.
# Surfaced by ``feat_studies_convergence_visibility`` AC-3b, which
# writes ``"sideways"`` deliberately to exercise the
# invalid-direction parity path.
raw_direction = row.objective.get("direction", "maximize")
direction = raw_direction if raw_direction in ("maximize", "minimize") else "maximize"
Comment on lines +198 to +199
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Defensive programming check: row.objective can be None or not a dictionary in degenerate cases (as handled in resolve_list_convergence_verdicts via isinstance(study.objective, dict)). Calling .get() directly on row.objective without a guard can raise an AttributeError and crash the entire studies list response. We should safely default it to an empty dictionary or guard the access.

    objective = row.objective if isinstance(row.objective, dict) else {}
    raw_direction = objective.get("direction", "maximize")
    direction = raw_direction if raw_direction in ("maximize", "minimize") else "maximize"

return StudySummary(
id=row.id,
name=row.name,
Expand All @@ -183,6 +206,8 @@ def _summary(row: Study) -> StudySummary:
direction=direction,
created_at=row.created_at,
completed_at=row.completed_at,
trial_count=trial_count,
convergence_verdict=convergence_verdict,
)


Expand Down Expand Up @@ -545,8 +570,25 @@ async def list_studies(
cursor_value = getattr(last, parsed_sort.col_name)
next_cursor = _sort_encode_cursor(cursor_value, last.id)
has_more = True

# feat_studies_convergence_visibility Story 1.1 — populate per-row
# trial_count + convergence_verdict via bounded batched queries
# (FR-1/FR-2/FR-3): one GROUP BY aggregate for counts; one batched
# trial-load ONLY when the complete>=50 subset is non-empty
# (resolve_list_convergence_verdicts handles the gating).
page_ids = [str(r.id) for r in rows]
trial_counts = await repo.count_trials_for_studies(db, page_ids)
verdicts = await resolve_list_convergence_verdicts(db, rows, trial_counts)

return StudyListResponse(
data=[_summary(r) for r in rows],
data=[
_summary(
r,
trial_count=trial_counts.get(str(r.id), repo.TrialCounts(0, 0)).total,
convergence_verdict=verdicts.get(str(r.id)),
)
for r in rows
],
next_cursor=next_cursor,
has_more=has_more,
)
Expand Down Expand Up @@ -663,8 +705,22 @@ async def list_study_children(
children = await repo.list_children_of_study(db, study_id)
# Direct children of any single parent are at most 1 (linear chains in v1),
# so we never paginate this endpoint. has_more is always False.
#
# feat_studies_convergence_visibility Story 1.1: populate trial_count +
# convergence_verdict per the StudySummary contract — same bounded
# batched-query pattern as the main list_studies handler.
child_ids = [str(c.id) for c in children]
child_trial_counts = await repo.count_trials_for_studies(db, child_ids)
child_verdicts = await resolve_list_convergence_verdicts(db, children, child_trial_counts)
return StudyListResponse(
data=[_summary(c) for c in children],
data=[
_summary(
c,
trial_count=child_trial_counts.get(str(c.id), repo.TrialCounts(0, 0)).total,
convergence_verdict=child_verdicts.get(str(c.id)),
)
for c in children
],
next_cursor=None,
has_more=False,
)
Expand Down
8 changes: 8 additions & 0 deletions backend/app/db/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,14 @@
list_studies,
)
from backend.app.db.repo.trial import (
TrialCounts,
TrialsSummary,
aggregate_trials_summary,
count_trials,
count_trials_for_studies,
create_trial,
get_trial,
list_complete_optuna_trials_for_studies,
list_complete_optuna_trials_for_study,
list_trials_for_study,
list_trials_paginated,
Expand Down Expand Up @@ -276,4 +279,9 @@
# feat_study_convergence_indicator Story 2.1 — read-side helper feeding
# the trailing-window-flat convergence classifier.
"list_complete_optuna_trials_for_study",
# feat_studies_convergence_visibility Story 1.1 — batched count + trial
# load for the studies-list trial_count + convergence_verdict fields.
"TrialCounts",
"count_trials_for_studies",
"list_complete_optuna_trials_for_studies",
]
98 changes: 98 additions & 0 deletions backend/app/db/repo/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,104 @@ async def list_trials_for_study(db: AsyncSession, study_id: str) -> Sequence[Tri
return list((await db.execute(stmt)).scalars().all())


@dataclass(frozen=True)
class TrialCounts:
"""Per-study trial counts for the studies-list response.

``total`` mirrors :class:`TrialsSummary.total` (non-baseline rows —
``is_baseline.is_(False)``, matching :func:`aggregate_trials_summary`'s
parity target) so the list's ``trial_count`` equals the detail's
``trials_summary.total`` exactly.

``complete`` mirrors :func:`list_complete_optuna_trials_for_study`'s
own filter (``status == "complete" AND is_baseline.is_not(True)``), so
the list's count-gate decision (``< 5`` / ``< 50`` / ``≥ 50``) keys
off the same row set the classifier would see. Per D-17 in
``feat_studies_convergence_visibility/feature_spec.md`` —
``trials.is_baseline`` is ``BOOLEAN NOT NULL DEFAULT FALSE`` (model
``trial.py:114``; migration ``0020``) so ``is_(False)`` ≡
``is_not(True)`` today; pinning each predicate to its parity target
keeps the contract unambiguous if the column ever becomes nullable.
"""

total: int
complete: int


async def count_trials_for_studies(
db: AsyncSession, study_ids: Sequence[str]
) -> dict[str, TrialCounts]:
"""Batched non-baseline trial counts for a page of studies.

One ``GROUP BY study_id`` aggregate returning ``(total, complete)``
per study, both non-baseline. Powers the studies-list
``trial_count`` field + the convergence verdict's count gate
(``feat_studies_convergence_visibility`` Story 1.1, FR-1/FR-3).

Studies whose ID is in the input but have zero trials yet (e.g.,
``queued``) are returned with ``TrialCounts(0, 0)``. Empty input
returns an empty dict (no query issued).
"""
if not study_ids:
return {}
# Pin BOTH predicates to ``is_(False)`` — the only divergent case
# would be a NULL ``is_baseline`` row, which cannot exist (NOT NULL
# column). ``complete`` further filters by status. This keeps the
# aggregate's row-set identical to the parity sources documented in
# the TrialCounts docstring.
stmt = (
select(
Trial.study_id.label("study_id"),
func.count(Trial.id).filter(Trial.is_baseline.is_(False)).label("total"),
func.count(Trial.id)
.filter(Trial.is_baseline.is_(False), Trial.status == "complete")
.label("complete"),
)
.where(Trial.study_id.in_(list(study_ids)))
.group_by(Trial.study_id)
)
rows = (await db.execute(stmt)).all()
result: dict[str, TrialCounts] = {
row.study_id: TrialCounts(total=int(row.total), complete=int(row.complete)) for row in rows
}
Comment on lines +150 to +152
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Type safety and defensive programming: row.study_id can be a uuid.UUID object depending on the database dialect and model definition. Since the function is typed to return dict[str, TrialCounts] and the caller lookup uses stringified IDs (e.g., str(r.id)), we should explicitly stringify row.study_id to prevent key mismatch issues.

Suggested change
result: dict[str, TrialCounts] = {
row.study_id: TrialCounts(total=int(row.total), complete=int(row.complete)) for row in rows
}
result: dict[str, TrialCounts] = {
str(row.study_id): TrialCounts(total=int(row.total), complete=int(row.complete)) for row in rows
}

# Backfill zero-trial studies so callers can index by id without
# checking for KeyError.
for sid in study_ids:
result.setdefault(sid, TrialCounts(total=0, complete=0))
return result


async def list_complete_optuna_trials_for_studies(
db: AsyncSession, study_ids: Sequence[str]
) -> dict[str, list[Trial]]:
"""Batched sibling of :func:`list_complete_optuna_trials_for_study`.

One ``SELECT ... WHERE study_id IN (...)`` with the same filter set
(``status == "complete" AND is_baseline.is_not(True) AND
primary_metric IS NOT NULL``), ordered by ``study_id`` then
``optuna_trial_number ASC``, grouped in Python.

Called once per studies-list request — only for the subset of
studies with ``complete >= STUDIES_TPE_WARMUP_FLOOR`` (50), which
saves us from per-study trial loads in the common low-trial case
(FR-3 / D-14).
"""
if not study_ids:
return {}
stmt = (
select(Trial)
.where(Trial.study_id.in_(list(study_ids)))
.where(Trial.status == "complete")
.where(Trial.is_baseline.is_not(True))
.where(Trial.primary_metric.is_not(None))
.order_by(Trial.study_id, Trial.optuna_trial_number)
)
grouped: dict[str, list[Trial]] = {sid: [] for sid in study_ids}
for trial in (await db.execute(stmt)).scalars().all():
grouped[trial.study_id].append(trial)
Comment on lines +185 to +187
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

KeyError prevention: trial.study_id can be a uuid.UUID object. Since grouped is initialized with string keys from study_ids (which are stringified via str(s.id) in the service layer), accessing grouped[trial.study_id] directly will raise a KeyError at runtime. We should explicitly stringify trial.study_id when accessing the dictionary.

Suggested change
grouped: dict[str, list[Trial]] = {sid: [] for sid in study_ids}
for trial in (await db.execute(stmt)).scalars().all():
grouped[trial.study_id].append(trial)
grouped: dict[str, list[Trial]] = {sid: [] for sid in study_ids}
for trial in (await db.execute(stmt)).scalars().all():
grouped[str(trial.study_id)].append(trial)

return grouped


async def list_complete_optuna_trials_for_study(db: AsyncSession, study_id: str) -> Sequence[Trial]:
"""List trials usable by the convergence classifier.

Expand Down
Loading
Loading