Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 92 additions & 7 deletions dev/breeze/src/airflow_breeze/commands/pr_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class PRData:
commits_behind: int # how many commits behind the base branch
mergeable: str # MERGEABLE, CONFLICTING, or UNKNOWN
labels: list[str] # label names attached to this PR
unresolved_review_comments: int # count of unresolved review threads from maintainers


@click.group(cls=BreezeGroup, name="pr", help="Tools for managing GitHub pull requests.")
Expand Down Expand Up @@ -485,6 +486,70 @@ def _fetch_check_details_batch(token: str, github_repository: str, prs: list[PRD
pr.failed_checks = failed


_REVIEW_THREADS_BATCH_SIZE = 10


def _fetch_unresolved_comments_batch(token: str, github_repository: str, prs: list[PRData]) -> None:
"""Fetch unresolved review thread counts for PRs in chunked GraphQL queries.

Counts only threads started by collaborators/members/owners (i.e. maintainers).
Updates each PR's unresolved_review_comments in-place.
"""
owner, repo = github_repository.split("/", 1)
if not prs:
return

for chunk_start in range(0, len(prs), _REVIEW_THREADS_BATCH_SIZE):
chunk = prs[chunk_start : chunk_start + _REVIEW_THREADS_BATCH_SIZE]

pr_fields = []
for pr in chunk:
alias = f"pr{pr.number}"
pr_fields.append(
f" {alias}: pullRequest(number: {pr.number}) {{\n"
f" reviewThreads(first: 100) {{\n"
f" nodes {{\n"
f" isResolved\n"
f" comments(first: 1) {{\n"
f" nodes {{\n"
f" author {{ login }}\n"
f" authorAssociation\n"
f" }}\n"
f" }}\n"
f" }}\n"
f" }}\n"
f" }}"
)

query = (
f'query {{\n repository(owner: "{owner}", name: "{repo}") {{\n'
+ "\n".join(pr_fields)
+ "\n }\n}"
)

try:
data = _graphql_request(token, query, {})
except SystemExit:
continue

repo_data = data.get("repository", {})
for pr in chunk:
alias = f"pr{pr.number}"
pr_data = repo_data.get(alias) or {}
threads = pr_data.get("reviewThreads", {}).get("nodes", [])
unresolved = 0
for thread in threads:
if thread.get("isResolved"):
continue
# Only count threads started by maintainers (collaborators/members/owners)
comments = thread.get("comments", {}).get("nodes", [])
if comments:
assoc = comments[0].get("authorAssociation", "NONE")
if assoc in _COLLABORATOR_ASSOCIATIONS:
unresolved += 1
pr.unresolved_review_comments = unresolved


def _fetch_commits_behind_batch(token: str, github_repository: str, prs: list[PRData]) -> dict[int, int]:
"""Fetch how many commits each PR is behind its base branch in chunked GraphQL queries.

Expand Down Expand Up @@ -586,6 +651,7 @@ def _fetch_prs_graphql(
commits_behind=0,
mergeable=node.get("mergeable", "UNKNOWN"),
labels=[lbl["name"] for lbl in (node.get("labels") or {}).get("nodes", []) if lbl],
unresolved_review_comments=0,
)
)

Expand Down Expand Up @@ -623,6 +689,7 @@ def _fetch_single_pr_graphql(token: str, github_repository: str, pr_number: int)
commits_behind=0,
mergeable=node.get("mergeable", "UNKNOWN"),
labels=[lbl["name"] for lbl in (node.get("labels") or {}).get("nodes", []) if lbl],
unresolved_review_comments=0,
)


Expand Down Expand Up @@ -1248,7 +1315,12 @@ def auto_triage(
llm_model: str,
answer_triage: str | None,
):
from airflow_breeze.utils.github import PRAssessment, assess_pr_checks, assess_pr_conflicts
from airflow_breeze.utils.github import (
PRAssessment,
assess_pr_checks,
assess_pr_conflicts,
assess_pr_unresolved_comments,
)
from airflow_breeze.utils.llm_utils import (
_check_cli_available,
_resolve_cli_provider,
Expand Down Expand Up @@ -1400,7 +1472,16 @@ def auto_triage(
)
pr.failed_checks = _fetch_failed_checks(token, github_repository, pr.head_sha)

# Phase 3: Deterministic checks (CI failures + merge conflicts), then LLM for the rest
# Phase 2c: Fetch unresolved review comment counts for candidate PRs
if candidate_prs and run_ci:
get_console().print(
f"[info]Fetching review thread details for {len(candidate_prs)} "
f"candidate {'PRs' if len(candidate_prs) != 1 else 'PR'}...[/]"
)
_fetch_unresolved_comments_batch(token, github_repository, candidate_prs)

# Phase 3: Deterministic checks (CI failures + merge conflicts + unresolved comments),
# then LLM for the rest
# PRs with NOT_RUN checks are separated for workflow approval instead of LLM assessment.
assessments: dict[int, PRAssessment] = {}
llm_candidates: list[PRData] = []
Expand All @@ -1411,9 +1492,10 @@ def auto_triage(
for pr in candidate_prs:
ci_assessment = assess_pr_checks(pr.number, pr.checks_state, pr.failed_checks)
conflict_assessment = assess_pr_conflicts(pr.number, pr.mergeable, pr.base_ref, pr.commits_behind)
comments_assessment = assess_pr_unresolved_comments(pr.number, pr.unresolved_review_comments)

# Merge violations from both deterministic checks
if ci_assessment or conflict_assessment:
# Merge violations from all deterministic checks
if ci_assessment or conflict_assessment or comments_assessment:
total_deterministic_flags += 1
violations = []
summaries = []
Expand All @@ -1423,6 +1505,9 @@ def auto_triage(
if ci_assessment:
violations.extend(ci_assessment.violations)
summaries.append(ci_assessment.summary)
if comments_assessment:
violations.extend(comments_assessment.violations)
summaries.append(comments_assessment.summary)
assessments[pr.number] = PRAssessment(
should_flag=True,
violations=violations,
Expand All @@ -1449,7 +1534,7 @@ def auto_triage(
f"{'PRs' if len(llm_candidates) != 1 else 'PR'}.[/]\n"
)
elif llm_candidates:
skipped_detail = f"{total_deterministic_flags} CI/conflicts"
skipped_detail = f"{total_deterministic_flags} CI/conflicts/comments"
if pending_approval:
skipped_detail += f", {len(pending_approval)} awaiting workflow approval"
get_console().print(
Expand Down Expand Up @@ -1481,7 +1566,7 @@ def auto_triage(

total_flagged = len(assessments)
summary_parts = [
f"{total_deterministic_flags} CI/conflicts",
f"{total_deterministic_flags} CI/conflicts/comments",
f"{total_flagged - total_deterministic_flags} LLM-flagged",
]
if pending_approval:
Expand Down Expand Up @@ -1795,7 +1880,7 @@ def auto_triage(
summary_table.add_row("Ready-for-review skipped", str(total_skipped_accepted))
summary_table.add_row("PRs skipped (filtered)", str(total_skipped))
summary_table.add_row("PRs assessed", str(len(candidate_prs)))
summary_table.add_row("Flagged by CI/conflicts", str(total_deterministic_flags))
summary_table.add_row("Flagged by CI/conflicts/comments", str(total_deterministic_flags))
summary_table.add_row("Flagged by LLM", str(total_flagged - total_deterministic_flags))
summary_table.add_row("LLM errors (skipped)", str(total_llm_errors))
summary_table.add_row("Total flagged", str(total_flagged))
Expand Down
32 changes: 32 additions & 0 deletions dev/breeze/src/airflow_breeze/utils/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,35 @@ def assess_pr_conflicts(
],
summary=f"PR #{pr_number} has merge conflicts.",
)


def assess_pr_unresolved_comments(pr_number: int, unresolved_review_comments: int) -> PRAssessment | None:
"""Deterministically flag a PR if it has unresolved review comments from maintainers.

Returns None if there are no unresolved comments.
"""
if unresolved_review_comments <= 0:
return None

thread_word = "thread" if unresolved_review_comments == 1 else "threads"
return PRAssessment(
should_flag=True,
violations=[
Violation(
category="Unresolved review comments",
explanation=(
f"This PR has {unresolved_review_comments} unresolved review "
f"{thread_word} from maintainers."
),
severity="warning",
details=(
"Please review and resolve all inline review comments before requesting "
"another review. You can resolve a conversation by clicking 'Resolve conversation' "
"on each thread after addressing the feedback. "
f"See [pull request guidelines]"
f"({_CONTRIBUTING_DOCS_URL}/05_pull_requests.rst)."
),
)
],
summary=f"PR #{pr_number} has {unresolved_review_comments} unresolved review {thread_word}.",
)
Loading