diff --git a/cmd/tui/main.go b/cmd/tui/main.go index 55b44f4b..86ce0ce6 100644 --- a/cmd/tui/main.go +++ b/cmd/tui/main.go @@ -3,269 +3,122 @@ package main import ( "bufio" "context" - "errors" "flag" "fmt" + "io" "os" "strings" "go-llm-demo/configs" - "go-llm-demo/internal/server/infra/provider" - "go-llm-demo/internal/server/infra/tools" - "go-llm-demo/internal/tui/core" - "go-llm-demo/internal/tui/infra" + "go-llm-demo/internal/tui/bootstrap" tea "github.com/charmbracelet/bubbletea" ) -func main() { - workspaceFlag := flag.String("workspace", "", "指定工作区根目录") - flag.Parse() +const defaultConfigPath = "config.yaml" - setUTF8Mode() +var buildRunDeps = defaultRunDeps - workspaceRoot, err := tools.ResolveWorkspaceRoot(*workspaceFlag) - if err != nil { - fmt.Fprintf(os.Stderr, "解析工作区失败: %v\n", err) - os.Exit(1) - } - if err := tools.SetWorkspaceRoot(workspaceRoot); err != nil { - fmt.Fprintf(os.Stderr, "设置工作区失败: %v\n", err) - os.Exit(1) - } +type programRunner interface { + Run() (tea.Model, error) +} - scanner := bufio.NewScanner(os.Stdin) - ready, err := ensureAPIKeyInteractive(context.Background(), scanner, "config.yaml") - if err != nil { - fmt.Fprintf(os.Stderr, "初始化配置失败: %v\n", err) - os.Exit(1) - } - if !ready { - fmt.Println("已退出 NeoCode") - return - } +type runDeps struct { + stdin io.Reader + stdout io.Writer + stderr io.Writer + setUTF8Mode func() + prepareWorkspace func(string) (string, error) + ensureAPIKeyInteractive func(context.Context, *bufio.Scanner, string) (bool, error) + loadAppConfig func(string) error + loadPersonaPrompt func(string) (string, string, error) + newProgram func(string, int, string, string) (programRunner, error) +} - if err := configs.LoadAppConfig("config.yaml"); err != nil { - fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err) - os.Exit(1) +func defaultRunDeps(stdin io.Reader, stdout, stderr io.Writer) runDeps { + return runDeps{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + setUTF8Mode: setUTF8Mode, + prepareWorkspace: bootstrap.PrepareWorkspace, + ensureAPIKeyInteractive: bootstrap.EnsureAPIKeyInteractive, + loadAppConfig: configs.LoadAppConfig, + loadPersonaPrompt: configs.LoadPersonaPrompt, + newProgram: func(persona string, historyTurns int, configPath, workspaceRoot string) (programRunner, error) { + return bootstrap.NewProgram(persona, historyTurns, configPath, workspaceRoot) + }, } +} - persona, personaPath, err := configs.LoadPersonaPrompt(configs.GlobalAppConfig.Persona.FilePath) - if err != nil { - fmt.Fprintf(os.Stderr, "警告: 人设加载失败: %v\n", err) - } else if personaPath != "" && strings.TrimSpace(configs.GlobalAppConfig.Persona.FilePath) != personaPath { - fmt.Fprintf(os.Stderr, "提示: 人设已从 %s 回退加载\n", personaPath) - } - historyTurns := configs.GlobalAppConfig.History.ShortTermTurns - - client, err := infra.NewLocalChatClient() +func main() { + workspaceFlag, err := parseWorkspaceFlag(os.Args[1:], os.Stderr) if err != nil { - fmt.Fprintf(os.Stderr, "初始化失败: %v\n", err) os.Exit(1) } - model := core.NewModel(client, persona, historyTurns, "config.yaml", workspaceRoot) - p := tea.NewProgram(model, - tea.WithAltScreen(), - tea.WithMouseCellMotion(), - ) - if _, err := p.Run(); err != nil { - fmt.Fprintf(os.Stderr, "运行失败: %v\n", err) + if err := run(workspaceFlag, os.Stdin, os.Stdout, os.Stderr); err != nil { + fmt.Fprintln(os.Stderr, err) os.Exit(1) } } -func ensureAPIKeyInteractive(ctx context.Context, scanner *bufio.Scanner, configPath string) (bool, error) { - cfg, created, err := configs.EnsureConfigFile(configPath) - if err != nil { - return false, err - } - if created { - fmt.Printf("已创建 %s\n", configPath) - } - - for { - if cfg.RuntimeAPIKey() == "" { - envName := cfg.APIKeyEnvVarName() - fmt.Printf("未检测到环境变量 %s。可使用 /apikey 、/provider 、/switch 切换配置,或先设置该环境变量后再 /retry。\n", envName) - fmt.Printf("Windows 示例: setx %s \"your-api-key\"\n", envName) - result, handleErr := handleSetupDecision(scanner, cfg, false, configPath) - if handleErr != nil { - return false, handleErr - } - if result == setupExit { - return false, nil - } - continue - } +func parseWorkspaceFlag(args []string, stderr io.Writer) (string, error) { + fs := flag.NewFlagSet("tui", flag.ContinueOnError) + fs.SetOutput(stderr) - if err := provider.ValidateChatAPIKey(ctx, cfg); err == nil { - if saveErr := configs.WriteAppConfig(configPath, cfg); saveErr != nil { - return false, saveErr - } - configs.GlobalAppConfig = cfg - fmt.Println("API key 验证通过。") - return true, nil - } else if errors.Is(err, provider.ErrInvalidAPIKey) { - fmt.Printf("环境变量 %s 中的 API key 无效: %v\n", cfg.APIKeyEnvVarName(), err) - result, handleErr := handleSetupDecision(scanner, cfg, false, configPath) - if handleErr != nil { - return false, handleErr - } - if result == setupExit { - return false, nil - } - continue - } else if errors.Is(err, provider.ErrAPIKeyValidationSoft) { - fmt.Printf("无法确认环境变量 %s 中的 API key 有效性: %v\n", cfg.APIKeyEnvVarName(), err) - result, handleErr := handleSetupDecision(scanner, cfg, true, configPath) - if handleErr != nil { - return false, handleErr - } - if result == setupExit { - return false, nil - } - if result == setupContinue { - configs.GlobalAppConfig = cfg - return true, nil - } - continue - } else { - fmt.Printf("模型验证失败: %v\n", err) - result, handleErr := handleSetupDecision(scanner, cfg, false, configPath) - if handleErr != nil { - return false, handleErr - } - if result == setupExit { - return false, nil - } - if result == setupContinue { - configs.GlobalAppConfig = cfg - return true, nil - } - } + workspaceFlag := fs.String("workspace", "", "指定工作区根目录") + if err := fs.Parse(args); err != nil { + return "", err } + return *workspaceFlag, nil } -type setupDecision int - -const ( - setupRetry setupDecision = iota - setupContinue - setupExit -) +func run(workspaceFlag string, stdin io.Reader, stdout, stderr io.Writer) error { + return runWithDeps(workspaceFlag, buildRunDeps(stdin, stdout, stderr)) +} -func handleSetupDecision(scanner *bufio.Scanner, cfg *configs.AppConfiguration, allowContinue bool, configPath string) (setupDecision, error) { - for { - prompt := "选择 /retry, /apikey , /provider , /switch , 或 /exit > " - if allowContinue { - prompt = "选择 /retry, /continue, /apikey , /provider , /switch , 或 /exit > " - } - decision, ok, inputErr := readInteractiveLine(scanner, prompt) - if inputErr != nil { - return setupExit, inputErr - } - if !ok { - return setupExit, nil - } +func runWithDeps(workspaceFlag string, deps runDeps) error { + if deps.setUTF8Mode != nil { + deps.setUTF8Mode() + } - fields := strings.Fields(strings.TrimSpace(decision)) - if len(fields) == 0 { - continue - } + workspaceRoot, err := deps.prepareWorkspace(workspaceFlag) + if err != nil { + return fmt.Errorf("解析工作区失败: %w", err) + } - switch strings.ToLower(fields[0]) { - case "/retry": - return setupRetry, nil - case "/apikey": - if len(fields) < 2 { - fmt.Println("用法: /apikey ") - continue - } - applyAPIKeyEnvName(cfg, fields[1]) - fmt.Printf("已切换 API Key 环境变量名为: %s\n", cfg.APIKeyEnvVarName()) - return setupRetry, nil - case "/continue": - if !allowContinue { - fmt.Println("/continue 仅在网络或服务问题导致无法确认时可用。") - continue - } - if saveErr := configs.WriteAppConfig(configPath, cfg); saveErr != nil { - return setupExit, saveErr - } - fmt.Println("继续启动,使用当前 API key 和模型。") - return setupContinue, nil - case "/provider": - if len(fields) < 2 { - fmt.Println("用法: /provider ") - printSupportedProviders() - continue - } - providerName, ok := provider.NormalizeProviderName(fields[1]) - if !ok { - fmt.Printf("不支持的提供商 %q\n", fields[1]) - printSupportedProviders() - continue - } - cfg.AI.Provider = providerName - cfg.AI.Model = provider.DefaultModelForProvider(providerName) - fmt.Printf("已切换到提供商: %s\n", providerName) - fmt.Printf("当前模型已重置为默认值: %s\n", cfg.AI.Model) - return setupRetry, nil - case "/switch": - if len(fields) < 2 { - fmt.Println("用法: /switch ") - continue - } - target := strings.Join(fields[1:], " ") - cfg.AI.Model = target - fmt.Printf("已切换到模型: %s\n", target) - return setupRetry, nil - case "/exit": - return setupExit, nil - default: - if allowContinue { - fmt.Println("请输入 /retry, /continue, /apikey , /provider , /switch , 或 /exit。") - } else { - fmt.Println("请输入 /retry, /apikey , /provider , /switch , 或 /exit。") - } - } + scanner := bufio.NewScanner(deps.stdin) + ready, err := deps.ensureAPIKeyInteractive(context.Background(), scanner, defaultConfigPath) + if err != nil { + return fmt.Errorf("初始化配置失败: %w", err) + } + if !ready { + fmt.Fprintln(deps.stdout, "已退出 NeoCode") + return nil } -} -func applyAPIKeyEnvName(cfg *configs.AppConfiguration, envName string) { - if cfg == nil { - return + if err := deps.loadAppConfig(defaultConfigPath); err != nil { + return fmt.Errorf("加载配置失败: %w", err) } - cfg.AI.APIKey = strings.TrimSpace(envName) -} -func readInteractiveLine(scanner *bufio.Scanner, prompt string) (string, bool, error) { - for { - fmt.Print(prompt) - if !scanner.Scan() { - if err := scanner.Err(); err != nil { - return "", false, err - } - return "", false, nil - } - input := strings.TrimSpace(scanner.Text()) - if input == "" { - fmt.Println("输入不能为空。") - continue - } - if input == "/exit" { - return "", false, nil - } - return input, true, nil + persona, personaPath, err := deps.loadPersonaPrompt(configs.GlobalAppConfig.Persona.FilePath) + if err != nil { + fmt.Fprintf(deps.stderr, "警告: 人设加载失败: %v\n", err) + } else if personaPath != "" && strings.TrimSpace(configs.GlobalAppConfig.Persona.FilePath) != personaPath { + fmt.Fprintf(deps.stderr, "提示: 人设已从 %s 回退加载\n", personaPath) } -} -func printSupportedProviders() { - fmt.Println("可用提供商:") - for _, name := range provider.SupportedProviders() { - fmt.Printf(" %s\n", name) + historyTurns := configs.GlobalAppConfig.History.ShortTermTurns + p, err := deps.newProgram(persona, historyTurns, defaultConfigPath, workspaceRoot) + if err != nil { + return fmt.Errorf("初始化失败: %w", err) + } + if _, err := p.Run(); err != nil { + return fmt.Errorf("运行失败: %w", err) } + + return nil } func loadDotEnv(path string) error { @@ -314,4 +167,3 @@ func loadPersonaPrompt(path string) string { return strings.TrimSpace(string(data)) } - diff --git a/cmd/tui/main_test.go b/cmd/tui/main_test.go new file mode 100644 index 00000000..368663bf --- /dev/null +++ b/cmd/tui/main_test.go @@ -0,0 +1,419 @@ +package main + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "go-llm-demo/configs" + + tea "github.com/charmbracelet/bubbletea" +) + +type fakeProgram struct { + runErr error + called bool +} + +func (p *fakeProgram) Run() (tea.Model, error) { + p.called = true + return nil, p.runErr +} + +func TestLoadDotEnvSetsMissingVarsOnly(t *testing.T) { + t.Setenv("EXISTING_KEY", "keep-me") + t.Setenv("NEW_KEY", "") + + dir := t.TempDir() + path := filepath.Join(dir, ".env") + content := "EXISTING_KEY=override\nNEW_KEY= new-value \n# comment\nINVALID_LINE\n" + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write env file: %v", err) + } + + if err := loadDotEnv(path); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got := os.Getenv("EXISTING_KEY"); got != "keep-me" { + t.Fatalf("expected existing env var to be preserved, got %q", got) + } + if got := os.Getenv("NEW_KEY"); got != "new-value" { + t.Fatalf("expected new env var to load, got %q", got) + } +} + +func TestLoadDotEnvIgnoresMissingFile(t *testing.T) { + missing := filepath.Join(t.TempDir(), "missing.env") + if err := loadDotEnv(missing); err != nil { + t.Fatalf("expected missing file to be ignored, got %v", err) + } +} + +func TestLoadDotEnvReturnsNonENOENTError(t *testing.T) { + if err := loadDotEnv(t.TempDir()); err == nil { + t.Fatal("expected non-ENOENT read error") + } +} + +func TestLoadDotEnvTrimsQuotedValuesAndSkipsEmptyKeys(t *testing.T) { + t.Setenv("QUOTED_KEY", "") + + dir := t.TempDir() + path := filepath.Join(dir, ".env") + content := "QUOTED_KEY=' spaced value '\n =ignored\n" + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write env file: %v", err) + } + + if err := loadDotEnv(path); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got := os.Getenv("QUOTED_KEY"); got != " spaced value " { + t.Fatalf("expected quoted value to be trimmed, got %q", got) + } +} + +func TestLoadPersonaPromptReturnsTrimmedContent(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "persona.txt") + if err := os.WriteFile(path, []byte("\n hello persona \n"), 0o644); err != nil { + t.Fatalf("write persona file: %v", err) + } + + if got := loadPersonaPrompt(path); got != "hello persona" { + t.Fatalf("expected trimmed persona prompt, got %q", got) + } +} + +func TestLoadPersonaPromptReturnsEmptyForMissingFile(t *testing.T) { + missing := filepath.Join(t.TempDir(), "missing.txt") + if got := loadPersonaPrompt(missing); got != "" { + t.Fatalf("expected empty string for missing file, got %q", got) + } +} + +func TestLoadPersonaPromptReturnsEmptyForBlankPath(t *testing.T) { + if got := loadPersonaPrompt(" "); got != "" { + t.Fatalf("expected empty string for blank path, got %q", got) + } +} + +func TestParseWorkspaceFlagParsesWorkspaceValue(t *testing.T) { + stderr := &bytes.Buffer{} + + got, err := parseWorkspaceFlag([]string{"-workspace", "D:/neo-code/workspace"}, stderr) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got != "D:/neo-code/workspace" { + t.Fatalf("unexpected workspace flag value %q", got) + } +} + +func TestDefaultRunDepsWiresStandardStreamsAndFunctions(t *testing.T) { + deps := defaultRunDeps(strings.NewReader("in"), &bytes.Buffer{}, &bytes.Buffer{}) + + if deps.stdin == nil || deps.stdout == nil || deps.stderr == nil { + t.Fatal("expected stdio to be preserved in deps") + } + if deps.setUTF8Mode == nil || deps.prepareWorkspace == nil || deps.ensureAPIKeyInteractive == nil || deps.loadAppConfig == nil || deps.loadPersonaPrompt == nil || deps.newProgram == nil { + t.Fatal("expected default dependencies to be populated") + } +} + +func TestRunUsesInjectableDepBuilder(t *testing.T) { + origBuildRunDeps := buildRunDeps + t.Cleanup(func() { buildRunDeps = origBuildRunDeps }) + + called := false + buildRunDeps = func(stdin io.Reader, stdout, stderr io.Writer) runDeps { + called = true + return runDeps{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + setUTF8Mode: func() {}, + prepareWorkspace: func(string) (string, error) { return "D:/neo-code", nil }, + ensureAPIKeyInteractive: func(context.Context, *bufio.Scanner, string) (bool, error) { + return false, nil + }, + } + } + + err := run("D:/neo-code", strings.NewReader(""), &bytes.Buffer{}, &bytes.Buffer{}) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !called { + t.Fatal("expected run to use buildRunDeps") + } +} + +func TestRunWithDepsReturnsWorkspacePreparationError(t *testing.T) { + stderr := &bytes.Buffer{} + + err := runWithDeps("", runDeps{ + stdin: strings.NewReader(""), + stdout: &bytes.Buffer{}, + stderr: stderr, + setUTF8Mode: func() {}, + prepareWorkspace: func(string) (string, error) { return "", errors.New("workspace failed") }, + }) + if err == nil || !strings.Contains(err.Error(), "workspace failed") { + t.Fatalf("expected workspace error, got %v", err) + } +} + +func TestRunWithDepsStopsCleanlyWhenSetupNotReady(t *testing.T) { + stdout := &bytes.Buffer{} + loadCalled := false + + err := runWithDeps("D:/neo-code", runDeps{ + stdin: strings.NewReader(""), + stdout: stdout, + stderr: &bytes.Buffer{}, + setUTF8Mode: func() {}, + prepareWorkspace: func(string) (string, error) { return "D:/neo-code", nil }, + ensureAPIKeyInteractive: func(_ context.Context, _ *bufio.Scanner, path string) (bool, error) { + if path != defaultConfigPath { + t.Fatalf("expected config path %q, got %q", defaultConfigPath, path) + } + return false, nil + }, + loadAppConfig: func(string) error { + loadCalled = true + return nil + }, + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if loadCalled { + t.Fatal("loadAppConfig should not run when setup is not ready") + } + if !strings.Contains(stdout.String(), "NeoCode") { + t.Fatalf("expected exit message in stdout, got %q", stdout.String()) + } +} + +func TestRunWithDepsReturnsBootstrapError(t *testing.T) { + err := runWithDeps("D:/neo-code", runDeps{ + stdin: strings.NewReader(""), + stdout: &bytes.Buffer{}, + stderr: &bytes.Buffer{}, + setUTF8Mode: func() {}, + prepareWorkspace: func(string) (string, error) { + return "D:/neo-code", nil + }, + ensureAPIKeyInteractive: func(context.Context, *bufio.Scanner, string) (bool, error) { + return false, errors.New("bootstrap failed") + }, + }) + if err == nil || !strings.Contains(err.Error(), "bootstrap failed") { + t.Fatalf("expected bootstrap error, got %v", err) + } +} + +func TestRunWithDepsReturnsLoadAppConfigError(t *testing.T) { + err := runWithDeps("D:/neo-code", runDeps{ + stdin: strings.NewReader(""), + stdout: &bytes.Buffer{}, + stderr: &bytes.Buffer{}, + setUTF8Mode: func() {}, + prepareWorkspace: func(string) (string, error) { + return "D:/neo-code", nil + }, + ensureAPIKeyInteractive: func(context.Context, *bufio.Scanner, string) (bool, error) { + return true, nil + }, + loadAppConfig: func(string) error { return errors.New("load failed") }, + }) + if err == nil || !strings.Contains(err.Error(), "load failed") { + t.Fatalf("expected load error, got %v", err) + } +} + +func TestRunWithDepsPrintsPersonaFallbackHintAndRunsProgram(t *testing.T) { + origGlobalConfig := configs.GlobalAppConfig + t.Cleanup(func() { configs.GlobalAppConfig = origGlobalConfig }) + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + cfg := configs.DefaultAppConfig() + cfg.Persona.FilePath = "./persona.txt" + + program := &fakeProgram{} + newProgramCalled := false + err := runWithDeps("D:/neo-code", runDeps{ + stdin: strings.NewReader(""), + stdout: stdout, + stderr: stderr, + setUTF8Mode: func() {}, + prepareWorkspace: func(string) (string, error) { return "D:/neo-code", nil }, + ensureAPIKeyInteractive: func(context.Context, *bufio.Scanner, string) (bool, error) { return true, nil }, + loadAppConfig: func(string) error { + configs.GlobalAppConfig = cfg + return nil + }, + loadPersonaPrompt: func(path string) (string, string, error) { + if path != "./persona.txt" { + t.Fatalf("expected configured persona path, got %q", path) + } + return "persona text", "./configs/persona.txt", nil + }, + newProgram: func(persona string, historyTurns int, configPath, workspaceRoot string) (programRunner, error) { + newProgramCalled = true + if persona != "persona text" { + t.Fatalf("unexpected persona %q", persona) + } + if historyTurns != cfg.History.ShortTermTurns { + t.Fatalf("unexpected history turns %d", historyTurns) + } + if configPath != defaultConfigPath || workspaceRoot != "D:/neo-code" { + t.Fatalf("unexpected program args: %q %q", configPath, workspaceRoot) + } + return program, nil + }, + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !newProgramCalled || !program.called { + t.Fatal("expected program to be created and run") + } + if !strings.Contains(stderr.String(), "./configs/persona.txt") { + t.Fatalf("expected fallback persona hint, got %q", stderr.String()) + } +} + +func TestRunWithDepsContinuesWhenPersonaLoadFails(t *testing.T) { + origGlobalConfig := configs.GlobalAppConfig + t.Cleanup(func() { configs.GlobalAppConfig = origGlobalConfig }) + + cfg := configs.DefaultAppConfig() + stderr := &bytes.Buffer{} + program := &fakeProgram{} + + err := runWithDeps("D:/neo-code", runDeps{ + stdin: strings.NewReader(""), + stdout: &bytes.Buffer{}, + stderr: stderr, + setUTF8Mode: func() {}, + prepareWorkspace: func(string) (string, error) { return "D:/neo-code", nil }, + ensureAPIKeyInteractive: func(context.Context, *bufio.Scanner, string) (bool, error) { return true, nil }, + loadAppConfig: func(string) error { + configs.GlobalAppConfig = cfg + return nil + }, + loadPersonaPrompt: func(string) (string, string, error) { + return "", "", errors.New("persona failed") + }, + newProgram: func(persona string, historyTurns int, configPath, workspaceRoot string) (programRunner, error) { + if persona != "" { + t.Fatalf("expected empty persona on load failure, got %q", persona) + } + return program, nil + }, + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !program.called { + t.Fatal("expected program to still run") + } + if !strings.Contains(stderr.String(), "persona failed") { + t.Fatalf("expected persona warning, got %q", stderr.String()) + } +} + +func TestRunWithDepsReturnsNewProgramError(t *testing.T) { + origGlobalConfig := configs.GlobalAppConfig + t.Cleanup(func() { configs.GlobalAppConfig = origGlobalConfig }) + + cfg := configs.DefaultAppConfig() + + err := runWithDeps("D:/neo-code", runDeps{ + stdin: strings.NewReader(""), + stdout: &bytes.Buffer{}, + stderr: &bytes.Buffer{}, + setUTF8Mode: func() {}, + prepareWorkspace: func(string) (string, error) { return "D:/neo-code", nil }, + ensureAPIKeyInteractive: func(context.Context, *bufio.Scanner, string) (bool, error) { return true, nil }, + loadAppConfig: func(string) error { + configs.GlobalAppConfig = cfg + return nil + }, + loadPersonaPrompt: func(string) (string, string, error) { return "persona", "", nil }, + newProgram: func(string, int, string, string) (programRunner, error) { return nil, errors.New("new program failed") }, + }) + if err == nil || !strings.Contains(err.Error(), "new program failed") { + t.Fatalf("expected new program error, got %v", err) + } +} + +func TestRunWithDepsReturnsProgramRunError(t *testing.T) { + origGlobalConfig := configs.GlobalAppConfig + t.Cleanup(func() { configs.GlobalAppConfig = origGlobalConfig }) + + cfg := configs.DefaultAppConfig() + program := &fakeProgram{runErr: errors.New("program failed")} + + err := runWithDeps("D:/neo-code", runDeps{ + stdin: strings.NewReader(""), + stdout: &bytes.Buffer{}, + stderr: &bytes.Buffer{}, + setUTF8Mode: func() {}, + prepareWorkspace: func(string) (string, error) { return "D:/neo-code", nil }, + ensureAPIKeyInteractive: func(context.Context, *bufio.Scanner, string) (bool, error) { return true, nil }, + loadAppConfig: func(string) error { + configs.GlobalAppConfig = cfg + return nil + }, + loadPersonaPrompt: func(string) (string, string, error) { return "", "", nil }, + newProgram: func(string, int, string, string) (programRunner, error) { return program, nil }, + }) + if err == nil || !strings.Contains(err.Error(), "program failed") { + t.Fatalf("expected run error, got %v", err) + } +} + +func TestRunWithDepsHappyPathCallsUTF8Hook(t *testing.T) { + origGlobalConfig := configs.GlobalAppConfig + t.Cleanup(func() { configs.GlobalAppConfig = origGlobalConfig }) + + cfg := configs.DefaultAppConfig() + utf8Called := false + program := &fakeProgram{} + + err := runWithDeps("D:/neo-code", runDeps{ + stdin: strings.NewReader(""), + stdout: &bytes.Buffer{}, + stderr: &bytes.Buffer{}, + setUTF8Mode: func() { + utf8Called = true + }, + prepareWorkspace: func(string) (string, error) { return "D:/neo-code", nil }, + ensureAPIKeyInteractive: func(context.Context, *bufio.Scanner, string) (bool, error) { return true, nil }, + loadAppConfig: func(string) error { + configs.GlobalAppConfig = cfg + return nil + }, + loadPersonaPrompt: func(string) (string, string, error) { return "persona", "", nil }, + newProgram: func(string, int, string, string) (programRunner, error) { return program, nil }, + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !utf8Called { + t.Fatal("expected UTF8 hook to be called") + } + if !program.called { + t.Fatal("expected program to run") + } +} diff --git a/docs/TUI_REFINED_ARCHITECTURE.md b/docs/TUI_REFINED_ARCHITECTURE.md new file mode 100644 index 00000000..a94d1db6 --- /dev/null +++ b/docs/TUI_REFINED_ARCHITECTURE.md @@ -0,0 +1,149 @@ +# NeoCode TUI 增强版架构设计指南 (细化版) + +## 核心设计 + +本架构基于 **TEA (The Elm Architecture)** 模式,遵循“**单向数据流**”与“**逻辑/视图/通信彻底解耦**”的原则。 +- **状态驱动**:界面是状态(State)的函数。 +- **异步解耦**:所有耗时操作(网络、I/O)必须通过 `tea.Cmd` 异步执行。 +- **物理隔离**:TUI 层严禁引用 `internal/server` 目录下的任何非 API 结构。 + +--- + +## 细化后的六层架构模型 + +我们将整个 TUI 客户端细化为六个逻辑层,每个层级有严格的边界限制: + +### 入口层 (Entry) - `cmd/tui/` + +* **职责**:程序的物理起点。 +* **具体任务**: + * 解析命令行启动参数(如 `--debug`, `--config`)。 + * 调用 `bootstrap` 层获取初始化好的 `Program` 实例。 + * 启动 `tea.NewProgram` 并处理最终的退出错误。 +* **禁止**:严禁编写任何业务逻辑或具体的 UI 布局代码。 + +### 启动与注入层 (Bootstrap) - `internal/tui/bootstrap/` + +* **职责**:系统的“总装车间”,负责**依赖注入 (DI)**。 +* **具体任务**: + * 读取 `config.yaml` 配置文件。 + * 实例化 `services` 层(如 `APIClient`, `Logger`)。 + * 将这些服务注入到 `app` 层中。 + * **关键作用**:通过在这一层注入不同的实现,可以轻松实现“离线测试模式”或“Mock 测试”。 + +### 应用逻辑层 (App/Core) - `internal/tui/app/` + +* **职责**:状态机中心,负责调度 `Update` 和 `View`。 +* **具体文件**: + * `model.go`: 定义顶层 `Model`,聚合各子模块状态。 + * `update.go`: 核心业务逻辑路由器。根据收到的 `tea.Msg` 决定调用哪个 Service 或更新哪个 State。 + * `view.go`: 顶层布局管理器。决定 `Header`, `Content`, `Footer` 的排版位置(使用 Lipgloss)。 + * `msg.go`: 定义所有自定义消息类型(如 `AIGeneratingMsg`, `SocketErrorMsg`)。 + +### 纯状态层 (State) - `internal/tui/state/` + +* **职责**:**数据容器**。仅存放纯粹的 Go 结构体。 +* **具体任务**: + * `ui_state.go`: 记录 UI 细节(如:窗口宽高、当前焦点在哪个输入框、滚动条位置)。 + * `chat_state.go`: 存放当前的聊天历史、AI 思考中的临时文本。 +* **准则**:这一层**不含任何方法**,只存放数据,确保状态的可序列化和易测试性。 + +### 视图组件层 (Components) - `internal/tui/components/` + +* **职责**:**原子级 UI 渲染器**(“傻瓜组件”)。 +* **具体任务**: + * `code_block.go`: 负责代码高亮渲染。 + * `status_bar.go`: 负责底部状态栏的样式。 +* **原则**: + * **输入**:仅接收基础数据或 State 结构体。 + * **输出**:返回渲染好的字符串(`string`)。 + * **禁止**:组件内严禁发起任何网络请求或修改全局状态。 + +### 服务对接层 (Services) - `internal/tui/services/` + +* **职责**:**外交部**。负责与后端 Server 或系统环境通信。 +* **具体任务**: + * `api_client.go`: 封装对 `internal/server/transport` 的调用。 + * `file_service.go`: 处理本地文件的临时读取。 +* **原则**:所有方法必须返回 `tea.Cmd` 或在回调中触发 `tea.Msg`。 + +--- + +## 标准数据流向 (Lifecycle) + +以“用户发送消息”为例: +1. **用户按下回车**:`app/update.go` 捕获到按键事件。 +2. **更新本地状态**:`update.go` 将用户输入追加到 `state/chat_state.go`,并返回一个 `tea.Cmd` 触发发送请求。 +3. **服务调用**:`services/api_client.go` 执行异步 API 调用。 +4. **结果反馈**:API 返回结果后,封装成 `APIResponseMsg` 发回给 `app/update.go`。 +5. **界面重绘**:`app/view.go` 根据更新后的 `state` 重新生成字符串,渲染到屏幕。 + + + +#### 用户视角:一个“发送消息”动作的全层级演变 + + + 假设用户在终端输入了 "你好" 并按下 回车键。以下是各层级像齿轮一样咬合转动的过程: + 第一阶段:输入捕获 (Entry -> App) + * 用户看到:手指按下回车。 + * 层级变化:入口层 (Entry) 的底层的 tea.Program 捕获到操作系统发来的按键信号,并将其包装成一个 KeyMsg + 发送给 应用逻辑层 (App) 的 Update 函数。 + + + 第二阶段:本地状态更新 (App -> State -> View) + * 层级变化: + 1. App (Update):识别出这是回车键。 + 2. State:Update 把 state.ChatState.InputBuffer 里的 "你好" 提取出来,清空输入框,并把这条消息塞进 + state.ChatState.History 数组。 + 3. View:Runtime 立即调用 View。View 发现 History 里多了一条消息,于是命令 组件层 (Components) 渲染一个新的气泡。 + * 用户看到:屏幕上自己发送的 "你好" 瞬间出现在了聊天记录区域,且输入框变空了。(此时 AI + 还没说话,但用户感觉响应非常快) + + + 第三阶段:发起异步请求 (App -> Services) + * 层级变化: + 1. App (Update):在更新完本地状态的同时,返回一个 tea.Cmd。这个命令指向 服务层 (Services) 的 SendToAI 函数。 + 2. Services:在后台偷偷发起网络请求,把 "你好" 发给后端的 Go Server。 + + + 第四阶段:AI 响应回流 (Services -> App -> State -> View) + * 层级变化: + 1. Services:收到后端返回的 AI 回复(如 "你好!我是 NeoCode"),将其包装成一个 AIResponseMsg 投递回App。 + 2. App (Update):收到这个消息,再次修改 State,把 AI 的话加入 History。 + 3. View:Runtime 再次触发 View 重绘。 + * 用户看到:屏幕上刷新出了 AI 的回复。 + +--- + +## 细化后的目录结构 + +```text +internal/tui/ +├── bootstrap/ # 依赖装配 (Runtime 构造器) +│ └── runtime.go +├── app/ # 状态机核心 (TEA 循环) +│ ├── model.go # 顶层模型定义 +│ ├── update.go # 消息分发逻辑 +│ ├── view.go # 顶层布局 (Layout) +│ ├── msg.go # 消息类型定义 +│ └── keymap.go # 快捷键配置 +├── state/ # 纯状态定义 (数据结构) +│ ├── ui_state.go # 窗口、焦点等状态 +│ └── chat_state.go # 聊天记录等数据 +├── components/ # 纯 UI 组件 (Lipgloss 渲染) +│ ├── code_block.go # 代码块组件 +│ ├── input_box.go # 输入框增强 +│ └── status_bar.go # 状态栏组件 +└── services/ # 外部适配器 (API/I/O) + ├── api_client.go # 后端通信 + └── config_svc.go # 配置读取 +``` + +--- + +## 开发守则 (Constraints) + +1. **禁止跨层修改**:`components` 里的代码绝对不能修改 `state` 里的数据,必须通过 `Update` 函数统一处理。 +2. **样式与逻辑分离**:所有的颜色、边距(Lipgloss 样式)应在 `components` 中定义,`app/view.go` 仅负责大框架的拼装。 +3. **零后端业务依赖**:TUI 只能依赖 `api/proto` 中定义的结构。如果后端修改了业务逻辑,只要 API 不变,TUI 代码应保持一行不改。 +4. **异步命令化**:任何可能超过 10ms 的操作(读取大文件、调用 AI)必须封装在 `services` 层返回 `tea.Cmd`。 diff --git a/internal/tui/bootstrap/runtime.go b/internal/tui/bootstrap/runtime.go new file mode 100644 index 00000000..beed9c46 --- /dev/null +++ b/internal/tui/bootstrap/runtime.go @@ -0,0 +1,21 @@ +package bootstrap + +import ( + "go-llm-demo/internal/tui/core" + "go-llm-demo/internal/tui/services" + + tea "github.com/charmbracelet/bubbletea" +) + +func NewProgram(persona string, historyTurns int, configPath, workspaceRoot string) (*tea.Program, error) { + client, err := services.NewLocalChatClient() + if err != nil { + return nil, err + } + + model := core.NewModel(client, persona, historyTurns, configPath, workspaceRoot) + return tea.NewProgram(model, + tea.WithAltScreen(), + tea.WithMouseCellMotion(), + ), nil +} diff --git a/internal/tui/bootstrap/runtime_test.go b/internal/tui/bootstrap/runtime_test.go new file mode 100644 index 00000000..44fe795e --- /dev/null +++ b/internal/tui/bootstrap/runtime_test.go @@ -0,0 +1,38 @@ +package bootstrap + +import ( + "path/filepath" + "testing" + + "go-llm-demo/configs" +) + +func TestNewProgramReturnsErrorWhenGlobalConfigMissing(t *testing.T) { + origGlobalConfig := configs.GlobalAppConfig + t.Cleanup(func() { configs.GlobalAppConfig = origGlobalConfig }) + + configs.GlobalAppConfig = nil + + p, err := NewProgram("persona", 4, "config.yaml", "D:/neo-code") + if err == nil { + t.Fatalf("expected error, got program %+v", p) + } +} + +func TestNewProgramBuildsBubbleTeaProgram(t *testing.T) { + origGlobalConfig := configs.GlobalAppConfig + t.Cleanup(func() { configs.GlobalAppConfig = origGlobalConfig }) + + cfg := configs.DefaultAppConfig() + cfg.Memory.StoragePath = filepath.Join(t.TempDir(), "memory.json") + configs.GlobalAppConfig = cfg + t.Setenv(cfg.APIKeyEnvVarName(), "secret") + + p, err := NewProgram("persona", 4, "config.yaml", "D:/neo-code") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if p == nil { + t.Fatal("expected non-nil program") + } +} diff --git a/internal/tui/bootstrap/setup.go b/internal/tui/bootstrap/setup.go new file mode 100644 index 00000000..994eafc0 --- /dev/null +++ b/internal/tui/bootstrap/setup.go @@ -0,0 +1,224 @@ +package bootstrap + +import ( + "bufio" + "context" + "errors" + "fmt" + "strings" + + "go-llm-demo/configs" + "go-llm-demo/internal/tui/services" +) + +type setupDecision int + +const ( + setupRetry setupDecision = iota + setupContinue + setupExit +) + +var ( + resolveWorkspaceRoot = services.ResolveWorkspaceRoot + setWorkspaceRoot = services.SetWorkspaceRoot + ensureConfigFile = configs.EnsureConfigFile + validateChatAPIKey = services.ValidateChatAPIKey + writeAppConfig = configs.WriteAppConfig +) + +func PrepareWorkspace(workspaceFlag string) (string, error) { + workspaceRoot, err := resolveWorkspaceRoot(workspaceFlag) + if err != nil { + return "", err + } + if err := setWorkspaceRoot(workspaceRoot); err != nil { + return "", err + } + return workspaceRoot, nil +} + +func EnsureAPIKeyInteractive(ctx context.Context, scanner *bufio.Scanner, configPath string) (bool, error) { + cfg, created, err := ensureConfigFile(configPath) + if err != nil { + return false, err + } + if created { + fmt.Printf("已创建 %s\n", configPath) + } + + for { + if cfg.RuntimeAPIKey() == "" { + envName := cfg.APIKeyEnvVarName() + fmt.Printf("未检测到环境变量 %s。可使用 /apikey 、/provider 、/switch 切换配置,或先设置该环境变量后再 /retry。\n", envName) + fmt.Printf("Windows 示例: setx %s \"your-api-key\"\n", envName) + result, handleErr := handleSetupDecision(scanner, cfg, false, configPath) + if handleErr != nil { + return false, handleErr + } + if result == setupExit { + return false, nil + } + continue + } + + if err := validateChatAPIKey(ctx, cfg); err == nil { + if saveErr := writeAppConfig(configPath, cfg); saveErr != nil { + return false, saveErr + } + configs.GlobalAppConfig = cfg + fmt.Println("API key 验证通过。") + return true, nil + } else if errors.Is(err, services.ErrInvalidAPIKey) { + fmt.Printf("环境变量 %s 中的 API key 无效: %v\n", cfg.APIKeyEnvVarName(), err) + result, handleErr := handleSetupDecision(scanner, cfg, false, configPath) + if handleErr != nil { + return false, handleErr + } + if result == setupExit { + return false, nil + } + continue + } else if errors.Is(err, services.ErrAPIKeyValidationSoft) { + fmt.Printf("无法确认环境变量 %s 中的 API key 有效性: %v\n", cfg.APIKeyEnvVarName(), err) + result, handleErr := handleSetupDecision(scanner, cfg, true, configPath) + if handleErr != nil { + return false, handleErr + } + if result == setupExit { + return false, nil + } + if result == setupContinue { + configs.GlobalAppConfig = cfg + return true, nil + } + continue + } else { + fmt.Printf("模型验证失败: %v\n", err) + result, handleErr := handleSetupDecision(scanner, cfg, false, configPath) + if handleErr != nil { + return false, handleErr + } + if result == setupExit { + return false, nil + } + if result == setupContinue { + configs.GlobalAppConfig = cfg + return true, nil + } + } + } +} + +func handleSetupDecision(scanner *bufio.Scanner, cfg *configs.AppConfiguration, allowContinue bool, configPath string) (setupDecision, error) { + for { + prompt := "选择 /retry, /apikey , /provider , /switch , 或 /exit > " + if allowContinue { + prompt = "选择 /retry, /continue, /apikey , /provider , /switch , 或 /exit > " + } + decision, ok, inputErr := readInteractiveLine(scanner, prompt) + if inputErr != nil { + return setupExit, inputErr + } + if !ok { + return setupExit, nil + } + + fields := strings.Fields(strings.TrimSpace(decision)) + if len(fields) == 0 { + continue + } + + switch strings.ToLower(fields[0]) { + case "/retry": + return setupRetry, nil + case "/apikey": + if len(fields) < 2 { + fmt.Println("用法: /apikey ") + continue + } + applyAPIKeyEnvName(cfg, fields[1]) + fmt.Printf("已切换 API Key 环境变量名为: %s\n", cfg.APIKeyEnvVarName()) + return setupRetry, nil + case "/continue": + if !allowContinue { + fmt.Println("/continue 仅在网络或服务问题导致无法确认时可用。") + continue + } + if saveErr := writeAppConfig(configPath, cfg); saveErr != nil { + return setupExit, saveErr + } + fmt.Println("继续启动,使用当前 API key 和模型。") + return setupContinue, nil + case "/provider": + if len(fields) < 2 { + fmt.Println("用法: /provider ") + printSupportedProviders() + continue + } + providerName, ok := services.NormalizeProviderName(fields[1]) + if !ok { + fmt.Printf("不支持的提供商 %q\n", fields[1]) + printSupportedProviders() + continue + } + cfg.AI.Provider = providerName + cfg.AI.Model = services.DefaultModelForProvider(providerName) + fmt.Printf("已切换到提供商: %s\n", providerName) + fmt.Printf("当前模型已重置为默认值: %s\n", cfg.AI.Model) + return setupRetry, nil + case "/switch": + if len(fields) < 2 { + fmt.Println("用法: /switch ") + continue + } + target := strings.Join(fields[1:], " ") + cfg.AI.Model = target + fmt.Printf("已切换到模型: %s\n", target) + return setupRetry, nil + case "/exit": + return setupExit, nil + default: + if allowContinue { + fmt.Println("请输入 /retry, /continue, /apikey , /provider , /switch , 或 /exit。") + } else { + fmt.Println("请输入 /retry, /apikey , /provider , /switch , 或 /exit。") + } + } + } +} + +func applyAPIKeyEnvName(cfg *configs.AppConfiguration, envName string) { + if cfg == nil { + return + } + cfg.AI.APIKey = strings.TrimSpace(envName) +} + +func readInteractiveLine(scanner *bufio.Scanner, prompt string) (string, bool, error) { + for { + fmt.Print(prompt) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return "", false, err + } + return "", false, nil + } + input := strings.TrimSpace(scanner.Text()) + if input == "" { + fmt.Println("输入不能为空。") + continue + } + if input == "/exit" { + return "", false, nil + } + return input, true, nil + } +} + +func printSupportedProviders() { + fmt.Println("可用提供商:") + for _, name := range services.SupportedProviders() { + fmt.Printf(" %s\n", name) + } +} diff --git a/internal/tui/bootstrap/setup_test.go b/internal/tui/bootstrap/setup_test.go new file mode 100644 index 00000000..a827fca6 --- /dev/null +++ b/internal/tui/bootstrap/setup_test.go @@ -0,0 +1,503 @@ +package bootstrap + +import ( + "bufio" + "context" + "errors" + "io" + "strings" + "testing" + + "go-llm-demo/configs" + "go-llm-demo/internal/tui/services" +) + +type errReader struct{} + +func (errReader) Read([]byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + +func restoreSetupGlobals(t *testing.T) { + t.Helper() + + origResolveWorkspaceRoot := resolveWorkspaceRoot + origSetWorkspaceRoot := setWorkspaceRoot + origEnsureConfigFile := ensureConfigFile + origValidateChatAPIKey := validateChatAPIKey + origWriteAppConfig := writeAppConfig + origGlobalConfig := configs.GlobalAppConfig + + t.Cleanup(func() { + resolveWorkspaceRoot = origResolveWorkspaceRoot + setWorkspaceRoot = origSetWorkspaceRoot + ensureConfigFile = origEnsureConfigFile + validateChatAPIKey = origValidateChatAPIKey + writeAppConfig = origWriteAppConfig + configs.GlobalAppConfig = origGlobalConfig + }) +} + +func TestApplyAPIKeyEnvNameUpdatesConfig(t *testing.T) { + cfg := configs.DefaultAppConfig() + applyAPIKeyEnvName(cfg, " TEST_KEY_ENV ") + + if got := cfg.AI.APIKey; got != "TEST_KEY_ENV" { + t.Fatalf("expected API key env name to be trimmed, got %q", got) + } +} + +func TestReadInteractiveLineRejectsEmptyInputThenReadsValue(t *testing.T) { + scanner := bufio.NewScanner(strings.NewReader("\n /retry \n")) + + got, ok, err := readInteractiveLine(scanner, "> ") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !ok { + t.Fatal("expected ok=true") + } + if got != "/retry" { + t.Fatalf("expected trimmed input, got %q", got) + } +} + +func TestReadInteractiveLineTreatsExitAsStop(t *testing.T) { + scanner := bufio.NewScanner(strings.NewReader("/exit\n")) + + got, ok, err := readInteractiveLine(scanner, "> ") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if ok { + t.Fatal("expected ok=false for /exit") + } + if got != "" { + t.Fatalf("expected empty value, got %q", got) + } +} + +func TestHandleSetupDecisionHandlesProviderSwitch(t *testing.T) { + cfg := configs.DefaultAppConfig() + scanner := bufio.NewScanner(strings.NewReader("/provider openai\n")) + + decision, err := handleSetupDecision(scanner, cfg, false, "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if decision != setupRetry { + t.Fatalf("expected setupRetry, got %v", decision) + } + if cfg.AI.Provider != "openai" { + t.Fatalf("expected provider to switch, got %q", cfg.AI.Provider) + } + if cfg.AI.Model == "" { + t.Fatal("expected provider switch to set a default model") + } +} + +func TestHandleSetupDecisionRejectsContinueWhenNotAllowed(t *testing.T) { + cfg := configs.DefaultAppConfig() + scanner := bufio.NewScanner(strings.NewReader("/continue\n/retry\n")) + + decision, err := handleSetupDecision(scanner, cfg, false, "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if decision != setupRetry { + t.Fatalf("expected setupRetry after rejecting continue, got %v", decision) + } +} + +func TestPrepareWorkspaceResolvesAndSetsWorkspaceRoot(t *testing.T) { + restoreSetupGlobals(t) + + var setRoot string + resolveWorkspaceRoot = func(workspaceFlag string) (string, error) { + if workspaceFlag != "./workspace" { + t.Fatalf("expected workspace flag to flow through, got %q", workspaceFlag) + } + return "D:/neo-code/workspace", nil + } + setWorkspaceRoot = func(root string) error { + setRoot = root + return nil + } + + root, err := PrepareWorkspace("./workspace") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if root != "D:/neo-code/workspace" { + t.Fatalf("unexpected workspace root %q", root) + } + if setRoot != root { + t.Fatalf("expected SetWorkspaceRoot to receive %q, got %q", root, setRoot) + } +} + +func TestPrepareWorkspaceReturnsSetWorkspaceRootError(t *testing.T) { + restoreSetupGlobals(t) + + resolveWorkspaceRoot = func(string) (string, error) { return "D:/neo-code/workspace", nil } + setWorkspaceRoot = func(string) error { return errors.New("set failed") } + + _, err := PrepareWorkspace("./workspace") + if err == nil || !strings.Contains(err.Error(), "set failed") { + t.Fatalf("expected SetWorkspaceRoot error, got %v", err) + } +} + +func TestPrepareWorkspaceReturnsResolveError(t *testing.T) { + restoreSetupGlobals(t) + + resolveWorkspaceRoot = func(string) (string, error) { return "", errors.New("resolve failed") } + + _, err := PrepareWorkspace("./workspace") + if err == nil || !strings.Contains(err.Error(), "resolve failed") { + t.Fatalf("expected resolve error, got %v", err) + } +} + +func TestReadInteractiveLineReturnsEOF(t *testing.T) { + scanner := bufio.NewScanner(strings.NewReader("")) + + got, ok, err := readInteractiveLine(scanner, "> ") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if ok || got != "" { + t.Fatalf("expected EOF stop, got value=%q ok=%v", got, ok) + } +} + +func TestReadInteractiveLineReturnsScannerError(t *testing.T) { + scanner := bufio.NewScanner(errReader{}) + + _, _, err := readInteractiveLine(scanner, "> ") + if err == nil { + t.Fatal("expected scanner error") + } +} + +func TestHandleSetupDecisionAPIKeyRequiresArgument(t *testing.T) { + cfg := configs.DefaultAppConfig() + scanner := bufio.NewScanner(strings.NewReader("/apikey\n/apikey TEST_ENV\n")) + + decision, err := handleSetupDecision(scanner, cfg, false, "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if decision != setupRetry { + t.Fatalf("expected setupRetry, got %v", decision) + } + if cfg.AI.APIKey != "TEST_ENV" { + t.Fatalf("expected API key env to switch, got %q", cfg.AI.APIKey) + } +} + +func TestHandleSetupDecisionAllowsContinue(t *testing.T) { + restoreSetupGlobals(t) + + cfg := configs.DefaultAppConfig() + writeCalled := false + writeAppConfig = func(string, *configs.AppConfiguration) error { + writeCalled = true + return nil + } + + decision, err := handleSetupDecision(bufio.NewScanner(strings.NewReader("/continue\n")), cfg, true, "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if decision != setupContinue { + t.Fatalf("expected setupContinue, got %v", decision) + } + if !writeCalled { + t.Fatal("expected config write on continue") + } +} + +func TestHandleSetupDecisionContinueWriteFailure(t *testing.T) { + restoreSetupGlobals(t) + + cfg := configs.DefaultAppConfig() + writeAppConfig = func(string, *configs.AppConfiguration) error { return errors.New("write failed") } + + decision, err := handleSetupDecision(bufio.NewScanner(strings.NewReader("/continue\n")), cfg, true, "config.yaml") + if err == nil || !strings.Contains(err.Error(), "write failed") { + t.Fatalf("expected write failure, got decision=%v err=%v", decision, err) + } +} + +func TestHandleSetupDecisionProviderRequiresArgumentAndRejectsUnknownProvider(t *testing.T) { + cfg := configs.DefaultAppConfig() + scanner := bufio.NewScanner(strings.NewReader("/provider\n/provider invalid\n/provider openai\n")) + + decision, err := handleSetupDecision(scanner, cfg, false, "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if decision != setupRetry { + t.Fatalf("expected setupRetry, got %v", decision) + } + if cfg.AI.Provider != "openai" { + t.Fatalf("expected provider to switch after retries, got %q", cfg.AI.Provider) + } +} + +func TestHandleSetupDecisionSwitchRequiresArgumentThenSucceeds(t *testing.T) { + cfg := configs.DefaultAppConfig() + scanner := bufio.NewScanner(strings.NewReader("/switch\n/switch gpt-5.4-mini\n")) + + decision, err := handleSetupDecision(scanner, cfg, false, "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if decision != setupRetry { + t.Fatalf("expected setupRetry, got %v", decision) + } + if cfg.AI.Model != "gpt-5.4-mini" { + t.Fatalf("expected model switch, got %q", cfg.AI.Model) + } +} + +func TestHandleSetupDecisionUnknownCommandThenExit(t *testing.T) { + cfg := configs.DefaultAppConfig() + scanner := bufio.NewScanner(strings.NewReader("/unknown\n/exit\n")) + + decision, err := handleSetupDecision(scanner, cfg, false, "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if decision != setupExit { + t.Fatalf("expected setupExit, got %v", decision) + } +} + +func TestEnsureAPIKeyInteractiveReturnsConfigError(t *testing.T) { + restoreSetupGlobals(t) + + ensureConfigFile = func(string) (*configs.AppConfiguration, bool, error) { + return nil, false, errors.New("config failed") + } + + ready, err := EnsureAPIKeyInteractive(context.Background(), bufio.NewScanner(strings.NewReader("")), "config.yaml") + if err == nil || !strings.Contains(err.Error(), "config failed") { + t.Fatalf("expected config error, got ready=%v err=%v", ready, err) + } +} + +func TestEnsureAPIKeyInteractiveExitsWhenAPIKeyMissing(t *testing.T) { + restoreSetupGlobals(t) + + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "MISSING_ENV" + ensureConfigFile = func(string) (*configs.AppConfiguration, bool, error) { + return cfg, false, nil + } + validateCalled := false + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { + validateCalled = true + return nil + } + + ready, err := EnsureAPIKeyInteractive(context.Background(), bufio.NewScanner(strings.NewReader("/exit\n")), "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if ready { + t.Fatal("expected setup to stop without becoming ready") + } + if validateCalled { + t.Fatal("validation should not run when runtime API key is missing") + } +} + +func TestEnsureAPIKeyInteractiveWritesConfigAfterSuccessfulValidation(t *testing.T) { + restoreSetupGlobals(t) + + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "READY_ENV" + t.Setenv("READY_ENV", "secret") + + ensureConfigFile = func(string) (*configs.AppConfiguration, bool, error) { + return cfg, false, nil + } + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { return nil } + var writePath string + writeAppConfig = func(path string, gotCfg *configs.AppConfiguration) error { + writePath = path + if gotCfg != cfg { + t.Fatal("expected the same config instance to be written") + } + return nil + } + + ready, err := EnsureAPIKeyInteractive(context.Background(), bufio.NewScanner(strings.NewReader("")), "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !ready { + t.Fatal("expected setup to become ready") + } + if writePath != "config.yaml" { + t.Fatalf("expected config write path config.yaml, got %q", writePath) + } + if configs.GlobalAppConfig != cfg { + t.Fatal("expected global config to be updated after successful validation") + } +} + +func TestEnsureAPIKeyInteractiveAllowsContinueOnSoftValidationError(t *testing.T) { + restoreSetupGlobals(t) + + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "READY_ENV" + t.Setenv("READY_ENV", "secret") + + ensureConfigFile = func(string) (*configs.AppConfiguration, bool, error) { + return cfg, false, nil + } + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { + return services.ErrAPIKeyValidationSoft + } + writeCount := 0 + writeAppConfig = func(string, *configs.AppConfiguration) error { + writeCount++ + return nil + } + + ready, err := EnsureAPIKeyInteractive(context.Background(), bufio.NewScanner(strings.NewReader("/continue\n")), "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !ready { + t.Fatal("expected continue to allow startup") + } + if writeCount != 1 { + t.Fatalf("expected config to be written once, got %d", writeCount) + } + if configs.GlobalAppConfig != cfg { + t.Fatal("expected global config to be updated on continue") + } +} + +func TestEnsureAPIKeyInteractiveReportsCreatedConfigThenSucceeds(t *testing.T) { + restoreSetupGlobals(t) + + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "READY_ENV" + t.Setenv("READY_ENV", "secret") + + ensureConfigFile = func(string) (*configs.AppConfiguration, bool, error) { + return cfg, true, nil + } + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { return nil } + writeAppConfig = func(string, *configs.AppConfiguration) error { return nil } + + ready, err := EnsureAPIKeyInteractive(context.Background(), bufio.NewScanner(strings.NewReader("")), "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !ready { + t.Fatal("expected setup to become ready") + } +} + +func TestEnsureAPIKeyInteractiveRetriesAfterChangingAPIKeyEnv(t *testing.T) { + restoreSetupGlobals(t) + + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "MISSING_ENV" + t.Setenv("RECOVERED_ENV", "secret") + + ensureConfigFile = func(string) (*configs.AppConfiguration, bool, error) { + return cfg, false, nil + } + validateCount := 0 + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { + validateCount++ + return nil + } + writeAppConfig = func(string, *configs.AppConfiguration) error { return nil } + + ready, err := EnsureAPIKeyInteractive(context.Background(), bufio.NewScanner(strings.NewReader("/apikey RECOVERED_ENV\n/retry\n")), "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !ready { + t.Fatal("expected setup to become ready after retry") + } + if cfg.AI.APIKey != "RECOVERED_ENV" { + t.Fatalf("expected env name update, got %q", cfg.AI.APIKey) + } + if validateCount != 1 { + t.Fatalf("expected one validation call after retry, got %d", validateCount) + } +} + +func TestEnsureAPIKeyInteractiveHandlesInvalidAPIKeyAndExit(t *testing.T) { + restoreSetupGlobals(t) + + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "READY_ENV" + t.Setenv("READY_ENV", "secret") + + ensureConfigFile = func(string) (*configs.AppConfiguration, bool, error) { + return cfg, false, nil + } + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { + return services.ErrInvalidAPIKey + } + + ready, err := EnsureAPIKeyInteractive(context.Background(), bufio.NewScanner(strings.NewReader("/exit\n")), "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if ready { + t.Fatal("expected setup to stop on invalid key + exit") + } +} + +func TestEnsureAPIKeyInteractiveHandlesGenericValidationErrorAndExit(t *testing.T) { + restoreSetupGlobals(t) + + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "READY_ENV" + t.Setenv("READY_ENV", "secret") + + ensureConfigFile = func(string) (*configs.AppConfiguration, bool, error) { + return cfg, false, nil + } + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { + return errors.New("validation failed") + } + + ready, err := EnsureAPIKeyInteractive(context.Background(), bufio.NewScanner(strings.NewReader("/exit\n")), "config.yaml") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if ready { + t.Fatal("expected setup to stop on generic validation failure") + } +} + +func TestEnsureAPIKeyInteractiveReturnsWriteErrorAfterValidationSuccess(t *testing.T) { + restoreSetupGlobals(t) + + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "READY_ENV" + t.Setenv("READY_ENV", "secret") + + ensureConfigFile = func(string) (*configs.AppConfiguration, bool, error) { + return cfg, false, nil + } + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { return nil } + writeAppConfig = func(string, *configs.AppConfiguration) error { return errors.New("write failed") } + + ready, err := EnsureAPIKeyInteractive(context.Background(), bufio.NewScanner(strings.NewReader("")), "config.yaml") + if err == nil || !strings.Contains(err.Error(), "write failed") { + t.Fatalf("expected write error, got ready=%v err=%v", ready, err) + } +} diff --git a/internal/tui/components/help.go b/internal/tui/components/help.go new file mode 100644 index 00000000..c5fc1159 --- /dev/null +++ b/internal/tui/components/help.go @@ -0,0 +1,67 @@ +package components + +import ( + "strings" + + "github.com/charmbracelet/lipgloss" +) + +func RenderHelp(width int) string { + var b strings.Builder + + title := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#61AFEF")). + Bold(true). + Render("NeoCode 帮助") + + b.WriteString(title) + b.WriteString("\n\n") + + commands := []struct { + cmd string + desc string + }{ + {"/help", "显示帮助"}, + {"/pwd | /workspace", "显示当前工作区目录"}, + {"/apikey ", "切换 API Key 变量名"}, + {"/provider ", "切换模型提供商"}, + {"/switch ", "切换模型"}, + {"/run ", "执行代码"}, + {"/explain ", "解释代码"}, + {"/memory", "显示记忆统计"}, + {"/clear-memory confirm", "清空长期记忆"}, + {"/clear-context", "清空会话上下文"}, + {"/exit", "退出程序"}, + } + + cmdStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#98C379")). + Width(22) + + descStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#ABB2BF")) + + dimStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#5C6370")) + + helpStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#61AFEF")) + + for _, c := range commands { + b.WriteString(cmdStyle.Render(c.cmd)) + b.WriteString(descStyle.Render(c.desc)) + b.WriteString("\n") + } + + b.WriteString("\n") + b.WriteString(helpStyle.Render("输入框支持光标、粘贴、滚动,F5/F8 发送")) + b.WriteString("\n") + b.WriteString(helpStyle.Render("聊天区支持 PgUp/PgDn 和鼠标滚轮")) + b.WriteString("\n") + b.WriteString(helpStyle.Render("取消: Ctrl+C")) + + b.WriteString("\n\n") + b.WriteString(dimStyle.Render("按 Esc 或 /help 关闭")) + + return lipgloss.NewStyle().MaxWidth(width).Render(b.String()) +} diff --git a/internal/tui/components/input.go b/internal/tui/components/input.go deleted file mode 100644 index a14db52a..00000000 --- a/internal/tui/components/input.go +++ /dev/null @@ -1,69 +0,0 @@ -package components - -import ( - "fmt" - "strings" - - "github.com/charmbracelet/lipgloss" -) - -type Input struct { - Buffer string - Multiline bool - CursorLine int - CursorCol int -} - -func (i Input) Render() string { - cleanBuffer := strings.ReplaceAll(i.Buffer, "\r", "") - cleanBuffer = strings.ReplaceAll(cleanBuffer, "\t", " ") - lines := strings.Split(cleanBuffer, "\n") - lang := DetectLanguage(i.Buffer) - - var b strings.Builder - - for idx, line := range lines { - lineNumStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("#5C6370")) - lineNum := lineNumStyle.Render(fmt.Sprintf(" %2d: ", idx+1)) - b.WriteString(lineNum) - - runes := []rune(line) - if i.Multiline && idx == i.CursorLine && i.CursorCol <= len(runes) { - cursorStyle := lipgloss.NewStyle(). - Background(lipgloss.Color("#3E4451")). - Foreground(lipgloss.Color("#ABB2BF")) - - var before, after string - var char string - if i.CursorCol < len(runes) { - char = string(runes[i.CursorCol]) - before = string(runes[:i.CursorCol]) - after = string(runes[i.CursorCol+1:]) - } else { - before = string(runes) - char = " " - } - - if before != "" { - b.WriteString(HighlightCodeInline(before, lang)) - } - b.WriteString(cursorStyle.Render(char)) - if after != "" { - b.WriteString(HighlightCodeInline(after, lang)) - } - } else { - if line != "" { - b.WriteString(HighlightCodeInline(line, lang)) - } - } - b.WriteString("\n") - } - - if i.Multiline { - b.WriteString("[方向键移动 Enter换行 F5/F8发送 Del删除]") - } else { - b.WriteString("[Enter换行 F5/F8发送]") - } - - return b.String() -} diff --git a/internal/tui/components/input_box.go b/internal/tui/components/input_box.go new file mode 100644 index 00000000..f8016f48 --- /dev/null +++ b/internal/tui/components/input_box.go @@ -0,0 +1,21 @@ +package components + +import "github.com/charmbracelet/lipgloss" + +type InputBox struct { + Body string + Generating bool +} + +func (i InputBox) Render() string { + helpText := "[Enter换行 F5/F8发送 PgUp/PgDn滚动]" + if !i.Generating { + helpText = "[Enter换行 F5/F8发送 Ctrl+V粘贴 PgUp/PgDn滚动]" + } + + footer := lipgloss.NewStyle(). + Foreground(lipgloss.Color("#5C6370")). + Render(helpText) + + return i.Body + "\n" + footer +} diff --git a/internal/tui/components/layout_helpers_test.go b/internal/tui/components/layout_helpers_test.go new file mode 100644 index 00000000..21a8fe78 --- /dev/null +++ b/internal/tui/components/layout_helpers_test.go @@ -0,0 +1,55 @@ +package components + +import ( + "strings" + "testing" + "time" +) + +func TestRenderHelpContainsKeyCommands(t *testing.T) { + rendered := RenderHelp(80) + + for _, want := range []string{"NeoCode 帮助", "/help", "/provider ", "按 Esc 或 /help 关闭"} { + if !strings.Contains(rendered, want) { + t.Fatalf("expected help to contain %q, got %q", want, rendered) + } + } +} + +func TestInputBoxRenderChangesFooterByGeneratingState(t *testing.T) { + idle := InputBox{Body: "body", Generating: false}.Render() + if !strings.Contains(idle, "Ctrl+V粘贴") { + t.Fatalf("expected idle footer to mention paste, got %q", idle) + } + + busy := InputBox{Body: "body", Generating: true}.Render() + if strings.Contains(busy, "Ctrl+V粘贴") { + t.Fatalf("expected generating footer to omit paste hint, got %q", busy) + } + if !strings.Contains(busy, "F5/F8发送") { + t.Fatalf("expected busy footer to keep send hint, got %q", busy) + } +} + +func TestMessageListRenderIncludesRoleSpecificLabels(t *testing.T) { + rendered := MessageList{ + Width: 60, + Messages: []Message{ + {Role: "user", Content: "hello", Timestamp: time.Unix(1, 0)}, + {Role: "assistant", Content: "world", Timestamp: time.Unix(2, 0)}, + {Role: "system", Content: "note", Timestamp: time.Unix(3, 0)}, + }, + }.Render() + + for _, want := range []string{"你 [1]:", "Neo [2]:", "[系统]", "hello", "world", "note"} { + if !strings.Contains(rendered, want) { + t.Fatalf("expected rendered list to contain %q, got %q", want, rendered) + } + } +} + +func TestMessageListRenderReturnsEmptyForNoMessages(t *testing.T) { + if got := (MessageList{Width: 40}).Render(); got != "" { + t.Fatalf("expected empty render, got %q", got) + } +} diff --git a/internal/tui/core/model.go b/internal/tui/core/model.go index b0ce2951..88d54e60 100644 --- a/internal/tui/core/model.go +++ b/internal/tui/core/model.go @@ -6,7 +6,8 @@ import ( "time" "go-llm-demo/configs" - "go-llm-demo/internal/tui/infra" + "go-llm-demo/internal/tui/services" + "go-llm-demo/internal/tui/state" "github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/textarea" @@ -15,62 +16,26 @@ import ( "github.com/charmbracelet/lipgloss" ) -type Mode int - -const ( - ModeChat Mode = iota - ModeCodeInput - ModeHelp - ModeMemory -) - type Model struct { - width int - height int - mode Mode - focused string - - messages []Message - historyTurns int - - generating bool - activeModel string - - memoryStats infra.MemoryStats + ui state.UIState + chat state.ChatState - commandHistory []string - cmdHistIndex int - - client infra.ChatClient + client services.ChatClient persona string - workspaceRoot string - - toolExecuting bool - apiKeyReady bool - configPath string - streamChan <-chan string textarea textarea.Model viewport viewport.Model - autoScroll bool mu *sync.Mutex } -type Message struct { - Role string - Content string - Timestamp time.Time - Streaming bool -} - // NewModel 创建 TUI 状态模型。 // historyTurns 用于限制发送给后端的短期对话轮数,避免原始消息无限增长。 -func NewModel(client infra.ChatClient, persona string, historyTurns int, configPath, workspaceRoot string) Model { +func NewModel(client services.ChatClient, persona string, historyTurns int, configPath, workspaceRoot string) Model { stats, _ := client.GetMemoryStats(context.Background()) if stats == nil { - stats = &infra.MemoryStats{} + stats = &services.MemoryStats{} } if historyTurns <= 0 { historyTurns = 6 @@ -99,23 +64,27 @@ func NewModel(client infra.ChatClient, persona string, historyTurns int, configP vp.SetContent("") return Model{ - mode: ModeChat, - focused: "input", - messages: make([]Message, 0), - historyTurns: historyTurns, - activeModel: client.DefaultModel(), - memoryStats: *stats, - commandHistory: make([]string, 0), - cmdHistIndex: -1, - client: client, - persona: persona, - workspaceRoot: workspaceRoot, - apiKeyReady: configs.RuntimeAPIKey() != "", - configPath: configPath, - textarea: input, - viewport: vp, - autoScroll: true, - mu: &sync.Mutex{}, + ui: state.UIState{ + Mode: state.ModeChat, + Focused: "input", + AutoScroll: true, + }, + chat: state.ChatState{ + Messages: make([]state.Message, 0), + HistoryTurns: historyTurns, + ActiveModel: client.DefaultModel(), + MemoryStats: *stats, + CommandHistory: make([]string, 0), + CmdHistIndex: -1, + WorkspaceRoot: workspaceRoot, + APIKeyReady: configs.RuntimeAPIKey() != "", + ConfigPath: configPath, + }, + client: client, + persona: persona, + textarea: input, + viewport: vp, + mu: &sync.Mutex{}, } } @@ -133,12 +102,12 @@ func (m Model) Init() tea.Cmd { // SetWidth 更新当前视口宽度。 func (m *Model) SetWidth(w int) { - m.width = w + m.ui.Width = w } // SetHeight 更新当前视口高度。 func (m *Model) SetHeight(h int) { - m.height = h + m.ui.Height = h } // AddMessage 向聊天历史追加一条带时间戳的消息。 @@ -146,7 +115,7 @@ func (m *Model) AddMessage(role, content string) { mu := m.mutex() mu.Lock() defer mu.Unlock() - m.messages = append(m.messages, Message{ + m.chat.Messages = append(m.chat.Messages, state.Message{ Role: role, Content: content, Timestamp: time.Now(), @@ -158,8 +127,8 @@ func (m *Model) AppendLastMessage(content string) { mu := m.mutex() mu.Lock() defer mu.Unlock() - if len(m.messages) > 0 { - m.messages[len(m.messages)-1].Content += content + if len(m.chat.Messages) > 0 { + m.chat.Messages[len(m.chat.Messages)-1].Content += content } } @@ -168,8 +137,8 @@ func (m *Model) FinishLastMessage() { mu := m.mutex() mu.Lock() defer mu.Unlock() - if len(m.messages) > 0 { - m.messages[len(m.messages)-1].Streaming = false + if len(m.chat.Messages) > 0 { + m.chat.Messages[len(m.chat.Messages)-1].Streaming = false } } @@ -178,14 +147,14 @@ func (m *Model) TrimHistory(maxTurns int) { mu := m.mutex() mu.Lock() defer mu.Unlock() - if len(m.messages) <= maxTurns*2 { + if len(m.chat.Messages) <= maxTurns*2 { return } - var system []Message - var others []Message + var system []state.Message + var others []state.Message - for _, msg := range m.messages { + for _, msg := range m.chat.Messages { if msg.Role == "system" { system = append(system, msg) } else { @@ -197,5 +166,5 @@ func (m *Model) TrimHistory(maxTurns int) { others = others[len(others)-maxTurns*2:] } - m.messages = append(system, others...) + m.chat.Messages = append(system, others...) } diff --git a/internal/tui/core/model_test.go b/internal/tui/core/model_test.go new file mode 100644 index 00000000..f5550964 --- /dev/null +++ b/internal/tui/core/model_test.go @@ -0,0 +1,91 @@ +package core + +import ( + "testing" + + "go-llm-demo/configs" + "go-llm-demo/internal/tui/state" +) + +func TestNewModelAppliesDefaultsAndRuntimeFlags(t *testing.T) { + restoreCoreGlobals(t) + + client := &fakeChatClient{defaultModelName: "demo-model"} + t.Setenv(configs.DefaultAPIKeyEnvVar, "secret") + configs.GlobalAppConfig = nil + + m := NewModel(client, "persona", 0, "config.yaml", "D:/neo-code") + + if m.chat.HistoryTurns != 6 { + t.Fatalf("expected default history turns 6, got %d", m.chat.HistoryTurns) + } + if m.chat.ActiveModel != "demo-model" { + t.Fatalf("expected default model from client, got %q", m.chat.ActiveModel) + } + if !m.chat.APIKeyReady { + t.Fatal("expected API key readiness to reflect runtime env var") + } + if m.chat.WorkspaceRoot != "D:/neo-code" { + t.Fatalf("expected workspace root to be stored, got %q", m.chat.WorkspaceRoot) + } +} + +func TestNewModelUsesEmptyStatsWhenClientReturnsNil(t *testing.T) { + restoreCoreGlobals(t) + + client := &fakeChatClient{nilMemoryStats: true} + + m := NewModel(client, "persona", 4, "config.yaml", "D:/neo-code") + if m.chat.MemoryStats.TotalItems != 0 { + t.Fatalf("expected zero-value stats, got %+v", m.chat.MemoryStats) + } +} + +func TestAppendAndFinishLastMessage(t *testing.T) { + m := Model{} + m.chat.Messages = []state.Message{{Role: "assistant", Content: "hello", Streaming: true}} + + m.AppendLastMessage(" world") + m.FinishLastMessage() + + if m.chat.Messages[0].Content != "hello world" { + t.Fatalf("expected appended content, got %q", m.chat.Messages[0].Content) + } + if m.chat.Messages[0].Streaming { + t.Fatal("expected last message streaming to be cleared") + } +} + +func TestInitReturnsNonNilCmd(t *testing.T) { + restoreCoreGlobals(t) + + m := NewModel(&fakeChatClient{}, "persona", 4, "config.yaml", "D:/neo-code") + if cmd := m.Init(); cmd == nil { + t.Fatal("expected non-nil init cmd") + } +} + +func TestTrimHistoryKeepsSystemMessagesAndLatestTurns(t *testing.T) { + m := Model{} + m.chat.Messages = []state.Message{ + {Role: "system", Content: "persona"}, + {Role: "user", Content: "u1"}, + {Role: "assistant", Content: "a1"}, + {Role: "user", Content: "u2"}, + {Role: "assistant", Content: "a2"}, + {Role: "user", Content: "u3"}, + {Role: "assistant", Content: "a3"}, + } + + m.TrimHistory(2) + + if len(m.chat.Messages) != 5 { + t.Fatalf("expected system message plus last two turns, got %d messages", len(m.chat.Messages)) + } + if m.chat.Messages[0].Role != "system" || m.chat.Messages[0].Content != "persona" { + t.Fatalf("expected system message to be preserved, got %+v", m.chat.Messages[0]) + } + if m.chat.Messages[1].Content != "u2" || m.chat.Messages[4].Content != "a3" { + t.Fatalf("expected only latest turns to remain, got %+v", m.chat.Messages) + } +} diff --git a/internal/tui/core/msg.go b/internal/tui/core/msg.go index e4a9c560..acbf0e77 100644 --- a/internal/tui/core/msg.go +++ b/internal/tui/core/msg.go @@ -1,27 +1,6 @@ package core -import ( - tea "github.com/charmbracelet/bubbletea" - "go-llm-demo/internal/server/infra/tools" -) - -func Chunk(content string) tea.Cmd { - return func() tea.Msg { - return StreamChunkMsg{Content: content} - } -} - -func Done() tea.Cmd { - return func() tea.Msg { - return StreamDoneMsg{} - } -} - -func CmdErr(err error) tea.Cmd { - return func() tea.Msg { - return StreamErrorMsg{Err: err} - } -} +import "go-llm-demo/internal/tui/services" type StreamChunkMsg struct { Content string @@ -40,7 +19,7 @@ type StreamErrorMsg struct { func (StreamErrorMsg) isMsg() {} type ToolResultMsg struct { - Result *tools.ToolResult + Result *services.ToolResult } func (ToolResultMsg) isMsg() {} @@ -66,11 +45,3 @@ func (HideHelpMsg) isMsg() {} type RefreshMemoryMsg struct{} func (RefreshMemoryMsg) isMsg() {} - -type streamNextChunk struct { - stream <-chan string -} - -func (streamNextChunk) isMsg() {} - -var StreamDone = func() tea.Msg { return StreamDoneMsg{} } diff --git a/internal/tui/core/update.go b/internal/tui/core/update.go index cc6e05ab..7deaee7c 100644 --- a/internal/tui/core/update.go +++ b/internal/tui/core/update.go @@ -10,10 +10,8 @@ import ( "strings" "go-llm-demo/configs" - "go-llm-demo/internal/server/domain" - "go-llm-demo/internal/server/infra/provider" - "go-llm-demo/internal/server/infra/tools" - "go-llm-demo/internal/tui/infra" + "go-llm-demo/internal/tui/services" + "go-llm-demo/internal/tui/state" tea "github.com/charmbracelet/bubbletea" ) @@ -25,6 +23,13 @@ const ( maxToolContextMessages = 3 ) +var ( + validateChatAPIKey = services.ValidateChatAPIKey + writeAppConfig = configs.WriteAppConfig + getWorkspaceRoot = services.GetWorkspaceRoot + executeToolCall = services.ExecuteToolCall +) + // Update 处理 Bubble Tea 事件并驱动聊天状态更新。 func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmd tea.Cmd @@ -44,11 +49,11 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.MouseMsg: var vpCmd tea.Cmd m.viewport, vpCmd = m.viewport.Update(msg) - m.autoScroll = m.viewport.AtBottom() + m.ui.AutoScroll = m.viewport.AtBottom() return m, vpCmd case StreamChunkMsg: - if m.generating { + if m.chat.Generating { m.AppendLastMessage(msg.Content) m.refreshViewport() } @@ -57,13 +62,13 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case StreamDoneMsg: mu := m.mutex() mu.Lock() - m.generating = false + m.chat.Generating = false m.streamChan = nil var lastContent string - shouldCheckToolCall := !m.toolExecuting && len(m.messages) > 0 - if len(m.messages) > 0 { - lastMsg := &m.messages[len(m.messages)-1] + shouldCheckToolCall := !m.chat.ToolExecuting && len(m.chat.Messages) > 0 + if len(m.chat.Messages) > 0 { + lastMsg := &m.chat.Messages[len(m.chat.Messages)-1] lastMsg.Streaming = false if lastMsg.Role == "assistant" { lastContent = lastMsg.Content @@ -81,16 +86,16 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if toolName, ok := jsonData["tool"].(string); ok && toolName != "" { mu := m.mutex() mu.Lock() - if m.toolExecuting { + if m.chat.ToolExecuting { mu.Unlock() return m, nil } - m.toolExecuting = true + m.chat.ToolExecuting = true mu.Unlock() paramsMap := map[string]interface{}{} if toolParams, ok := jsonData["params"].(map[string]interface{}); ok { - paramsMap = tools.NormalizeParams(toolParams) + paramsMap = services.NormalizeToolParams(toolParams) } // 显示工具执行中提示(仅用于 UI,不参与模型上下文) @@ -98,12 +103,12 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // 在goroutine中执行工具调用 return m, func() tea.Msg { - call := domain.ToolCall{Tool: toolName, Params: paramsMap} - result := tools.GlobalRegistry.Execute(call) + call := services.ToolCall{Tool: toolName, Params: paramsMap} + result := executeToolCall(call) if result == nil { mu := m.mutex() mu.Lock() - m.toolExecuting = false + m.chat.ToolExecuting = false mu.Unlock() return ToolErrorMsg{Err: fmt.Errorf("工具执行失败: 空返回")} } @@ -119,11 +124,11 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case StreamErrorMsg: mu := m.mutex() mu.Lock() - m.generating = false + m.chat.Generating = false m.streamChan = nil replacedPlaceholder := false - if len(m.messages) > 0 { - lastMsg := &m.messages[len(m.messages)-1] + if len(m.chat.Messages) > 0 { + lastMsg := &m.chat.Messages[len(m.chat.Messages)-1] if lastMsg.Role == "assistant" && strings.TrimSpace(lastMsg.Content) == "" { lastMsg.Content = fmt.Sprintf("错误: %v", msg.Err) lastMsg.Streaming = false @@ -134,24 +139,24 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if !replacedPlaceholder { m.AddMessage("assistant", fmt.Sprintf("错误: %v", msg.Err)) } - m.TrimHistory(m.historyTurns) + m.TrimHistory(m.chat.HistoryTurns) m.refreshViewport() return m, nil case ShowHelpMsg: - m.mode = ModeHelp + m.ui.Mode = state.ModeHelp m.refreshViewport() return m, nil case HideHelpMsg: - m.mode = ModeChat + m.ui.Mode = state.ModeChat m.refreshViewport() return m, nil case RefreshMemoryMsg: stats, err := m.client.GetMemoryStats(context.Background()) if err == nil && stats != nil { - m.memoryStats = *stats + m.chat.MemoryStats = *stats } m.refreshViewport() return m, nil @@ -162,12 +167,12 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case ToolResultMsg: mu := m.mutex() mu.Lock() - m.toolExecuting = false + m.chat.ToolExecuting = false mu.Unlock() // 将结构化工具上下文添加为系统消息,然后重新获取AI响应 m.AddMessage("system", formatToolContextMessage(msg.Result)) m.AddMessage("assistant", "") - m.generating = true + m.chat.Generating = true m.refreshViewport() // 构建包含工具结果的消息并重新请求AI @@ -177,12 +182,12 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case ToolErrorMsg: mu := m.mutex() mu.Lock() - m.toolExecuting = false + m.chat.ToolExecuting = false mu.Unlock() // 将工具执行错误添加为结构化系统上下文 m.AddMessage("system", formatToolErrorContext(msg.Err)) m.AddMessage("assistant", "") - m.generating = true + m.chat.Generating = true m.refreshViewport() // 构建包含错误信息的消息并重新请求AI @@ -194,8 +199,8 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { - if msg.Type == tea.KeyEsc && m.mode == ModeHelp { - m.mode = ModeChat + if msg.Type == tea.KeyEsc && m.ui.Mode == state.ModeHelp { + m.ui.Mode = state.ModeChat m.refreshViewport() return *m, nil } @@ -208,46 +213,46 @@ func (m *Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m.handleSubmit() case tea.KeyPgUp: - m.autoScroll = false + m.ui.AutoScroll = false m.viewport.HalfViewUp() return *m, nil case tea.KeyPgDown: m.viewport.HalfViewDown() - m.autoScroll = m.viewport.AtBottom() + m.ui.AutoScroll = m.viewport.AtBottom() return *m, nil case tea.KeyUp: - if strings.TrimSpace(m.textarea.Value()) == "" && len(m.commandHistory) > 0 { - if m.cmdHistIndex < len(m.commandHistory)-1 { - m.cmdHistIndex++ + if strings.TrimSpace(m.textarea.Value()) == "" && len(m.chat.CommandHistory) > 0 { + if m.chat.CmdHistIndex < len(m.chat.CommandHistory)-1 { + m.chat.CmdHistIndex++ } - if m.cmdHistIndex >= 0 && m.cmdHistIndex < len(m.commandHistory) { - m.textarea.SetValue(m.commandHistory[len(m.commandHistory)-1-m.cmdHistIndex]) + if m.chat.CmdHistIndex >= 0 && m.chat.CmdHistIndex < len(m.chat.CommandHistory) { + m.textarea.SetValue(m.chat.CommandHistory[len(m.chat.CommandHistory)-1-m.chat.CmdHistIndex]) m.textarea.CursorEnd() return *m, nil } } case tea.KeyDown: - if m.cmdHistIndex > 0 { - m.cmdHistIndex-- - m.textarea.SetValue(m.commandHistory[len(m.commandHistory)-1-m.cmdHistIndex]) + if m.chat.CmdHistIndex > 0 { + m.chat.CmdHistIndex-- + m.textarea.SetValue(m.chat.CommandHistory[len(m.chat.CommandHistory)-1-m.chat.CmdHistIndex]) m.textarea.CursorEnd() return *m, nil } - if m.cmdHistIndex == 0 { - m.cmdHistIndex = -1 + if m.chat.CmdHistIndex == 0 { + m.chat.CmdHistIndex = -1 m.textarea.Reset() return *m, nil } } - m.cmdHistIndex = -1 + m.chat.CmdHistIndex = -1 var inputCmd tea.Cmd m.textarea, inputCmd = m.textarea.Update(msg) m.refreshViewport() if m.viewport.AtBottom() { - m.autoScroll = true + m.ui.AutoScroll = true } return *m, inputCmd } @@ -262,16 +267,16 @@ func (m *Model) handleSubmit() (tea.Model, tea.Cmd) { return *m, nil } - switch m.mode { - case ModeHelp: - m.mode = ModeChat + switch m.ui.Mode { + case state.ModeHelp: + m.ui.Mode = state.ModeChat return *m, nil } if strings.HasPrefix(input, "/") { return m.handleCommand(input) } - if !m.apiKeyReady { + if !m.chat.APIKeyReady { m.AddMessage("assistant", "当前 API Key 未通过校验,请使用 /apikey 、/provider 、/switch 调整配置,或 /exit 退出。") return *m, nil } @@ -279,13 +284,13 @@ func (m *Model) handleSubmit() (tea.Model, tea.Cmd) { m.AddMessage("user", input) m.AddMessage("assistant", "") // 在请求发出前先裁剪原始消息,避免 UI 历史无限扩张并影响短期上下文质量。 - m.TrimHistory(m.historyTurns) - m.generating = true - m.autoScroll = true + m.TrimHistory(m.chat.HistoryTurns) + m.chat.Generating = true + m.ui.AutoScroll = true m.refreshViewport() - m.commandHistory = append(m.commandHistory, input) - m.cmdHistIndex = -1 + m.chat.CommandHistory = append(m.chat.CommandHistory, input) + m.chat.CmdHistIndex = -1 messages := m.buildMessages() return *m, m.streamResponse(messages) @@ -299,14 +304,14 @@ func (m *Model) handleCommand(input string) (tea.Model, tea.Cmd) { cmd := fields[0] args := fields[1:] - if !m.apiKeyReady && !isAPIKeyRecoveryCommand(cmd) { + if !m.chat.APIKeyReady && !isAPIKeyRecoveryCommand(cmd) { m.AddMessage("assistant", "当前 API Key 未通过校验,仅支持 /apikey 、/provider 、/help、/switch 、/pwd(/workspace)或 /exit。") return *m, nil } switch cmd { case "/help": - m.mode = ModeHelp + m.ui.Mode = state.ModeHelp case "/exit", "/quit", "/q": return *m, tea.Quit case "/apikey": @@ -323,24 +328,24 @@ func (m *Model) handleCommand(input string) (tea.Model, tea.Cmd) { cfg.AI.APIKey = strings.TrimSpace(args[0]) envName := cfg.APIKeyEnvVarName() if cfg.RuntimeAPIKey() == "" { - m.apiKeyReady = false + m.chat.APIKeyReady = false m.AddMessage("assistant", fmt.Sprintf("环境变量 %s 未设置。请继续使用 /apikey 切换,或 /exit 退出。", envName)) return *m, nil } - err := provider.ValidateChatAPIKey(context.Background(), cfg) + err := validateChatAPIKey(context.Background(), cfg) if err == nil { - if writeErr := configs.WriteAppConfig(m.configPath, cfg); writeErr != nil { + if writeErr := writeAppConfig(m.chat.ConfigPath, cfg); writeErr != nil { cfg.AI.APIKey = previousEnvName - m.apiKeyReady = configs.RuntimeAPIKey() != "" + m.chat.APIKeyReady = configs.RuntimeAPIKey() != "" m.AddMessage("assistant", fmt.Sprintf("切换 API Key 环境变量名失败: %v", writeErr)) return *m, nil } - m.apiKeyReady = true + m.chat.APIKeyReady = true m.AddMessage("assistant", fmt.Sprintf("已切换 API Key 环境变量名为 %s,并通过校验。", envName)) return *m, nil } - m.apiKeyReady = false - if errors.Is(err, provider.ErrInvalidAPIKey) { + m.chat.APIKeyReady = false + if errors.Is(err, services.ErrInvalidAPIKey) { m.AddMessage("assistant", fmt.Sprintf("环境变量 %s 中的 API Key 无效:%v。请继续使用 /apikey 、/provider 、/switch 调整配置,或 /exit 退出。", envName, err)) return *m, nil } @@ -348,7 +353,7 @@ func (m *Model) handleCommand(input string) (tea.Model, tea.Cmd) { return *m, nil case "/provider": if len(args) == 0 { - m.AddMessage("assistant", fmt.Sprintf("用法: /provider \n可用提供商:\n - %s", strings.Join(provider.SupportedProviders(), "\n - "))) + m.AddMessage("assistant", fmt.Sprintf("用法: /provider \n可用提供商:\n - %s", strings.Join(services.SupportedProviders(), "\n - "))) return *m, nil } cfg := configs.GlobalAppConfig @@ -356,29 +361,29 @@ func (m *Model) handleCommand(input string) (tea.Model, tea.Cmd) { m.AddMessage("assistant", "当前配置未加载,无法切换提供商") return *m, nil } - providerName, ok := provider.NormalizeProviderName(strings.Join(args, " ")) + providerName, ok := services.NormalizeProviderName(strings.Join(args, " ")) if !ok { - m.AddMessage("assistant", fmt.Sprintf("不支持的提供商: %s\n可用提供商:\n - %s", strings.Join(args, " "), strings.Join(provider.SupportedProviders(), "\n - "))) + m.AddMessage("assistant", fmt.Sprintf("不支持的提供商: %s\n可用提供商:\n - %s", strings.Join(args, " "), strings.Join(services.SupportedProviders(), "\n - "))) return *m, nil } cfg.AI.Provider = providerName - cfg.AI.Model = provider.DefaultModelForProvider(providerName) - m.activeModel = cfg.AI.Model - if writeErr := configs.WriteAppConfig(m.configPath, cfg); writeErr != nil { + cfg.AI.Model = services.DefaultModelForProvider(providerName) + m.chat.ActiveModel = cfg.AI.Model + if writeErr := writeAppConfig(m.chat.ConfigPath, cfg); writeErr != nil { m.AddMessage("assistant", fmt.Sprintf("切换提供商失败: %v", writeErr)) return *m, nil } if cfg.RuntimeAPIKey() == "" { - m.apiKeyReady = false + m.chat.APIKeyReady = false m.AddMessage("assistant", fmt.Sprintf("已切换到提供商 %s,但当前环境变量 %s 未设置。请使用 /apikey 或设置该环境变量。", providerName, cfg.APIKeyEnvVarName())) return *m, nil } - if err := provider.ValidateChatAPIKey(context.Background(), cfg); err == nil { - m.apiKeyReady = true + if err := validateChatAPIKey(context.Background(), cfg); err == nil { + m.chat.APIKeyReady = true m.AddMessage("assistant", fmt.Sprintf("已切换到提供商 %s,当前模型已重置为默认值: %s。", providerName, cfg.AI.Model)) return *m, nil } else { - m.apiKeyReady = false + m.chat.APIKeyReady = false m.AddMessage("assistant", fmt.Sprintf("已切换到提供商 %s,但 API Key 未通过校验:%v。可继续使用 /apikey 、/provider 、/switch 调整配置。", providerName, err)) return *m, nil } @@ -394,22 +399,22 @@ func (m *Model) handleCommand(input string) (tea.Model, tea.Cmd) { } target := strings.Join(args, " ") cfg.AI.Model = target - if writeErr := configs.WriteAppConfig(m.configPath, cfg); writeErr != nil { + if writeErr := writeAppConfig(m.chat.ConfigPath, cfg); writeErr != nil { m.AddMessage("assistant", fmt.Sprintf("切换模型失败: %v", writeErr)) return *m, nil } - m.activeModel = target + m.chat.ActiveModel = target if cfg.RuntimeAPIKey() == "" { - m.apiKeyReady = false + m.chat.APIKeyReady = false m.AddMessage("assistant", fmt.Sprintf("已切换到模型: %s,但当前环境变量 %s 未设置。", target, cfg.APIKeyEnvVarName())) return *m, nil } - if err := provider.ValidateChatAPIKey(context.Background(), cfg); err == nil { - m.apiKeyReady = true + if err := validateChatAPIKey(context.Background(), cfg); err == nil { + m.chat.APIKeyReady = true m.AddMessage("assistant", fmt.Sprintf("已切换到模型: %s", target)) return *m, nil } else { - m.apiKeyReady = false + m.chat.APIKeyReady = false m.AddMessage("assistant", fmt.Sprintf("已切换到模型 %s,但 API Key 未通过校验:%v。", target, err)) return *m, nil } @@ -418,9 +423,9 @@ func (m *Model) handleCommand(input string) (tea.Model, tea.Cmd) { m.AddMessage("assistant", "用法: /pwd 或 /workspace") return *m, nil } - root := strings.TrimSpace(m.workspaceRoot) + root := strings.TrimSpace(m.chat.WorkspaceRoot) if root == "" { - root = tools.GetWorkspaceRoot() + root = getWorkspaceRoot() } if strings.TrimSpace(root) == "" { m.AddMessage("assistant", "当前工作区: 未知") @@ -433,7 +438,7 @@ func (m *Model) handleCommand(input string) (tea.Model, tea.Cmd) { m.AddMessage("assistant", fmt.Sprintf("读取记忆统计失败: %v", err)) return *m, nil } - m.memoryStats = *stats + m.chat.MemoryStats = *stats m.AddMessage("assistant", fmt.Sprintf( "记忆统计:\n 长期: %d\n 会话: %d\n 总计: %d\n TopK: %d\n 最小分数: %.2f\n 文件: %s\n 类型: %s", stats.PersistentItems, stats.SessionItems, stats.TotalItems, stats.TopK, stats.MinScore, stats.Path, formatTypeStats(stats.ByType), @@ -449,7 +454,7 @@ func (m *Model) handleCommand(input string) (tea.Model, tea.Cmd) { } stats, _ := m.client.GetMemoryStats(context.Background()) if stats != nil { - m.memoryStats = *stats + m.chat.MemoryStats = *stats } m.AddMessage("assistant", "已清空本地长期记忆") case "/clear-context": @@ -457,10 +462,10 @@ func (m *Model) handleCommand(input string) (tea.Model, tea.Cmd) { m.AddMessage("assistant", fmt.Sprintf("清空会话记忆失败: %v", err)) return *m, nil } - m.messages = nil + m.chat.Messages = nil stats, _ := m.client.GetMemoryStats(context.Background()) if stats != nil { - m.memoryStats = *stats + m.chat.MemoryStats = *stats } m.AddMessage("assistant", "已清空当前会话上下文") case "/run": @@ -508,11 +513,11 @@ func formatTypeStats(byType map[string]int) string { return "无" } ordered := []string{ - domain.TypeUserPreference, - domain.TypeProjectRule, - domain.TypeCodeFact, - domain.TypeFixRecipe, - domain.TypeSessionMemory, + services.TypeUserPreference, + services.TypeProjectRule, + services.TypeCodeFact, + services.TypeFixRecipe, + services.TypeSessionMemory, } parts := make([]string, 0, len(byType)) for _, key := range ordered { @@ -526,17 +531,17 @@ func formatTypeStats(byType map[string]int) string { return strings.Join(parts, ", ") } -func (m *Model) buildMessages() []infra.Message { +func (m *Model) buildMessages() []services.Message { mu := m.mutex() mu.Lock() defer mu.Unlock() - result := make([]infra.Message, 0, len(m.messages)) + result := make([]services.Message, 0, len(m.chat.Messages)) // 工具结果会被注入成 system 上下文,但只保留最近几条, // 否则连续工具链很容易把真正的对话历史挤出上下文窗口。 - keepToolContextIndex := recentToolContextIndexes(m.messages, maxToolContextMessages) + keepToolContextIndex := recentToolContextIndexes(m.chat.Messages, maxToolContextMessages) // 按照消息的原始时间顺序进行迭代 - for idx, msg := range m.messages { + for idx, msg := range m.chat.Messages { if msg.Role == "system" && isTransientToolStatusMessage(msg.Content) { continue } @@ -550,7 +555,7 @@ func (m *Model) buildMessages() []infra.Message { continue } // 将非空消息按其原始角色和内容添加到结果中 - result = append(result, infra.Message{ + result = append(result, services.Message{ Role: msg.Role, Content: msg.Content, }) @@ -559,8 +564,8 @@ func (m *Model) buildMessages() []infra.Message { return result } -func (m *Model) streamResponse(messages []infra.Message) tea.Cmd { - stream, err := m.client.Chat(context.Background(), messages, m.activeModel) +func (m *Model) streamResponse(messages []services.Message) tea.Cmd { + stream, err := m.client.Chat(context.Background(), messages, m.chat.ActiveModel) if err != nil { return func() tea.Msg { return StreamErrorMsg{Err: err} } } @@ -592,9 +597,9 @@ func (m *Model) sendCodeToAI(code string) tea.Cmd { prompt := fmt.Sprintf("请解释以下代码:\n```\n%s\n```", code) m.AddMessage("user", prompt) m.AddMessage("assistant", "") - m.TrimHistory(m.historyTurns) - m.generating = true - m.autoScroll = true + m.TrimHistory(m.chat.HistoryTurns) + m.chat.Generating = true + m.ui.AutoScroll = true m.refreshViewport() messages := m.buildMessages() @@ -609,7 +614,7 @@ func isToolContextMessage(content string) bool { return strings.HasPrefix(strings.TrimSpace(content), toolContextPrefix) } -func recentToolContextIndexes(messages []Message, keep int) map[int]struct{} { +func recentToolContextIndexes(messages []state.Message, keep int) map[int]struct{} { result := map[int]struct{}{} if keep <= 0 || len(messages) == 0 { return result @@ -635,7 +640,7 @@ func formatToolStatusMessage(toolName string, params map[string]interface{}) str return fmt.Sprintf("%s tool=%s%s", toolStatusPrefix, strings.TrimSpace(toolName), detail) } -func formatToolContextMessage(result *tools.ToolResult) string { +func formatToolContextMessage(result *services.ToolResult) string { if result == nil { return toolContextPrefix + "\n" + "tool=unknown\n" + "success=false\n" + "error:\n工具返回为空" } @@ -767,10 +772,10 @@ func (m *Model) calculateInputHeight() int { } func (m *Model) syncLayout() { - if m.width <= 0 || m.height <= 0 { + if m.ui.Width <= 0 || m.ui.Height <= 0 { return } - inputWidth := m.width + inputWidth := m.ui.Width if inputWidth < 20 { inputWidth = 20 } @@ -781,14 +786,14 @@ func (m *Model) syncLayout() { statusHeight := 1 inputHeight := m.textarea.Height() + 2 helpHeight := 0 - if m.mode == ModeHelp { - helpHeight = minInt(20, m.height-statusHeight-3) + if m.ui.Mode == state.ModeHelp { + helpHeight = minInt(20, m.ui.Height-statusHeight-3) } - contentHeight := m.height - statusHeight - inputHeight - helpHeight + contentHeight := m.ui.Height - statusHeight - inputHeight - helpHeight if contentHeight < 3 { contentHeight = 3 } - m.viewport.Width = m.width + m.viewport.Width = m.ui.Width m.viewport.Height = contentHeight } @@ -796,8 +801,8 @@ func (m *Model) refreshViewport() { m.syncLayout() content := m.renderChatContent() m.viewport.SetContent(content) - if m.autoScroll || m.viewport.AtBottom() { + if m.ui.AutoScroll || m.viewport.AtBottom() { m.viewport.GotoBottom() - m.autoScroll = true + m.ui.AutoScroll = true } } diff --git a/internal/tui/core/update_test.go b/internal/tui/core/update_test.go index d4375402..9215d45e 100644 --- a/internal/tui/core/update_test.go +++ b/internal/tui/core/update_test.go @@ -6,38 +6,1049 @@ import ( "strings" "testing" - "go-llm-demo/internal/tui/infra" + "go-llm-demo/configs" + "go-llm-demo/internal/tui/services" + "go-llm-demo/internal/tui/state" + + tea "github.com/charmbracelet/bubbletea" ) -type fakeChatClient struct{} +type fakeChatClient struct { + chatChunks []string + chatErr error + lastMessages []services.Message + lastModel string + memoryStats *services.MemoryStats + nilMemoryStats bool + memoryErr error + clearMemoryErr error + clearSessionErr error + defaultModelName string +} + +func (f *fakeChatClient) Chat(_ context.Context, messages []services.Message, model string) (<-chan string, error) { + f.lastMessages = append([]services.Message(nil), messages...) + f.lastModel = model + if f.chatErr != nil { + return nil, f.chatErr + } -func (fakeChatClient) Chat(context.Context, []infra.Message, string) (<-chan string, error) { - return nil, errors.New("not implemented") + ch := make(chan string, len(f.chatChunks)) + for _, chunk := range f.chatChunks { + ch <- chunk + } + close(ch) + return ch, nil } -func (fakeChatClient) GetMemoryStats(context.Context) (*infra.MemoryStats, error) { - return &infra.MemoryStats{}, nil +func (f *fakeChatClient) GetMemoryStats(context.Context) (*services.MemoryStats, error) { + if f.memoryErr != nil { + return nil, f.memoryErr + } + if f.nilMemoryStats { + return nil, nil + } + if f.memoryStats != nil { + statsCopy := *f.memoryStats + return &statsCopy, nil + } + return &services.MemoryStats{}, nil } -func (fakeChatClient) ClearMemory(context.Context) error { - return nil +func (f *fakeChatClient) ClearMemory(context.Context) error { + return f.clearMemoryErr } -func (fakeChatClient) ClearSessionMemory(context.Context) error { - return nil +func (f *fakeChatClient) ClearSessionMemory(context.Context) error { + return f.clearSessionErr } -func (fakeChatClient) DefaultModel() string { +func (f *fakeChatClient) DefaultModel() string { + if f.defaultModelName != "" { + return f.defaultModelName + } return "test-model" } +func newTestModel(t *testing.T, client *fakeChatClient) *Model { + t.Helper() + + restoreCoreGlobals(t) + cfg := configs.DefaultAppConfig() + configs.GlobalAppConfig = cfg + + m := NewModel(client, "persona", 4, "config.yaml", "") + m.ui.Width = 80 + m.ui.Height = 24 + m.syncLayout() + return &m +} + +func restoreCoreGlobals(t *testing.T) { + t.Helper() + + origValidateChatAPIKey := validateChatAPIKey + origWriteAppConfig := writeAppConfig + origGetWorkspaceRoot := getWorkspaceRoot + origExecuteToolCall := executeToolCall + origGlobalConfig := configs.GlobalAppConfig + + t.Cleanup(func() { + validateChatAPIKey = origValidateChatAPIKey + writeAppConfig = origWriteAppConfig + getWorkspaceRoot = origGetWorkspaceRoot + executeToolCall = origExecuteToolCall + configs.GlobalAppConfig = origGlobalConfig + }) +} + +func lastMessageContent(t *testing.T, m Model) string { + t.Helper() + if len(m.chat.Messages) == 0 { + t.Fatal("expected at least one message") + } + return m.chat.Messages[len(m.chat.Messages)-1].Content +} + +func assertLastMessageContains(t *testing.T, m Model, want string) { + t.Helper() + if !strings.Contains(lastMessageContent(t, m), want) { + t.Fatalf("expected last message to contain %q, got %q", want, lastMessageContent(t, m)) + } +} + +func TestHandleSubmitEmptyInputNoOp(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.textarea.SetValue(" ") + + updated, cmd := m.handleSubmit() + got := updated.(Model) + + if cmd != nil { + t.Fatal("expected no command for empty input") + } + if len(got.chat.Messages) != 0 { + t.Fatalf("expected no messages, got %d", len(got.chat.Messages)) + } +} + +func TestHandleSubmitFromHelpModeReturnsToChat(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.ui.Mode = state.ModeHelp + m.textarea.SetValue("help") + + updated, cmd := m.handleSubmit() + got := updated.(Model) + + if cmd != nil { + t.Fatal("expected no command when leaving help mode") + } + if got.ui.Mode != state.ModeChat { + t.Fatalf("expected chat mode, got %v", got.ui.Mode) + } + if len(got.chat.Messages) != 0 { + t.Fatalf("expected no messages while leaving help mode, got %d", len(got.chat.Messages)) + } +} + +func TestHandleSubmitRequiresReadyAPIKey(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.APIKeyReady = false + m.textarea.SetValue("hello") + + updated, cmd := m.handleSubmit() + got := updated.(Model) + + if cmd != nil { + t.Fatal("expected no command when API key is not ready") + } + if len(got.chat.Messages) != 1 { + t.Fatalf("expected one assistant warning, got %d messages", len(got.chat.Messages)) + } + if got.chat.Messages[0].Role != "assistant" { + t.Fatalf("expected assistant warning, got %+v", got.chat.Messages[0]) + } + if !strings.Contains(got.chat.Messages[0].Content, "API Key") { + t.Fatalf("expected API key warning, got %q", got.chat.Messages[0].Content) + } +} + +func TestHandleSubmitStartsStreamingConversation(t *testing.T) { + client := &fakeChatClient{chatChunks: []string{"hello back"}} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + m.textarea.SetValue("hello") + + updated, cmd := m.handleSubmit() + got := updated.(Model) + + if cmd == nil { + t.Fatal("expected streaming command") + } + if !got.chat.Generating { + t.Fatal("expected generating=true") + } + if len(got.chat.Messages) != 2 { + t.Fatalf("expected user and assistant placeholder, got %d messages", len(got.chat.Messages)) + } + if got.chat.Messages[0].Role != "user" || got.chat.Messages[0].Content != "hello" { + t.Fatalf("unexpected user message: %+v", got.chat.Messages[0]) + } + if got.chat.Messages[1].Role != "assistant" || got.chat.Messages[1].Content != "" { + t.Fatalf("expected assistant placeholder, got %+v", got.chat.Messages[1]) + } + if len(got.chat.CommandHistory) != 1 || got.chat.CommandHistory[0] != "hello" { + t.Fatalf("expected command history to record input, got %+v", got.chat.CommandHistory) + } + + msg := cmd() + chunk, ok := msg.(StreamChunkMsg) + if !ok { + t.Fatalf("expected StreamChunkMsg, got %T", msg) + } + if chunk.Content != "hello back" { + t.Fatalf("expected first stream chunk, got %q", chunk.Content) + } + if len(client.lastMessages) != 1 || client.lastMessages[0].Role != "user" || client.lastMessages[0].Content != "hello" { + t.Fatalf("expected streamed context to contain only the user message, got %+v", client.lastMessages) + } +} + +func TestHandleCommandRejectsNonRecoveryCommandWithoutAPIKey(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.APIKeyReady = false + + updated, cmd := m.handleCommand("/memory") + got := updated.(Model) + + if cmd != nil { + t.Fatal("expected no command") + } + if len(got.chat.Messages) != 1 { + t.Fatalf("expected one warning message, got %d", len(got.chat.Messages)) + } + if !strings.Contains(got.chat.Messages[0].Content, "API Key") { + t.Fatalf("expected API key guidance, got %q", got.chat.Messages[0].Content) + } +} + +func TestHandleCommandAPIKeyRequiresArgument(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + + updated, _ := m.handleCommand("/apikey") + got := updated.(Model) + assertLastMessageContains(t, got, "/apikey ") +} + +func TestHandleCommandAPIKeyRequiresLoadedConfig(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + configs.GlobalAppConfig = nil + + updated, _ := m.handleCommand("/apikey TEST_ENV") + got := updated.(Model) + assertLastMessageContains(t, got, "配置") +} + +func TestHandleCommandAPIKeyEnvStillMissing(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + configs.GlobalAppConfig = cfg + + updated, _ := m.handleCommand("/apikey MISSING_ENV") + got := updated.(Model) + + if got.chat.APIKeyReady { + t.Fatal("expected API key to remain not ready") + } + if cfg.AI.APIKey != "MISSING_ENV" { + t.Fatalf("expected env name to switch, got %q", cfg.AI.APIKey) + } + assertLastMessageContains(t, got, "MISSING_ENV") +} + +func TestHandleCommandAPIKeyInvalidKey(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + configs.GlobalAppConfig = cfg + t.Setenv("BAD_ENV", "secret") + + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { + return services.ErrInvalidAPIKey + } + + updated, _ := m.handleCommand("/apikey BAD_ENV") + got := updated.(Model) + + if got.chat.APIKeyReady { + t.Fatal("expected invalid key to mark API key as not ready") + } + assertLastMessageContains(t, got, "BAD_ENV") +} + +func TestHandleCommandAPIKeyGenericValidationError(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + configs.GlobalAppConfig = cfg + t.Setenv("GENERIC_ENV", "secret") + + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { + return errors.New("validation failed") + } + + updated, _ := m.handleCommand("/apikey GENERIC_ENV") + got := updated.(Model) + + if got.chat.APIKeyReady { + t.Fatal("expected generic validation failure to mark API key as not ready") + } + assertLastMessageContains(t, got, "validation failed") +} + +func TestHandleCommandAPIKeySuccessWritesConfig(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "ORIGINAL_ENV" + configs.GlobalAppConfig = cfg + t.Setenv("TEST_API_KEY_ENV", "secret") + + var writePath string + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { return nil } + writeAppConfig = func(path string, cfg *configs.AppConfiguration) error { + writePath = path + if cfg.AI.APIKey != "TEST_API_KEY_ENV" { + t.Fatalf("expected config to be updated before write, got %q", cfg.AI.APIKey) + } + return nil + } + + updated, cmd := m.handleCommand("/apikey TEST_API_KEY_ENV") + got := updated.(Model) + + if cmd != nil { + t.Fatal("expected no command") + } + if !got.chat.APIKeyReady { + t.Fatal("expected API key to be ready after validation") + } + if cfg.AI.APIKey != "TEST_API_KEY_ENV" { + t.Fatalf("expected config env name to change, got %q", cfg.AI.APIKey) + } + if writePath != "config.yaml" { + t.Fatalf("expected config write path config.yaml, got %q", writePath) + } + if !strings.Contains(lastMessageContent(t, got), "TEST_API_KEY_ENV") { + t.Fatalf("expected success message to mention env name, got %q", lastMessageContent(t, got)) + } +} + +func TestHandleCommandAPIKeyWriteFailureRestoresPreviousEnvName(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "PREVIOUS_ENV" + configs.GlobalAppConfig = cfg + t.Setenv("PREVIOUS_ENV", "old-secret") + t.Setenv("NEW_ENV", "new-secret") + + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { return nil } + writeAppConfig = func(string, *configs.AppConfiguration) error { return errors.New("write failed") } + + updated, _ := m.handleCommand("/apikey NEW_ENV") + got := updated.(Model) + + if cfg.AI.APIKey != "PREVIOUS_ENV" { + t.Fatalf("expected previous env name to be restored, got %q", cfg.AI.APIKey) + } + if !got.chat.APIKeyReady { + t.Fatal("expected API key readiness to match restored environment variable") + } + if !strings.Contains(lastMessageContent(t, got), "write failed") { + t.Fatalf("expected write failure message, got %q", lastMessageContent(t, got)) + } +} + +func TestHandleCommandProviderWithoutRuntimeKeyMarksAPIKeyNotReady(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + cfg.AI.Provider = "openll" + cfg.AI.APIKey = "MISSING_ENV" + configs.GlobalAppConfig = cfg + + writeAppConfig = func(string, *configs.AppConfiguration) error { return nil } + + updated, _ := m.handleCommand("/provider openai") + got := updated.(Model) + + if got.chat.APIKeyReady { + t.Fatal("expected API key to become not ready when provider env var is missing") + } + if cfg.AI.Provider != "openai" { + t.Fatalf("expected provider to switch, got %q", cfg.AI.Provider) + } + if got.chat.ActiveModel == "" { + t.Fatal("expected provider switch to reset active model") + } + if !strings.Contains(lastMessageContent(t, got), "openai") { + t.Fatalf("expected provider switch message, got %q", lastMessageContent(t, got)) + } +} + +func TestHandleCommandProviderRequiresArgument(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + + updated, _ := m.handleCommand("/provider") + got := updated.(Model) + assertLastMessageContains(t, got, "/provider ") +} + +func TestHandleCommandProviderRejectsUnknownProvider(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + configs.GlobalAppConfig = configs.DefaultAppConfig() + + updated, _ := m.handleCommand("/provider unknown") + got := updated.(Model) + assertLastMessageContains(t, got, "unknown") +} + +func TestHandleCommandProviderRequiresLoadedConfig(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + configs.GlobalAppConfig = nil + + updated, _ := m.handleCommand("/provider openai") + got := updated.(Model) + assertLastMessageContains(t, got, "配置") +} + +func TestHandleCommandProviderWriteFailure(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + configs.GlobalAppConfig = cfg + writeAppConfig = func(string, *configs.AppConfiguration) error { return errors.New("write failed") } + + updated, _ := m.handleCommand("/provider openai") + got := updated.(Model) + assertLastMessageContains(t, got, "write failed") +} + +func TestHandleCommandProviderValidationFailure(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "READY_ENV" + configs.GlobalAppConfig = cfg + t.Setenv("READY_ENV", "secret") + + writeAppConfig = func(string, *configs.AppConfiguration) error { return nil } + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { + return errors.New("validation failed") + } + + updated, _ := m.handleCommand("/provider openai") + got := updated.(Model) + + if got.chat.APIKeyReady { + t.Fatal("expected validation failure to mark API key as not ready") + } + assertLastMessageContains(t, got, "validation failed") +} + +func TestHandleCommandProviderSuccess(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "READY_ENV" + configs.GlobalAppConfig = cfg + t.Setenv("READY_ENV", "secret") + + writeAppConfig = func(string, *configs.AppConfiguration) error { return nil } + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { return nil } + + updated, _ := m.handleCommand("/provider openai") + got := updated.(Model) + + if !got.chat.APIKeyReady { + t.Fatal("expected successful provider switch to keep API key ready") + } + if got.chat.ActiveModel == "" { + t.Fatal("expected active model to be set") + } + assertLastMessageContains(t, got, "openai") +} + +func TestHandleCommandSwitchModelValidationSuccess(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "READY_ENV" + configs.GlobalAppConfig = cfg + t.Setenv("READY_ENV", "secret") + + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { return nil } + writeAppConfig = func(string, *configs.AppConfiguration) error { return nil } + + updated, _ := m.handleCommand("/switch gpt-5.4-mini") + got := updated.(Model) + + if !got.chat.APIKeyReady { + t.Fatal("expected API key to stay ready") + } + if got.chat.ActiveModel != "gpt-5.4-mini" { + t.Fatalf("expected active model to switch, got %q", got.chat.ActiveModel) + } + if cfg.AI.Model != "gpt-5.4-mini" { + t.Fatalf("expected config model to switch, got %q", cfg.AI.Model) + } +} + +func TestHandleCommandSwitchRequiresArgument(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + + updated, _ := m.handleCommand("/switch") + got := updated.(Model) + assertLastMessageContains(t, got, "/switch ") +} + +func TestHandleCommandSwitchRequiresLoadedConfig(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + configs.GlobalAppConfig = nil + + updated, _ := m.handleCommand("/switch gpt-5.4") + got := updated.(Model) + assertLastMessageContains(t, got, "配置") +} + +func TestHandleCommandSwitchWriteFailure(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + configs.GlobalAppConfig = cfg + writeAppConfig = func(string, *configs.AppConfiguration) error { return errors.New("write failed") } + + updated, _ := m.handleCommand("/switch gpt-5.4") + got := updated.(Model) + assertLastMessageContains(t, got, "write failed") +} + +func TestHandleCommandSwitchMissingRuntimeKey(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "MISSING_ENV" + configs.GlobalAppConfig = cfg + writeAppConfig = func(string, *configs.AppConfiguration) error { return nil } + + updated, _ := m.handleCommand("/switch gpt-5.4") + got := updated.(Model) + + if got.chat.APIKeyReady { + t.Fatal("expected API key to be not ready when runtime key is missing") + } + assertLastMessageContains(t, got, "MISSING_ENV") +} + +func TestHandleCommandSwitchValidationFailure(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + cfg := configs.DefaultAppConfig() + cfg.AI.APIKey = "READY_ENV" + configs.GlobalAppConfig = cfg + t.Setenv("READY_ENV", "secret") + + writeAppConfig = func(string, *configs.AppConfiguration) error { return nil } + validateChatAPIKey = func(context.Context, *configs.AppConfiguration) error { return errors.New("validation failed") } + + updated, _ := m.handleCommand("/switch gpt-5.4") + got := updated.(Model) + + if got.chat.APIKeyReady { + t.Fatal("expected validation failure to mark API key not ready") + } + assertLastMessageContains(t, got, "validation failed") +} + +func TestHandleCommandWorkspaceFallsBackToGlobalRoot(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.WorkspaceRoot = "" + getWorkspaceRoot = func() string { return `D:/neo-code/workspace` } + + updated, _ := m.handleCommand("/workspace") + got := updated.(Model) + + if !strings.Contains(lastMessageContent(t, got), `D:/neo-code/workspace`) { + t.Fatalf("expected workspace fallback path, got %q", lastMessageContent(t, got)) + } +} + +func TestHandleCommandWorkspaceRejectsExtraArgs(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + + updated, _ := m.handleCommand("/pwd extra") + got := updated.(Model) + assertLastMessageContains(t, got, "/pwd") +} + +func TestHandleCommandMemoryFailure(t *testing.T) { + client := &fakeChatClient{memoryErr: errors.New("stats failed")} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + updated, _ := m.handleCommand("/memory") + got := updated.(Model) + assertLastMessageContains(t, got, "stats failed") +} + +func TestHandleCommandMemorySuccess(t *testing.T) { + client := &fakeChatClient{memoryStats: &services.MemoryStats{ + PersistentItems: 1, + SessionItems: 2, + TotalItems: 3, + TopK: 4, + MinScore: 1.5, + Path: "memory.json", + ByType: map[string]int{ + services.TypeUserPreference: 1, + }, + }} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + updated, _ := m.handleCommand("/memory") + got := updated.(Model) + assertLastMessageContains(t, got, "memory.json") +} + +func TestHandleCommandClearMemoryRequiresConfirm(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + updated, _ := m.handleCommand("/clear-memory") + got := updated.(Model) + assertLastMessageContains(t, got, "confirm") +} + +func TestHandleCommandClearMemoryFailure(t *testing.T) { + client := &fakeChatClient{clearMemoryErr: errors.New("clear failed")} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + updated, _ := m.handleCommand("/clear-memory confirm") + got := updated.(Model) + assertLastMessageContains(t, got, "clear failed") +} + +func TestHandleCommandClearMemorySuccessRefreshesStats(t *testing.T) { + client := &fakeChatClient{memoryStats: &services.MemoryStats{TotalItems: 9}} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + updated, _ := m.handleCommand("/clear-memory confirm") + got := updated.(Model) + + if got.chat.MemoryStats.TotalItems != 9 { + t.Fatalf("expected stats refresh, got %+v", got.chat.MemoryStats) + } +} + +func TestHandleCommandClearContextFailure(t *testing.T) { + client := &fakeChatClient{clearSessionErr: errors.New("clear session failed")} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + updated, _ := m.handleCommand("/clear-context") + got := updated.(Model) + assertLastMessageContains(t, got, "clear session failed") +} + +func TestHandleCommandRunReturnsBatchCommand(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + _, cmd := m.handleCommand("/run package main") + if cmd == nil { + t.Fatal("expected batch command") + } +} + +func TestHandleCommandRunWithoutArgsIsNoOp(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + updated, cmd := m.handleCommand("/run") + got := updated.(Model) + if cmd != nil { + t.Fatal("expected no command") + } + if len(got.chat.Messages) != 0 { + t.Fatalf("expected no messages, got %d", len(got.chat.Messages)) + } +} + +func TestHandleCommandExplainStartsStreaming(t *testing.T) { + client := &fakeChatClient{chatChunks: []string{"explained"}} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + updated, cmd := m.handleCommand("/explain package main") + got := updated.(Model) + if cmd == nil { + t.Fatal("expected stream command") + } + if !got.chat.Generating { + t.Fatal("expected generating=true") + } + if len(got.chat.Messages) != 2 { + t.Fatalf("expected user and assistant messages, got %d", len(got.chat.Messages)) + } + if !strings.Contains(got.chat.Messages[0].Content, "package main") { + t.Fatalf("expected explain prompt to include code, got %q", got.chat.Messages[0].Content) + } + msg := cmd() + if chunk, ok := msg.(StreamChunkMsg); !ok || chunk.Content != "explained" { + t.Fatalf("expected explain stream chunk, got %#v", msg) + } +} + +func TestHandleCommandExplainWithoutArgsIsNoOp(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + updated, cmd := m.handleCommand("/explain") + got := updated.(Model) + if cmd != nil { + t.Fatal("expected no command") + } + if len(got.chat.Messages) != 0 { + t.Fatalf("expected no messages, got %d", len(got.chat.Messages)) + } +} + +func TestHandleCommandUnknownCommand(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + + updated, _ := m.handleCommand("/unknown") + got := updated.(Model) + assertLastMessageContains(t, got, "/unknown") +} + +func TestStreamChunkMsgAppendsContentAndSchedulesNextChunk(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.Generating = true + m.chat.Messages = []state.Message{{Role: "assistant", Content: ""}} + + ch := make(chan string, 1) + ch <- "second" + close(ch) + m.streamChan = ch + + updated, cmd := m.Update(StreamChunkMsg{Content: "first"}) + got := updated.(Model) + + if got.chat.Messages[0].Content != "first" { + t.Fatalf("expected first chunk to append, got %q", got.chat.Messages[0].Content) + } + if cmd == nil { + t.Fatal("expected follow-up command") + } + msg := cmd() + chunk, ok := msg.(StreamChunkMsg) + if !ok { + t.Fatalf("expected StreamChunkMsg, got %T", msg) + } + if chunk.Content != "second" { + t.Fatalf("expected second chunk, got %q", chunk.Content) + } +} + +func TestWindowSizeMsgUpdatesLayout(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + + updated, cmd := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40}) + got := updated.(Model) + + if cmd != nil { + t.Fatal("expected no command") + } + if got.ui.Width != 100 || got.ui.Height != 40 { + t.Fatalf("expected updated size, got %dx%d", got.ui.Width, got.ui.Height) + } +} + +func TestMouseMsgUpdatesViewport(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.viewport.SetContent("line1\nline2\nline3\nline4") + + updated, _ := m.Update(tea.MouseMsg{Type: tea.MouseWheelDown}) + _ = updated.(Model) +} + +func TestStreamChunkMsgNoOpWhenNotGenerating(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.Generating = false + m.chat.Messages = []state.Message{{Role: "assistant", Content: "start"}} + + updated, _ := m.Update(StreamChunkMsg{Content: "ignored"}) + got := updated.(Model) + + if got.chat.Messages[0].Content != "start" { + t.Fatalf("expected content unchanged, got %q", got.chat.Messages[0].Content) + } +} + +func TestStreamDoneMsgCompletesWithoutToolCall(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.Generating = true + ch := make(chan string) + close(ch) + m.streamChan = ch + m.chat.Messages = []state.Message{{Role: "assistant", Content: "done", Streaming: true}} + + updated, cmd := m.Update(StreamDoneMsg{}) + got := updated.(Model) + + if cmd != nil { + t.Fatal("expected no command") + } + if got.chat.Generating { + t.Fatal("expected generation to stop") + } + if got.streamChan != nil { + t.Fatal("expected stream channel to be cleared") + } + if got.chat.Messages[0].Streaming { + t.Fatal("expected last message streaming flag to clear") + } +} + +func TestStreamDoneMsgDoesNotReexecuteWhenToolAlreadyExecuting(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.Generating = true + m.chat.ToolExecuting = true + m.chat.Messages = []state.Message{{Role: "assistant", Content: `{"tool":"read","params":{"path":"README.md"}}`, Streaming: true}} + + called := false + executeToolCall = func(services.ToolCall) *services.ToolResult { + called = true + return nil + } + + updated, cmd := m.Update(StreamDoneMsg{}) + got := updated.(Model) + if cmd != nil { + t.Fatal("expected no command") + } + if !got.chat.ToolExecuting { + t.Fatal("expected tool executing flag to remain true") + } + if called { + t.Fatal("expected no duplicate tool execution") + } +} + +func TestShowHideHelpRefreshMemoryAndExitMsgs(t *testing.T) { + client := &fakeChatClient{memoryStats: &services.MemoryStats{TotalItems: 7}} + m := newTestModel(t, client) + + updated, _ := m.Update(ShowHelpMsg{}) + got := updated.(Model) + if got.ui.Mode != state.ModeHelp { + t.Fatalf("expected help mode, got %v", got.ui.Mode) + } + + updated, _ = got.Update(HideHelpMsg{}) + got = updated.(Model) + if got.ui.Mode != state.ModeChat { + t.Fatalf("expected chat mode, got %v", got.ui.Mode) + } + + updated, _ = got.Update(RefreshMemoryMsg{}) + got = updated.(Model) + if got.chat.MemoryStats.TotalItems != 7 { + t.Fatalf("expected refreshed stats, got %+v", got.chat.MemoryStats) + } + + _, cmd := got.Update(ExitMsg{}) + if cmd == nil { + t.Fatal("expected quit command") + } +} + +func TestRefreshMemoryMsgIgnoresClientError(t *testing.T) { + client := &fakeChatClient{memoryErr: errors.New("stats failed")} + m := newTestModel(t, client) + m.chat.MemoryStats.TotalItems = 5 + + updated, _ := m.Update(RefreshMemoryMsg{}) + got := updated.(Model) + if got.chat.MemoryStats.TotalItems != 5 { + t.Fatalf("expected previous stats to be preserved, got %+v", got.chat.MemoryStats) + } +} + +func TestStreamDoneMsgExecutesToolCallFromAssistantJSON(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.Generating = true + m.chat.Messages = []state.Message{{Role: "assistant", Content: `{"tool":"read","params":{"path":"README.md"}}`, Streaming: true}} + + expected := &services.ToolResult{ToolName: "read", Success: true, Output: "ok"} + executeToolCall = func(call services.ToolCall) *services.ToolResult { + if call.Tool != "read" { + t.Fatalf("expected read tool, got %q", call.Tool) + } + if got, _ := call.Params["path"].(string); got != "README.md" { + t.Fatalf("expected normalized path param, got %+v", call.Params) + } + return expected + } + + updated, cmd := m.Update(StreamDoneMsg{}) + got := updated.(Model) + + if !got.chat.ToolExecuting { + t.Fatal("expected tool execution flag to be set") + } + if len(got.chat.Messages) != 2 { + t.Fatalf("expected tool status message to be appended, got %d messages", len(got.chat.Messages)) + } + if !strings.HasPrefix(got.chat.Messages[1].Content, toolStatusPrefix) { + t.Fatalf("expected transient tool status, got %q", got.chat.Messages[1].Content) + } + if cmd == nil { + t.Fatal("expected tool execution command") + } + msg := cmd() + resultMsg, ok := msg.(ToolResultMsg) + if !ok { + t.Fatalf("expected ToolResultMsg, got %T", msg) + } + if resultMsg.Result != expected { + t.Fatalf("expected tool result to round-trip, got %+v", resultMsg.Result) + } +} + +func TestStreamDoneMsgReturnsToolErrorWhenToolResultIsNil(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.Generating = true + m.chat.Messages = []state.Message{{Role: "assistant", Content: `{"tool":"read","params":{"path":"README.md"}}`, Streaming: true}} + + executeToolCall = func(services.ToolCall) *services.ToolResult { return nil } + + _, cmd := m.Update(StreamDoneMsg{}) + if cmd == nil { + t.Fatal("expected tool execution command") + } + msg := cmd() + if _, ok := msg.(ToolErrorMsg); !ok { + t.Fatalf("expected ToolErrorMsg, got %T", msg) + } +} + +func TestToolResultMsgAddsContextAndRestartsStreaming(t *testing.T) { + client := &fakeChatClient{chatChunks: []string{"tool follow-up"}} + m := newTestModel(t, client) + m.chat.Messages = []state.Message{{Role: "user", Content: "hello"}} + m.chat.ToolExecuting = true + + result := &services.ToolResult{ToolName: "read", Success: true, Output: "README"} + updated, cmd := m.Update(ToolResultMsg{Result: result}) + got := updated.(Model) + + if got.chat.ToolExecuting { + t.Fatal("expected tool execution flag to be cleared") + } + if !got.chat.Generating { + t.Fatal("expected follow-up generation to start") + } + if len(got.chat.Messages) != 3 { + t.Fatalf("expected tool context and placeholder messages, got %d", len(got.chat.Messages)) + } + if !strings.HasPrefix(got.chat.Messages[1].Content, toolContextPrefix) { + t.Fatalf("expected tool context message, got %q", got.chat.Messages[1].Content) + } + if got.chat.Messages[2].Role != "assistant" || got.chat.Messages[2].Content != "" { + t.Fatalf("expected assistant placeholder, got %+v", got.chat.Messages[2]) + } + if cmd == nil { + t.Fatal("expected streaming command") + } + msg := cmd() + chunk, ok := msg.(StreamChunkMsg) + if !ok { + t.Fatalf("expected StreamChunkMsg, got %T", msg) + } + if chunk.Content != "tool follow-up" { + t.Fatalf("expected tool follow-up chunk, got %q", chunk.Content) + } +} + +func TestToolErrorMsgAddsErrorContextAndRestartsStreaming(t *testing.T) { + client := &fakeChatClient{chatChunks: []string{"error recovery"}} + m := newTestModel(t, client) + m.chat.ToolExecuting = true + + updated, cmd := m.Update(ToolErrorMsg{Err: errors.New("tool failed")}) + got := updated.(Model) + + if got.chat.ToolExecuting { + t.Fatal("expected tool execution flag to be cleared") + } + if !got.chat.Generating { + t.Fatal("expected generation restart after tool error") + } + if len(got.chat.Messages) != 2 { + t.Fatalf("expected tool error context and placeholder, got %d messages", len(got.chat.Messages)) + } + if !strings.Contains(got.chat.Messages[0].Content, "tool failed") { + t.Fatalf("expected tool error context, got %q", got.chat.Messages[0].Content) + } + if cmd == nil { + t.Fatal("expected follow-up stream command") + } + msg := cmd() + if _, ok := msg.(StreamChunkMsg); !ok { + t.Fatalf("expected StreamChunkMsg, got %T", msg) + } +} + func TestBuildMessagesSkipsEmptyAssistantPlaceholder(t *testing.T) { m := Model{ - messages: []Message{ + chat: state.ChatState{Messages: []state.Message{ {Role: "system", Content: "persona"}, {Role: "user", Content: "hello"}, {Role: "assistant", Content: ""}, - }, + }}, } got := m.buildMessages() @@ -52,53 +1063,190 @@ func TestBuildMessagesSkipsEmptyAssistantPlaceholder(t *testing.T) { } } +func TestFormatTypeStats(t *testing.T) { + if got := formatTypeStats(nil); got == "" { + t.Fatal("expected non-empty placeholder") + } + got := formatTypeStats(map[string]int{ + services.TypeUserPreference: 2, + services.TypeCodeFact: 1, + }) + if !strings.Contains(got, services.TypeUserPreference+"=2") || !strings.Contains(got, services.TypeCodeFact+"=1") { + t.Fatalf("unexpected formatted stats %q", got) + } +} + +func TestRecentToolContextIndexes(t *testing.T) { + messages := []state.Message{ + {Role: "system", Content: "[TOOL_CONTEXT]\na"}, + {Role: "assistant", Content: "x"}, + {Role: "system", Content: "[TOOL_CONTEXT]\nb"}, + } + got := recentToolContextIndexes(messages, 1) + if len(got) != 1 { + t.Fatalf("expected one index, got %+v", got) + } + if _, ok := got[2]; !ok { + t.Fatalf("expected newest index to be kept, got %+v", got) + } +} + +func TestFormatToolStatusMessage(t *testing.T) { + got := formatToolStatusMessage("read", map[string]interface{}{"filePath": "README.md"}) + if !strings.Contains(got, "tool=read") || !strings.Contains(got, "README.md") { + t.Fatalf("unexpected tool status %q", got) + } +} + +func TestFormatToolContextMessage(t *testing.T) { + got := formatToolContextMessage(&services.ToolResult{ + ToolName: "read", + Success: true, + Output: "hello", + Metadata: map[string]interface{}{"k": "v"}, + }) + if !strings.Contains(got, "tool=read") || !strings.Contains(got, "metadata=") || !strings.Contains(got, "output:") { + t.Fatalf("unexpected tool context %q", got) + } + + got = formatToolContextMessage(&services.ToolResult{ToolName: "read", Success: false, Error: "boom"}) + if !strings.Contains(got, "error:") || !strings.Contains(got, "boom") { + t.Fatalf("unexpected error context %q", got) + } +} + +func TestFormatToolErrorContext(t *testing.T) { + got := formatToolErrorContext(errors.New("boom")) + if !strings.Contains(got, "boom") { + t.Fatalf("unexpected tool error context %q", got) + } +} + +func TestTruncateForContext(t *testing.T) { + if got := truncateForContext(" hi ", 10); got != "hi" { + t.Fatalf("expected trimmed content, got %q", got) + } + got := truncateForContext(strings.Repeat("a", 20), 10) + if !strings.Contains(got, "truncated") { + t.Fatalf("expected truncation marker, got %q", got) + } +} + +func TestDetectLanguage(t *testing.T) { + tests := []struct { + code string + ext string + run string + }{ + {"#!/bin/bash\necho hi", "sh", "bash"}, + {"package main\nfunc main(){}", "go", ""}, + {"def hi():\n pass", "py", "python"}, + {"fn main() {}", "rs", "rustc"}, + {"console.log('x')", "js", "node"}, + {"unknown", "", ""}, + } + for _, tt := range tests { + ext, run := detectLanguage(tt.code) + if ext != tt.ext || run != tt.run { + t.Fatalf("detectLanguage(%q) = (%q,%q), want (%q,%q)", tt.code, ext, run, tt.ext, tt.run) + } + } +} + +func TestCalculateInputHeight(t *testing.T) { + client := &fakeChatClient{} + m := newTestModel(t, client) + + m.textarea.SetValue("one") + if got := m.calculateInputHeight(); got != 3 { + t.Fatalf("expected minimum height 3, got %d", got) + } + m.textarea.SetValue(strings.Repeat("a\n", 10)) + if got := m.calculateInputHeight(); got != 8 { + t.Fatalf("expected capped height 8, got %d", got) + } +} + +func TestStreamResponseReturnsErrorMsg(t *testing.T) { + client := &fakeChatClient{chatErr: errors.New("chat failed")} + m := newTestModel(t, client) + + cmd := m.streamResponse([]services.Message{{Role: "user", Content: "hi"}}) + if cmd == nil { + t.Fatal("expected command") + } + msg := cmd() + if _, ok := msg.(StreamErrorMsg); !ok { + t.Fatalf("expected StreamErrorMsg, got %T", msg) + } +} + +func TestStreamResponseAndStreamResponseFromChannelDone(t *testing.T) { + client := &fakeChatClient{chatChunks: nil} + m := newTestModel(t, client) + + cmd := m.streamResponse([]services.Message{{Role: "user", Content: "hi"}}) + if cmd == nil { + t.Fatal("expected command") + } + msg := cmd() + if _, ok := msg.(StreamDoneMsg); !ok { + t.Fatalf("expected StreamDoneMsg, got %T", msg) + } + + m.streamChan = nil + if cmd := m.streamResponseFromChannel(); cmd != nil { + t.Fatal("expected nil command when stream channel is nil") + } +} + func TestStreamErrorReplacesTrailingPlaceholder(t *testing.T) { m := Model{ - historyTurns: 6, - messages: []Message{ - {Role: "user", Content: "hello"}, - {Role: "assistant", Content: ""}, + chat: state.ChatState{ + HistoryTurns: 6, + Messages: []state.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: ""}, + }, }, } updated, _ := m.Update(StreamErrorMsg{Err: errors.New("boom")}) got := updated.(Model) - if len(got.messages) != 2 { - t.Fatalf("expected placeholder replacement without extra message, got %d messages", len(got.messages)) + if len(got.chat.Messages) != 2 { + t.Fatalf("expected placeholder replacement without extra message, got %d messages", len(got.chat.Messages)) } - if got.messages[1].Content != "错误: boom" { - t.Fatalf("expected trailing placeholder to become error, got %q", got.messages[1].Content) + if !strings.Contains(got.chat.Messages[1].Content, "boom") { + t.Fatalf("expected trailing placeholder to become error, got %q", got.chat.Messages[1].Content) } } func TestClearContextDoesNotReinjectStalePersonaMessage(t *testing.T) { - m := Model{ - client: fakeChatClient{}, - persona: "stale persona", - apiKeyReady: true, - messages: []Message{ - {Role: "system", Content: "stale persona"}, - {Role: "user", Content: "hello"}, - }, + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + m.chat.Messages = []state.Message{ + {Role: "system", Content: "stale persona"}, + {Role: "user", Content: "hello"}, } updated, _ := m.handleCommand("/clear-context") got := updated.(Model) - if len(got.messages) != 1 { - t.Fatalf("expected only confirmation message after clear-context, got %d messages", len(got.messages)) + if len(got.chat.Messages) != 1 { + t.Fatalf("expected only confirmation message after clear-context, got %d messages", len(got.chat.Messages)) } - if got.messages[0].Role != "assistant" { - t.Fatalf("expected confirmation assistant message, got %+v", got.messages[0]) + if got.chat.Messages[0].Role != "assistant" { + t.Fatalf("expected confirmation assistant message, got %+v", got.chat.Messages[0]) } } func TestBuildMessagesSkipsTransientToolStatusMessage(t *testing.T) { m := Model{ - messages: []Message{ + chat: state.ChatState{Messages: []state.Message{ {Role: "user", Content: "hello"}, {Role: "system", Content: "[TOOL_STATUS] tool=read file=README.md"}, {Role: "assistant", Content: "ok"}, - }, + }}, } got := m.buildMessages() @@ -114,11 +1262,11 @@ func TestBuildMessagesSkipsTransientToolStatusMessage(t *testing.T) { func TestBuildMessagesKeepsOnlyRecentToolContextMessages(t *testing.T) { m := Model{} - m.messages = append(m.messages, Message{Role: "user", Content: "step 1"}) + m.chat.Messages = append(m.chat.Messages, state.Message{Role: "user", Content: "step 1"}) for i := 1; i <= 5; i++ { - m.messages = append(m.messages, Message{Role: "system", Content: "[TOOL_CONTEXT]\ntool=read\nsuccess=true\noutput:\nchunk " + string(rune('0'+i))}) + m.chat.Messages = append(m.chat.Messages, state.Message{Role: "system", Content: "[TOOL_CONTEXT]\ntool=read\nsuccess=true\noutput:\nchunk " + string(rune('0'+i))}) } - m.messages = append(m.messages, Message{Role: "assistant", Content: "done"}) + m.chat.Messages = append(m.chat.Messages, state.Message{Role: "assistant", Content: "done"}) got := m.buildMessages() toolCtxCount := 0 @@ -144,21 +1292,20 @@ func TestBuildMessagesKeepsOnlyRecentToolContextMessages(t *testing.T) { } func TestWorkspaceCommandShowsWorkspaceRoot(t *testing.T) { - m := Model{ - client: fakeChatClient{}, - apiKeyReady: true, - workspaceRoot: `F:/Qiniu/test1`, - } + client := &fakeChatClient{} + m := newTestModel(t, client) + m.chat.APIKeyReady = true + m.chat.WorkspaceRoot = `F:/Qiniu/test1` updated, _ := m.handleCommand("/pwd") got := updated.(Model) - if len(got.messages) != 1 { - t.Fatalf("expected exactly 1 message, got %d", len(got.messages)) + if len(got.chat.Messages) != 1 { + t.Fatalf("expected exactly 1 message, got %d", len(got.chat.Messages)) } - if got.messages[0].Role != "assistant" { - t.Fatalf("expected assistant message, got %+v", got.messages[0]) + if got.chat.Messages[0].Role != "assistant" { + t.Fatalf("expected assistant message, got %+v", got.chat.Messages[0]) } - if !strings.Contains(got.messages[0].Content, `F:/Qiniu/test1`) { - t.Fatalf("expected workspace path in response, got %q", got.messages[0].Content) + if !strings.Contains(got.chat.Messages[0].Content, `F:/Qiniu/test1`) { + t.Fatalf("expected workspace path in response, got %q", got.chat.Messages[0].Content) } } diff --git a/internal/tui/core/view.go b/internal/tui/core/view.go index 60d2a206..35269b01 100644 --- a/internal/tui/core/view.go +++ b/internal/tui/core/view.go @@ -4,19 +4,20 @@ import ( "strings" "go-llm-demo/internal/tui/components" + "go-llm-demo/internal/tui/state" "github.com/charmbracelet/lipgloss" ) func (m Model) View() string { - if m.width < 20 || m.height < 6 { + if m.ui.Width < 20 || m.ui.Height < 6 { return "窗口太小" } statusHeight := 1 helpHeight := 0 - if m.mode == ModeHelp { - helpHeight = minInt(20, m.height-statusHeight-3) + if m.ui.Mode == state.ModeHelp { + helpHeight = minInt(20, m.ui.Height-statusHeight-3) } inputContent := m.renderInputArea() @@ -25,37 +26,37 @@ func (m Model) View() string { inputHeight = 4 } - contentHeight := m.height - statusHeight - inputHeight - helpHeight + contentHeight := m.ui.Height - statusHeight - inputHeight - helpHeight if contentHeight < 3 { contentHeight = 3 } statusBar := lipgloss.NewStyle(). Height(statusHeight). - Width(m.width). + Width(m.ui.Width). Render(components.StatusBar{ - Model: m.activeModel, - MemoryCnt: m.memoryStats.TotalItems, - Generating: m.generating, - Width: m.width, + Model: m.chat.ActiveModel, + MemoryCnt: m.chat.MemoryStats.TotalItems, + Generating: m.chat.Generating, + Width: m.ui.Width, }.Render()) viewportView := m.viewport viewportView.SetContent(m.renderChatContent()) chatArea := lipgloss.NewStyle(). - Width(m.width). + Width(m.ui.Width). Height(contentHeight). Render(viewportView.View()) inputArea := lipgloss.NewStyle(). - Width(m.width). + Width(m.ui.Width). Render(inputContent) - if m.mode == ModeHelp { + if m.ui.Mode == state.ModeHelp { help := lipgloss.NewStyle(). - Width(m.width). + Width(m.ui.Width). Height(helpHeight). - Render(RenderHelp(m.width)) + Render(components.RenderHelp(m.ui.Width)) return lipgloss.JoinVertical(lipgloss.Left, statusBar, chatArea, help, inputArea) } @@ -70,16 +71,10 @@ func countLines(s string) int { } func (m Model) renderInputArea() string { - helpText := "[Enter换行 F5/F8发送 PgUp/PgDn滚动]" - if !m.generating { - helpText = "[Enter换行 F5/F8发送 Ctrl+V粘贴 PgUp/PgDn滚动]" - } - - footer := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#5C6370")). - Render(helpText) - - return m.textarea.View() + "\n" + footer + return components.InputBox{ + Body: m.textarea.View(), + Generating: m.chat.Generating, + }.Render() } func (m Model) renderChatContent() string { @@ -87,8 +82,8 @@ func (m Model) renderChatContent() string { } func (m Model) toComponentMessages() []components.Message { - messages := make([]components.Message, len(m.messages)) - for i, msg := range m.messages { + messages := make([]components.Message, len(m.chat.Messages)) + for i, msg := range m.chat.Messages { messages[i] = components.Message{ Role: msg.Role, Content: msg.Content, @@ -105,63 +100,3 @@ func minInt(a, b int) int { } return b } - -func RenderHelp(width int) string { - var b strings.Builder - - title := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#61AFEF")). - Bold(true). - Render("NeoCode 帮助") - - b.WriteString(title) - b.WriteString("\n\n") - - commands := []struct { - cmd string - desc string - }{ - {"/help", "显示帮助"}, - {"/pwd | /workspace", "显示当前工作区目录"}, - {"/apikey ", "切换 API Key 变量名"}, - {"/provider ", "切换模型提供商"}, - {"/switch ", "切换模型"}, - {"/run ", "执行代码"}, - {"/explain ", "解释代码"}, - {"/memory", "显示记忆统计"}, - {"/clear-memory confirm", "清空长期记忆"}, - {"/clear-context", "清空会话上下文"}, - {"/exit", "退出程序"}, - } - - cmdStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#98C379")). - Width(22) - - descStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#ABB2BF")) - - dimStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#5C6370")) - - helpStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#61AFEF")) - - for _, c := range commands { - b.WriteString(cmdStyle.Render(c.cmd)) - b.WriteString(descStyle.Render(c.desc)) - b.WriteString("\n") - } - - b.WriteString("\n") - b.WriteString(helpStyle.Render("输入框支持光标、粘贴、滚动,F5/F8 发送")) - b.WriteString("\n") - b.WriteString(helpStyle.Render("聊天区支持 PgUp/PgDn 和鼠标滚轮")) - b.WriteString("\n") - b.WriteString(helpStyle.Render("取消: Ctrl+C")) - - b.WriteString("\n\n") - b.WriteString(dimStyle.Render("按 Esc 或 /help 关闭")) - - return lipgloss.NewStyle().MaxWidth(width).Render(b.String()) -} diff --git a/internal/tui/core/view_test.go b/internal/tui/core/view_test.go new file mode 100644 index 00000000..48816def --- /dev/null +++ b/internal/tui/core/view_test.go @@ -0,0 +1,74 @@ +package core + +import ( + "strings" + "testing" + "time" + + "go-llm-demo/internal/tui/state" +) + +func TestCountLines(t *testing.T) { + tests := []struct { + name string + in string + want int + }{ + {name: "empty", in: "", want: 0}, + {name: "single", in: "hello", want: 1}, + {name: "multi", in: "a\nb\nc", want: 3}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := countLines(tt.in); got != tt.want { + t.Fatalf("expected %d, got %d", tt.want, got) + } + }) + } +} + +func TestToComponentMessagesPreservesFields(t *testing.T) { + ts := time.Unix(123, 0) + m := Model{ + chat: state.ChatState{Messages: []state.Message{{ + Role: "assistant", + Content: "hello", + Timestamp: ts, + Streaming: true, + }}}, + } + + got := m.toComponentMessages() + if len(got) != 1 { + t.Fatalf("expected 1 message, got %d", len(got)) + } + if got[0].Role != "assistant" || got[0].Content != "hello" || !got[0].Timestamp.Equal(ts) || !got[0].Streaming { + t.Fatalf("unexpected converted message: %+v", got[0]) + } +} + +func TestViewShowsSmallWindowMessage(t *testing.T) { + m := Model{} + m.ui.Width = 10 + m.ui.Height = 5 + + if got := m.View(); got != "窗口太小" { + t.Fatalf("expected small window warning, got %q", got) + } +} + +func TestViewRendersHelpPanelInHelpMode(t *testing.T) { + m := NewModel(&fakeChatClient{}, "persona", 6, "config.yaml", "workspace") + m.ui.Width = 80 + m.ui.Height = 30 + m.ui.Mode = state.ModeHelp + + rendered := m.View() + if !strings.Contains(rendered, "NeoCode 帮助") { + t.Fatalf("expected help panel in view, got %q", rendered) + } + if !strings.Contains(rendered, "/help") { + t.Fatalf("expected help commands in view, got %q", rendered) + } +} diff --git a/internal/tui/infra/api_client.go b/internal/tui/services/api_client.go similarity index 99% rename from internal/tui/infra/api_client.go rename to internal/tui/services/api_client.go index 263e2a8b..7bca50b3 100644 --- a/internal/tui/infra/api_client.go +++ b/internal/tui/services/api_client.go @@ -1,4 +1,4 @@ -package infra +package services import ( "context" diff --git a/internal/tui/services/runtime_services.go b/internal/tui/services/runtime_services.go new file mode 100644 index 00000000..864b95f6 --- /dev/null +++ b/internal/tui/services/runtime_services.go @@ -0,0 +1,62 @@ +package services + +import ( + "context" + + "go-llm-demo/configs" + "go-llm-demo/internal/server/domain" + serverprovider "go-llm-demo/internal/server/infra/provider" + servertools "go-llm-demo/internal/server/infra/tools" +) + +type ToolCall = domain.ToolCall +type ToolResult = servertools.ToolResult + +const ( + TypeUserPreference = domain.TypeUserPreference + TypeProjectRule = domain.TypeProjectRule + TypeCodeFact = domain.TypeCodeFact + TypeFixRecipe = domain.TypeFixRecipe + TypeSessionMemory = domain.TypeSessionMemory +) + +var ( + ErrInvalidAPIKey = serverprovider.ErrInvalidAPIKey + ErrAPIKeyValidationSoft = serverprovider.ErrAPIKeyValidationSoft +) + +func ResolveWorkspaceRoot(workspaceFlag string) (string, error) { + return servertools.ResolveWorkspaceRoot(workspaceFlag) +} + +func SetWorkspaceRoot(root string) error { + return servertools.SetWorkspaceRoot(root) +} + +func GetWorkspaceRoot() string { + return servertools.GetWorkspaceRoot() +} + +func NormalizeToolParams(params map[string]interface{}) map[string]interface{} { + return servertools.NormalizeParams(params) +} + +func ExecuteToolCall(call ToolCall) *ToolResult { + return servertools.GlobalRegistry.Execute(call) +} + +func ValidateChatAPIKey(ctx context.Context, cfg *configs.AppConfiguration) error { + return serverprovider.ValidateChatAPIKey(ctx, cfg) +} + +func NormalizeProviderName(name string) (string, bool) { + return serverprovider.NormalizeProviderName(name) +} + +func SupportedProviders() []string { + return serverprovider.SupportedProviders() +} + +func DefaultModelForProvider(name string) string { + return serverprovider.DefaultModelForProvider(name) +} diff --git a/internal/tui/services/runtime_services_test.go b/internal/tui/services/runtime_services_test.go new file mode 100644 index 00000000..15622053 --- /dev/null +++ b/internal/tui/services/runtime_services_test.go @@ -0,0 +1,106 @@ +package services + +import ( + "context" + "strings" + "testing" + + servertools "go-llm-demo/internal/server/infra/tools" +) + +func TestResolveWorkspaceRootUsesEnvOverride(t *testing.T) { + dir := t.TempDir() + t.Setenv(servertools.WorkspaceEnvVar, dir) + + got, err := ResolveWorkspaceRoot("") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got != dir { + t.Fatalf("expected env workspace root %q, got %q", dir, got) + } +} + +func TestSetAndGetWorkspaceRoot(t *testing.T) { + origRoot := GetWorkspaceRoot() + t.Cleanup(func() { + _ = SetWorkspaceRoot(origRoot) + }) + + dir := t.TempDir() + if err := SetWorkspaceRoot(dir); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if got := GetWorkspaceRoot(); got != dir { + t.Fatalf("expected workspace root %q, got %q", dir, got) + } +} + +func TestNormalizeToolParamsRecursivelyConvertsSnakeCase(t *testing.T) { + got := NormalizeToolParams(map[string]interface{}{ + "file_path": "README.md", + "nested_map": map[string]interface{}{ + "line_number": 12, + }, + }) + + if got["filePath"] != "README.md" { + t.Fatalf("expected filePath key, got %+v", got) + } + nested, ok := got["nestedMap"].(map[string]interface{}) + if !ok { + t.Fatalf("expected nestedMap, got %+v", got["nestedMap"]) + } + if nested["lineNumber"] != 12 { + t.Fatalf("expected nested camelCase key, got %+v", nested) + } +} + +func TestExecuteToolCallReturnsUnknownToolError(t *testing.T) { + result := ExecuteToolCall(ToolCall{Tool: "unknown-tool", Params: map[string]interface{}{}}) + if result == nil { + t.Fatal("expected tool result") + } + if result.Success { + t.Fatalf("expected failure for unknown tool, got %+v", result) + } + if result.ToolName != "unknown-tool" { + t.Fatalf("expected tool name to round-trip, got %q", result.ToolName) + } +} + +func TestValidateChatAPIKeyRejectsNilConfig(t *testing.T) { + err := ValidateChatAPIKey(context.Background(), nil) + if err == nil || !strings.Contains(strings.ToLower(err.Error()), "config") { + t.Fatalf("expected nil-config error, got %v", err) + } +} + +func TestNormalizeProviderNameSupportedProvidersAndDefaultModel(t *testing.T) { + name, ok := NormalizeProviderName("openai") + if !ok || name != "openai" { + t.Fatalf("expected normalized openai provider, got %q ok=%v", name, ok) + } + if _, ok := NormalizeProviderName("unknown-provider"); ok { + t.Fatal("expected unknown provider to be rejected") + } + + providers := SupportedProviders() + if len(providers) == 0 { + t.Fatal("expected supported providers") + } + foundOpenAI := false + for _, provider := range providers { + if provider == "openai" { + foundOpenAI = true + break + } + } + if !foundOpenAI { + t.Fatalf("expected openai in supported providers, got %+v", providers) + } + + if model := DefaultModelForProvider("openai"); strings.TrimSpace(model) == "" { + t.Fatal("expected default model for openai") + } +} diff --git a/internal/tui/state/chat_state.go b/internal/tui/state/chat_state.go new file mode 100644 index 00000000..cdd146f9 --- /dev/null +++ b/internal/tui/state/chat_state.go @@ -0,0 +1,28 @@ +package state + +import ( + "time" + + "go-llm-demo/internal/tui/services" +) + +type Message struct { + Role string + Content string + Timestamp time.Time + Streaming bool +} + +type ChatState struct { + Messages []Message + HistoryTurns int + Generating bool + ActiveModel string + MemoryStats services.MemoryStats + CommandHistory []string + CmdHistIndex int + WorkspaceRoot string + ToolExecuting bool + APIKeyReady bool + ConfigPath string +} diff --git a/internal/tui/state/ui_state.go b/internal/tui/state/ui_state.go new file mode 100644 index 00000000..4958a434 --- /dev/null +++ b/internal/tui/state/ui_state.go @@ -0,0 +1,18 @@ +package state + +type Mode int + +const ( + ModeChat Mode = iota + ModeCodeInput + ModeHelp + ModeMemory +) + +type UIState struct { + Width int + Height int + Mode Mode + Focused string + AutoScroll bool +}