Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions src/aish/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""CLI entry point for AI Shell."""

import os
import shutil
import subprocess
import sys
from enum import Enum
from typing import Optional

import anyio
Expand All @@ -14,6 +17,12 @@
from .i18n import t
from .i18n.typer import I18nTyperCommand, I18nTyperGroup
from .logging_utils import init_logging
from .openai_codex import (OPENAI_CODEX_DEFAULT_CALLBACK_PORT,
OPENAI_CODEX_DEFAULT_MODEL,
OpenAICodexAuthError,
load_openai_codex_auth,
login_openai_codex_with_browser,
login_openai_codex_with_device_code)
from .shell import AIShell
from .skills import SkillManager
from .wizard.setup_wizard import (needs_interactive_setup,
Expand All @@ -29,6 +38,16 @@
)

console = Console()
models_app = typer.Typer(help="Manage models and provider auth", cls=I18nTyperGroup)
models_auth_app = typer.Typer(help="Manage provider login state", cls=I18nTyperGroup)
models_app.add_typer(models_auth_app, name="auth")
app.add_typer(models_app, name="models")


class OpenAICodexAuthFlow(str, Enum):
BROWSER = "browser"
DEVICE_CODE = "device-code"
CODEX_CLI = "codex-cli"


def _load_raw_yaml_config(config_file: str | os.PathLike[str]) -> dict:
Expand Down Expand Up @@ -71,6 +90,10 @@ def get_effective_config(
if api_key_env:
config_data["api_key"] = api_key_env

codex_auth_path_env = os.getenv("AISH_CODEX_AUTH_PATH")
if codex_auth_path_env:
config_data["codex_auth_path"] = codex_auth_path_env

# Override with command line arguments (highest priority)
if model is not None:
config_data["model"] = model
Expand Down Expand Up @@ -222,6 +245,136 @@ def setup(
sys.exit(1)


@models_auth_app.command("login", cls=I18nTyperCommand)
def models_auth_login(
provider: str = typer.Option(
...,
"--provider",
help="Provider id to log in (currently only openai-codex).",
),
model: str = typer.Option(
OPENAI_CODEX_DEFAULT_MODEL,
"--model",
help="Default OpenAI Codex model to store in config after login.",
),
set_default: bool = typer.Option(
True,
"--set-default/--no-set-default",
help="Update the config model to the OpenAI Codex model after login.",
),
auth_flow: OpenAICodexAuthFlow = typer.Option(
OpenAICodexAuthFlow.BROWSER,
"--auth-flow",
help="Auth flow to use: browser, device-code, or codex-cli.",
),
force: bool = typer.Option(
False,
"--force/--no-force",
help="Force a fresh OpenAI Codex login even if local auth already exists.",
),
open_browser: bool = typer.Option(
True,
"--open-browser/--no-open-browser",
help="Open the browser automatically for browser auth.",
),
callback_port: int = typer.Option(
OPENAI_CODEX_DEFAULT_CALLBACK_PORT,
"--callback-port",
min=0,
max=65535,
help="Local callback port for browser auth. Use 0 for an ephemeral port.",
),
config_file: Optional[str] = typer.Option(
None,
"--config",
"-c",
help=t("cli.option.config"),
),
):
normalized_provider = provider.strip().lower().replace("_", "-")
if normalized_provider != "openai-codex":
console.print(
"Only `--provider openai-codex` is supported right now.",
style="red",
)
raise typer.Exit(1)

try:
config = Config(config_file_path=config_file)
except FileNotFoundError as exc:
console.print(t("cli.startup.config_file_error", error=str(exc)), style="red")
console.print(t("cli.startup.config_file_hint"), style="dim")
raise typer.Exit(1) from exc

auth_path = getattr(config.model_config, "codex_auth_path", None)
auth_state = None
if not force:
try:
auth_state = load_openai_codex_auth(auth_path)
except OpenAICodexAuthError:
auth_state = None

if auth_state is None:
try:
if auth_flow == OpenAICodexAuthFlow.BROWSER:
auth_state = login_openai_codex_with_browser(
auth_path=auth_path,
open_browser=open_browser,
callback_port=callback_port,
notify=lambda message: console.print(message, style="dim"),
)
elif auth_flow == OpenAICodexAuthFlow.DEVICE_CODE:
auth_state = login_openai_codex_with_device_code(
auth_path=auth_path,
notify=lambda message: console.print(message, style="dim"),
)
else:
codex_bin = shutil.which("codex")
if not codex_bin:
console.print(
"The `codex` CLI is not installed. Install `@openai/codex` or use "
"`--auth-flow browser` / `--auth-flow device-code`.",
style="red",
)
raise typer.Exit(1)

try:
subprocess.run([codex_bin, "login"], check=True)
except subprocess.CalledProcessError as exc:
console.print(
f"`codex login` failed with exit code {exc.returncode}.",
style="red",
)
raise typer.Exit(exc.returncode or 1) from exc
except KeyboardInterrupt as exc:
raise typer.Exit(1) from exc

auth_state = load_openai_codex_auth(auth_path)
except OpenAICodexAuthError as exc:
console.print(str(exc), style="red")
raise typer.Exit(1) from exc

config_data = config.model_config.model_dump()
config_data["codex_auth_path"] = str(auth_state.auth_path)
if set_default:
config_data["model"] = f"openai-codex/{model.strip() or OPENAI_CODEX_DEFAULT_MODEL}"
config_data["api_key"] = None
config.config_model = ConfigModel.model_validate(config_data)
config.save_config()

console.print(
f"OpenAI Codex auth ready: {auth_state.auth_path}",
style="green",
)
if set_default:
console.print(f"Default model set to {config.config_model.model}", style="green")
else:
console.print(
f"OpenAI Codex model available: openai-codex/{model.strip() or OPENAI_CODEX_DEFAULT_MODEL}",
style="dim",
)


@app.command(help=t("cli.check_tool_support_command_help"), cls=I18nTyperCommand)
def check_tool_support(
model: str = typer.Option(
Expand Down Expand Up @@ -338,6 +491,12 @@ def info():
# Check Langfuse configuration
aish check-langfuse

# Log into OpenAI Codex account auth
aish models auth login --provider openai-codex

# Use built-in device-code auth on headless servers
aish models auth login --provider openai-codex --auth-flow device-code

# Use config file
cat > ~/.config/aish/config.yaml << EOF
model: gpt-4
Expand Down
7 changes: 7 additions & 0 deletions src/aish/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ class ConfigModel(BaseModel):
default=None, description="Custom API base URL (e.g., for OpenRouter)"
)
api_key: Optional[str] = Field(default=None, description="API key for the service")
codex_auth_path: Optional[str] = Field(
default=None,
description=(
"Path to OpenAI Codex auth.json. Defaults to $AISH_CODEX_AUTH_PATH, "
"$CODEX_HOME/auth.json, or ~/.codex/auth.json"
),
)
temperature: float = Field(
default=0.7, ge=0.0, le=2.0, description="Temperature for LLM responses"
)
Expand Down
78 changes: 58 additions & 20 deletions src/aish/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from aish.i18n import t
from aish.interruption import ShellState
from aish.litellm_loader import load_litellm
from aish.openai_codex import (create_openai_codex_chat_completion,
is_openai_codex_model)
from aish.prompts import PromptManager
from aish.skills import SkillManager
from aish.tools.base import ToolBase
Expand Down Expand Up @@ -445,6 +447,39 @@ def _get_stream_chunk_builder(self):
self._stream_chunk_builder_func = litellm.stream_chunk_builder
return self._stream_chunk_builder_func

def _uses_openai_codex(self) -> bool:
return is_openai_codex_model(self.model)

async def _create_completion_response(
self,
*,
messages: list[dict],
stream: bool,
tools: Optional[list[dict]] = None,
tool_choice: str = "auto",
**kwargs,
):
if self._uses_openai_codex():
return await create_openai_codex_chat_completion(
model=self.model,
messages=messages,
tools=tools,
tool_choice=tool_choice,
api_base=self.api_base,
auth_path=getattr(self.config, "codex_auth_path", None),
timeout=float(kwargs.get("timeout", 300)),
)

acompletion = self._get_acompletion()
return await acompletion(
model=self.model,
api_base=self.api_base,
api_key=self.api_key,
messages=messages,
stream=stream,
**kwargs,
)

def update_model(
self,
model: str,
Expand All @@ -469,6 +504,9 @@ def update_model(

async def _background_initialize(self):
"""后台初始化 litellm 模块,使用独立线程避免阻塞事件循环"""
if self._uses_openai_codex():
return

async with self._init_lock:
if self._initialized:
return
Expand Down Expand Up @@ -533,6 +571,9 @@ def _sync_initialize(self):

async def _ensure_initialized(self):
"""确保 litellm 已初始化,等待后台初始化完成"""
if self._uses_openai_codex():
return

# 如果已经初始化完成,直接返回
if self._initialized:
return
Expand Down Expand Up @@ -575,7 +616,7 @@ async def _ensure_initialized_with_retry(
max_retries: 最大重试次数,默认 5 次
retry_delay: 重试间隔(秒),默认 0.5 秒
"""
if self._initialized:
if self._uses_openai_codex() or self._initialized:
return

last_error = None
Expand Down Expand Up @@ -1064,7 +1105,8 @@ async def pre_execute_tool(

def _trim_messages(self, messages: list[dict]) -> list[dict]:
"""Trim messages to keep under token limit"""
# return messages
if self._uses_openai_codex():
return messages
old_size = len(messages)
trim_messages = self._get_trim_messages()
new_messages = trim_messages(messages, model=self.model)
Expand Down Expand Up @@ -1320,9 +1362,10 @@ async def process_input(
stream = bool(merged_kwargs.pop("stream"))
except Exception:
merged_kwargs.pop("stream")
actual_stream = stream and not self._uses_openai_codex()

events.emit_generation_start(
generation_type=generation_type, stream=stream
generation_type=generation_type, stream=actual_stream
)

# Get Langfuse metadata
Expand All @@ -1347,19 +1390,15 @@ async def process_input(
):
raise anyio.get_cancelled_exc_class()

acompletion = self._get_acompletion()
response = await acompletion(
model=self.model,
api_base=self.api_base,
api_key=self.api_key,
response = await self._create_completion_response(
messages=messages,
tools=tools_spec,
tool_choice="auto",
stream=stream,
stream=actual_stream,
**merged_kwargs,
)

if stream:
if actual_stream:
content_acc = ""
reasoning_acc = ""
stream_chunks: list[object] = []
Expand Down Expand Up @@ -1543,7 +1582,7 @@ async def process_input(

content = msg.get("content")
if content:
if stream:
if actual_stream:
if has_tool_calls and not content_preview_started:
events.emit_content_delta(
delta=content, accumulated=content, is_final=False
Expand Down Expand Up @@ -1573,7 +1612,7 @@ async def process_input(
tool_calls, context_manager, system_message, output
)

if not stream:
if not actual_stream:
events.emit_generation_end(
status="success", finish_reason=finish_reason
)
Expand Down Expand Up @@ -1643,25 +1682,24 @@ async def completion(
elif langfuse_metadata:
merged_kwargs["metadata"] = langfuse_metadata

events.emit_generation_start(generation_type=generation_type, stream=stream)
actual_stream = stream and not self._uses_openai_codex()
events.emit_generation_start(
generation_type=generation_type, stream=actual_stream
)

result = ""
try:
# 检查取消令牌,在开始 LLM 请求前
if self.cancellation_token and self.cancellation_token.is_cancelled():
raise anyio.get_cancelled_exc_class()

acompletion = self._get_acompletion()
response = await acompletion(
model=self.model,
api_base=self.api_base,
api_key=self.api_key,
response = await self._create_completion_response(
messages=messages,
stream=stream,
stream=actual_stream,
**merged_kwargs,
)

if stream:
if actual_stream:
reasoning_acc = ""
finish_reason = None
generation_status = "success"
Expand Down
Loading
Loading