Skip to content
Merged
304 changes: 78 additions & 226 deletions cmd/tui/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,269 +3,122 @@ package main
import (
"bufio"
"context"
"errors"
"flag"
"fmt"
"io"
"os"
"strings"

"go-llm-demo/configs"
"go-llm-demo/internal/server/infra/provider"
"go-llm-demo/internal/server/infra/tools"
"go-llm-demo/internal/tui/core"
"go-llm-demo/internal/tui/infra"
"go-llm-demo/internal/tui/bootstrap"

tea "github.com/charmbracelet/bubbletea"
)

func main() {
workspaceFlag := flag.String("workspace", "", "指定工作区根目录")
flag.Parse()
const defaultConfigPath = "config.yaml"

setUTF8Mode()
var buildRunDeps = defaultRunDeps

workspaceRoot, err := tools.ResolveWorkspaceRoot(*workspaceFlag)
if err != nil {
fmt.Fprintf(os.Stderr, "解析工作区失败: %v\n", err)
os.Exit(1)
}
if err := tools.SetWorkspaceRoot(workspaceRoot); err != nil {
fmt.Fprintf(os.Stderr, "设置工作区失败: %v\n", err)
os.Exit(1)
}
type programRunner interface {
Run() (tea.Model, error)
}

scanner := bufio.NewScanner(os.Stdin)
ready, err := ensureAPIKeyInteractive(context.Background(), scanner, "config.yaml")
if err != nil {
fmt.Fprintf(os.Stderr, "初始化配置失败: %v\n", err)
os.Exit(1)
}
if !ready {
fmt.Println("已退出 NeoCode")
return
}
type runDeps struct {
stdin io.Reader
stdout io.Writer
stderr io.Writer
setUTF8Mode func()
prepareWorkspace func(string) (string, error)
ensureAPIKeyInteractive func(context.Context, *bufio.Scanner, string) (bool, error)
loadAppConfig func(string) error
loadPersonaPrompt func(string) (string, string, error)
newProgram func(string, int, string, string) (programRunner, error)
}

if err := configs.LoadAppConfig("config.yaml"); err != nil {
fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err)
os.Exit(1)
func defaultRunDeps(stdin io.Reader, stdout, stderr io.Writer) runDeps {
return runDeps{
stdin: stdin,
stdout: stdout,
stderr: stderr,
setUTF8Mode: setUTF8Mode,
prepareWorkspace: bootstrap.PrepareWorkspace,
ensureAPIKeyInteractive: bootstrap.EnsureAPIKeyInteractive,
loadAppConfig: configs.LoadAppConfig,
loadPersonaPrompt: configs.LoadPersonaPrompt,
newProgram: func(persona string, historyTurns int, configPath, workspaceRoot string) (programRunner, error) {
return bootstrap.NewProgram(persona, historyTurns, configPath, workspaceRoot)
},
}
}

persona, personaPath, err := configs.LoadPersonaPrompt(configs.GlobalAppConfig.Persona.FilePath)
if err != nil {
fmt.Fprintf(os.Stderr, "警告: 人设加载失败: %v\n", err)
} else if personaPath != "" && strings.TrimSpace(configs.GlobalAppConfig.Persona.FilePath) != personaPath {
fmt.Fprintf(os.Stderr, "提示: 人设已从 %s 回退加载\n", personaPath)
}
historyTurns := configs.GlobalAppConfig.History.ShortTermTurns

client, err := infra.NewLocalChatClient()
func main() {
workspaceFlag, err := parseWorkspaceFlag(os.Args[1:], os.Stderr)
if err != nil {
fmt.Fprintf(os.Stderr, "初始化失败: %v\n", err)
os.Exit(1)
}

model := core.NewModel(client, persona, historyTurns, "config.yaml", workspaceRoot)
p := tea.NewProgram(model,
tea.WithAltScreen(),
tea.WithMouseCellMotion(),
)
if _, err := p.Run(); err != nil {
fmt.Fprintf(os.Stderr, "运行失败: %v\n", err)
if err := run(workspaceFlag, os.Stdin, os.Stdout, os.Stderr); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}

func ensureAPIKeyInteractive(ctx context.Context, scanner *bufio.Scanner, configPath string) (bool, error) {
cfg, created, err := configs.EnsureConfigFile(configPath)
if err != nil {
return false, err
}
if created {
fmt.Printf("已创建 %s\n", configPath)
}

for {
if cfg.RuntimeAPIKey() == "" {
envName := cfg.APIKeyEnvVarName()
fmt.Printf("未检测到环境变量 %s。可使用 /apikey <env_name>、/provider <name>、/switch <model> 切换配置,或先设置该环境变量后再 /retry。\n", envName)
fmt.Printf("Windows 示例: setx %s \"your-api-key\"\n", envName)
result, handleErr := handleSetupDecision(scanner, cfg, false, configPath)
if handleErr != nil {
return false, handleErr
}
if result == setupExit {
return false, nil
}
continue
}
func parseWorkspaceFlag(args []string, stderr io.Writer) (string, error) {
fs := flag.NewFlagSet("tui", flag.ContinueOnError)
fs.SetOutput(stderr)

if err := provider.ValidateChatAPIKey(ctx, cfg); err == nil {
if saveErr := configs.WriteAppConfig(configPath, cfg); saveErr != nil {
return false, saveErr
}
configs.GlobalAppConfig = cfg
fmt.Println("API key 验证通过。")
return true, nil
} else if errors.Is(err, provider.ErrInvalidAPIKey) {
fmt.Printf("环境变量 %s 中的 API key 无效: %v\n", cfg.APIKeyEnvVarName(), err)
result, handleErr := handleSetupDecision(scanner, cfg, false, configPath)
if handleErr != nil {
return false, handleErr
}
if result == setupExit {
return false, nil
}
continue
} else if errors.Is(err, provider.ErrAPIKeyValidationSoft) {
fmt.Printf("无法确认环境变量 %s 中的 API key 有效性: %v\n", cfg.APIKeyEnvVarName(), err)
result, handleErr := handleSetupDecision(scanner, cfg, true, configPath)
if handleErr != nil {
return false, handleErr
}
if result == setupExit {
return false, nil
}
if result == setupContinue {
configs.GlobalAppConfig = cfg
return true, nil
}
continue
} else {
fmt.Printf("模型验证失败: %v\n", err)
result, handleErr := handleSetupDecision(scanner, cfg, false, configPath)
if handleErr != nil {
return false, handleErr
}
if result == setupExit {
return false, nil
}
if result == setupContinue {
configs.GlobalAppConfig = cfg
return true, nil
}
}
workspaceFlag := fs.String("workspace", "", "指定工作区根目录")
if err := fs.Parse(args); err != nil {
return "", err
}
return *workspaceFlag, nil
}

type setupDecision int

const (
setupRetry setupDecision = iota
setupContinue
setupExit
)
func run(workspaceFlag string, stdin io.Reader, stdout, stderr io.Writer) error {
return runWithDeps(workspaceFlag, buildRunDeps(stdin, stdout, stderr))
}

func handleSetupDecision(scanner *bufio.Scanner, cfg *configs.AppConfiguration, allowContinue bool, configPath string) (setupDecision, error) {
for {
prompt := "选择 /retry, /apikey <env_name>, /provider <name>, /switch <model>, 或 /exit > "
if allowContinue {
prompt = "选择 /retry, /continue, /apikey <env_name>, /provider <name>, /switch <model>, 或 /exit > "
}
decision, ok, inputErr := readInteractiveLine(scanner, prompt)
if inputErr != nil {
return setupExit, inputErr
}
if !ok {
return setupExit, nil
}
func runWithDeps(workspaceFlag string, deps runDeps) error {
if deps.setUTF8Mode != nil {
deps.setUTF8Mode()
}

fields := strings.Fields(strings.TrimSpace(decision))
if len(fields) == 0 {
continue
}
workspaceRoot, err := deps.prepareWorkspace(workspaceFlag)
if err != nil {
return fmt.Errorf("解析工作区失败: %w", err)
}

switch strings.ToLower(fields[0]) {
case "/retry":
return setupRetry, nil
case "/apikey":
if len(fields) < 2 {
fmt.Println("用法: /apikey <env_name>")
continue
}
applyAPIKeyEnvName(cfg, fields[1])
fmt.Printf("已切换 API Key 环境变量名为: %s\n", cfg.APIKeyEnvVarName())
return setupRetry, nil
case "/continue":
if !allowContinue {
fmt.Println("/continue 仅在网络或服务问题导致无法确认时可用。")
continue
}
if saveErr := configs.WriteAppConfig(configPath, cfg); saveErr != nil {
return setupExit, saveErr
}
fmt.Println("继续启动,使用当前 API key 和模型。")
return setupContinue, nil
case "/provider":
if len(fields) < 2 {
fmt.Println("用法: /provider <name>")
printSupportedProviders()
continue
}
providerName, ok := provider.NormalizeProviderName(fields[1])
if !ok {
fmt.Printf("不支持的提供商 %q\n", fields[1])
printSupportedProviders()
continue
}
cfg.AI.Provider = providerName
cfg.AI.Model = provider.DefaultModelForProvider(providerName)
fmt.Printf("已切换到提供商: %s\n", providerName)
fmt.Printf("当前模型已重置为默认值: %s\n", cfg.AI.Model)
return setupRetry, nil
case "/switch":
if len(fields) < 2 {
fmt.Println("用法: /switch <model>")
continue
}
target := strings.Join(fields[1:], " ")
cfg.AI.Model = target
fmt.Printf("已切换到模型: %s\n", target)
return setupRetry, nil
case "/exit":
return setupExit, nil
default:
if allowContinue {
fmt.Println("请输入 /retry, /continue, /apikey <env_name>, /provider <name>, /switch <model>, 或 /exit。")
} else {
fmt.Println("请输入 /retry, /apikey <env_name>, /provider <name>, /switch <model>, 或 /exit。")
}
}
scanner := bufio.NewScanner(deps.stdin)
ready, err := deps.ensureAPIKeyInteractive(context.Background(), scanner, defaultConfigPath)
if err != nil {
return fmt.Errorf("初始化配置失败: %w", err)
}
if !ready {
fmt.Fprintln(deps.stdout, "已退出 NeoCode")
return nil
}
}

func applyAPIKeyEnvName(cfg *configs.AppConfiguration, envName string) {
if cfg == nil {
return
if err := deps.loadAppConfig(defaultConfigPath); err != nil {
return fmt.Errorf("加载配置失败: %w", err)
}
cfg.AI.APIKey = strings.TrimSpace(envName)
}

func readInteractiveLine(scanner *bufio.Scanner, prompt string) (string, bool, error) {
for {
fmt.Print(prompt)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return "", false, err
}
return "", false, nil
}
input := strings.TrimSpace(scanner.Text())
if input == "" {
fmt.Println("输入不能为空。")
continue
}
if input == "/exit" {
return "", false, nil
}
return input, true, nil
persona, personaPath, err := deps.loadPersonaPrompt(configs.GlobalAppConfig.Persona.FilePath)
if err != nil {
fmt.Fprintf(deps.stderr, "警告: 人设加载失败: %v\n", err)
} else if personaPath != "" && strings.TrimSpace(configs.GlobalAppConfig.Persona.FilePath) != personaPath {
fmt.Fprintf(deps.stderr, "提示: 人设已从 %s 回退加载\n", personaPath)
}
}

func printSupportedProviders() {
fmt.Println("可用提供商:")
for _, name := range provider.SupportedProviders() {
fmt.Printf(" %s\n", name)
historyTurns := configs.GlobalAppConfig.History.ShortTermTurns
p, err := deps.newProgram(persona, historyTurns, defaultConfigPath, workspaceRoot)
if err != nil {
return fmt.Errorf("初始化失败: %w", err)
}
if _, err := p.Run(); err != nil {
return fmt.Errorf("运行失败: %w", err)
}

return nil
}

func loadDotEnv(path string) error {
Expand Down Expand Up @@ -314,4 +167,3 @@ func loadPersonaPrompt(path string) string {

return strings.TrimSpace(string(data))
}

Loading