diff --git a/nemoguardrails/library/jailbreak_detection/request.py b/nemoguardrails/library/jailbreak_detection/request.py index 64d5a0b1a..933381457 100644 --- a/nemoguardrails/library/jailbreak_detection/request.py +++ b/nemoguardrails/library/jailbreak_detection/request.py @@ -31,12 +31,30 @@ import asyncio import logging from typing import Optional +from urllib.parse import urljoin import aiohttp log = logging.getLogger(__name__) +def join_nim_url(base_url: str, classification_path: str) -> str: + """Join NIM base URL with classification path, handling trailing/leading slashes. + + Args: + base_url: The base NIM URL (with or without trailing slash) + classification_path: The classification endpoint path (with or without leading slash) + + Returns: + Properly joined URL + """ + # Ensure base_url ends with '/' for proper urljoin behavior + normalized_base = base_url.rstrip("/") + "/" + # Remove leading slash from classification path to ensure relative joining + normalized_path = classification_path.lstrip("/") + return urljoin(normalized_base, normalized_path) + + async def jailbreak_detection_heuristics_request( prompt: str, api_url: str = "http://localhost:1337/heuristics", @@ -101,14 +119,12 @@ async def jailbreak_nim_request( nim_auth_token: Optional[str], nim_classification_path: str, ): - from urllib.parse import urljoin - headers = {"Content-Type": "application/json", "Accept": "application/json"} payload = { "input": prompt, } - endpoint = urljoin(nim_url, nim_classification_path) + endpoint = join_nim_url(nim_url, nim_classification_path) try: async with aiohttp.ClientSession() as session: try: diff --git a/tests/test_jailbreak_request.py b/tests/test_jailbreak_request.py index c5227d516..54ea997d1 100644 --- a/tests/test_jailbreak_request.py +++ b/tests/test_jailbreak_request.py @@ -25,31 +25,51 @@ class TestJailbreakRequestChanges: """Test jailbreak request function changes introduced in this PR.""" def test_url_joining_logic(self): - """Test that URL joining works correctly using urljoin.""" + """Test that URL joining works correctly with all slash combinations.""" + from nemoguardrails.library.jailbreak_detection.request import join_nim_url + test_cases = [ ( "http://localhost:8000/v1", "classify", - "http://localhost:8000/classify", - ), # v1 replaced by classify + "http://localhost:8000/v1/classify", + ), ( "http://localhost:8000/v1/", "classify", "http://localhost:8000/v1/classify", - ), # trailing slash preserves v1 + ), + ( + "http://localhost:8000/v1", + "/classify", + "http://localhost:8000/v1/classify", + ), ( - "http://localhost:8000", - "v1/classify", + "http://localhost:8000/v1/", + "/classify", "http://localhost:8000/v1/classify", ), + ("http://localhost:8000", "classify", "http://localhost:8000/classify"), + ("http://localhost:8000", "/classify", "http://localhost:8000/classify"), + ("http://localhost:8000/", "classify", "http://localhost:8000/classify"), ("http://localhost:8000/", "/classify", "http://localhost:8000/classify"), + ( + "http://localhost:8000/api/v1", + "classify", + "http://localhost:8000/api/v1/classify", + ), + ( + "http://localhost:8000/api/v1/", + "/classify", + "http://localhost:8000/api/v1/classify", + ), ] - for base_url, path, expected_url in test_cases: - result = urljoin(base_url, path) + for base_url, classification_path, expected_url in test_cases: + result = join_nim_url(base_url, classification_path) assert ( result == expected_url - ), f"urljoin({base_url}, {path}) should equal {expected_url}" + ), f"join_nim_url({base_url}, {classification_path}) should equal {expected_url}, got {result}" def test_auth_header_logic(self): """Test the authorization header logic."""