From 723db74cb0511891c14f2a18682d7bfb7f53f056 Mon Sep 17 00:00:00 2001 From: Yumiue <229866007@qq.com> Date: Sat, 30 May 2026 21:37:11 -0400 Subject: [PATCH 1/4] feat: support web multimodal session assets --- docs/gateway-rpc-api.md | 52 ++- docs/reference/gateway-error-catalog.md | 2 +- docs/reference/gateway-rpc-api.md | 40 ++- internal/cli/gateway_runtime_bridge.go | 111 ++++++- internal/cli/gateway_runtime_bridge_test.go | 11 +- internal/cli/root_test.go | 8 + .../dispatcher_integration_unix_test.go | 14 + internal/gateway/bootstrap_test.go | 14 + internal/gateway/contracts.go | 49 +++ internal/gateway/contracts_test.go | 8 + internal/gateway/multi_workspace_runtime.go | 16 + .../gateway/multi_workspace_runtime_test.go | 24 ++ internal/gateway/network_server.go | 223 ++++++++++++- internal/gateway/network_server_test.go | 290 +++++++++++++++++ internal/gateway/protocol/jsonrpc.go | 2 + internal/gateway/protocol/jsonrpc_test.go | 13 +- internal/gateway/rpc_dispatch.go | 1 + internal/gateway/rpc_dispatch_test.go | 26 +- internal/gateway/server_test.go | 18 +- internal/gateway/static_files.go | 1 + internal/gateway/types.go | 2 + internal/gateway/validate.go | 6 +- internal/gateway/validate_test.go | 37 ++- internal/runtime/input_prepare.go | 1 + internal/runtime/runtime.go | 1 + internal/session/input_preparer.go | 59 ++++ internal/session/input_preparer_test.go | 86 +++++ web/src/api/gateway.test.ts | 77 ++++- web/src/api/gateway.ts | 65 +++- web/src/api/protocol.ts | 17 +- web/src/components/chat/ChatInput.test.tsx | 85 ++++- web/src/components/chat/ChatInput.tsx | 299 ++++++++++++++++-- web/src/components/chat/MessageItem.test.tsx | 83 ++++- web/src/components/chat/MessageItem.tsx | 84 ++++- web/src/context/RuntimeProvider.tsx | 2 +- web/src/stores/useChatStore.ts | 16 +- web/src/stores/useComposerStore.test.ts | 62 +++- web/src/stores/useComposerStore.ts | 58 ++++ web/src/stores/useSessionStore.test.ts | 29 ++ web/src/stores/useSessionStore.ts | 57 +++- 40 files changed, 1989 insertions(+), 60 deletions(-) diff --git a/docs/gateway-rpc-api.md b/docs/gateway-rpc-api.md index 6425bd9ed..c3013e963 100644 --- a/docs/gateway-rpc-api.md +++ b/docs/gateway-rpc-api.md @@ -155,7 +155,8 @@ type BindStreamParams struct { ```go type RunInputMedia struct { - URI string `json:"uri"` + URI string `json:"uri,omitempty"` + AssetID string `json:"asset_id,omitempty"` MimeType string `json:"mime_type"` FileName string `json:"file_name,omitempty"` } @@ -175,6 +176,12 @@ type RunParams struct { } ``` +- 多模态图片约束: + - `type=image` 时 `media.mime_type` 必填。 + - `media.uri` 与 `media.asset_id` 必须二选一,不能同时为空或同时提供。 + - `media.uri` 仅用于后端可读取的本地路径;Web 浏览器上传图片应先通过 `POST /api/session-assets` 保存,再在 `gateway.run` 中使用 `media.asset_id` 引用。 + - `asset_id` 必须属于当前 `session_id`,不存在或跨 session 引用会在 runtime 输入准备阶段失败。 + - Response Schema: - Success(受理即返回): @@ -223,6 +230,49 @@ type RunParams struct { --- +## HTTP API: session assets + +浏览器图片上传不应把本地伪路径传给 Runtime。Web 客户端需要在发送前先创建或确认 `session_id`,再通过受鉴权保护的 HTTP API 保存图片,最后在 `gateway.run.input_parts[].media.asset_id` 中引用。 + +### POST /api/session-assets + +- Auth Required: Yes(`Authorization: Bearer `) +- Headers: + - `X-NeoCode-Workspace-Hash`: 当前工作区哈希。多工作区 Web 客户端必须发送;单工作区或旧客户端可省略并回落到默认工作区。 +- Content-Type: `multipart/form-data` +- Fields: + - `session_id`: 目标会话 ID,必填。 + - `file`: 图片文件,必填。 +- Server-side validation: + - 仅接受 `image/png`、`image/jpeg`、`image/webp`。 + - MIME 以服务端文件头检测结果为准,不信任浏览器声明。 + - 空文件返回 `400`。 + - 超过 `MaxSessionAssetBytes` 返回 `413`。 + - 非图片或不支持类型返回 `415`。 + - 未认证返回 `401`,Origin/CORS 或 ACL 拒绝返回 `403`。 + - 工作区不存在返回 `404 workspace not found`;目标 session 不在该工作区返回 `404 session not found`。 +- Response: + +```json +{ + "session_id": "sess-1", + "asset_id": "asset-1", + "mime_type": "image/png", + "size": 1024 +} +``` + +### GET /api/session-assets/{session_id}/{asset_id} + +- Auth Required: Yes(`Authorization: Bearer `) +- Headers: + - `X-NeoCode-Workspace-Hash`: 当前工作区哈希。多工作区 Web 客户端必须发送;省略时回落到默认工作区。 +- 返回图片二进制,`Content-Type` 为保存时确认的 MIME。 +- 用于历史消息缩略图按需读取。 +- 工作区不存在返回 `404 workspace not found`;不存在或不可见的 asset 返回 `404 asset not found`。 + +--- + ## Method: gateway.compact - Stability: Stable diff --git a/docs/reference/gateway-error-catalog.md b/docs/reference/gateway-error-catalog.md index 1c3de61ea..c9a6712f1 100644 --- a/docs/reference/gateway-error-catalog.md +++ b/docs/reference/gateway-error-catalog.md @@ -10,7 +10,7 @@ | --- | --- | --- | --- | --- | --- | | `invalid_frame` | `200` | `-32700` / `-32600` / `-32602` | 请求帧结构或编码不合法。包括 JSON 解析失败、请求体包含多余 JSON 值、`id/jsonrpc` 非法、`params` 严格解码失败。 | 非法 JSON;`id` 为 `null`;`params` 含未知字段。 | 不要直接重试,先修复请求构造器。 | | `invalid_action` | `200` | `-32602` | 动作参数值非法,但方法本身存在。 | `params.channel` 不在 `all/ipc/ws/sse`;`params.decision` 非 `allow_once/allow_session/reject`。 | 视为调用方输入错误,修正参数后再发。 | -| `invalid_multimodal_payload` | `200` | `-32602` | `gateway.run` 的 `input_parts` 结构或字段不满足契约。 | `image` 分片缺少 `media.uri` 或 `media.mime_type`;`text` 分片文本为空。 | 校验输入分片后重试,不做盲重试。 | +| `invalid_multimodal_payload` | `200` | `-32602` | `gateway.run` 的 `input_parts` 结构或字段不满足契约。 | `image` 分片缺少 `media.mime_type`,或 `media.uri` / `media.asset_id` 未满足二选一;`text` 分片文本为空。 | 校验输入分片后重试,不做盲重试。 | | `missing_required_field` | `200` | `-32600` / `-32602` | 缺失必填字段。请求层字段缺失多映射为 `-32600`,方法参数层字段缺失多映射为 `-32602`。 | 缺失 `id`;缺失 `params`;`cancel` 缺失 `run_id`。 | 调整参数补齐必填项再重试。 | | `unsupported_action` | `200` | `-32601` | 方法未注册或不被网关识别。 | 调用不存在的方法名。 | 客户端按能力探测降级,或升级服务端版本。 | | `internal_error` | `200` | `-32603` | 网关内部异常或未分类下游异常。 | 结果编码失败;runtime port 不可用;未知运行时错误。 | 采用指数退避重试;持续失败时告警。 | diff --git a/docs/reference/gateway-rpc-api.md b/docs/reference/gateway-rpc-api.md index 0a9c8be45..82dad5784 100644 --- a/docs/reference/gateway-rpc-api.md +++ b/docs/reference/gateway-rpc-api.md @@ -306,6 +306,13 @@ type RunParams struct { Mode string `json:"mode,omitempty"` // Agent 工作模式:build|plan,可选,默认沿用 session 当前 mode } +type RunInputMedia struct { + URI string `json:"uri,omitempty"` + AssetID string `json:"asset_id,omitempty"` + MimeType string `json:"mime_type"` + FileName string `json:"file_name,omitempty"` +} + type RunInputPart struct { Type string `json:"type"` // text|image Text string `json:"text,omitempty"` // text MUST @@ -318,7 +325,7 @@ type RunInputPart struct { 1. `input_text` 与 `input_parts` 至少一项非空。 2. `input_parts` 中: 1. `type=text` 时 `text` `MUST` 非空。 -2. `type=image` 时 `media.uri` 与 `media.mime_type` `MUST` 非空。 +2. `type=image` 时 `media.mime_type` `MUST` 非空,`media.uri` 与 `media.asset_id` `MUST` 二选一且不能同时提供。Web 上传图片应先调用 `POST /api/session-assets`,再在 `gateway.run` 中用 `asset_id` 引用。 3. 未知字段会因严格解码触发 `invalid_frame`。 4. `run_id` 归一化顺序为:显式 `run_id` > `request_id` > 网关生成 `run_`。 5. `mode` 可选值为 `"build"` 或 `"plan"`,为空时默认沿用 session 当前 mode(新会话默认为 `"build"`)。切换 mode 后,后端会更新 session 并影响后续运行的工具可用性和 prompt 策略。 @@ -397,6 +404,37 @@ sequenceDiagram G-->>C: ack(cancel) ``` +### HTTP session asset API + +浏览器图片上传使用 HTTP API,不通过 JSON-RPC 传输文件内容。客户端发送图片前需要先拥有有效 `session_id`(新会话可先调用 `gateway.createSession`)。 + +`POST /api/session-assets` + +- Auth Required: `Yes`,使用 `Authorization: Bearer `。 +- Headers: `X-NeoCode-Workspace-Hash` 携带当前工作区哈希;多工作区 Web 客户端必须发送,省略时回落到默认工作区。 +- Content-Type: `multipart/form-data`。 +- 字段:`session_id`(必填)、`file`(必填)。 +- 仅接受 PNG/JPEG/WebP;服务端按文件头检测 MIME,不信任浏览器声明。 +- 空文件返回 `400`,超出 `MaxSessionAssetBytes` 返回 `413`,不支持 MIME 返回 `415`,未认证返回 `401`,Origin/CORS 或 ACL 拒绝返回 `403`。 +- 工作区不存在返回 `404 workspace not found`;目标 session 不在该工作区返回 `404 session not found`。 +- 成功返回: + +```json +{ + "session_id": "session-1", + "asset_id": "asset-1", + "mime_type": "image/png", + "size": 1024 +} +``` + +`GET /api/session-assets/{session_id}/{asset_id}` + +- Auth Required: `Yes`。 +- Headers: `X-NeoCode-Workspace-Hash` 携带当前工作区哈希;多工作区 Web 客户端必须发送。 +- 返回图片二进制,用于历史消息缩略图。 +- 工作区不存在返回 `404 workspace not found`;不存在、跨 session 或不可见的 asset 返回 `404 asset not found`。 + Observation: 1. `gateway_requests_total{method="gateway.run",status="ok|error"}`。 diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index c3f8b61f1..11404bdc7 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -697,6 +697,66 @@ func (b *gatewayRuntimePortBridge) CreateSession(ctx context.Context, input gate return strings.TrimSpace(session.ID), nil } +// SaveSessionAsset 将浏览器上传的附件保存到当前工作区的 session asset store。 +func (b *gatewayRuntimePortBridge) SaveSessionAsset( + ctx context.Context, + input gateway.SaveSessionAssetInput, +) (gateway.SessionAssetMeta, error) { + if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { + return gateway.SessionAssetMeta{}, err + } + sessionID := strings.TrimSpace(input.SessionID) + if sessionID == "" { + return gateway.SessionAssetMeta{}, gateway.ErrRuntimeResourceNotFound + } + assetStore, ok := b.sessionStore.(agentsession.AssetStore) + if !ok || assetStore == nil { + return gateway.SessionAssetMeta{}, fmt.Errorf("gateway runtime bridge: session asset store is unavailable") + } + meta, err := assetStore.SaveAsset(ctx, sessionID, input.Reader, strings.TrimSpace(input.MimeType)) + if err != nil { + return gateway.SessionAssetMeta{}, err + } + return gateway.SessionAssetMeta{ + SessionID: sessionID, + AssetID: strings.TrimSpace(meta.ID), + MimeType: strings.TrimSpace(meta.MimeType), + Size: meta.Size, + }, nil +} + +// OpenSessionAsset 打开当前工作区的会话附件,供 Gateway HTTP 读取端点流式返回。 +func (b *gatewayRuntimePortBridge) OpenSessionAsset( + ctx context.Context, + input gateway.OpenSessionAssetInput, +) (gateway.OpenSessionAssetResult, error) { + if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { + return gateway.OpenSessionAssetResult{}, err + } + sessionID := strings.TrimSpace(input.SessionID) + assetID := strings.TrimSpace(input.AssetID) + if sessionID == "" || assetID == "" { + return gateway.OpenSessionAssetResult{}, gateway.ErrRuntimeResourceNotFound + } + assetStore, ok := b.sessionStore.(agentsession.AssetStore) + if !ok || assetStore == nil { + return gateway.OpenSessionAssetResult{}, fmt.Errorf("gateway runtime bridge: session asset store is unavailable") + } + reader, meta, err := assetStore.Open(ctx, sessionID, assetID) + if err != nil { + return gateway.OpenSessionAssetResult{}, err + } + return gateway.OpenSessionAssetResult{ + Reader: reader, + Meta: gateway.SessionAssetMeta{ + SessionID: sessionID, + AssetID: strings.TrimSpace(meta.ID), + MimeType: strings.TrimSpace(meta.MimeType), + Size: meta.Size, + }, + }, nil +} + // DeleteSession 删除/归档指定会话。 func (b *gatewayRuntimePortBridge) DeleteSession(ctx context.Context, input gateway.DeleteSessionInput) (bool, error) { if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { @@ -1684,11 +1744,13 @@ func convertGatewayRunInput(input gateway.RunInput) agentruntime.PrepareInput { continue } path := strings.TrimSpace(part.Media.URI) - if path == "" { + assetID := strings.TrimSpace(part.Media.AssetID) + if path == "" && assetID == "" { continue } images = append(images, agentruntime.UserImageInput{ Path: path, + AssetID: assetID, MimeType: strings.TrimSpace(part.Media.MimeType), }) } @@ -1867,6 +1929,7 @@ func convertSessionMessages(messages []providertypes.Message) []gateway.SessionM convertedMessage := gateway.SessionMessage{ Role: strings.TrimSpace(message.Role), Content: renderSessionMessageContent(message.Parts), + Parts: convertProviderContentParts(message.Parts), ToolCallID: strings.TrimSpace(message.ToolCallID), IsError: message.IsError, } @@ -1885,6 +1948,52 @@ func convertSessionMessages(messages []providertypes.Message) []gateway.SessionM return converted } +// convertProviderContentParts 将 provider 通用内容分片转换为 Gateway 会话快照分片。 +func convertProviderContentParts(parts []providertypes.ContentPart) []gateway.InputPart { + if len(parts) == 0 { + return nil + } + converted := make([]gateway.InputPart, 0, len(parts)) + for _, part := range parts { + switch part.Kind { + case providertypes.ContentPartText: + if text := strings.TrimSpace(part.Text); text != "" { + converted = append(converted, gateway.InputPart{ + Type: gateway.InputPartTypeText, + Text: text, + }) + } + case providertypes.ContentPartImage: + if part.Image == nil { + continue + } + switch part.Image.SourceType { + case providertypes.ImageSourceSessionAsset: + if part.Image.Asset == nil || strings.TrimSpace(part.Image.Asset.ID) == "" { + continue + } + converted = append(converted, gateway.InputPart{ + Type: gateway.InputPartTypeImage, + Media: &gateway.Media{ + AssetID: strings.TrimSpace(part.Image.Asset.ID), + MimeType: strings.TrimSpace(part.Image.Asset.MimeType), + }, + }) + case providertypes.ImageSourceRemote: + if url := strings.TrimSpace(part.Image.URL); url != "" { + converted = append(converted, gateway.InputPart{ + Type: gateway.InputPartTypeImage, + Media: &gateway.Media{ + URI: url, + }, + }) + } + } + } + } + return converted +} + // convertRuntimePlanTodoItem 将 session 计划中的 legacy todo 项映射为 gateway 展示结构。 func convertRuntimePlanTodoItem(item agentsession.TodoItem) gateway.PlanTodoItem { required := false diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index 064fc332c..0102cbc25 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -1550,6 +1550,7 @@ func TestConvertGatewayRunInputAndSessionHelpers(t *testing.T) { {Type: gateway.InputPartTypeImage, Media: nil}, {Type: gateway.InputPartTypeImage, Media: &gateway.Media{URI: " "}}, {Type: gateway.InputPartTypeImage, Media: &gateway.Media{URI: " /tmp/a.png ", MimeType: " image/png "}}, + {Type: gateway.InputPartTypeImage, Media: &gateway.Media{AssetID: " asset-1 ", MimeType: " image/webp "}}, }, Workdir: " /tmp/work ", }) @@ -1559,8 +1560,14 @@ func TestConvertGatewayRunInputAndSessionHelpers(t *testing.T) { if converted.Text != "base\ntext" { t.Fatalf("text = %q, want %q", converted.Text, "base\ntext") } - if len(converted.Images) != 1 || converted.Images[0].Path != "/tmp/a.png" { - t.Fatalf("images = %#v, want one valid image", converted.Images) + if len(converted.Images) != 2 { + t.Fatalf("images = %#v, want two valid images", converted.Images) + } + if converted.Images[0].Path != "/tmp/a.png" || converted.Images[0].MimeType != "image/png" { + t.Fatalf("local image = %#v, want normalized path/mime", converted.Images[0]) + } + if converted.Images[1].AssetID != "asset-1" || converted.Images[1].MimeType != "image/webp" { + t.Fatalf("asset image = %#v, want normalized asset_id/mime", converted.Images[1]) } if got := renderSessionMessageContent(nil); got != "" { diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index 08ca8bbdc..2f5ebb98c 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -1100,6 +1100,14 @@ func (stubRuntimePort) CreateSession(context.Context, gateway.CreateSessionInput return "", nil } +func (stubRuntimePort) SaveSessionAsset(context.Context, gateway.SaveSessionAssetInput) (gateway.SessionAssetMeta, error) { + return gateway.SessionAssetMeta{}, nil +} + +func (stubRuntimePort) OpenSessionAsset(context.Context, gateway.OpenSessionAssetInput) (gateway.OpenSessionAssetResult, error) { + return gateway.OpenSessionAssetResult{}, nil +} + func (stubRuntimePort) ListSessionTodos(context.Context, gateway.ListSessionTodosInput) (gateway.TodoSnapshot, error) { return gateway.TodoSnapshot{}, nil } diff --git a/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go b/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go index 70377cc7e..7e54f27a5 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go @@ -189,6 +189,20 @@ func (s *urlschemeIntegrationRuntimeStub) CreateSession( return strings.TrimSpace("session-review-integration"), nil } +func (s *urlschemeIntegrationRuntimeStub) SaveSessionAsset( + context.Context, + gateway.SaveSessionAssetInput, +) (gateway.SessionAssetMeta, error) { + return gateway.SessionAssetMeta{}, nil +} + +func (s *urlschemeIntegrationRuntimeStub) OpenSessionAsset( + context.Context, + gateway.OpenSessionAssetInput, +) (gateway.OpenSessionAssetResult, error) { + return gateway.OpenSessionAssetResult{}, nil +} + func (s *urlschemeIntegrationRuntimeStub) ListSessionTodos( context.Context, gateway.ListSessionTodosInput, diff --git a/internal/gateway/bootstrap_test.go b/internal/gateway/bootstrap_test.go index 69a788d34..945201d43 100644 --- a/internal/gateway/bootstrap_test.go +++ b/internal/gateway/bootstrap_test.go @@ -311,6 +311,14 @@ func (s *bootstrapRuntimeStub) CreateSession(ctx context.Context, input CreateSe return strings.TrimSpace(input.SessionID), nil } +func (s *bootstrapRuntimeStub) SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{SessionID: input.SessionID, AssetID: "asset_test", MimeType: input.MimeType}, nil +} + +func (s *bootstrapRuntimeStub) OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} + func (s *bootstrapRuntimeStub) ListCheckpoints(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) { if s != nil && s.listCheckpointsFn != nil { return s.listCheckpointsFn(ctx, input) @@ -5335,6 +5343,12 @@ func (runtimeOnlyStub) GetRuntimeSnapshot(ctx context.Context, input GetRuntimeS func (runtimeOnlyStub) CreateSession(ctx context.Context, input CreateSessionInput) (string, error) { return "", nil } +func (runtimeOnlyStub) SaveSessionAsset(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} +func (runtimeOnlyStub) OpenSessionAsset(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} func (runtimeOnlyStub) DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) { return false, nil } diff --git a/internal/gateway/contracts.go b/internal/gateway/contracts.go index 18d62a61f..43eed5168 100644 --- a/internal/gateway/contracts.go +++ b/internal/gateway/contracts.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "io" "time" "neo-code/internal/config" @@ -227,6 +228,48 @@ type CreateSessionInput struct { SessionID string } +// SessionAssetMeta 描述 Gateway 可见的会话附件元数据。 +type SessionAssetMeta struct { + // SessionID 是附件所属会话标识。 + SessionID string `json:"session_id"` + // AssetID 是附件标识。 + AssetID string `json:"asset_id"` + // MimeType 是服务端确认后的 MIME 类型。 + MimeType string `json:"mime_type"` + // Size 是附件原始字节数。 + Size int64 `json:"size"` +} + +// SaveSessionAssetInput 表示保存浏览器上传附件的下游输入。 +type SaveSessionAssetInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是附件所属会话标识。 + SessionID string + // Reader 是附件二进制内容。 + Reader io.Reader + // MimeType 是服务端探测确认后的 MIME 类型。 + MimeType string +} + +// OpenSessionAssetInput 表示读取会话附件的下游输入。 +type OpenSessionAssetInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是附件所属会话标识。 + SessionID string + // AssetID 是附件标识。 + AssetID string +} + +// OpenSessionAssetResult 表示打开会话附件后的读取结果。 +type OpenSessionAssetResult struct { + // Reader 是附件内容流,调用方负责关闭。 + Reader io.ReadCloser + // Meta 是附件元数据。 + Meta SessionAssetMeta +} + // DeleteSessionInput 表示 gateway.deleteSession 动作的下游输入。 type DeleteSessionInput struct { // SubjectID 是请求方身份主体标识。 @@ -694,6 +737,8 @@ type SessionMessage struct { Role string `json:"role"` // Content 是消息内容。 Content string `json:"content"` + // Parts 是消息的结构化多模态分片,供支持图片的客户端渲染。 + Parts []InputPart `json:"parts,omitempty"` // ToolCalls 是 assistant 发起的工具调用元数据。 ToolCalls []ToolCall `json:"tool_calls,omitempty"` // ToolCallID 是工具消息关联的调用标识。 @@ -920,6 +965,10 @@ type RuntimePort interface { GetRuntimeSnapshot(ctx context.Context, input GetRuntimeSnapshotInput) (RuntimeSnapshot, error) // CreateSession 创建并返回可用会话标识。 CreateSession(ctx context.Context, input CreateSessionInput) (string, error) + // SaveSessionAsset 保存会话附件并返回元数据。 + SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) + // OpenSessionAsset 打开会话附件供 HTTP 读取接口返回。 + OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) // DeleteSession 删除/归档指定会话。 DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) // RenameSession 重命名指定会话。 diff --git a/internal/gateway/contracts_test.go b/internal/gateway/contracts_test.go index de1ef52cb..d67e4b57a 100644 --- a/internal/gateway/contracts_test.go +++ b/internal/gateway/contracts_test.go @@ -147,6 +147,14 @@ func (s *runtimePortCompileStub) CreateSession(_ context.Context, _ CreateSessio return "", nil } +func (s *runtimePortCompileStub) SaveSessionAsset(_ context.Context, _ SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} + +func (s *runtimePortCompileStub) OpenSessionAsset(_ context.Context, _ OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} + func (s *runtimePortCompileStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { return nil, nil } diff --git a/internal/gateway/multi_workspace_runtime.go b/internal/gateway/multi_workspace_runtime.go index 33c3bf52b..a30b94cf1 100644 --- a/internal/gateway/multi_workspace_runtime.go +++ b/internal/gateway/multi_workspace_runtime.go @@ -402,6 +402,22 @@ func (m *MultiWorkspaceRuntime) CreateSession(ctx context.Context, input CreateS return port.CreateSession(ctx, input) } +func (m *MultiWorkspaceRuntime) SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + port, err := m.getPort(ctx) + if err != nil { + return SessionAssetMeta{}, err + } + return port.SaveSessionAsset(ctx, input) +} + +func (m *MultiWorkspaceRuntime) OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + port, err := m.getPort(ctx) + if err != nil { + return OpenSessionAssetResult{}, err + } + return port.OpenSessionAsset(ctx, input) +} + func (m *MultiWorkspaceRuntime) DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) { port, err := m.getPort(ctx) if err != nil { diff --git a/internal/gateway/multi_workspace_runtime_test.go b/internal/gateway/multi_workspace_runtime_test.go index f4919c7ef..4dfa084b4 100644 --- a/internal/gateway/multi_workspace_runtime_test.go +++ b/internal/gateway/multi_workspace_runtime_test.go @@ -28,6 +28,8 @@ type recordingPort struct { approvePlanCalls atomic.Int32 resolveUserCalls atomic.Int32 cancelCalls atomic.Int32 + saveAssetCalls atomic.Int32 + openAssetCalls atomic.Int32 closed atomic.Int32 closeOnce sync.Once @@ -135,6 +137,16 @@ func (p *recordingPort) CreateSession(_ context.Context, _ CreateSessionInput) ( return p.id, nil } +func (p *recordingPort) SaveSessionAsset(_ context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + p.saveAssetCalls.Add(1) + return SessionAssetMeta{SessionID: input.SessionID, AssetID: p.id, MimeType: input.MimeType}, nil +} + +func (p *recordingPort) OpenSessionAsset(_ context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + p.openAssetCalls.Add(1) + return OpenSessionAssetResult{Meta: SessionAssetMeta{SessionID: input.SessionID, AssetID: input.AssetID}}, nil +} + func (p *recordingPort) DeleteSession(_ context.Context, _ DeleteSessionInput) (bool, error) { return true, nil } @@ -783,6 +795,12 @@ func TestMultiWorkspaceRuntime_RoutingMatrix(t *testing.T) { if _, err := mw.ExecuteSystemTool(alphaCtx, ExecuteSystemToolInput{}); err != nil { t.Fatalf("ExecuteSystemTool alpha: %v", err) } + if _, err := mw.SaveSessionAsset(betaCtx, SaveSessionAssetInput{SessionID: "s-1", MimeType: "image/png"}); err != nil { + t.Fatalf("SaveSessionAsset beta: %v", err) + } + if _, err := mw.OpenSessionAsset(alphaCtx, OpenSessionAssetInput{SessionID: "s-1", AssetID: "asset-1"}); err != nil { + t.Fatalf("OpenSessionAsset alpha: %v", err) + } alphaPort := builder.portFor(alpha.Path) betaPort := builder.portFor(beta.Path) @@ -801,6 +819,12 @@ func TestMultiWorkspaceRuntime_RoutingMatrix(t *testing.T) { if got := alphaPort.executeSysCalls.Load(); got != 1 { t.Fatalf("alpha ExecuteSystemTool calls = %d, want 1", got) } + if got := betaPort.saveAssetCalls.Load(); got != 1 { + t.Fatalf("beta SaveSessionAsset calls = %d, want 1", got) + } + if got := alphaPort.openAssetCalls.Load(); got != 1 { + t.Fatalf("alpha OpenSessionAsset calls = %d, want 1", got) + } } func TestMultiWorkspaceRuntime_ListWorkspacesMatchesIndex(t *testing.T) { diff --git a/internal/gateway/network_server.go b/internal/gateway/network_server.go index cc78a2bf8..1f4e7a4fd 100644 --- a/internal/gateway/network_server.go +++ b/internal/gateway/network_server.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "os" + "path" "strconv" "strings" "sync" @@ -21,6 +22,7 @@ import ( "golang.org/x/net/websocket" "neo-code/internal/gateway/protocol" + agentsession "neo-code/internal/session" ) const ( @@ -40,6 +42,8 @@ const ( DefaultNetworkMaxStreamConnections = 128 // DefaultWSUnauthenticatedTimeout 定义 WS 未认证连接的最大等待时间。 DefaultWSUnauthenticatedTimeout = 3 * time.Second + // SessionAssetWorkspaceHeader 定义 Web 上传/读取会话附件时携带当前工作区的 HTTP Header。 + SessionAssetWorkspaceHeader = "X-NeoCode-Workspace-Hash" ) var ( @@ -367,6 +371,12 @@ func (s *NetworkServer) buildHandler(runtimePort RuntimePort) http.Handler { mux.HandleFunc("/rpc", func(writer http.ResponseWriter, request *http.Request) { s.handleRPCRequest(writer, request, runtimePort) }) + mux.HandleFunc("/api/session-assets", func(writer http.ResponseWriter, request *http.Request) { + s.handleSessionAssetUpload(writer, request, runtimePort) + }) + mux.HandleFunc("/api/session-assets/", func(writer http.ResponseWriter, request *http.Request) { + s.handleSessionAssetRead(writer, request, runtimePort) + }) mux.Handle("/ws", websocket.Server{ Handshake: func(_ *websocket.Config, request *http.Request) error { return s.validateWebSocketOrigin(request) @@ -387,6 +397,217 @@ func (s *NetworkServer) buildHandler(runtimePort RuntimePort) http.Handler { return mux } +// handleSessionAssetUpload 接收浏览器上传图片,并保存为当前会话的 session asset。 +func (s *NetworkServer) handleSessionAssetUpload(writer http.ResponseWriter, request *http.Request, runtimePort RuntimePort) { + if request.Method != http.MethodPost { + http.Error(writer, "method not allowed", http.StatusMethodNotAllowed) + return + } + subjectID, ok := s.authenticatedHTTPSubjectID(request) + if !ok { + http.Error(writer, "unauthorized", http.StatusUnauthorized) + return + } + if runtimePort == nil { + writeJSONResponse(writer, http.StatusServiceUnavailable, map[string]string{"error": "runtime unavailable"}) + return + } + + limit := agentsession.MaxSessionAssetBytes + request.Body = http.MaxBytesReader(writer, request.Body, limit+(1<<20)) + if err := request.ParseMultipartForm(limit + 4096); err != nil { + if strings.Contains(strings.ToLower(err.Error()), "too large") { + writeJSONResponse(writer, http.StatusRequestEntityTooLarge, map[string]string{"error": "asset is too large"}) + return + } + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "invalid multipart form"}) + return + } + + sessionID := strings.TrimSpace(request.FormValue("session_id")) + if sessionID == "" { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "session_id is required"}) + return + } + + file, _, err := request.FormFile("file") + if err != nil { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "file is required"}) + return + } + defer func() { + _ = file.Close() + }() + + payload, err := io.ReadAll(io.LimitReader(file, limit+1)) + if err != nil { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "read uploaded file failed"}) + return + } + if len(payload) == 0 { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "file is empty"}) + return + } + if int64(len(payload)) > limit { + writeJSONResponse(writer, http.StatusRequestEntityTooLarge, map[string]string{"error": "asset is too large"}) + return + } + + mimeType := detectAllowedUploadImageMime(payload) + if mimeType == "" { + writeJSONResponse(writer, http.StatusUnsupportedMediaType, map[string]string{"error": "unsupported image type"}) + return + } + + meta, err := runtimePort.SaveSessionAsset(sessionAssetRequestContext(request), SaveSessionAssetInput{ + SubjectID: subjectID, + SessionID: sessionID, + Reader: bytes.NewReader(payload), + MimeType: mimeType, + }) + if err != nil { + writeSessionAssetUploadHTTPError(writer, err) + return + } + writeJSONResponse(writer, http.StatusOK, meta) +} + +// handleSessionAssetRead 读取会话图片附件,供 Web 历史消息缩略图展示。 +func (s *NetworkServer) handleSessionAssetRead(writer http.ResponseWriter, request *http.Request, runtimePort RuntimePort) { + if request.Method != http.MethodGet { + http.Error(writer, "method not allowed", http.StatusMethodNotAllowed) + return + } + subjectID, ok := s.authenticatedHTTPSubjectID(request) + if !ok { + http.Error(writer, "unauthorized", http.StatusUnauthorized) + return + } + if runtimePort == nil { + writeJSONResponse(writer, http.StatusServiceUnavailable, map[string]string{"error": "runtime unavailable"}) + return + } + + sessionID, assetID, ok := parseSessionAssetPath(request.URL.Path) + if !ok { + http.NotFound(writer, request) + return + } + result, err := runtimePort.OpenSessionAsset(sessionAssetRequestContext(request), OpenSessionAssetInput{ + SubjectID: subjectID, + SessionID: sessionID, + AssetID: assetID, + }) + if err != nil { + writeSessionAssetReadHTTPError(writer, err) + return + } + defer func() { + _ = result.Reader.Close() + }() + + writer.Header().Set("Content-Type", result.Meta.MimeType) + if result.Meta.Size > 0 { + writer.Header().Set("Content-Length", strconv.FormatInt(result.Meta.Size, 10)) + } + writer.Header().Set("Cache-Control", "private, max-age=300") + _, _ = io.Copy(writer, result.Reader) +} + +// sessionAssetRequestContext 将 HTTP Header 中的工作区哈希注入请求上下文,供多工作区 Runtime 路由。 +func sessionAssetRequestContext(request *http.Request) context.Context { + if request == nil { + return context.Background() + } + workspaceHash := strings.TrimSpace(request.Header.Get(SessionAssetWorkspaceHeader)) + if workspaceHash == "" { + return request.Context() + } + state := NewConnectionWorkspaceState() + state.SetWorkspaceHash(workspaceHash) + return WithConnectionWorkspaceState(request.Context(), state) +} + +// authenticatedHTTPSubjectID 校验 HTTP Bearer Token 并返回主体标识。 +func (s *NetworkServer) authenticatedHTTPSubjectID(request *http.Request) (string, bool) { + if s.authenticator == nil { + return "", false + } + token := extractBearerToken(request.Header.Get("Authorization")) + subjectID, ok := s.authenticator.ResolveSubjectID(token) + if !ok || strings.TrimSpace(subjectID) == "" { + return "", false + } + return strings.TrimSpace(subjectID), true +} + +// detectAllowedUploadImageMime 用文件头确认上传图片类型,只允许 PNG/JPEG/WebP。 +func detectAllowedUploadImageMime(payload []byte) string { + if len(payload) == 0 { + return "" + } + probe := payload + if len(probe) > 512 { + probe = probe[:512] + } + mimeType := strings.ToLower(strings.TrimSpace(http.DetectContentType(probe))) + switch mimeType { + case "image/png", "image/jpeg", "image/webp": + return mimeType + default: + return "" + } +} + +// parseSessionAssetPath 从 /api/session-assets/{session_id}/{asset_id} 提取路径参数。 +func parseSessionAssetPath(rawPath string) (string, string, bool) { + cleanPath := path.Clean("/" + strings.TrimSpace(rawPath)) + const prefix = "/api/session-assets/" + if !strings.HasPrefix(cleanPath, prefix) { + return "", "", false + } + parts := strings.Split(strings.TrimPrefix(cleanPath, prefix), "/") + if len(parts) != 2 { + return "", "", false + } + sessionID := strings.TrimSpace(parts[0]) + assetID := strings.TrimSpace(parts[1]) + return sessionID, assetID, sessionID != "" && assetID != "" +} + +// writeSessionAssetUploadHTTPError 将上传阶段的下游错误映射为明确 HTTP 状态。 +func writeSessionAssetUploadHTTPError(writer http.ResponseWriter, err error) { + writeSessionAssetHTTPError(writer, err, "session not found") +} + +// writeSessionAssetReadHTTPError 将读取阶段的下游错误映射为明确 HTTP 状态。 +func writeSessionAssetReadHTTPError(writer http.ResponseWriter, err error) { + writeSessionAssetHTTPError(writer, err, "asset not found") +} + +// writeSessionAssetHTTPError 将下游附件错误映射为明确 HTTP 状态。 +func writeSessionAssetHTTPError(writer http.ResponseWriter, err error, notFoundMessage string) { + if err == nil { + writeJSONResponse(writer, http.StatusInternalServerError, map[string]string{"error": "unknown asset error"}) + return + } + message := strings.ToLower(err.Error()) + switch { + case strings.Contains(message, "workspace") && strings.Contains(message, "not found"): + writeJSONResponse(writer, http.StatusNotFound, map[string]string{"error": "workspace not found"}) + case errors.Is(err, os.ErrNotExist) || errors.Is(err, ErrRuntimeResourceNotFound): + writeJSONResponse(writer, http.StatusNotFound, map[string]string{"error": notFoundMessage}) + case strings.Contains(message, "asset size exceeds"): + writeJSONResponse(writer, http.StatusRequestEntityTooLarge, map[string]string{"error": err.Error()}) + case strings.Contains(message, "unsupported") || strings.Contains(message, "not an image"): + writeJSONResponse(writer, http.StatusUnsupportedMediaType, map[string]string{"error": err.Error()}) + case strings.Contains(message, "access denied"): + writeJSONResponse(writer, http.StatusForbidden, map[string]string{"error": "access denied"}) + default: + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": err.Error()}) + } +} + // withCORS 为网络入口注入 CORS 头,仅对白名单 Origin 回显允许值。 // WebSocket 升级请求不受 CORS 约束,直接放行交予 WS 握手阶段的 Origin 校验。 func (s *NetworkServer) withCORS(next http.Handler) http.Handler { @@ -406,7 +627,7 @@ func (s *NetworkServer) withCORS(next http.Handler) http.Handler { } writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, "+SessionAssetWorkspaceHeader) if request.Method == http.MethodOptions { writer.WriteHeader(http.StatusNoContent) return diff --git a/internal/gateway/network_server_test.go b/internal/gateway/network_server_test.go index 96301525c..e0257d33e 100644 --- a/internal/gateway/network_server_test.go +++ b/internal/gateway/network_server_test.go @@ -2,10 +2,13 @@ package gateway import ( "bufio" + "bytes" "context" "encoding/json" + "fmt" "io" "log" + "mime/multipart" "net/http" "net/http/httptest" "strings" @@ -15,6 +18,7 @@ import ( "golang.org/x/net/websocket" "neo-code/internal/gateway/protocol" + agentsession "neo-code/internal/session" ) func TestResolveNetworkListenAddress(t *testing.T) { @@ -400,6 +404,253 @@ func TestNetworkServerRPCErrorBranches(t *testing.T) { }) } +func TestNetworkServerSessionAssetUploadAndRead(t *testing.T) { + payload := gatewayMinimalPNGBytes() + var capturedUpload SaveSessionAssetInput + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(_ context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + capturedUpload = input + got, err := io.ReadAll(input.Reader) + if err != nil { + t.Fatalf("read uploaded asset: %v", err) + } + if !bytes.Equal(got, payload) { + t.Fatalf("uploaded payload mismatch") + } + return SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: "asset-1", + MimeType: input.MimeType, + Size: int64(len(got)), + }, nil + }, + openAssetFn: func(_ context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + if input.SubjectID != "local_admin" || input.SessionID != "session-1" || input.AssetID != "asset-1" { + t.Fatalf("open input = %+v, want subject/session/asset", input) + } + return OpenSessionAssetResult{ + Reader: io.NopCloser(bytes.NewReader(payload)), + Meta: SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: input.AssetID, + MimeType: "image/png", + Size: int64(len(payload)), + }, + }, nil + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + uploadRequest := newSessionAssetUploadRequest(t, "session-1", "a.png", payload) + uploadRequest.Header.Set("Authorization", "Bearer gateway-token") + uploadRecorder := httptest.NewRecorder() + handler.ServeHTTP(uploadRecorder, uploadRequest) + if uploadRecorder.Code != http.StatusOK { + t.Fatalf("upload status = %d body=%s", uploadRecorder.Code, uploadRecorder.Body.String()) + } + var uploadResponse SessionAssetMeta + if err := json.Unmarshal(uploadRecorder.Body.Bytes(), &uploadResponse); err != nil { + t.Fatalf("decode upload response: %v", err) + } + if uploadResponse.AssetID != "asset-1" || uploadResponse.MimeType != "image/png" || uploadResponse.Size != int64(len(payload)) { + t.Fatalf("upload response = %+v", uploadResponse) + } + if capturedUpload.SubjectID != "local_admin" || capturedUpload.SessionID != "session-1" || capturedUpload.MimeType != "image/png" { + t.Fatalf("captured upload = %+v", capturedUpload) + } + + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusOK { + t.Fatalf("read status = %d body=%s", readRecorder.Code, readRecorder.Body.String()) + } + if got := readRecorder.Header().Get("Content-Type"); got != "image/png" { + t.Fatalf("read content-type = %q, want image/png", got) + } + if !bytes.Equal(readRecorder.Body.Bytes(), payload) { + t.Fatalf("read payload mismatch") + } +} + +func TestNetworkServerSessionAssetWorkspaceHeader(t *testing.T) { + payload := gatewayMinimalPNGBytes() + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + if got := WorkspaceHashFromContext(ctx); got != "workspace-b" { + t.Fatalf("upload workspace hash = %q, want workspace-b", got) + } + return SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: "asset-1", + MimeType: input.MimeType, + Size: int64(len(payload)), + }, nil + }, + openAssetFn: func(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + if got := WorkspaceHashFromContext(ctx); got != "workspace-b" { + t.Fatalf("read workspace hash = %q, want workspace-b", got) + } + return OpenSessionAssetResult{ + Reader: io.NopCloser(bytes.NewReader(payload)), + Meta: SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: input.AssetID, + MimeType: "image/png", + Size: int64(len(payload)), + }, + }, nil + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + uploadRequest := newSessionAssetUploadRequest(t, "session-1", "a.png", payload) + uploadRequest.Header.Set("Authorization", "Bearer gateway-token") + uploadRequest.Header.Set(SessionAssetWorkspaceHeader, "workspace-b") + uploadRecorder := httptest.NewRecorder() + handler.ServeHTTP(uploadRecorder, uploadRequest) + if uploadRecorder.Code != http.StatusOK { + t.Fatalf("upload status = %d body=%s", uploadRecorder.Code, uploadRecorder.Body.String()) + } + + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRequest.Header.Set(SessionAssetWorkspaceHeader, "workspace-b") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusOK { + t.Fatalf("read status = %d body=%s", readRecorder.Code, readRecorder.Body.String()) + } +} + +func TestNetworkServerSessionAssetWorkspaceHeaderEmptyFallback(t *testing.T) { + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + if got := WorkspaceHashFromContext(ctx); got != "" { + t.Fatalf("workspace hash = %q, want empty fallback", got) + } + return SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: "asset-1", + MimeType: input.MimeType, + Size: int64(len(gatewayMinimalPNGBytes())), + }, nil + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d body=%s", recorder.Code, recorder.Body.String()) + } +} + +func TestNetworkServerSessionAssetUploadErrors(t *testing.T) { + runtimePort := &runtimePortEventStub{} + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.withCORS(server.buildHandler(runtimePort)) + + t.Run("unauthorized", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnauthorized) + } + }) + + t.Run("forbidden origin", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + request.Header.Set("Authorization", "Bearer gateway-token") + request.Header.Set("Origin", "http://evil.example") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusForbidden) + } + }) + + t.Run("non image", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "bad.txt", []byte("not an image")) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusUnsupportedMediaType { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnsupportedMediaType) + } + }) + + t.Run("empty file", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "empty.png", nil) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusBadRequest) + } + }) + + t.Run("oversized file", func(t *testing.T) { + request := newSessionAssetUploadRequest( + t, + "session-1", + "huge.png", + bytes.Repeat([]byte{0}, int(agentsession.MaxSessionAssetBytes)+1), + ) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusRequestEntityTooLarge) + } + }) + + t.Run("workspace not found", func(t *testing.T) { + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, fmt.Errorf("%w: workspace missing not found", ErrRuntimeResourceNotFound) + }, + } + handler := server.withCORS(server.buildHandler(runtimePort)) + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + request.Header.Set("Authorization", "Bearer gateway-token") + request.Header.Set(SessionAssetWorkspaceHeader, "missing") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusNotFound) + } + if !strings.Contains(recorder.Body.String(), "workspace not found") { + t.Fatalf("body = %s, want workspace not found", recorder.Body.String()) + } + }) +} + +func TestNetworkServerSessionAssetReadNotFound(t *testing.T) { + runtimePort := &runtimePortEventStub{ + openAssetFn: func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, ErrRuntimeResourceNotFound + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + request := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/missing", nil) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusNotFound) + } +} + func TestNetworkServerWebSocketAndSSEPing(t *testing.T) { server := newTestNetworkServer(t, NetworkServerOptions{}) testContext, cancel := context.WithCancel(context.Background()) @@ -1322,6 +1573,45 @@ type noFlushResponseWriter struct { body strings.Builder } +func newSessionAssetUploadRequest(t *testing.T, sessionID, fileName string, payload []byte) *http.Request { + t.Helper() + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + if sessionID != "" { + if err := writer.WriteField("session_id", sessionID); err != nil { + t.Fatalf("write session_id field: %v", err) + } + } + part, err := writer.CreateFormFile("file", fileName) + if err != nil { + t.Fatalf("create file part: %v", err) + } + if _, err := part.Write(payload); err != nil { + t.Fatalf("write file part: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("close multipart writer: %v", err) + } + request := httptest.NewRequest(http.MethodPost, "/api/session-assets", &body) + request.Header.Set("Content-Type", writer.FormDataContentType()) + return request +} + +func gatewayMinimalPNGBytes() []byte { + return []byte{ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x06, 0x00, 0x00, 0x00, 0x1f, 0x15, 0xc4, + 0x89, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x44, 0x41, + 0x54, 0x78, 0x9c, 0x63, 0xf8, 0xcf, 0xc0, 0x00, + 0x00, 0x03, 0x01, 0x01, 0x00, 0xc9, 0xfe, 0x92, + 0xef, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, + 0x44, 0xae, 0x42, 0x60, 0x82, + } +} + type staticTokenAuthenticator struct { token string } diff --git a/internal/gateway/protocol/jsonrpc.go b/internal/gateway/protocol/jsonrpc.go index ac41d62aa..71ab1d6ed 100644 --- a/internal/gateway/protocol/jsonrpc.go +++ b/internal/gateway/protocol/jsonrpc.go @@ -250,6 +250,7 @@ const ( // RunInputMedia 用于承载 gateway.run 中图片分片的媒体元数据。 type RunInputMedia struct { URI string `json:"uri"` + AssetID string `json:"asset_id,omitempty"` MimeType string `json:"mime_type"` FileName string `json:"file_name,omitempty"` } @@ -1402,6 +1403,7 @@ func decodeRunParams(raw json.RawMessage) (RunParams, *JSONRPCError) { p.InputParts[i].Text = strings.TrimSpace(p.InputParts[i].Text) if m := p.InputParts[i].Media; m != nil { m.URI = strings.TrimSpace(m.URI) + m.AssetID = strings.TrimSpace(m.AssetID) m.MimeType = strings.TrimSpace(m.MimeType) m.FileName = strings.TrimSpace(m.FileName) } diff --git a/internal/gateway/protocol/jsonrpc_test.go b/internal/gateway/protocol/jsonrpc_test.go index 3a7f7da48..c64211e6c 100644 --- a/internal/gateway/protocol/jsonrpc_test.go +++ b/internal/gateway/protocol/jsonrpc_test.go @@ -393,7 +393,8 @@ func TestNormalizeJSONRPCRequestRuntimeMethods(t *testing.T) { "workdir":" /tmp/work ", "input_parts":[ {"type":" TEXT ","text":" world "}, - {"type":" image ","media":{"uri":" /tmp/a.png ","mime_type":" image/png ","file_name":" a.png "}} + {"type":" image ","media":{"uri":" /tmp/a.png ","mime_type":" image/png ","file_name":" a.png "}}, + {"type":" image ","media":{"asset_id":" asset-1 ","mime_type":" image/webp "}} ] }`), } @@ -414,8 +415,8 @@ func TestNormalizeJSONRPCRequestRuntimeMethods(t *testing.T) { if runParams.InputText != "hello" { t.Fatalf("run input_text = %q, want %q", runParams.InputText, "hello") } - if len(runParams.InputParts) != 2 { - t.Fatalf("run input_parts len = %d, want 2", len(runParams.InputParts)) + if len(runParams.InputParts) != 3 { + t.Fatalf("run input_parts len = %d, want 3", len(runParams.InputParts)) } if runParams.InputParts[0].Type != "text" || runParams.InputParts[0].Text != "world" { t.Fatalf("run text part = %#v, want normalized text part", runParams.InputParts[0]) @@ -426,6 +427,12 @@ func TestNormalizeJSONRPCRequestRuntimeMethods(t *testing.T) { if runParams.InputParts[1].Media.MimeType != "image/png" || runParams.InputParts[1].Media.FileName != "a.png" { t.Fatalf("run image media = %#v, want trimmed mime/file_name", runParams.InputParts[1].Media) } + if runParams.InputParts[2].Type != "image" || + runParams.InputParts[2].Media == nil || + runParams.InputParts[2].Media.AssetID != "asset-1" || + runParams.InputParts[2].Media.MimeType != "image/webp" { + t.Fatalf("run image asset media = %#v, want trimmed asset_id/mime", runParams.InputParts[2]) + } compactNormalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ JSONRPC: JSONRPCVersion, diff --git a/internal/gateway/rpc_dispatch.go b/internal/gateway/rpc_dispatch.go index 3b5f38ae0..33639ff66 100644 --- a/internal/gateway/rpc_dispatch.go +++ b/internal/gateway/rpc_dispatch.go @@ -344,6 +344,7 @@ func convertProtocolRunInputParts(parts []protocol.RunInputPart) []InputPart { if part.Media != nil { convertedPart.Media = &Media{ URI: strings.TrimSpace(part.Media.URI), + AssetID: strings.TrimSpace(part.Media.AssetID), MimeType: strings.TrimSpace(part.Media.MimeType), FileName: strings.TrimSpace(part.Media.FileName), } diff --git a/internal/gateway/rpc_dispatch_test.go b/internal/gateway/rpc_dispatch_test.go index 1994e3e51..737851253 100644 --- a/internal/gateway/rpc_dispatch_test.go +++ b/internal/gateway/rpc_dispatch_test.go @@ -234,6 +234,14 @@ func (s *rpcRunCaptureRuntimeStub) CreateSession(ctx context.Context, input Crea return s.createSessionID, nil } +func (s *rpcRunCaptureRuntimeStub) SaveSessionAsset(_ context.Context, _ SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} + +func (s *rpcRunCaptureRuntimeStub) OpenSessionAsset(_ context.Context, _ OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} + func (s *rpcRunCaptureRuntimeStub) ListSessionTodos(_ context.Context, _ ListSessionTodosInput) (TodoSnapshot, error) { return TodoSnapshot{}, nil } @@ -1130,6 +1138,12 @@ func (s *runtimePortOnlyStub) GetRuntimeSnapshot(_ context.Context, _ GetRuntime func (s *runtimePortOnlyStub) CreateSession(_ context.Context, _ CreateSessionInput) (string, error) { return "", nil } +func (s *runtimePortOnlyStub) SaveSessionAsset(_ context.Context, _ SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} +func (s *runtimePortOnlyStub) OpenSessionAsset(_ context.Context, _ OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} func (s *runtimePortOnlyStub) DeleteSession(_ context.Context, _ DeleteSessionInput) (bool, error) { return false, nil } @@ -1208,7 +1222,8 @@ func TestDispatchRPCRequestRunHydratesInputPartsAndFallbackRunID(t *testing.T) { "session_id":"session-run-1", "input_parts":[ {"type":"text","text":"hello world"}, - {"type":"image","media":{"uri":"C:/tmp/pic.png","mime_type":"image/png"}} + {"type":"image","media":{"uri":"C:/tmp/pic.png","mime_type":"image/png"}}, + {"type":"image","media":{"asset_id":"asset-1","mime_type":"image/webp"}} ] }`), }, runtimeStub) @@ -1229,8 +1244,8 @@ func TestDispatchRPCRequestRunHydratesInputPartsAndFallbackRunID(t *testing.T) { if captured.RunID != "req-run-hydrate" { t.Fatalf("runtime run run_id = %q, want %q", captured.RunID, "req-run-hydrate") } - if len(captured.InputParts) != 2 { - t.Fatalf("runtime run input_parts len = %d, want %d", len(captured.InputParts), 2) + if len(captured.InputParts) != 3 { + t.Fatalf("runtime run input_parts len = %d, want %d", len(captured.InputParts), 3) } if captured.InputParts[0].Type != InputPartTypeText { t.Fatalf("runtime text part type = %q, want %q", captured.InputParts[0].Type, InputPartTypeText) @@ -1241,6 +1256,11 @@ func TestDispatchRPCRequestRunHydratesInputPartsAndFallbackRunID(t *testing.T) { if captured.InputParts[1].Media == nil || captured.InputParts[1].Media.URI != "C:/tmp/pic.png" { t.Fatalf("runtime image media = %#v, want uri %q", captured.InputParts[1].Media, "C:/tmp/pic.png") } + if captured.InputParts[2].Media == nil || + captured.InputParts[2].Media.AssetID != "asset-1" || + captured.InputParts[2].Media.MimeType != "image/webp" { + t.Fatalf("runtime image asset media = %#v, want asset_id", captured.InputParts[2].Media) + } } func TestDispatchRPCRequest_DenyCrossSubjectLoadSession(t *testing.T) { diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go index a61d8eca9..261a37a70 100644 --- a/internal/gateway/server_test.go +++ b/internal/gateway/server_test.go @@ -367,7 +367,9 @@ func TestServerHandleConnectionAuthenticateFlow(t *testing.T) { } type runtimePortEventStub struct { - events <-chan RuntimeEvent + events <-chan RuntimeEvent + saveAssetFn func(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) + openAssetFn func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) } func (s *runtimePortEventStub) Run(_ context.Context, _ RunInput) error { @@ -467,6 +469,20 @@ func (s *runtimePortEventStub) CreateSession(_ context.Context, _ CreateSessionI return "", nil } +func (s *runtimePortEventStub) SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + if s.saveAssetFn != nil { + return s.saveAssetFn(ctx, input) + } + return SessionAssetMeta{}, nil +} + +func (s *runtimePortEventStub) OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + if s.openAssetFn != nil { + return s.openAssetFn(ctx, input) + } + return OpenSessionAssetResult{}, nil +} + func (s *runtimePortEventStub) ListSessionTodos(_ context.Context, _ ListSessionTodosInput) (TodoSnapshot, error) { return TodoSnapshot{}, nil } diff --git a/internal/gateway/static_files.go b/internal/gateway/static_files.go index 936f7c78f..12aeb4c17 100644 --- a/internal/gateway/static_files.go +++ b/internal/gateway/static_files.go @@ -16,6 +16,7 @@ var knownAPIPrefixes = map[string]bool{ "/healthz": true, "/version": true, "/rpc": true, + "/api": true, "/ws": true, "/sse": true, "/metrics": true, diff --git a/internal/gateway/types.go b/internal/gateway/types.go index 16d207394..e2a98a938 100644 --- a/internal/gateway/types.go +++ b/internal/gateway/types.go @@ -136,6 +136,8 @@ const ( type Media struct { // URI 是媒体资源地址。 URI string `json:"uri"` + // AssetID 是已保存的 session asset 标识。 + AssetID string `json:"asset_id,omitempty"` // MimeType 是媒体 MIME 类型。 MimeType string `json:"mime_type"` // FileName 是媒体文件名。 diff --git a/internal/gateway/validate.go b/internal/gateway/validate.go index 985ee96e3..4684a323c 100644 --- a/internal/gateway/validate.go +++ b/internal/gateway/validate.go @@ -585,8 +585,10 @@ func validateInputPart(part InputPart, index int) *FrameError { if part.Media == nil { return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires media") } - if strings.TrimSpace(part.Media.URI) == "" { - return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires media.uri") + hasURI := strings.TrimSpace(part.Media.URI) != "" + hasAssetID := strings.TrimSpace(part.Media.AssetID) != "" + if hasURI == hasAssetID { + return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires exactly one of media.uri or media.asset_id") } if strings.TrimSpace(part.Media.MimeType) == "" { return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires media.mime_type") diff --git a/internal/gateway/validate_test.go b/internal/gateway/validate_test.go index a2db95066..958e46b07 100644 --- a/internal/gateway/validate_test.go +++ b/internal/gateway/validate_test.go @@ -829,6 +829,23 @@ func TestValidateFrame_MultimodalPayloadRules(t *testing.T) { }, wantNil: true, }, + { + name: "valid image asset part", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRun, + InputParts: []InputPart{ + { + Type: InputPartTypeImage, + Media: &Media{ + AssetID: "asset-1", + MimeType: "image/png", + }, + }, + }, + }, + wantNil: true, + }, { name: "text part with empty text", frame: MessageFrame{ @@ -852,7 +869,7 @@ func TestValidateFrame_MultimodalPayloadRules(t *testing.T) { wantCode: ErrorCodeInvalidMultimodalPayload.String(), }, { - name: "image part missing media.uri", + name: "image part missing media.uri and media.asset_id", frame: MessageFrame{ Type: FrameTypeRequest, Action: FrameActionRun, @@ -865,6 +882,24 @@ func TestValidateFrame_MultimodalPayloadRules(t *testing.T) { }, wantCode: ErrorCodeInvalidMultimodalPayload.String(), }, + { + name: "image part has both media.uri and media.asset_id", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRun, + InputParts: []InputPart{ + { + Type: InputPartTypeImage, + Media: &Media{ + URI: "file:///a.png", + AssetID: "asset-1", + MimeType: "image/png", + }, + }, + }, + }, + wantCode: ErrorCodeInvalidMultimodalPayload.String(), + }, { name: "image part missing media.mime_type", frame: MessageFrame{ diff --git a/internal/runtime/input_prepare.go b/internal/runtime/input_prepare.go index 0752a370e..8b9e70827 100644 --- a/internal/runtime/input_prepare.go +++ b/internal/runtime/input_prepare.go @@ -148,6 +148,7 @@ func (p sessionInputPreparer) Prepare( for _, image := range input.Images { sessionImages = append(sessionImages, agentsession.PrepareImageInput{ Path: strings.TrimSpace(image.Path), + AssetID: strings.TrimSpace(image.AssetID), MimeType: strings.TrimSpace(image.MimeType), }) } diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 2761784e8..326fd415f 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -75,6 +75,7 @@ type UserInput struct { // UserImageInput 表示用户输入中附带的单个图片引用(路径 + MIME)。 type UserImageInput struct { Path string + AssetID string MimeType string } diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go index 76b8dc6df..f5b9d9bab 100644 --- a/internal/session/input_preparer.go +++ b/internal/session/input_preparer.go @@ -20,6 +20,7 @@ const defaultSessionTitle = "New Session" // PrepareImageInput 表示一次用户输入中附带的本地图片引用。 type PrepareImageInput struct { Path string + AssetID string MimeType string } @@ -128,6 +129,32 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar savedAssets := make([]AssetMeta, 0, len(input.Images)) for index, image := range input.Images { path := strings.TrimSpace(image.Path) + assetID := strings.TrimSpace(image.AssetID) + if assetID != "" { + if path != "" { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, &AssetSaveError{ + SessionID: session.ID, + Index: index, + Path: path, + Err: fmt.Errorf("image input cannot contain both path and asset id"), + } + } + meta, err := p.referenceImageAsset(ctx, session.ID, assetID, image.MimeType) + if err != nil { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, &AssetSaveError{ + SessionID: session.ID, + Index: index, + Path: assetID, + Err: err, + } + } + parts = append(parts, providertypes.NewSessionAssetImagePart(meta.ID, meta.MimeType)) + continue + } if path == "" { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) p.cleanupSavedAssets(ctx, session.ID, savedAssets) @@ -220,6 +247,38 @@ func (p *InputPreparer) saveImageAsset( return meta, nil } +// referenceImageAsset 校验已保存附件属于当前会话,并返回可进入 provider 的图片元数据。 +func (p *InputPreparer) referenceImageAsset( + ctx context.Context, + sessionID string, + assetID string, + mimeType string, +) (AssetMeta, error) { + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + if p.assetStore == nil { + return AssetMeta{}, fmt.Errorf("session: asset store is not configured") + } + normalizedAssetID := strings.TrimSpace(assetID) + if normalizedAssetID == "" { + return AssetMeta{}, fmt.Errorf("image asset id is empty") + } + + meta, err := p.assetStore.Stat(ctx, sessionID, normalizedAssetID) + if err != nil { + return AssetMeta{}, fmt.Errorf("stat image asset: %w", err) + } + if !strings.HasPrefix(strings.ToLower(strings.TrimSpace(meta.MimeType)), "image/") { + return AssetMeta{}, fmt.Errorf("asset %q is not an image", normalizedAssetID) + } + declaredMime := normalizeMimeType(mimeType) + if declaredMime != "" && declaredMime != meta.MimeType { + return AssetMeta{}, fmt.Errorf("declared mime type %q mismatches saved asset %q", declaredMime, meta.MimeType) + } + return meta, nil +} + // resolveImageMimeType 解析图片 MIME 类型,仅允许 image/*,并要求声明值与文件头探测一致。 func resolveImageMimeType(ctx context.Context, path string, declared string, file *os.File) (string, error) { if err := ctx.Err(); err != nil { diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go index d45527799..5d7fb334b 100644 --- a/internal/session/input_preparer_test.go +++ b/internal/session/input_preparer_test.go @@ -1,6 +1,7 @@ package session import ( + "bytes" "context" "errors" "io" @@ -94,6 +95,46 @@ func TestInputPreparerPrepareTextAndImage(t *testing.T) { } } +func TestInputPreparerPrepareSavedAssetReference(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := newInputPreparerTestStore(t, workdir) + session := NewWithWorkdir("existing", workdir) + if err := createSessionForPreparerTest(context.Background(), store, session); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + meta, err := store.SaveAsset(context.Background(), session.ID, bytes.NewReader(minimalPNGBytes()), "image/png") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + result, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: session.ID, + Text: "describe it", + Images: []PrepareImageInput{{AssetID: meta.ID, MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 0 { + t.Fatalf("expected no newly saved assets, got %+v", result.SavedAssets) + } + if len(result.Parts) != 2 { + t.Fatalf("expected text and image parts, got %+v", result.Parts) + } + imagePart := result.Parts[1] + if imagePart.Kind != providertypes.ContentPartImage || + imagePart.Image == nil || + imagePart.Image.Asset == nil || + imagePart.Image.Asset.ID != meta.ID || + imagePart.Image.Asset.MimeType != "image/png" { + t.Fatalf("unexpected image part: %+v", imagePart) + } +} + func TestInputPreparerPrepareImageInfersMimeWhenMissing(t *testing.T) { t.Parallel() @@ -185,6 +226,51 @@ func TestInputPreparerPrepareErrors(t *testing.T) { } }) + t.Run("missing image reference is rejected", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "bad asset", + Images: []PrepareImageInput{{AssetID: " ", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected missing image reference error") + } + if !strings.Contains(err.Error(), "image path is empty") { + t.Fatalf("expected image reference error, got %v", err) + } + }) + + t.Run("missing referenced asset is rejected", func(t *testing.T) { + localStore := newInputPreparerTestStore(t, workdir) + existing := NewWithWorkdir("asset-missing", workdir) + if err := createSessionForPreparerTest(context.Background(), localStore, existing); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + preparer := NewInputPreparer(localStore, localStore) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: existing.ID, + Text: "bad asset", + Images: []PrepareImageInput{{AssetID: "asset-missing", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected missing referenced asset error") + } + }) + + t.Run("asset id and path cannot both be set", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "bad asset", + Images: []PrepareImageInput{{Path: "a.png", AssetID: "asset-1", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected asset id and path conflict error") + } + }) + t.Run("asset save error is structured", func(t *testing.T) { preparer := NewInputPreparer(store, store) _, err := preparer.Prepare(context.Background(), PrepareInput{ diff --git a/web/src/api/gateway.test.ts b/web/src/api/gateway.test.ts index bb4cf0eec..f022a7405 100644 --- a/web/src/api/gateway.test.ts +++ b/web/src/api/gateway.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, vi, beforeEach } from 'vitest' +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' import { GatewayAPI } from './gateway' import { Method } from './protocol' @@ -13,6 +13,10 @@ describe('GatewayAPI', () => { api = new GatewayAPI(ws) }) + afterEach(() => { + vi.unstubAllGlobals() + }) + it('maps authenticate and run methods', async () => { await api.authenticate('tok') await api.run({ input_text: 'hello' }) @@ -21,6 +25,14 @@ describe('GatewayAPI', () => { expect(call).toHaveBeenNthCalledWith(2, Method.Run, { input_text: 'hello' }) }) + it('maps createSession method', async () => { + await api.createSession() + await api.createSession('s1') + + expect(call).toHaveBeenNthCalledWith(1, Method.CreateSession, {}) + expect(call).toHaveBeenNthCalledWith(2, Method.CreateSession, { session_id: 's1' }) + }) + it('maps optional session_id in listModels', async () => { await api.listModels() await api.listModels('s1') @@ -60,5 +72,68 @@ describe('GatewayAPI', () => { expect(call).toHaveBeenNthCalledWith(2, Method.ApprovePlan, { session_id: 's1', plan_id: 'p1', revision: 2 }) expect(call).toHaveBeenNthCalledWith(3, Method.UserQuestionAnswer, { request_id: 'q1', status: 'answered', message: 'ok' }) }) + + it('uploads session assets with bearer auth, workspace header, and multipart body', async () => { + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ session_id: 's1', asset_id: 'asset-1', mime_type: 'image/png', size: 3 }), + }) + vi.stubGlobal('fetch', fetchMock) + api = new GatewayAPI(ws, 'http://localhost:1455/', ' token-1 ') + + const file = new File(['abc'], 'a.png', { type: 'image/png' }) + const result = await api.uploadSessionAsset('s1', file, 'workspace-b') + + expect(result.asset_id).toBe('asset-1') + expect(fetchMock).toHaveBeenCalledWith('http://localhost:1455/api/session-assets', expect.objectContaining({ + method: 'POST', + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-b' }, + })) + const init = fetchMock.mock.calls[0][1] as RequestInit + expect(init.body).toBeInstanceOf(FormData) + expect((init.body as FormData).get('session_id')).toBe('s1') + expect((init.body as FormData).get('file')).toBe(file) + }) + + it('fetches session asset blobs with bearer auth and workspace header', async () => { + const blob = new Blob(['img'], { type: 'image/png' }) + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + blob: () => Promise.resolve(blob), + }) + vi.stubGlobal('fetch', fetchMock) + api = new GatewayAPI(ws, '/gateway', 'token-1') + + await expect(api.fetchSessionAsset('s 1', 'asset/1', 'workspace-b')).resolves.toBe(blob) + expect(fetchMock).toHaveBeenCalledWith('/gateway/api/session-assets/s%201/asset%2F1', { + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-b' }, + }) + }) + + it('uses switched workspace as session asset HTTP fallback', async () => { + call.mockResolvedValueOnce({ type: 'ack', payload: { workspace_hash: 'workspace-c' } }) + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + blob: () => Promise.resolve(new Blob(['img'])), + }) + vi.stubGlobal('fetch', fetchMock) + api = new GatewayAPI(ws, '', 'token-1') + + await api.switchWorkspace('workspace-c') + await api.fetchSessionAsset('s1', 'asset-1') + + expect(fetchMock).toHaveBeenCalledWith('/api/session-assets/s1/asset-1', { + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-c' }, + }) + }) + + it('surfaces session asset HTTP errors', async () => { + vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ + ok: false, + status: 415, + json: () => Promise.resolve({ error: 'unsupported image type' }), + })) + await expect(api.uploadSessionAsset('s1', new File(['x'], 'x.txt'))).rejects.toThrow('unsupported image type') + }) }) diff --git a/web/src/api/gateway.ts b/web/src/api/gateway.ts index 18357270b..45ab4f313 100644 --- a/web/src/api/gateway.ts +++ b/web/src/api/gateway.ts @@ -5,6 +5,8 @@ import { type AuthenticateParams, type BindStreamParams, type RunParams, + type CreateSessionParams, + type CreateSessionResult, type CancelParams, type LoadSessionParams, type ListSessionTodosParams, @@ -73,14 +75,20 @@ import { type RenameWorkspaceResult, type DeleteWorkspaceParams, type DeleteWorkspaceResult, + type SessionAssetUploadResult, } from './protocol' /** Gateway 业务 API 客户端,基于 WebSocket 全双工通道 */ export class GatewayAPI { private ws: WSClient + private baseURL: string + private token: string + private currentWorkspaceHash = '' - constructor(ws: WSClient) { + constructor(ws: WSClient, baseURL = '', token = '') { this.ws = ws + this.baseURL = baseURL.replace(/\/+$/, '') + this.token = token.trim() } /** 认证,返回 ack 结果 */ @@ -93,11 +101,45 @@ export class GatewayAPI { return this.ws.call(Method.BindStream, params) } + /** 显式创建一个会话,供发送图片前建立 asset 归属 */ + async createSession(sessionId?: string) { + const params: CreateSessionParams = sessionId ? { session_id: sessionId } : {} + return this.ws.call(Method.CreateSession, params) + } + /** 发起一次 run,返回 ack 含 session_id 和 run_id */ async run(params: RunParams) { return this.ws.call(Method.Run, params) } + /** 上传会话图片附件,返回可在 input_parts 中引用的 asset_id */ + async uploadSessionAsset(sessionId: string, file: File, workspaceHash = '') { + const form = new FormData() + form.append('session_id', sessionId) + form.append('file', file) + const res = await fetch(`${this.baseURL}/api/session-assets`, { + method: 'POST', + headers: this.httpHeaders(workspaceHash), + body: form, + }) + if (!res.ok) { + throw new Error(await readHTTPError(res, 'Upload failed')) + } + return res.json() as Promise + } + + /** 读取会话图片附件 Blob,用于历史消息缩略图 */ + async fetchSessionAsset(sessionId: string, assetId: string, workspaceHash = '') { + const res = await fetch( + `${this.baseURL}/api/session-assets/${encodeURIComponent(sessionId)}/${encodeURIComponent(assetId)}`, + { headers: this.httpHeaders(workspaceHash) }, + ) + if (!res.ok) { + throw new Error(await readHTTPError(res, 'Asset fetch failed')) + } + return res.blob() + } + /** 取消运行,返回取消结果 */ async cancel(params: CancelParams) { return this.ws.call(Method.Cancel, params) @@ -290,7 +332,9 @@ export class GatewayAPI { /** 切换工作区 */ async switchWorkspace(workspaceHash: string) { - return this.ws.call(Method.SwitchWorkspace, { workspace_hash: workspaceHash } satisfies SwitchWorkspaceParams) + const result = await this.ws.call(Method.SwitchWorkspace, { workspace_hash: workspaceHash } satisfies SwitchWorkspaceParams) + this.currentWorkspaceHash = workspaceHash.trim() + return result } /** 重命名工作区 */ @@ -302,4 +346,21 @@ export class GatewayAPI { async deleteWorkspace(workspaceHash: string, removeData?: boolean) { return this.ws.call(Method.DeleteWorkspace, { workspace_hash: workspaceHash, remove_data: removeData } satisfies DeleteWorkspaceParams) } + + getCurrentWorkspaceHash() { + return this.currentWorkspaceHash + } + + private httpHeaders(workspaceHash = '') { + const headers: Record = {} + if (this.token) headers.Authorization = `Bearer ${this.token}` + const resolvedWorkspaceHash = workspaceHash.trim() || this.currentWorkspaceHash + if (resolvedWorkspaceHash) headers['X-NeoCode-Workspace-Hash'] = resolvedWorkspaceHash + return Object.keys(headers).length > 0 ? headers : undefined + } +} + +async function readHTTPError(res: Response, fallback: string) { + const data = await res.json().catch(() => null) as { error?: string } | null + return data?.error || `${fallback} (HTTP ${res.status})` } diff --git a/web/src/api/protocol.ts b/web/src/api/protocol.ts index 15bad0b11..52c4e37eb 100644 --- a/web/src/api/protocol.ts +++ b/web/src/api/protocol.ts @@ -11,6 +11,7 @@ export const Method = { Ping: "gateway.ping", BindStream: "gateway.bindStream", Run: "gateway.run", + CreateSession: "gateway.createSession", Cancel: "gateway.cancel", Compact: "gateway.compact", ListSessions: "gateway.listSessions", @@ -234,9 +235,15 @@ export interface RunParams { export interface RunInputPart { type: string; text?: string; - media?: { uri: string; mime_type: string; file_name?: string }; + media?: { uri?: string; asset_id?: string; mime_type: string; file_name?: string }; } +export interface CreateSessionParams { + session_id?: string; +} + +export type CreateSessionResult = RPCResult<{ session_id: string }>; + /** gateway.cancel 参数 */ export interface CancelParams { session_id?: string; @@ -307,11 +314,19 @@ export interface SessionSummary { export interface SessionMessage { role: string; content: string; + parts?: RunInputPart[]; tool_calls?: ToolCall[]; tool_call_id?: string; is_error?: boolean; } +export interface SessionAssetUploadResult { + session_id: string; + asset_id: string; + mime_type: string; + size: number; +} + /** 工具调用 */ export interface ToolCall { id: string; diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index 53829db08..484fa90e8 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -6,10 +6,13 @@ import { useComposerStore } from '@/stores/useComposerStore' import { useSessionStore } from '@/stores/useSessionStore' import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' import { useGatewayStore } from '@/stores/useGatewayStore' +import { useWorkspaceStore } from '@/stores/useWorkspaceStore' const mockGatewayAPI = { listAvailableSkills: vi.fn(), listModels: vi.fn(), + createSession: vi.fn(), + uploadSessionAsset: vi.fn(), run: vi.fn(), bindStream: vi.fn(), cancel: vi.fn(), @@ -68,9 +71,23 @@ describe('ChatInput', () => { selected_model_id: '', }, }) - - useComposerStore.setState({ composerText: '' }) + mockGatewayAPI.createSession.mockResolvedValue({ payload: { session_id: 'session-created' } }) + mockGatewayAPI.uploadSessionAsset.mockResolvedValue({ + session_id: 'session-created', + asset_id: 'asset-1', + mime_type: 'image/png', + size: 3, + }) + mockGatewayAPI.run.mockResolvedValue({ session_id: 'session-created', run_id: 'run-1' }) + mockGatewayAPI.bindStream.mockResolvedValue({}) + if (typeof URL.createObjectURL !== 'function') { + Object.defineProperty(URL, 'createObjectURL', { configurable: true, value: vi.fn() }) + } + vi.spyOn(URL, 'createObjectURL').mockReturnValue('blob:preview-1') + + useComposerStore.setState({ composerText: '', attachments: [] }) useSessionStore.setState({ currentSessionId: '' } as never) + useWorkspaceStore.setState({ currentWorkspaceHash: 'workspace-b' } as never) useGatewayStore.setState({ currentRunId: '' } as never) useRuntimeInsightStore.getState().reset() useChatStore.setState({ @@ -157,12 +174,72 @@ describe('ChatInput', () => { }) }) - it('does not render the unimplemented attachment and mention buttons', () => { + it('renders the image attachment picker but keeps mention button absent', () => { render() - expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: /添加图片/ })).toBeInTheDocument() expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument() }) + + it('uploads selected image and sends image-only input parts after creating a session', async () => { + render() + + const file = new File(['img'], 'a.png', { type: 'image/png' }) + const input = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(input, { target: { files: [file] } }) + + await waitFor(() => { + expect(screen.getByAltText('a.png')).toBeInTheDocument() + }) + + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) + + await waitFor(() => { + expect(mockGatewayAPI.createSession).toHaveBeenCalled() + expect(mockGatewayAPI.uploadSessionAsset).toHaveBeenCalledWith('session-created', file, 'workspace-b') + expect(mockGatewayAPI.run).toHaveBeenCalledWith({ + session_id: 'session-created', + input_parts: [ + { type: 'image', media: { asset_id: 'asset-1', mime_type: 'image/png', file_name: 'a.png' } }, + ], + mode: 'build', + }) + }) + + expect(useChatStore.getState().messages[0]).toMatchObject({ + role: 'user', + attachments: [{ assetId: 'asset-1', previewUrl: 'blob:preview-1', workspaceHash: 'workspace-b' }], + }) + }) + + it('treats slash text as a normal message when an image is attached', async () => { + useSessionStore.setState({ currentSessionId: 'session-1' } as never) + mockGatewayAPI.uploadSessionAsset.mockResolvedValueOnce({ + session_id: 'session-1', + asset_id: 'asset-2', + mime_type: 'image/png', + size: 3, + }) + render() + + const file = new File(['img'], 'slash.png', { type: 'image/png' }) + const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(fileInput, { target: { files: [file] } }) + fireEvent.change(screen.getByRole('textbox'), { target: { value: '/memo' } }) + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) + + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).not.toHaveBeenCalled() + expect(mockGatewayAPI.uploadSessionAsset).toHaveBeenCalledWith('session-1', file, 'workspace-b') + expect(mockGatewayAPI.run).toHaveBeenCalledWith(expect.objectContaining({ + session_id: 'session-1', + input_parts: [ + { type: 'text', text: '/memo' }, + { type: 'image', media: { asset_id: 'asset-2', mime_type: 'image/png', file_name: 'slash.png' } }, + ], + })) + }) + }) it('blocks normal sends while compaction is running', async () => { useChatStore.getState().startCompacting('manual', 'Compacting context...') render() diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index 6291a702b..abf63a858 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -3,8 +3,14 @@ import { useChatStore, createUserMessage } from '@/stores/useChatStore' import { useGatewayStore } from '@/stores/useGatewayStore' import { useSessionStore, isValidSessionId } from '@/stores/useSessionStore' import { useUIStore } from '@/stores/useUIStore' -import { useComposerStore } from '@/stores/useComposerStore' +import { + acceptedImageMimeTypes, + maxComposerAttachmentBytes, + useComposerStore, + type ComposerAttachment, +} from '@/stores/useComposerStore' import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' +import { useWorkspaceStore } from '@/stores/useWorkspaceStore' import { formatTokenCount } from '@/utils/format' import { useGatewayAPI } from '@/context/RuntimeProvider' import { @@ -19,7 +25,7 @@ import { import SlashCommandMenu from './SlashCommandMenu' import SkillPicker from './SkillPicker' import ModelSelector from './ModelSelector' -import { Send, Square } from 'lucide-react' +import { ImagePlus, Loader2, Send, Square, X } from 'lucide-react' const slashMenuAnchorStyle: React.CSSProperties = { position: 'absolute', @@ -123,14 +129,22 @@ function resolveBudgetRingState( export default function ChatInput() { const gatewayAPI = useGatewayAPI() const text = useComposerStore((state) => state.composerText) + const attachments = useComposerStore((state) => state.attachments) const setText = useComposerStore((state) => state.setComposerText) + const addAttachmentFiles = useComposerStore((state) => state.addAttachmentFiles) + const removeAttachment = useComposerStore((state) => state.removeAttachment) + const clearAttachments = useComposerStore((state) => state.clearAttachments) + const setAttachmentStatus = useComposerStore((state) => state.setAttachmentStatus) const [rows, setRows] = useState(1) + const [dragActive, setDragActive] = useState(false) const textareaRef = useRef(null) + const fileInputRef = useRef(null) const runCancelledRef = useRef(false) const composingRef = useRef(false) const isGenerating = useChatStore((state) => state.isGenerating) const isCompacting = useChatStore((state) => state.isCompacting) const addMessage = useChatStore((state) => state.addMessage) + const removeMessage = useChatStore((state) => state.removeMessage) const addSystemMessage = useChatStore((state) => state.addSystemMessage) const setGenerating = useChatStore((state) => state.setGenerating) const sessionId = useSessionStore((state) => state.currentSessionId) @@ -138,6 +152,7 @@ export default function ChatInput() { const setAgentMode = useChatStore((state) => state.setAgentMode) const permissionMode = useChatStore((state) => state.permissionMode) const setPermissionMode = useChatStore((state) => state.setPermissionMode) + const currentWorkspaceHash = useWorkspaceStore((state) => state.currentWorkspaceHash) const [showSlashMenu, setShowSlashMenu] = useState(false) const [selectedIndex, setSelectedIndex] = useState(0) @@ -302,7 +317,9 @@ export default function ChatInput() { async function handleSubmit() { const input = text.trim() - if (!input) return + const pendingAttachments = attachments + if (!input && pendingAttachments.length === 0) return + let submittedMessageId = '' if (isCompacting) { useUIStore.getState().showToast('Context compaction is still running', 'info') @@ -314,30 +331,65 @@ export default function ChatInput() { return } - if (isSlashCommand(input)) { + if (pendingAttachments.length === 0 && isSlashCommand(input)) { setText('') setShowSlashMenu(false) const handled = await executeSlashCommand(input) if (handled) return } - setText('') - const userMsg = createUserMessage(input) - addMessage(userMsg) - useRuntimeInsightStore.getState().setTodoSnapshot({ - items: [], - summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 }, - }) - setGenerating(true) - runCancelledRef.current = false - try { if (!gatewayAPI) return - const isNewSession = !isValidSessionId(sessionId) + let targetSessionId = sessionId + if (!isValidSessionId(targetSessionId)) { + const created = await gatewayAPI.createSession() + targetSessionId = created.payload?.session_id || '' + if (!isValidSessionId(targetSessionId)) throw new Error('Create session failed') + useSessionStore.getState().setCurrentSessionId(targetSessionId) + await gatewayAPI.bindStream({ session_id: targetSessionId, channel: 'all' }).catch(() => {}) + } + + const workspaceHash = currentWorkspaceHash.trim() + const uploaded = [] + for (const attachment of pendingAttachments) { + setAttachmentStatus(attachment.id, 'uploading') + try { + const meta = await gatewayAPI.uploadSessionAsset(targetSessionId, attachment.file, workspaceHash) + setAttachmentStatus(attachment.id, 'uploaded') + uploaded.push({ attachment, meta }) + } catch (err) { + const message = err instanceof Error ? err.message : 'Upload failed' + setAttachmentStatus(attachment.id, 'error', message) + throw err + } + } + + const inputParts = buildRunInputParts(input, uploaded) + const userMsg = createUserMessage(input, uploaded.map(({ attachment, meta }) => ({ + id: attachment.id, + sessionId: targetSessionId, + workspaceHash, + assetId: meta.asset_id, + mimeType: meta.mime_type, + name: attachment.file.name, + size: meta.size, + previewUrl: attachment.previewUrl, + }))) + + setText('') + clearAttachments(false) + addMessage(userMsg) + submittedMessageId = userMsg.id + useRuntimeInsightStore.getState().setTodoSnapshot({ + items: [], + summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 }, + }) + setGenerating(true) + runCancelledRef.current = false + const ack = await gatewayAPI.run({ - session_id: isNewSession ? undefined : sessionId, - new_session: isNewSession ? true : undefined, - input_text: input, + session_id: targetSessionId, + input_parts: inputParts, mode: agentMode, }) if (!runCancelledRef.current) { @@ -351,10 +403,12 @@ export default function ChatInput() { } } catch (err) { if (!runCancelledRef.current) { + if (submittedMessageId) { + removeMessage(submittedMessageId) + } setGenerating(false) - useChatStore.getState().removeMessage(userMsg.id) console.error('Run failed:', err) - useUIStore.getState().showToast('Failed to send message', 'error') + useUIStore.getState().showToast(err instanceof Error ? err.message : 'Failed to send message', 'error') } } } @@ -421,6 +475,33 @@ export default function ChatInput() { void executeSlashCommand(cmd.usage) } + function handleFilesSelected(files: FileList | File[]) { + const accepted: File[] = [] + for (const file of Array.from(files)) { + if (!acceptedImageMimeTypes.includes(file.type as any)) { + useUIStore.getState().showToast('Only PNG, JPEG, and WebP images are supported', 'error') + continue + } + if (file.size <= 0) { + useUIStore.getState().showToast('Cannot upload an empty file', 'error') + continue + } + if (file.size > maxComposerAttachmentBytes) { + useUIStore.getState().showToast('Image exceeds the 20 MiB limit', 'error') + continue + } + accepted.push(file) + } + if (accepted.length > 0) addAttachmentFiles(accepted) + } + + function handleDrop(e: React.DragEvent) { + e.preventDefault() + setDragActive(false) + if (controlsLocked) return + handleFilesSelected(e.dataTransfer.files) + } + async function handleCancel() { runCancelledRef.current = true const runId = useGatewayStore.getState().currentRunId @@ -439,7 +520,7 @@ export default function ChatInput() { } } - const isEmpty = !text.trim() + const isEmpty = !text.trim() && attachments.length === 0 const controlsLocked = isGenerating || isCompacting return ( @@ -460,7 +541,16 @@ export default function ChatInput() { /> )} -
+
{ + e.preventDefault() + if (!controlsLocked) setDragActive(true) + }} + onDragLeave={() => setDragActive(false)} + onDrop={handleDrop} + > +