diff --git a/engine/language_server/src/baml_project/mod.rs b/engine/language_server/src/baml_project/mod.rs index c9266d2d31..f537c8c8c5 100644 --- a/engine/language_server/src/baml_project/mod.rs +++ b/engine/language_server/src/baml_project/mod.rs @@ -1255,84 +1255,80 @@ impl Project { /// Checks if all generators use the same major.minor version. /// Returns Ok(()) if they do, /// otherwise returns an Err with a descriptive message. - pub fn get_common_generator_version( - &self, - feature_flags: &[String], - client_version: Option<&str>, - ) -> anyhow::Result { - let runtime_version = env!("CARGO_PKG_VERSION"); - + pub fn get_common_generator_version(&self) -> anyhow::Result { // list generators. If we can't get the runtime, we'll error out. let generators = self .runtime()? .codegen_generators() .map(|gen| gen.version.as_str()); - // add runtime version on top since that's what we want to compare with. - let gen_version_strings = [runtime_version] - .into_iter() - .chain(client_version) - .chain(generators); - - let mut major_minor_versions = std::collections::HashMap::new(); - let mut highest_patch_by_major_minor = std::collections::HashMap::new(); - - // Track major.minor versions and find highest patch for each - for version_str in gen_version_strings { - if let Ok(version) = semver::Version::parse(version_str) { - let major_minor = format!("{}.{}", version.major, version.minor); - - // Track generators with this major.minor - major_minor_versions - .entry(major_minor.clone()) - .or_insert_with(Vec::new) - .push(version_str); - - // Track highest patch version for this major.minor - highest_patch_by_major_minor - .entry(major_minor) - .and_modify(|highest_patch: &mut u64| { - if version.patch > *highest_patch { - *highest_patch = version.patch; - } - }) - .or_insert(version.patch); - } else { - tracing::warn!("Invalid semver version in generator: {}", version_str); - // Consider how to handle invalid versions - for now, we ignore them for the check - } + common_version_up_to_patch(generators) + } +} + +/// Given a set of SemVer version strings, match them to the same `major.minor`, returning an error otherwise. Invalid semver strings are ignored for the check. +/// an error otherwise. +pub fn common_version_up_to_patch<'a>( + gen_version_strings: impl IntoIterator, +) -> anyhow::Result { + let mut major_minor_versions = std::collections::HashMap::new(); + let mut highest_patch_by_major_minor = std::collections::HashMap::new(); + + // Track major.minor versions and find highest patch for each + for version_str in gen_version_strings { + if let Ok(version) = semver::Version::parse(version_str) { + let major_minor = format!("{}.{}", version.major, version.minor); + + // Track generators with this major.minor + major_minor_versions + .entry(major_minor.clone()) + .or_insert_with(Vec::new) + .push(version_str); + + // Track highest patch version for this major.minor + highest_patch_by_major_minor + .entry(major_minor) + .and_modify(|highest_patch: &mut u64| { + if version.patch > *highest_patch { + *highest_patch = version.patch; + } + }) + .or_insert(version.patch); + } else { + tracing::warn!("Invalid semver version in generator: {}", version_str); + // Consider how to handle invalid versions - for now, we ignore them for the check } + } - // If there's more than one major.minor version, return an error - if major_minor_versions.len() > 1 { - let versions_str = major_minor_versions - .keys() - .map(|v| format!("'{v}'")) - .collect::>() - .join(", "); + // If there's more than one major.minor version, return an error + if major_minor_versions.len() > 1 { + let versions_str = major_minor_versions + .keys() + .map(|v| format!("'{v}'")) + .collect::>() + .join(", "); - let message = anyhow::anyhow!( - "Multiple generator major.minor versions detected: {versions_str}. Major and minor versions must match across all generators." - ); - Err(message) - // If there's only one major.minor version, return it with the highest patch - } else if let Some((version, _)) = major_minor_versions.iter().next() { - if let Some(highest_patch) = highest_patch_by_major_minor.get(version) { - // Parse the version string to create a proper semver::Version - if let Ok(mut v) = Version::parse(&format!("{version}.0")) { - // Update with the highest patch version - v.patch = *highest_patch; - Ok(v.to_string()) - } else { - Ok(format!("{version}.{highest_patch}")) - } + let message = anyhow::anyhow!( + "Multiple major.minor versions detected: {versions_str}. Major and minor versions must match across all generators." + ); + Err(message) + // If there's only one major.minor version, return it with the highest patch + } else if let Some((version, _)) = major_minor_versions.into_iter().next() { + if let Some(highest_patch) = highest_patch_by_major_minor.get(&version) { + // Parse the version string to create a proper semver::Version + if let Ok(mut v) = Version::parse(&format!("{version}.0")) { + // Update with the highest patch version + v.patch = *highest_patch; + Ok(v.to_string()) } else { - Ok(version.clone()) + Ok(format!("{version}.{highest_patch}")) } - // Fallback to the runtime version if no valid versions were found } else { - Err(anyhow::anyhow!("No valid generator versions found")) + Ok(version) } + // Fallback to the runtime version if no valid versions were found + } else { + Err(anyhow::anyhow!("No valid generator versions found")) } } diff --git a/engine/language_server/src/server/api/diagnostics.rs b/engine/language_server/src/server/api/diagnostics.rs index 0e44030b44..27144876e3 100644 --- a/engine/language_server/src/server/api/diagnostics.rs +++ b/engine/language_server/src/server/api/diagnostics.rs @@ -268,9 +268,7 @@ pub fn project_diagnostics( } // Check for generator version mismatch as well. - if let Err(message) = guard - .get_common_generator_version(feature_flags, session.baml_settings.get_client_version()) - { + if let Err(message) = guard.get_common_generator_version() { // Add the diagnostic to all generators if let Ok(generators) = guard.list_generators(feature_flags) { // Need to list generators again to get their spans diff --git a/engine/language_server/src/server/api/notifications/did_open.rs b/engine/language_server/src/server/api/notifications/did_open.rs index f533e7f060..4dd4dc6bc6 100644 --- a/engine/language_server/src/server/api/notifications/did_open.rs +++ b/engine/language_server/src/server/api/notifications/did_open.rs @@ -7,7 +7,10 @@ use crate::{ server::{ api::{ diagnostics::publish_session_lsp_diagnostics, - notifications::baml_src_version::BamlSrcVersionPayload, + notifications::{ + baml_src_version::BamlSrcVersionPayload, + did_save_text_document::send_generator_version, + }, traits::{NotificationHandler, SyncNotificationHandler}, ResultExt, }, @@ -78,22 +81,9 @@ impl SyncNotificationHandler for DidOpenTextDocumentHandler { .as_ref() .unwrap_or(&default_flags); let client_version = session.baml_settings.get_client_version(); - if let Ok(version) = - locked.get_common_generator_version(effective_flags, client_version) - { - notifier - .0 - .send(lsp_server::Message::Notification( - lsp_server::Notification::new( - "baml_src_generator_version".to_string(), - BamlSrcVersionPayload { - version, - root_path: locked.root_path().to_string_lossy().to_string(), - }, - ), - )) - .internal_error()?; - } + + let generator_version = locked.get_common_generator_version(); + send_generator_version(¬ifier, &locked, generator_version.as_ref().ok()); } else { tracing::error!("Failed to get or create project for path: {:?}", file_path); show_err_msg!("Failed to get or create project for path: {:?}", file_path); diff --git a/engine/language_server/src/server/api/notifications/did_save_text_document.rs b/engine/language_server/src/server/api/notifications/did_save_text_document.rs index 94c7d7298c..53e4200865 100644 --- a/engine/language_server/src/server/api/notifications/did_save_text_document.rs +++ b/engine/language_server/src/server/api/notifications/did_save_text_document.rs @@ -3,6 +3,7 @@ use std::borrow::Cow; use lsp_types::{self as types, notification as notif, request::Request, ConfigurationParams}; use crate::{ + baml_project::{common_version_up_to_patch, Project}, server::{ api::{self, notifications::baml_src_version::BamlSrcVersionPayload, ResultExt}, client::{Notifier, Requester}, @@ -57,22 +58,41 @@ impl super::SyncNotificationHandler for DidSaveTextDocument { .unwrap_or(&default_flags); let client_version = session.baml_settings.get_client_version(); - let version = locked - .get_common_generator_version(effective_flags, client_version) - .map_err(|msg| api::Error { - error: anyhow::anyhow!(msg), - code: lsp_server::ErrorCode::InternalError, - })?; - - let _ = notifier.0.send(lsp_server::Message::Notification( - lsp_server::Notification::new( - "baml_src_generator_version".to_string(), - BamlSrcVersionPayload { - version, - root_path: locked.root_path().to_string_lossy().to_string(), - }, - ), - )); + // There are 3 components to check version of: + // - generators -> if they don't resolve to the same major/minor, then we'll error for now. + // - LSP client (vscode extension) + // - LSP server (CLI binary) + // + // Upon baml_src_generator_version notification, LSP client will replace the server version + // with the given version. + // If there's no generation version to be used, the notification won't be sent. + // + // Independently, the three versions will be checked against each other. If a major.minor + // version can't be reached, then nothing is going to be generated. + + let generator_version = locked.get_common_generator_version(); + + let opt_version = generator_version.as_ref().ok(); + send_generator_version(¬ifier, &locked, opt_version); + + // Make sure to check all available versions againt each other, & generate only if there's + // no errors. + + { + let gen_version_iter = generator_version.as_ref().map(AsRef::as_ref); + + let runtime_version = env!("CARGO_PKG_VERSION"); + let version_iter = [runtime_version] + .into_iter() + .chain(client_version) + .chain(gen_version_iter); + + // check all versions against each other, ignoring any errors + // in common generator version. + _ = common_version_up_to_patch(version_iter).internal_error()?; + // Make sure to propagate the generator version check as well. + _ = generator_version.internal_error()?; + } let default_flags2 = vec!["beta".to_string()]; let effective_flags = session @@ -96,6 +116,27 @@ impl super::SyncNotificationHandler for DidSaveTextDocument { } } +/// Upon `baml_src_generator_version` notification, LSP client will replace the server version +/// with the given version. +/// If there's no generation version to be used, the notification won't be sent. +pub(crate) fn send_generator_version( + notifier: &Notifier, + project: &Project, + opt_version: Option<&impl ToOwned>, +) { + if let Some(version) = opt_version.map(ToOwned::to_owned) { + let _ = notifier.0.send(lsp_server::Message::Notification( + lsp_server::Notification::new( + "baml_src_generator_version".to_string(), + BamlSrcVersionPayload { + version, + root_path: project.root_path().to_string_lossy().to_string(), + }, + ), + )); + } +} + // Do not use this yet, it seems it has an outdated view of the project files and it generates // stale baml clients impl super::BackgroundDocumentNotificationHandler for DidSaveTextDocument {