Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 169 additions & 2 deletions src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ pub struct DaemonHandshake {
pub scope_prefix: Option<String>,
pub timings: bool,
pub allow_init: bool,
#[serde(default)]
pub allow_initialize_root_routing: bool,
pub client_identity: DaemonClientIdentity,
/// Version of the tracedecay binary that opened this connection.
///
Expand All @@ -219,6 +221,7 @@ impl DaemonHandshake {
scope_prefix,
timings,
allow_init,
allow_initialize_root_routing: false,
client_identity: DaemonClientIdentity::current()?,
client_version: binary_version().to_string(),
})
Expand Down Expand Up @@ -803,16 +806,49 @@ pub async fn proxy_transport_to_daemon(
replay_line: Option<String>,
transport: &mut impl McpTransport,
) -> Result<()> {
let mut routed_handshake = handshake.clone();
if let Some(line) = replay_line {
proxy_request_line_to_daemon(socket_path, handshake, &line, transport).await?;
update_proxy_handshake_from_initialize(&mut routed_handshake, &line).await;
proxy_request_line_to_daemon(socket_path, &routed_handshake, &line, transport).await?;
}

while let Some(line) = transport.read_line().await? {
proxy_request_line_to_daemon(socket_path, handshake, &line, transport).await?;
update_proxy_handshake_from_initialize(&mut routed_handshake, &line).await;
proxy_request_line_to_daemon(socket_path, &routed_handshake, &line, transport).await?;
}
Ok(())
}

#[cfg(unix)]
async fn update_proxy_handshake_from_initialize(handshake: &mut DaemonHandshake, line: &str) {
if !handshake.allow_initialize_root_routing {
return;
}
let Ok(request) = serde_json::from_str::<JsonRpcRequest>(line.trim()) else {
return;
};
if request.method != "initialize" {
return;
}
let Some(registry) =
crate::global_db::GlobalDb::open_at(&handshake.client_identity.global_db_path).await
else {
return;
};
let Some(project_path) = crate::mcp::server::resolve_initialize_roots_project_path(
request.params.as_ref(),
Some(&registry),
)
.await
else {
return;
};
if handshake.project_path.as_deref() != Some(project_path.as_path()) {
handshake.scope_prefix = None;
}
handshake.project_path = Some(project_path);
}

#[cfg(unix)]
async fn proxy_request_line_to_daemon(
socket_path: &Path,
Expand Down Expand Up @@ -1946,6 +1982,7 @@ mod tests {
scope_prefix: None,
timings: false,
allow_init: false,
allow_initialize_root_routing: false,
client_identity: test_client_identity(),
client_version: super::binary_version().to_string(),
}
Expand Down Expand Up @@ -2238,6 +2275,136 @@ mod tests {
daemon.await.expect("fake daemon task");
}

#[cfg(unix)]
#[tokio::test]
async fn proxy_transport_carries_initialize_root_into_followup_handshake() {
let dir = TempDir::new().expect("temp dir");
let temp_root = dir.path().canonicalize().expect("canonical temp dir");
let active_root = temp_root.join("active");
let target_root = temp_root.join("target");
std::fs::create_dir_all(active_root.join("src")).expect("active src");
std::fs::create_dir_all(target_root.join("src")).expect("target src");
let active = active_root.canonicalize().expect("active root");
let target = target_root.canonicalize().expect("target root");
let socket = temp_root.join("daemon.sock");
let client_identity = test_client_identity_for(temp_root.join("profile"));
let open_options = crate::tracedecay::TraceDecayOpenOptions {
profile_root: Some(client_identity.profile_root.clone()),
global_db_path: Some(client_identity.global_db_path.clone()),
};

std::fs::write(active.join("src/active.rs"), "pub fn active_marker() {}\n")
.expect("active source");
std::fs::write(target.join("src/target.rs"), "pub fn target_marker() {}\n")
.expect("target source");

let active_cg =
crate::tracedecay::TraceDecay::init_with_options(&active, open_options.clone())
.await
.expect("active init");
active_cg.index_all().await.expect("active index");
let target_cg =
crate::tracedecay::TraceDecay::init_with_options(&target, open_options.clone())
.await
.expect("target init");
target_cg.index_all().await.expect("target index");
let registry = crate::global_db::GlobalDb::open_at(&client_identity.global_db_path)
.await
.expect("registry");
registry
.upsert_code_project("proj_active_proxy", &active, None, None, Some("main"))
.await
.expect("active project registry");
registry
.upsert_code_project("proj_target_proxy", &target, None, None, Some("main"))
.await
.expect("target project registry");

let listener = tokio::net::UnixListener::bind(&socket).expect("daemon socket");
let engine = super::DaemonEngine::default();
let accept_task = tokio::spawn(async move {
let mut tasks = Vec::new();
for _ in 0..2 {
let (stream, _addr) = listener.accept().await.expect("accept daemon client");
let engine = engine.clone();
tasks.push(tokio::spawn(async move {
super::serve_socket_client(stream, engine)
.await
.expect("serve proxied client");
}));
}
for task in tasks {
task.await.expect("client task");
}
});

let (mut transport, sender, mut receiver) = crate::mcp::transport::ChannelTransport::new();
sender
.send(
serde_json::to_string(&json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"clientInfo": {"name": "codex", "version": "test"},
"roots": [{"uri": format!("file://{}", target.display()), "name": "target"}]
}
}))
.expect("initialize json"),
)
.expect("send initialize");
sender
.send(
serde_json::to_string(&json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {
"name": "tracedecay_files",
"arguments": {"format": "flat"}
}
}))
.expect("tools/call json"),
)
.expect("send tools/call");
drop(sender);

let handshake = DaemonHandshake {
project_path: Some(active.clone()),
allow_initialize_root_routing: true,
client_identity,
..test_handshake_defaults()
};
super::proxy_transport_to_daemon(&socket, &handshake, None, &mut transport)
.await
.expect("proxy transport");

let mut responses = Vec::new();
while let Ok(Some(line)) =
tokio::time::timeout(std::time::Duration::from_millis(100), receiver.recv()).await
{
responses.push(line);
}
let files_response = responses
.iter()
.map(|line| serde_json::from_str::<Value>(line.trim()).expect("response json"))
.find(|response| response["id"] == json!(2))
.expect("files response");
let text = files_response["result"]["content"][0]["text"]
.as_str()
.expect("files text");
assert!(
text.contains("src/target.rs"),
"daemon proxy should route follow-up tools/call to initialize root, got {text}"
);
assert!(
!text.contains("src/active.rs"),
"daemon proxy should not keep using the original handshake project: {text}"
);

accept_task.await.expect("daemon accept task");
}

#[cfg(unix)]
#[test]
fn scheduler_task_start_log_uses_task_key_and_project() {
Expand Down
Loading
Loading