-
Notifications
You must be signed in to change notification settings - Fork 2
docs(research): complementary-architecture one-pager (three-layer handoff) #421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
55ee738
8083ae1
77ae157
b90d547
37683db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -90,6 +90,104 @@ async def list_trials_for_study(db: AsyncSession, study_id: str) -> Sequence[Tri | |||||||||||||
| return list((await db.execute(stmt)).scalars().all()) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @dataclass(frozen=True) | ||||||||||||||
| class TrialCounts: | ||||||||||||||
| """Per-study trial counts for the studies-list response. | ||||||||||||||
|
|
||||||||||||||
| ``total`` mirrors :class:`TrialsSummary.total` (non-baseline rows — | ||||||||||||||
| ``is_baseline.is_(False)``, matching :func:`aggregate_trials_summary`'s | ||||||||||||||
| parity target) so the list's ``trial_count`` equals the detail's | ||||||||||||||
| ``trials_summary.total`` exactly. | ||||||||||||||
|
|
||||||||||||||
| ``complete`` mirrors :func:`list_complete_optuna_trials_for_study`'s | ||||||||||||||
| own filter (``status == "complete" AND is_baseline.is_not(True)``), so | ||||||||||||||
| the list's count-gate decision (``< 5`` / ``< 50`` / ``≥ 50``) keys | ||||||||||||||
| off the same row set the classifier would see. Per D-17 in | ||||||||||||||
| ``feat_studies_convergence_visibility/feature_spec.md`` — | ||||||||||||||
| ``trials.is_baseline`` is ``BOOLEAN NOT NULL DEFAULT FALSE`` (model | ||||||||||||||
| ``trial.py:114``; migration ``0020``) so ``is_(False)`` ≡ | ||||||||||||||
| ``is_not(True)`` today; pinning each predicate to its parity target | ||||||||||||||
| keeps the contract unambiguous if the column ever becomes nullable. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| total: int | ||||||||||||||
| complete: int | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| async def count_trials_for_studies( | ||||||||||||||
| db: AsyncSession, study_ids: Sequence[str] | ||||||||||||||
| ) -> dict[str, TrialCounts]: | ||||||||||||||
| """Batched non-baseline trial counts for a page of studies. | ||||||||||||||
|
|
||||||||||||||
| One ``GROUP BY study_id`` aggregate returning ``(total, complete)`` | ||||||||||||||
| per study, both non-baseline. Powers the studies-list | ||||||||||||||
| ``trial_count`` field + the convergence verdict's count gate | ||||||||||||||
| (``feat_studies_convergence_visibility`` Story 1.1, FR-1/FR-3). | ||||||||||||||
|
|
||||||||||||||
| Studies whose ID is in the input but have zero trials yet (e.g., | ||||||||||||||
| ``queued``) are returned with ``TrialCounts(0, 0)``. Empty input | ||||||||||||||
| returns an empty dict (no query issued). | ||||||||||||||
| """ | ||||||||||||||
| if not study_ids: | ||||||||||||||
| return {} | ||||||||||||||
| # Pin BOTH predicates to ``is_(False)`` — the only divergent case | ||||||||||||||
| # would be a NULL ``is_baseline`` row, which cannot exist (NOT NULL | ||||||||||||||
| # column). ``complete`` further filters by status. This keeps the | ||||||||||||||
| # aggregate's row-set identical to the parity sources documented in | ||||||||||||||
| # the TrialCounts docstring. | ||||||||||||||
| stmt = ( | ||||||||||||||
| select( | ||||||||||||||
| Trial.study_id.label("study_id"), | ||||||||||||||
| func.count(Trial.id).filter(Trial.is_baseline.is_(False)).label("total"), | ||||||||||||||
| func.count(Trial.id) | ||||||||||||||
| .filter(Trial.is_baseline.is_(False), Trial.status == "complete") | ||||||||||||||
| .label("complete"), | ||||||||||||||
| ) | ||||||||||||||
| .where(Trial.study_id.in_(list(study_ids))) | ||||||||||||||
| .group_by(Trial.study_id) | ||||||||||||||
| ) | ||||||||||||||
| rows = (await db.execute(stmt)).all() | ||||||||||||||
| result: dict[str, TrialCounts] = { | ||||||||||||||
| row.study_id: TrialCounts(total=int(row.total), complete=int(row.complete)) for row in rows | ||||||||||||||
| } | ||||||||||||||
|
Comment on lines
+150
to
+152
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type safety and defensive programming:
Suggested change
|
||||||||||||||
| # Backfill zero-trial studies so callers can index by id without | ||||||||||||||
| # checking for KeyError. | ||||||||||||||
| for sid in study_ids: | ||||||||||||||
| result.setdefault(sid, TrialCounts(total=0, complete=0)) | ||||||||||||||
| return result | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| async def list_complete_optuna_trials_for_studies( | ||||||||||||||
| db: AsyncSession, study_ids: Sequence[str] | ||||||||||||||
| ) -> dict[str, list[Trial]]: | ||||||||||||||
| """Batched sibling of :func:`list_complete_optuna_trials_for_study`. | ||||||||||||||
|
|
||||||||||||||
| One ``SELECT ... WHERE study_id IN (...)`` with the same filter set | ||||||||||||||
| (``status == "complete" AND is_baseline.is_not(True) AND | ||||||||||||||
| primary_metric IS NOT NULL``), ordered by ``study_id`` then | ||||||||||||||
| ``optuna_trial_number ASC``, grouped in Python. | ||||||||||||||
|
|
||||||||||||||
| Called once per studies-list request — only for the subset of | ||||||||||||||
| studies with ``complete >= STUDIES_TPE_WARMUP_FLOOR`` (50), which | ||||||||||||||
| saves us from per-study trial loads in the common low-trial case | ||||||||||||||
| (FR-3 / D-14). | ||||||||||||||
| """ | ||||||||||||||
| if not study_ids: | ||||||||||||||
| return {} | ||||||||||||||
| stmt = ( | ||||||||||||||
| select(Trial) | ||||||||||||||
| .where(Trial.study_id.in_(list(study_ids))) | ||||||||||||||
| .where(Trial.status == "complete") | ||||||||||||||
| .where(Trial.is_baseline.is_not(True)) | ||||||||||||||
| .where(Trial.primary_metric.is_not(None)) | ||||||||||||||
| .order_by(Trial.study_id, Trial.optuna_trial_number) | ||||||||||||||
| ) | ||||||||||||||
| grouped: dict[str, list[Trial]] = {sid: [] for sid in study_ids} | ||||||||||||||
| for trial in (await db.execute(stmt)).scalars().all(): | ||||||||||||||
| grouped[trial.study_id].append(trial) | ||||||||||||||
|
Comment on lines
+185
to
+187
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. KeyError prevention:
Suggested change
|
||||||||||||||
| return grouped | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| async def list_complete_optuna_trials_for_study(db: AsyncSession, study_id: str) -> Sequence[Trial]: | ||||||||||||||
| """List trials usable by the convergence classifier. | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defensive programming check:
row.objectivecan beNoneor not a dictionary in degenerate cases (as handled inresolve_list_convergence_verdictsviaisinstance(study.objective, dict)). Calling.get()directly onrow.objectivewithout a guard can raise anAttributeErrorand crash the entire studies list response. We should safely default it to an empty dictionary or guard the access.