Skip to content

feat: add NVIDIA rerank provider support#7227

Merged
Soulter merged 8 commits intoAstrBotDevs:masterfrom
WenqiOfficial:feat/nvidia-rerank
Apr 6, 2026
Merged

feat: add NVIDIA rerank provider support#7227
Soulter merged 8 commits intoAstrBotDevs:masterfrom
WenqiOfficial:feat/nvidia-rerank

Conversation

@WenqiOfficial
Copy link
Copy Markdown
Contributor

@WenqiOfficial WenqiOfficial commented Mar 31, 2026

Issues #7195 的代码实现。对NVIDIA NIM平台的Rerank模型API的适配

Modifications / 改动点

  • 增加 NVIDIA Rerank API 请求逻辑

    • astrbot/core/provider/sources/nvidia_rerank_source.py
    • astrbot/core/provider/manager.py
  • 针对 NVIDIA Rerank 增加WebUI上的配置支持以及多语言支持( zh-CN & en-US )

    • astrbot/core/config/default.py
    • dashboard/src/i18n/locales/en-US/features/config-metadata.json
    • dashboard/src/i18n/locales/zh-CN/features/config-metadata.json
  • This is NOT a breaking change. / 这不是一个破坏性变更。

Screenshots or Test Results / 运行截图或测试结果

已经过本地搭建使用测试

WebUI配置界面

image

可用性测试

image

知识库检索

image

日志记录

image

Checklist / 检查清单

  • 😊 If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
    / 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。

  • 👀 My changes have been well-tested, and "Verification Steps" and "Screenshots" have been provided above.
    / 我的更改经过了良好的测试,并已在上方提供了“验证步骤”和“运行截图”

  • 🤓 I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in requirements.txt and pyproject.toml.
    / 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到 requirements.txtpyproject.toml 文件相应位置。

  • 😮 My changes do not introduce malicious code.
    / 我的更改没有引入恶意代码。

Summary by Sourcery

Add support for NVIDIA NIM-based rerank models as a configurable provider and wire it into the provider management system.

New Features:

  • Introduce an NVIDIA rerank provider adapter that calls the NVIDIA NIM retrieval/reranking API.
  • Expose NVIDIA rerank provider configuration (API base, key, model, endpoint, truncate policy) in the default provider templates and Web UI metadata with i18n support.

@auto-assign auto-assign bot requested review from Soulter and advent259141 March 31, 2026 08:40
@dosubot dosubot bot added size:L This PR changes 100-499 lines, ignoring generated files. area:provider The bug / feature is about AI Provider, Models, LLM Agent, LLM Agent Runner. area:webui The bug / feature is about webui(dashboard) of astrbot. labels Mar 31, 2026
Copy link
Copy Markdown
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 1 issue, and left some high level feedback:

  • Consider lazily creating and reusing the aiohttp.ClientSession via a shared/session manager instead of constructing it directly in __init__, to avoid potential event loop binding issues and make cleanup/lifecycle more robust.
  • In _get_endpoint, the model_path handling only checks for / and replaces . with _, but NVIDIA models also commonly include : (e.g., version suffixes); you may want to normalize or strip these consistently to avoid malformed URLs across different model name formats.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- Consider lazily creating and reusing the `aiohttp.ClientSession` via a shared/session manager instead of constructing it directly in `__init__`, to avoid potential event loop binding issues and make cleanup/lifecycle more robust.
- In `_get_endpoint`, the `model_path` handling only checks for `/` and replaces `.` with `_`, but NVIDIA models also commonly include `:` (e.g., version suffixes); you may want to normalize or strip these consistently to avoid malformed URLs across different model name formats.

## Individual Comments

### Comment 1
<location path="astrbot/core/provider/sources/nvidia_rerank_source.py" line_range="58-59" />
<code_context>
+
+        model_path = "nvidia"
+        logger.debug(f"[NVIDIA Rerank] Building endpoint for model: {self.model}")
+        if "/" in self.model:
+            model_path = self.model.strip("/").replace(".", "_")
+        endpoint = self.model_endpoint.lstrip("/")
+        return f"{self.base_url}/{model_path}/{endpoint}"
</code_context>
<issue_to_address>
**question (bug_risk):** Model path transformation may not match NVIDIA’s documented URL scheme and `replace('.', '_')` looks inconsistent.

This logic also couples model detection with transformation: any model containing `/` is both selected and mutated. Given the docstring’s example (`nvidia/llama-nemotron-rerank-1b-v2``.../nvidia/llama-nemotron-rerank-1b-v2/reranking`), keeping the model segment unchanged seems expected, but `replace('.', '_')` would silently alter valid model names. Consider either using the model string as-is (after trimming `/`) or applying transformations only when there is a documented NVIDIA requirement for a specific pattern.
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the NVIDIA Rerank provider, including its configuration schema, dynamic loading logic, and a dedicated provider implementation using aiohttp. Localization support for the new configuration fields has also been added to the dashboard. Feedback was provided to improve error handling when parsing API responses, specifically to ensure that non-JSON error bodies do not mask the underlying HTTP status code.

Comment on lines +129 to +140
response_data = await response.json()
logger.debug(f"[NVIDIA Rerank] API Response: {response_data}")

if response.status != 200:
error_detail = response_data.get(
"detail", response_data.get("message", "Unknown Error")
)
raise Exception(f"HTTP {response.status} - {error_detail}")

results = self._parse_results(response_data, top_n)
self._log_usage(response_data)
return results
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The response is parsed as JSON before checking the HTTP status code. If the API returns a non-200 status with a non-JSON body (e.g., an HTML error page from a proxy or a 504 Gateway Timeout), response.json() will raise a ContentTypeError, which masks the actual HTTP status and error message. It is safer to check the status code first and handle potential non-JSON error responses gracefully.

Suggested change
response_data = await response.json()
logger.debug(f"[NVIDIA Rerank] API Response: {response_data}")
if response.status != 200:
error_detail = response_data.get(
"detail", response_data.get("message", "Unknown Error")
)
raise Exception(f"HTTP {response.status} - {error_detail}")
results = self._parse_results(response_data, top_n)
self._log_usage(response_data)
return results
if response.status != 200:
try:
err_json = await response.json()
err_msg = err_json.get("detail", err_json.get("message", "Unknown Error"))
except Exception:
err_msg = await response.text()
raise Exception(f"HTTP {response.status} - {err_msg}")
response_data = await response.json()
logger.debug(f"[NVIDIA Rerank] API Response: {response_data}")
results = self._parse_results(response_data, top_n)
self._log_usage(response_data)
return results

@WenqiOfficial
Copy link
Copy Markdown
Contributor Author

@gemini-code-assist review

@WenqiOfficial
Copy link
Copy Markdown
Contributor Author

@sourcery-ai review

Copy link
Copy Markdown
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've left some high level feedback:

  • In NvidiaRerankProvider.rerank, consider raising a domain-specific error type (e.g., the existing provider error/exception class used elsewhere) instead of generic Exception, so caller-side error handling can distinguish provider failures more reliably.
  • You may want to validate nvidia_rerank_api_key (and possibly base_url) at initialization and log/raise a clear configuration error early, rather than deferring to a network call that will fail with a less specific message.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- In `NvidiaRerankProvider.rerank`, consider raising a domain-specific error type (e.g., the existing provider error/exception class used elsewhere) instead of generic `Exception`, so caller-side error handling can distinguish provider failures more reliably.
- You may want to validate `nvidia_rerank_api_key` (and possibly `base_url`) at initialization and log/raise a clear configuration error early, rather than deferring to a network call that will fail with a less specific message.

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the NvidiaRerankProvider to support NVIDIA's reranking API, including configuration defaults, dynamic provider loading, and localization. Feedback was provided regarding the error handling logic in the rerank method, specifically recommending that the response body be read once to avoid potential issues when attempting to parse it as both JSON and text during error reporting.

Comment on lines +137 to +146
error_detail = response_data.get(
"detail", response_data.get("message", "Unknown Error")
)

except Exception:
error_detail = await response.text()
response_data = {"message": error_detail}

logger.error(f"[NVIDIA Rerank] API Error Response: {response_data}")
raise Exception(f"HTTP {response.status} - {error_detail}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error handling logic here might fail if the response body is consumed twice. When response.status != 200, await response.json() is called. If it fails (e.g., the response is not valid JSON), the except block calls await response.text(). While aiohttp usually allows reading the body multiple times if it's cached, it's safer to read the body once and then attempt to parse it.

@Soulter Soulter merged commit b0b6816 into AstrBotDevs:master Apr 6, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:provider The bug / feature is about AI Provider, Models, LLM Agent, LLM Agent Runner. area:webui The bug / feature is about webui(dashboard) of astrbot. size:L This PR changes 100-499 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants