-
Notifications
You must be signed in to change notification settings - Fork 679
FEAT Add supports_multi_turn property to targets and adapt attacks accordingly #1433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3fe8b18
8848c45
4adf22c
079751e
9afa84a
66f3646
0bf3fbf
f3a7b7a
2f0f8c2
17b56c3
4a590df
5dd2238
7150643
71f8a6b
2db50fd
c293cdb
17483f0
4e8ea0c
99e9595
a1fc0ec
6485fc6
0a55c5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -777,9 +777,22 @@ def duplicate(self) -> "_TreeOfAttacksNode": | |
| ) | ||
|
|
||
| # Duplicate the conversations to preserve history | ||
| duplicate_node.objective_target_conversation_id = self._memory.duplicate_conversation( | ||
| conversation_id=self.objective_target_conversation_id | ||
| ) | ||
| # For single-turn targets, duplicate only the system messages (e.g., system prompt | ||
| # from prepended conversation) so the target retains its configuration without | ||
| # carrying over attack turn history that would cause validation errors. | ||
| if self._objective_target.supports_multi_turn: | ||
| duplicate_node.objective_target_conversation_id = self._memory.duplicate_conversation( | ||
| conversation_id=self.objective_target_conversation_id | ||
| ) | ||
| else: | ||
| messages = self._memory.get_conversation(conversation_id=self.objective_target_conversation_id) | ||
| system_messages = [m for m in messages if m.api_role == "system"] | ||
| if system_messages: | ||
| new_id, pieces = self._memory.duplicate_messages(messages=system_messages) | ||
| self._memory.add_message_pieces_to_memory(message_pieces=pieces) | ||
| duplicate_node.objective_target_conversation_id = new_id | ||
|
Comment on lines
+790
to
+793
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO @romanlutz to check correctness |
||
| else: | ||
| duplicate_node.objective_target_conversation_id = str(uuid.uuid4()) | ||
|
|
||
| duplicate_node.adversarial_chat_conversation_id = self._memory.duplicate_conversation( | ||
| conversation_id=self.adversarial_chat_conversation_id | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |||||||||||||
| from pyrit.models import MessagePiece | ||||||||||||||
| from pyrit.models.json_response_config import _JsonResponseConfig | ||||||||||||||
| from pyrit.prompt_target.common.prompt_target import PromptTarget | ||||||||||||||
| from pyrit.prompt_target.common.target_capabilities import TargetCapabilities | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class PromptChatTarget(PromptTarget): | ||||||||||||||
|
|
@@ -21,6 +22,8 @@ class PromptChatTarget(PromptTarget): | |||||||||||||
| Realtime chat targets or OpenAI completions are NOT PromptChatTargets. You don't send the conversation history. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=True) | ||||||||||||||
|
|
||||||||||||||
| def __init__( | ||||||||||||||
| self, | ||||||||||||||
| *, | ||||||||||||||
|
|
@@ -47,6 +50,19 @@ def __init__( | |||||||||||||
| underlying_model=underlying_model, | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| @property | ||||||||||||||
| def supports_multi_turn(self) -> bool: | ||||||||||||||
| """ | ||||||||||||||
| Whether this target supports multi-turn conversations. | ||||||||||||||
|
|
||||||||||||||
| Chat targets retrieve conversation history from memory and send it | ||||||||||||||
| with each request, supporting true multi-turn conversations. | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| bool: True for chat targets. | ||||||||||||||
| """ | ||||||||||||||
| return True | ||||||||||||||
|
Comment on lines
+62
to
+64
|
||||||||||||||
| bool: True for chat targets. | |
| """ | |
| return True | |
| bool: True for chat targets by default, unless overridden via capabilities. | |
| """ | |
| return self.capabilities.supports_multi_turn |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |||||||||
| from pyrit.identifiers import ComponentIdentifier, Identifiable | ||||||||||
| from pyrit.memory import CentralMemory, MemoryInterface | ||||||||||
| from pyrit.models import Message | ||||||||||
| from pyrit.prompt_target.common.target_capabilities import TargetCapabilities | ||||||||||
|
|
||||||||||
| logger = logging.getLogger(__name__) | ||||||||||
|
|
||||||||||
|
|
@@ -28,13 +29,16 @@ class PromptTarget(Identifiable): | |||||||||
|
|
||||||||||
| _identifier: Optional[ComponentIdentifier] = None | ||||||||||
|
|
||||||||||
| _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities() | ||||||||||
|
|
||||||||||
| def __init__( | ||||||||||
| self, | ||||||||||
| verbose: bool = False, | ||||||||||
| max_requests_per_minute: Optional[int] = None, | ||||||||||
| endpoint: str = "", | ||||||||||
| model_name: str = "", | ||||||||||
| underlying_model: Optional[str] = None, | ||||||||||
| capabilities: Optional[TargetCapabilities] = None, | ||||||||||
| ) -> None: | ||||||||||
|
Comment on lines
34
to
42
|
||||||||||
| """ | ||||||||||
| Initialize the PromptTarget. | ||||||||||
|
|
@@ -48,13 +52,18 @@ def __init__( | |||||||||
| identification purposes. This is useful when the deployment name in Azure differs | ||||||||||
| from the actual model. If not provided, `model_name` will be used for the identifier. | ||||||||||
| Defaults to None. | ||||||||||
| capabilities (TargetCapabilities, Optional): Override the default capabilities for | ||||||||||
| this target instance. Useful for targets whose capabilities depend on deployment | ||||||||||
| configuration (e.g., Playwright, HTTP). If None, uses the class-level | ||||||||||
| ``_DEFAULT_CAPABILITIES``. Defaults to None. | ||||||||||
| """ | ||||||||||
| self._memory = CentralMemory.get_memory_instance() | ||||||||||
| self._verbose = verbose | ||||||||||
| self._max_requests_per_minute = max_requests_per_minute | ||||||||||
| self._endpoint = endpoint | ||||||||||
| self._model_name = model_name | ||||||||||
| self._underlying_model = underlying_model | ||||||||||
| self._capabilities = capabilities if capabilities is not None else type(self)._DEFAULT_CAPABILITIES | ||||||||||
|
|
||||||||||
| if self._verbose: | ||||||||||
| logging.basicConfig(level=logging.INFO) | ||||||||||
|
|
@@ -128,12 +137,43 @@ def _create_identifier( | |||||||||
| "model_name": model_name, | ||||||||||
| "max_requests_per_minute": self._max_requests_per_minute, | ||||||||||
| "supports_conversation_history": isinstance(self, PromptChatTarget), | ||||||||||
| "supports_multi_turn": self.supports_multi_turn, | ||||||||||
| } | ||||||||||
| if params: | ||||||||||
| all_params.update(params) | ||||||||||
|
|
||||||||||
| return ComponentIdentifier.of(self, params=all_params, children=children) | ||||||||||
|
|
||||||||||
| @property | ||||||||||
| def capabilities(self) -> TargetCapabilities: | ||||||||||
| """ | ||||||||||
| The capabilities of this target instance. | ||||||||||
|
|
||||||||||
| Defaults to the class-level ``_DEFAULT_CAPABILITIES``. Can be overridden | ||||||||||
| per instance by setting this property, which is useful for targets whose | ||||||||||
| capabilities depend on deployment configuration (e.g., Playwright, HTTP). | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| TargetCapabilities: The capabilities for this target. | ||||||||||
| """ | ||||||||||
| return self._capabilities | ||||||||||
|
|
||||||||||
| @capabilities.setter | ||||||||||
| def capabilities(self, value: TargetCapabilities) -> None: | ||||||||||
| self._capabilities = value | ||||||||||
|
||||||||||
| self._capabilities = value | |
| self._capabilities = value | |
| # Invalidate cached identifier so it can be rebuilt with updated capabilities. | |
| self._identifier = None |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class TargetCapabilities: | ||
| """ | ||
| Describes the capabilities of a PromptTarget so that attacks | ||
| and other components can adapt their behavior accordingly. | ||
| Each target class defines default capabilities via the _DEFAULT_CAPABILITIES | ||
| class attribute. Users can override individual capabilities per instance | ||
| through constructor parameters, which is useful for targets whose | ||
| capabilities depend on deployment configuration (e.g., Playwright, HTTP). | ||
| """ | ||
|
|
||
| # Whether the target natively supports multi-turn conversations | ||
| # (i.e., it accepts and uses conversation history or maintains state | ||
| # across turns via external mechanisms like WebSocket connections). | ||
| supports_multi_turn: bool = False |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -507,7 +507,7 @@ def _is_text_message_format(self, conversation: MutableSequence[Message]) -> boo | |
| for turn in conversation: | ||
| if len(turn.message_pieces) != 1: | ||
| return False | ||
| if turn.message_pieces[0].converted_value_data_type != "text": | ||
| if turn.message_pieces[0].converted_value_data_type not in ("text", "error"): | ||
| return False | ||
| return True | ||
|
|
||
|
|
@@ -535,7 +535,7 @@ def _build_chat_messages_for_text(self, conversation: MutableSequence[Message]) | |
|
|
||
| message_piece = message.message_pieces[0] | ||
|
|
||
| if message_piece.converted_value_data_type != "text": | ||
| if message_piece.converted_value_data_type not in ("text", "error"): | ||
| raise ValueError("_build_chat_messages_for_text only supports text.") | ||
|
|
||
|
Comment on lines
+538
to
540
|
||
| chat_message = ChatMessage(role=message_piece.api_role, content=message_piece.converted_value) | ||
|
|
@@ -581,7 +581,7 @@ async def _build_chat_messages_for_multi_modal_async( | |
| ): | ||
| continue | ||
|
|
||
| if message_piece.converted_value_data_type == "text": | ||
| if message_piece.converted_value_data_type in ("text", "error"): | ||
| entry = {"type": "text", "text": message_piece.converted_value} | ||
| content.append(entry) | ||
| elif message_piece.converted_value_data_type == "image_path": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_rotate_conversation_for_single_turn_targetrelies on an inlinefrom pyrit.memory import CentralMemoryimport. This module doesn't appear to have a circular dependency withpyrit.memory, so the import should be moved to the top of the file to match the project's import-organization convention and avoid repeated imports on every rotation call.