diff --git a/src/language_client.rs b/src/language_client.rs index 224df21e..f0688544 100644 --- a/src/language_client.rs +++ b/src/language_client.rs @@ -2,17 +2,42 @@ use super::*; use crate::vim::Vim; use std::ops::DerefMut; -pub struct LanguageClient(pub Arc>); +pub struct LanguageClient { + pub state_mutex: Arc>, + pub clients_mutex: Arc>>>>, +} impl LanguageClient { // NOTE: Don't expose this as public. // MutexGuard could easily halt the program when one guard is not released immediately after use. fn lock(&self) -> Fallible> { - self.0 + self.state_mutex .lock() .map_err(|err| format_err!("Failed to lock state: {:?}", err)) } + // This fetches a mutex that is unique to the provided languageId. + // + // Here, we return a mutex instead of the mutex guard because we need to satisfy the borrow + // checker. Otherwise, there is no way to guarantee that the mutex in the hash map wouldn't be + // garbage collected as a result of another modification updating the hash map, while something was holding the lock + pub fn get_client_update_mutex(&self, languageId: LanguageId) -> Fallible>> { + let map_guard = self.clients_mutex.lock(); + if map_guard.is_err() { + return Err(format_err!( + "Failed to lock client creation for languageId {:?}: {:?}", + languageId, + map_guard.unwrap_err() + )); + } + let mut map = map_guard.unwrap(); + if !map.contains_key(&languageId) { + map.insert(languageId.clone(), Arc::new(Mutex::new(()))); + } + let mutex: Arc> = map.get(&languageId).unwrap().clone(); + Ok(mutex) + } + pub fn get(&self, f: impl FnOnce(&State) -> T) -> Fallible { Ok(f(self.lock()?.deref())) } diff --git a/src/language_server_protocol.rs b/src/language_server_protocol.rs index 7c0a88aa..050344f3 100644 --- a/src/language_server_protocol.rs +++ b/src/language_server_protocol.rs @@ -25,7 +25,10 @@ impl LanguageClient { pub fn loop_call(&self, rx: &crossbeam_channel::Receiver) -> Fallible<()> { for call in rx.iter() { - let language_client = Self(self.0.clone()); + let language_client = LanguageClient { + state_mutex: self.state_mutex.clone(), + clients_mutex: self.clients_mutex.clone(), // not sure if useful to clone this + }; thread::spawn(move || { if let Err(err) = language_client.handle_call(call) { error!("Error handling request:\n{:?}", err); @@ -2827,6 +2830,29 @@ impl LanguageClient { let cmdparams = vim_cmd_args_to_value(&cmdargs)?; let params = params.combine(&cmdparams); + // When multiple buffers get opened up concurrently, + // startServer gets called concurrently. + // This lock ensures that at most one language server is starting up at a time per + // languageId. + // We keep the mutex in scope to satisfy the borrow checker. + // This ensures that the mutex isn't garbage collected while the MutexGuard is held. + // + // - e.g. prevents starting multiple servers with `vim -p`. + // - This continues to allow distinct language servers to start up concurrently + // by languageId (e.g. java and rust) + // - Revisit this when more than one server is allowed per languageId. + // (ensure that the mutex is acquired by what starts the group of servers) + // + // TODO: May want to lock other methods that update the list of clients. + let mutex_for_language_id = self.get_client_update_mutex(Some(languageId.clone()))?; + let _raii_lock: MutexGuard<()> = mutex_for_language_id.lock().map_err(|err| { + format_err!( + "Failed to lock client creation for languageId {:?}: {:?}", + languageId, + err + ) + })?; + if self.get(|state| state.clients.contains_key(&Some(languageId.clone())))? { return Ok(json!({})); } diff --git a/src/main.rs b/src/main.rs index bdcedf0a..94d49fdb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,7 +61,10 @@ fn main() -> Fallible<()> { let _ = args.get_matches(); let (tx, rx) = crossbeam_channel::unbounded(); - let language_client = language_client::LanguageClient(Arc::new(Mutex::new(State::new(tx)?))); + let language_client = language_client::LanguageClient { + state_mutex: Arc::new(Mutex::new(State::new(tx)?)), + clients_mutex: Arc::new(Mutex::new(HashMap::new())), + }; language_client.loop_call(&rx) }