Skip to content
This repository was archived by the owner on Jan 2, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 60 additions & 33 deletions server/bleep/src/webserver.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{env::Feature, snippet, state, Application};

use axum::middleware;
use axum::{response::IntoResponse, routing::get, Extension, Json};
use axum::{http::StatusCode, response::IntoResponse, routing::get, Extension, Json};
use std::{borrow::Cow, net::SocketAddr};
use tower::Service;
use tower_http::services::{ServeDir, ServeFile};
Expand All @@ -26,7 +26,7 @@ pub type Router<S = Application> = axum::Router<S>;

#[allow(unused)]
pub(in crate::webserver) mod prelude {
pub(in crate::webserver) use super::{error, json, EndpointError, ErrorKind};
pub(in crate::webserver) use super::{json, EndpointError, Error, ErrorKind, Result};
pub(in crate::webserver) use crate::indexes::Indexes;
pub(in crate::webserver) use axum::{
extract::Query, http::StatusCode, response::IntoResponse, Extension,
Expand Down Expand Up @@ -127,54 +127,81 @@ where
Json(Response::from(val))
}

pub(in crate::webserver) type Result<T, E = Error> = std::result::Result<T, E>;
pub(in crate::webserver) type Error = Json<Response<'static>>;
type Result<T, E = Error> = std::result::Result<T, E>;

pub(in crate::webserver) fn error(kind: ErrorKind, message: impl Into<Cow<'static, str>>) -> Error {
Json(Response::from(EndpointError {
kind,
message: message.into(),
}))
struct Error {
status: StatusCode,
body: Json<Response<'static>>,
}

pub(in crate::webserver) fn internal_error<S: std::fmt::Display>(message: S) -> Error {
Json(Response::from(EndpointError {
kind: ErrorKind::Internal,
message: message.to_string().into(),
}))
impl Error {
fn new(kind: ErrorKind, message: impl Into<Cow<'static, str>>) -> Error {
let status = match kind {
ErrorKind::Configuration
| ErrorKind::Unknown
| ErrorKind::UpstreamService
| ErrorKind::Internal
| ErrorKind::Custom => StatusCode::INTERNAL_SERVER_ERROR,
ErrorKind::User => StatusCode::BAD_REQUEST,
ErrorKind::NotFound => StatusCode::NOT_FOUND,
};

let body = Json(Response::from(EndpointError {
kind,
message: message.into(),
}));

Error { status, body }
}

fn with_status(mut self, status_code: StatusCode) -> Self {
self.status = status_code;
self
}

fn internal<S: std::fmt::Display>(message: S) -> Self {
Error {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: Json(Response::from(EndpointError {
kind: ErrorKind::Internal,
message: message.to_string().into(),
})),
}
}

fn user<S: std::fmt::Display>(message: S) -> Self {
Error {
status: StatusCode::BAD_REQUEST,
body: Json(Response::from(EndpointError {
kind: ErrorKind::User,
message: message.to_string().into(),
})),
}
}
}

impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
(self.status, self.body).into_response()
}
}

/// The response upon encountering an error
#[derive(serde::Serialize, PartialEq, Eq, ToSchema, Debug)]
pub(in crate::webserver) struct EndpointError<'a> {
struct EndpointError<'a> {
/// The kind of this error
pub kind: ErrorKind,
kind: ErrorKind,

/// A context aware message describing the error
pub message: Cow<'a, str>,
}

impl<'a> EndpointError<'a> {
fn user(message: Cow<'a, str>) -> Self {
Self {
kind: ErrorKind::User,
message,
}
}
fn internal(message: Cow<'a, str>) -> Self {
Self {
kind: ErrorKind::Internal,
message,
}
}
message: Cow<'a, str>,
}

/// The kind of an error
#[allow(unused)]
#[derive(serde::Serialize, PartialEq, Eq, ToSchema, Debug)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub(in crate::webserver) enum ErrorKind {
enum ErrorKind {
User,
Unknown,
NotFound,
Expand Down
91 changes: 33 additions & 58 deletions server/bleep/src/webserver/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ use std::{

use axum::{
extract::Query,
http::StatusCode,
response::{sse::Event, IntoResponse, Sse},
Extension, Json,
Extension,
};
use futures::{Stream, StreamExt, TryStreamExt};
use thiserror::Error;
Expand All @@ -25,7 +24,7 @@ use crate::{
Application,
};

use super::ErrorKind;
use super::prelude::*;

/// Mirrored from `answer_api/lib.rs` to avoid private dependency.
pub mod api {
Expand Down Expand Up @@ -104,7 +103,7 @@ const SNIPPET_COUNT: usize = 13;
pub(super) async fn handle(
Query(params): Query<Params>,
Extension(app): Extension<Application>,
) -> Result<impl IntoResponse, (StatusCode, Json<super::Response<'static>>)> {
) -> Result<impl IntoResponse> {
// create a new analytics event for this query
let event = Arc::new(RwLock::new(QueryEvent::default()));

Expand All @@ -126,16 +125,14 @@ async fn _handle(
params: Params,
app: Application,
event: Arc<RwLock<QueryEvent>>,
) -> Result<impl IntoResponse, (StatusCode, Json<super::Response<'static>>)> {
) -> Result<impl IntoResponse> {
let query_id = uuid::Uuid::new_v4();
let mut stop_watch = StopWatch::start();

let semantic = app.semantic.clone().ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
super::error(ErrorKind::Configuration, "Qdrant not configured"),
)
})?;
let semantic = app
.semantic
.clone()
.ok_or_else(|| Error::new(ErrorKind::Configuration, "Qdrant not configured"))?;

let mut analytics_event = event.write().await;

Expand All @@ -147,28 +144,15 @@ async fn _handle(
.stages
.push(Stage::new("user query", &params.q).with_time(stop_watch.lap()));

let query = parser::parse_nl(&params.q).map_err(|e| {
(
StatusCode::BAD_REQUEST,
super::error(ErrorKind::User, e.to_string()),
)
})?;
let target = query.target().ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
super::error(ErrorKind::User, "missing search target".to_owned()),
)
})?;
let query = parser::parse_nl(&params.q).map_err(Error::user)?;
let target = query
.target()
.ok_or_else(|| Error::user("missing search target"))?;

let all_snippets: Vec<Snippet> = semantic
.search(&query, 4 * SNIPPET_COUNT as u64) // heuristic
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
super::error(ErrorKind::Internal, e.to_string()),
)
})?
.map_err(Error::internal)?
.into_iter()
.map(|r| {
use qdrant_client::qdrant::{value::Kind, Value};
Expand Down Expand Up @@ -253,10 +237,7 @@ async fn _handle(

if snippets.is_empty() {
warn!("Semantic search returned no snippets");
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
(super::internal_error("semantic search returned no snippets")),
));
return Err(Error::internal("semantic search returned no snippets"));
} else {
info!("Semantic search returned {} snippets", snippets.len());
}
Expand All @@ -272,8 +253,7 @@ async fn _handle(

let relevant_snippet_index = answer_api_client
.select_snippet(&select_prompt)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?
.await?
.trim()
.to_string()
.clone();
Expand All @@ -282,40 +262,36 @@ async fn _handle(

let mut relevant_snippet_index = relevant_snippet_index
.parse::<usize>()
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, super::internal_error(e)))?;
.map_err(Error::internal)?;

analytics_event.stages.push(
Stage::new("relevant snippet index", &relevant_snippet_index).with_time(stop_watch.lap()),
);

if relevant_snippet_index == 0 {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
super::internal_error("None of the snippets help answer the question"),
return Err(Error::internal(
"None of the snippets help answer the question",
));
}

relevant_snippet_index -= 1; // return to 0-indexing
let relevant_snippet = snippets.get(relevant_snippet_index).ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
super::internal_error("answer-api returned out-of-bounds index"),
)
})?;
let relevant_snippet = snippets
.get(relevant_snippet_index)
.ok_or_else(|| Error::internal("answer-api returned out-of-bounds index"))?;

// grow the snippet by 60 lines above and below, we have sufficient space
// to grow this snippet by 10 times its original size (15 to 150)
let processed_snippet = {
let repo_ref = &relevant_snippet
.repo_ref
.parse::<RepoRef>()
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, super::internal_error(e)))?;
.map_err(Error::internal)?;
let doc = app
.indexes
.file
.by_path(repo_ref, &relevant_snippet.relative_path)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, super::internal_error(e)))?;
.map_err(Error::internal)?;

let mut grow_size = 40;
let grown_text = loop {
Expand Down Expand Up @@ -355,14 +331,14 @@ async fn _handle(
user_id: params.user_id.clone(),
answer_path,
}))
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, super::internal_error(e)))?;
.map_err(Error::internal)?;

let explain_prompt = answer_api_client.build_explain_prompt(&processed_snippet);

let mut snippet_explanation = answer_api_client
.explain_snippet(&explain_prompt)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, super::internal_error(e)))
.map_err(Error::internal)
.map(Box::pin)?;

drop(analytics_event);
Expand Down Expand Up @@ -443,13 +419,13 @@ enum AnswerAPIError {
BadRequest(#[from] api::Error),
}

impl From<AnswerAPIError> for super::Error {
fn from(e: AnswerAPIError) -> super::Error {
impl From<AnswerAPIError> for Error {
fn from(e: AnswerAPIError) -> Error {
sentry::capture_message(
format!("answer-api failed to respond: {e}").as_str(),
sentry::Level::Error,
);
super::error(ErrorKind::UpstreamService, e.to_string())
Error::new(ErrorKind::UpstreamService, e.to_string())
}
}

Expand Down Expand Up @@ -497,8 +473,8 @@ impl<'s> AnswerAPIClient<'s> {

match stream.next().await {
Some(Ok(reqwest_eventsource::Event::Open)) => {}
Some(Err(e)) => Err(AnswerAPIError::EventSource(e))?,
_ => Err(AnswerAPIError::StreamFail)?,
Some(Err(e)) => return Err(AnswerAPIError::EventSource(e)),
_ => return Err(AnswerAPIError::StreamFail),
}

Ok(stream
Expand All @@ -513,8 +489,7 @@ impl<'s> AnswerAPIClient<'s> {
.map(|result| match result {
Ok(s) => Ok(serde_json::from_str::<api::Result>(&s)??),
Err(e) => Err(AnswerAPIError::EventSource(e)),
})
.map(|result: Result<String, AnswerAPIError>| result))
}))
}

async fn send_until_success(
Expand Down Expand Up @@ -596,13 +571,13 @@ Answer in GitHub Markdown:",
prompt
}

async fn select_snippet(&self, prompt: &str) -> super::Result<String> {
async fn select_snippet(&self, prompt: &str) -> Result<String> {
self.send_until_success(prompt, 1, 0.0).await.map_err(|e| {
sentry::capture_message(
format!("answer-api failed to respond: {e}").as_str(),
sentry::Level::Error,
);
super::error(ErrorKind::UpstreamService, e.to_string())
Error::new(ErrorKind::UpstreamService, e.to_string())
})
}

Expand Down
Loading