Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/content/docs/features/ai/backend-integration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ const model = createOpenAICompatible({
})('model-id');

// ...
createAIExtension({
AIExtension({
transport: new ClientSideTransport({
model,
}),
Expand Down
6 changes: 3 additions & 3 deletions docs/content/docs/features/ai/getting-started.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import { createBlockNoteEditor } from "@blocknote/core";
import { BlockNoteAIExtension } from "@blocknote/xl-ai";
import { en } from "@blocknote/core/locales";
import { en as aiEn } from "@blocknote/xl-ai/locales";
import { createAIExtension } from "@blocknote/xl-ai";
import { AIExtension } from "@blocknote/xl-ai";
import "@blocknote/xl-ai/style.css"; // add the AI stylesheet

const editor = createBlockNoteEditor({
Expand All @@ -34,7 +34,7 @@ const editor = createBlockNoteEditor({
ai: aiEn, // add default translations for the AI extension
},
extensions: [
createAIExtension({
AIExtension({
transport: new DefaultChatTransport({
api: `/api/chat`,
}),
Expand All @@ -44,7 +44,7 @@ const editor = createBlockNoteEditor({
});
```

See the [API Reference](/docs/features/ai/reference) for more information on the `createAIExtension` method.
See the [API Reference](/docs/features/ai/reference) for more information on the `AIExtension` options.

## Adding AI UI elements

Expand Down
27 changes: 15 additions & 12 deletions docs/content/docs/features/ai/reference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ description: Reference documentation for the BlockNote AI extension
imageTitle: BlockNote AI
---

## `createAIExtension`
## `AIExtension`

Use `createAIExtension` to create a new AI Extension that can be registered to an editor when calling `useCreateBlockNote`.
Use `AIExtension` to create a new AI Extension that can be registered to an editor when calling `useCreateBlockNote`.

```typescript
// Usage:
const aiExtension = createAIExtension(opts: AIExtensionOptions);

// Definitions:
function createAIExtension(options: AIExtensionOptions): (editor: BlockNoteEditor) => AIExtension;
useCreateBlockNote({
// Register the AI extension
extensions: [AIExtension(options)],
// other editor options
});

type AIExtensionOptions = AIRequestHelpers & {
/**
Expand Down Expand Up @@ -42,7 +43,7 @@ type AIRequestHelpers = {
* Customize which stream tools are available to the LLM.
*/
streamToolsProvider?: StreamToolsProvider<any, any>;
// Provide `streamToolsProvider` in createAIExtension(options) or override per call via InvokeAIOptions.
// Provide `streamToolsProvider` in AIExtension(options) or override per call via InvokeAIOptions.
// If omitted, defaults to using `aiDocumentFormats.html.getStreamToolsProvider()`.

/**
Expand All @@ -59,12 +60,12 @@ type AIRequestHelpers = {
};
```

## `AIExtension`
## `AIExtension` extension instance

The `AIExtension` class is the main class for the AI extension. It exposes state and methods to interact with BlockNote's AI features.
The `AIExtension` extension instance returned by `editor.getExtension(AIExtension)` exposes state and methods to interact with BlockNote's AI features.

```typescript
class AIExtension {
type AIExtensionInstance = {
/**
* Execute a call to an LLM and apply the result to the editor
*/
Expand Down Expand Up @@ -113,6 +114,8 @@ class AIExtension {
rejectChanges(): void;
/** Retry the previous LLM call (only valid when status is "error") */
retry(): Promise<void>;
/** Abort the current LLM request */
abort(reason?: any): Promise<void>;
/** Advanced: manually update the status shown by the AI menu */
setAIResponseStatus(
status:
Expand All @@ -122,12 +125,12 @@ class AIExtension {
| "user-reviewing"
| { status: "error"; error: any },
): void;
}
};
```

### `InvokeAI`

Requests to an LLM are made by calling `invokeAI` on the `AIExtension` object. This takes an `InvokeAIOptions` object as an argument.
Requests to an LLM are made by calling `invokeAI` on the `AIExtension` instance. This takes an `InvokeAIOptions` object as an argument.

```typescript
type InvokeAIOptions = {
Expand Down
1 change: 1 addition & 0 deletions examples/09-ai/01-minimal/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { en as aiEn } from "@blocknote/xl-ai/locales";
import "@blocknote/xl-ai/style.css";

import { DefaultChatTransport } from "ai";
import { useEffect } from "react";
import { getEnv } from "./getEnv";

const BASE_URL =
Expand Down
42 changes: 41 additions & 1 deletion packages/xl-ai/src/AIExtension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ export const AIExtension = createExtension(
| {
previousRequestOptions: InvokeAIOptions;
chat: Chat<UIMessage>;
abortController: AbortController;
}
| undefined;
let autoScroll = false;
Expand Down Expand Up @@ -233,6 +234,36 @@ export const AIExtension = createExtension(
this.closeAIMenu();
},

/**
* Abort the current LLM request.
*
* This will stop the ongoing request and revert any changes made by the AI.
* Only valid when there is an active AI request in progress.
*/
async abort(reason?: any) {
const { aiMenuState } = store.state;
if (aiMenuState === "closed" || !chatSession) {
return;
}

// Only abort if the request is in progress (thinking or ai-writing)
if (
aiMenuState.status !== "thinking" &&
aiMenuState.status !== "ai-writing"
) {
return;
}

const chat = chatSession.chat;
const abortController = chatSession.abortController;

// Abort the tool call operations
abortController.abort(reason);

// Stop the chat request
await chat.stop();
},

/**
* Retry the previous LLM call.
*
Expand Down Expand Up @@ -341,6 +372,9 @@ export const AIExtension = createExtension(
editor.getExtension(ForkYDocExtension)?.fork();

try {
// Create a new AbortController for this request
const abortController = new AbortController();

if (!chatSession) {
// note: in the current implementation opts.transport is only used when creating a new chat
// (so changing transport for a subsequent call in the same chat-session is not supported)
Expand All @@ -353,9 +387,11 @@ export const AIExtension = createExtension(
sendAutomaticallyWhen: () => false,
transport: opts.transport || this.options.state.transport,
}),
abortController,
};
} else {
chatSession.previousRequestOptions = opts;
chatSession.abortController = abortController;
}
const chat = chatSession.chat;

Expand Down Expand Up @@ -439,9 +475,13 @@ export const AIExtension = createExtension(
],
},
opts.chatRequestOptions || this.options.state.chatRequestOptions,
chatSession.abortController.signal,
);

if (result.ok && chat.status !== "error") {
if (
(result.ok && chat.status !== "error") ||
abortController.signal.aborted
) {
this.setAIResponseStatus("user-reviewing");
} else {
// eslint-disable-next-line no-console
Expand Down
3 changes: 3 additions & 0 deletions packages/xl-ai/src/api/aiRequest/sendMessageWithAIRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { AIRequest } from "./types.js";
* @param aiRequest - the AI request (create using {@link buildAIRequest})
* @param message - the message to send to the LLM (optional, defaults to the last message)
* @param options - the `ChatRequestOptions` to pass to the `chat.sendMessage` method (custom metadata, body, etc)
* @param abortSignal - Optional AbortSignal to cancel ongoing tool call operations
*
* @returns the result of the tool call processing. Consumer should check both `chat.status` and `result.ok`;
* - `chat.status` indicates if the LLM request succeeeded
Expand All @@ -29,6 +30,7 @@ export async function sendMessageWithAIRequest(
aiRequest: AIRequest,
message?: Parameters<Chat<UIMessage>["sendMessage"]>[0],
options?: Parameters<Chat<UIMessage>["sendMessage"]>[1],
abortSignal?: AbortSignal,
) {
const sendingMessage = message ?? chat.lastMessage;

Expand All @@ -44,6 +46,7 @@ export async function sendMessageWithAIRequest(
aiRequest.streamTools,
chat,
aiRequest.onStart,
abortSignal,
);
options = merge(options, {
metadata: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
import { updateToReplaceSteps } from "../../../prosemirror/changeset.js";
import { RebaseTool } from "../../../prosemirror/rebaseTool.js";
import { Result, streamTool } from "../../../streamTool/streamTool.js";
import { AbortError } from "../../../util/AbortError.js";
import { isEmptyParagraph } from "../../../util/emptyBlock.js";
import { validateBlockArray } from "./util/validateBlockArray.js";

Expand Down Expand Up @@ -182,7 +183,7 @@ export function createAddBlocksTool<T>(config: {
const referenceIdMap: Record<string, string> = {}; // TODO: unit test

return {
execute: async (chunk) => {
execute: async (chunk, abortSignal?: AbortSignal) => {
if (!chunk.isUpdateToPreviousOperation) {
// we have a new operation, reset the added block ids
addedBlockIds = [];
Expand Down Expand Up @@ -268,6 +269,9 @@ export function createAddBlocksTool<T>(config: {
// }

for (const step of agentSteps) {
if (abortSignal?.aborted) {
throw new AbortError("Operation was aborted");
}
if (options.withDelays) {
await delayAgentStep(step);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
import { updateToReplaceSteps } from "../../../prosemirror/changeset.js";
import { RebaseTool } from "../../../prosemirror/rebaseTool.js";
import { Result, streamTool } from "../../../streamTool/streamTool.js";
import { AbortError } from "../../../util/AbortError.js";

export type UpdateBlockToolCall<T> = {
type: "update";
Expand Down Expand Up @@ -177,7 +178,7 @@ export function createUpdateBlockTool<T>(config: {
}
: undefined;
return {
execute: async (chunk) => {
execute: async (chunk, abortSignal?: AbortSignal) => {
if (chunk.operation.type !== "update") {
// pass through non-update operations
return false;
Expand Down Expand Up @@ -244,6 +245,9 @@ export function createUpdateBlockTool<T>(config: {
const agentSteps = getStepsAsAgent(tr);

for (const step of agentSteps) {
if (abortSignal?.aborted) {
throw new AbortError("Operation was aborted");
}
if (options.withDelays) {
await delayAgentStep(step);
}
Expand Down
5 changes: 4 additions & 1 deletion packages/xl-ai/src/streamTool/ChunkExecutionError.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
export class ChunkExecutionError extends Error {
public readonly aborted: boolean;

constructor(
message: string,
public readonly chunk: any,
options?: { cause?: unknown },
options?: { cause?: unknown; aborted?: boolean },
) {
super(message, options);
this.name = "ChunkExecutionError";
this.aborted = options?.aborted ?? false;
}
}
32 changes: 22 additions & 10 deletions packages/xl-ai/src/streamTool/StreamToolExecutor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ export class StreamToolExecutor<T extends StreamTool<any>[]> {

/**
* @param streamTools - The StreamTools to use to apply the StreamToolCalls
* @param abortSignal - Optional AbortSignal to cancel ongoing operations
*/
constructor(private streamTools: T) {
constructor(
private streamTools: T,
private abortSignal?: AbortSignal,
) {
this.stream = this.createStream();
}

Expand Down Expand Up @@ -115,27 +119,35 @@ export class StreamToolExecutor<T extends StreamTool<any>[]> {
let handled = false;
for (const executor of executors) {
try {
const result = await executor.execute(chunk);
// Pass the signal to executor - it should handle abort internally
const result = await executor.execute(chunk, this.abortSignal);
if (result) {
controller.enqueue({ status: "ok", chunk });
handled = true;
break;
}
} catch (error) {
throw new ChunkExecutionError(
`Tool execution failed: ${getErrorMessage(error)}`,
chunk,
{
cause: error,
},
controller.error(
new ChunkExecutionError(
`Tool execution failed: ${getErrorMessage(error)}`,
chunk,
{
cause: error,
aborted: this.abortSignal?.aborted ?? false,
},
),
);
return;
}
}
if (!handled) {
const operationType = (chunk.operation as any)?.type || "unknown";
throw new Error(
`No tool could handle operation of type: ${operationType}`,
controller.error(
new Error(
`No tool could handle operation of type: ${operationType}`,
),
);
return;
}
},
});
Expand Down
15 changes: 9 additions & 6 deletions packages/xl-ai/src/streamTool/streamTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,15 @@ export type StreamTool<T extends { type: string }> = {
* @returns the stream of operations that have not been processed (and should be passed on to execute handlers of other StreamTools)
*/
executor: () => {
execute: (chunk: {
operation: StreamToolCall<StreamTool<{ type: string }>[]>;
isUpdateToPreviousOperation: boolean;
isPossiblyPartial: boolean;
metadata: any;
}) => Promise<boolean>;
execute: (
chunk: {
operation: StreamToolCall<StreamTool<{ type: string }>[]>;
isUpdateToPreviousOperation: boolean;
isPossiblyPartial: boolean;
metadata: any;
},
abortSignal?: AbortSignal,
) => Promise<boolean>;
};
};

Expand Down
Loading
Loading