Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Set;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import com.ibm.watsonx.ai.chat.model.AssistantMessage;
import com.ibm.watsonx.ai.chat.model.ChatUsage;
Expand Down Expand Up @@ -221,20 +222,33 @@ public String toText() {
}

/**
* Extracts the textual content enclosed within the specified XML-like tags from the assistant's response.
* Extracts textual content enclosed within the specified XML-like tags from the assistant's response.
* <p>
* This method is particularly useful when working with models that output segmented content using tags such as {@code <think>} or
* {@code <response>}. The input should contain only the tag names (e.g., {@code "think"}, {@code "response"}), not the angle brackets.
* This method parses the assistant's output as XML (wrapped in a synthetic {@code <root>} tag) and returns the textual content associated with
* each requested tag.
* <ul>
* <li>For normal tags, the returned value is the full text inside the tag, including nested elements' text.</li>
* <li>For the synthetic {@code root} tag (i.e., the top-level element), only the <b>direct text nodes</b> outside of any child elements are
* included.</li>
* </ul>
* <p>
* This behavior is particularly useful when working with models that output segmented content using tags such as {@code <think>} or
* {@code <response>}, and when you need to distinguish between top-level text and text inside nested tags.
* <p>
* The input should contain only the tag names (e.g., {@code "think"}, {@code "response"}), not the angle brackets.
*
* <p>
* <b>Example usage:</b>
* </p>
*
* <pre>{@code
* var tags = Set.of("think", "response");
* var parts = instance.toTextByTags(tags);
* String think = parts.get("think");
* var tags = Set.of("think", "response", "root");
* var parts = chatResponse.toTextByTags(tags);
* String think = parts.get("think"); // text inside <think>...</think>
* String resp = parts.get("response"); // text inside <response>...</response>
* String rootText = parts.get("root"); // only direct text outside child tags
* }</pre>
*
*
* @param tags a set of tag names to extract content from, without angle brackets
* @return a map where each key is a tag name and its value is the corresponding extracted text
*/
Expand All @@ -244,13 +258,34 @@ public Map<String, String> toTextByTags(Set<String> tags) {
var wrappedXml = "<root>" + toText() + "</root>";

Document doc = XmlUtils.parse(wrappedXml);
Element root = doc.getDocumentElement();
Map<String, String> result = new HashMap<>();

for (String tag : tags) {

NodeList nodes = doc.getElementsByTagName(tag);

for (int i = 0; i < nodes.getLength(); i++) {
Element element = (Element) nodes.item(i);
String textContent = element.getTextContent().trim();
String textContent;
if (element == root) {
StringBuilder sb = new StringBuilder();
NodeList children = element.getChildNodes();
for (int j = 0; j < children.getLength(); j++) {
Node child = children.item(j);

if (child.getNodeType() != Node.TEXT_NODE)
continue;

String text = child.getTextContent().trim();
if (!text.isEmpty())
sb.append(text);
}
textContent = sb.isEmpty() ? null : sb.toString();
} else {
textContent = element.getTextContent().trim();
}

result.put(tag, textContent);
}
}
Expand All @@ -261,18 +296,21 @@ public Map<String, String> toTextByTags(Set<String> tags) {
/**
* Extracts the textual content enclosed within a single specified XML-like tag from the assistant's response.
* <p>
* This method is particularly useful when working with models that output segmented content using tags such as {@code <think>} or
* {@code <response>}. The input should contain only the tag names (e.g., {@code "think"}, {@code "response"}), not the angle brackets.
* Behaves like {@link #toTextByTags(Set)} but for a single tag. If the specified tag is the synthetic {@code root} element, only the direct text
* nodes outside of child tags are included.
* <p>
* The input should contain only the tag name (e.g., {@code "think"}), not the angle brackets.
*
* <p>
* <b>Example usage:</b>
* </p>
*
* <pre>
* String response = instance.toTextByTag("response");
* </pre>
* <pre>{@code
* String think = instance.toTextByTag("think");
* }</pre>
*
* @param tag the tag name to extract content from, without angle brackets
* @return the textual content inside the specified tag, or {@code null} if not present
* @throws RuntimeException if the underlying text is not valid XML or parsing fails
* @return the textual content inside the specified tag
*/
public String toTextByTag(String tag) {
return toTextByTags(Set.of(tag)).get(tag);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1308,10 +1308,95 @@ void test_control_message() throws Exception {
var parts = chatResponse.toTextByTags(Set.of("think", "response"));
assertEquals("Think", parts.get("think"));
assertEquals("Result", parts.get("response"));
assertNull(parts.get("root"));
assertEquals("Result", chatResponse.toTextByTag("response"));
assertNull(chatResponse.toTextByTag("root"));
assertEquals(captor.getValue().headers().firstValue(TRANSACTION_ID_HEADER).orElse(null), "my-transaction-id");
}

@Test
void test_control_message_with_single_tag() throws Exception {

final String REQUEST = """
{
"model_id": "ibm/granite-3-3-8b-instruct",
"project_id": "63dc4cf1-252f-424b-b52d-5cdd9814987f",
"messages": [
{
"role": "control",
"content": "thinking"
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is the result of 1 + 1"
}
]
}
]
}""";

final String RESPONSE =
"""
{
"id": "chatcmpl-326dc89051c826dd8d3c690a2d716e77",
"object": "chat.completion",
"model_id": "ibm/granite-3-3-8b-instruct",
"model": "ibm/granite-3-3-8b-instruct",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "<think>Think</think>Result"
},
"finish_reason": "stop"
}
],
"created": 1749323488,
"model_version": "3.3.0",
"created_at": "2025-06-07T19:11:29.419Z",
"usage": {
"completion_tokens": 162,
"prompt_tokens": 198,
"total_tokens": 360
}
}""";

when(mockAuthenticationProvider.getToken()).thenReturn("my-super-token");
when(mockHttpResponse.statusCode()).thenReturn(200);
when(mockHttpResponse.body()).thenReturn(RESPONSE);

var chatService = ChatService.builder()
.authenticationProvider(mockAuthenticationProvider)
.httpClient(mockHttpClient)
.modelId("ibm/granite-3-3-8b-instruct")
.projectId("63dc4cf1-252f-424b-b52d-5cdd9814987f")
.url(CloudRegion.SYDNEY)
.build();

ArgumentCaptor<HttpRequest> captor = ArgumentCaptor.forClass(HttpRequest.class);
when(mockHttpClient.send(captor.capture(), any(BodyHandler.class))).thenReturn(mockHttpResponse);

var messages = List.<ChatMessage>of(
ControlMessage.of("thinking"),
UserMessage.text("What is the result of 1 + 1")
);

var chatResponse = chatService.chat(messages, ChatParameters.builder().transactionId("my-transaction-id").build());
JSONAssert.assertEquals(REQUEST, bodyPublisherToString(captor), false);
JSONAssert.assertEquals(RESPONSE, Json.toJson(chatResponse), false);

var parts = chatResponse.toTextByTags(Set.of("think", "root"));
assertEquals("Think", parts.get("think"));
assertEquals("Result", parts.get("root"));
assertEquals("Result", chatResponse.toTextByTag("root"));
assertEquals(captor.getValue().headers().firstValue(TRANSACTION_ID_HEADER).orElse(null), "my-transaction-id");
}


@Test
void chat_streaming_test() throws Exception {

Expand Down