From b5018a5563bca6ee4276a532000cab0e32b8c266 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Sat, 14 Feb 2026 21:20:34 +0100 Subject: [PATCH] feat: add utoipa --- Cargo.lock | 18 ++ Cargo.toml | 2 + examples/axum/Cargo.toml | 4 +- examples/axum/src/error.rs | 5 +- examples/axum/src/user/errors.rs | 7 +- examples/axum/src/user/routes.rs | 20 +- examples/basic/src/main.rs | 16 +- packages/breach-macros/Cargo.toml | 1 + packages/breach-macros/src/http.rs | 33 +++- packages/breach-macros/src/http/attribute.rs | 57 +++++- packages/breach-macros/src/http/data.rs | 8 + packages/breach-macros/src/http/enum.rs | 132 +++++++++---- packages/breach-macros/src/http/struct.rs | 5 + packages/breach-macros/src/http/union.rs | 4 + packages/breach-macros/src/lib.rs | 2 +- packages/breach-macros/src/status.rs | 193 +++++++++++++++++++ packages/breach-macros/src/util.rs | 21 -- packages/breach/Cargo.toml | 5 +- packages/breach/src/lib.rs | 2 + packages/breach/src/utoipa.rs | 87 +++++++++ 20 files changed, 522 insertions(+), 100 deletions(-) create mode 100644 packages/breach-macros/src/status.rs delete mode 100644 packages/breach-macros/src/util.rs create mode 100644 packages/breach/src/utoipa.rs diff --git a/Cargo.lock b/Cargo.lock index 0df90ea..d4158a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -81,6 +81,8 @@ version = "0.0.1" dependencies = [ "breach-macros", "http", + "itertools", + "utoipa", ] [[package]] @@ -112,6 +114,7 @@ dependencies = [ name = "breach-macros" version = "0.0.1" dependencies = [ + "http", "proc-macro2", "quote", "syn", @@ -135,6 +138,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "equivalent" version = "1.0.2" @@ -294,6 +303,15 @@ dependencies = [ "serde_core", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.17" diff --git a/Cargo.toml b/Cargo.toml index c010fb3..ceb0c45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,9 +12,11 @@ version = "0.0.1" [workspace.dependencies] breach = { path = "./packages/breach", version = "0.0.1" } breach-macros = { path = "./packages/breach-macros", version = "0.0.1" } +http = "1.4.0" serde = "1.0.228" serde_json = "1.0.149" tokio = "1.49.0" +utoipa = "5.4.0" [workspace.lints.rust] unsafe_code = "deny" diff --git a/examples/axum/Cargo.toml b/examples/axum/Cargo.toml index 2d91c79..42a9fbe 100644 --- a/examples/axum/Cargo.toml +++ b/examples/axum/Cargo.toml @@ -12,10 +12,10 @@ version.workspace = true [dependencies] anyhow = "1.0.101" axum = "0.8.8" -breach.workspace = true +breach = { workspace = true, features = ["utoipa"] } serde = { workspace = true, features = ["derive"] } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } -utoipa = { version = "5.4.0", features = ["axum_extras", "uuid"] } +utoipa = { workspace = true, features = ["axum_extras", "uuid"] } utoipa-axum = "0.2.0" utoipa-scalar = { version = "0.3.0", features = ["axum"] } uuid = { version = "1.20.0", features = ["serde", "v7"] } diff --git a/examples/axum/src/error.rs b/examples/axum/src/error.rs index d8197d2..f5535b6 100644 --- a/examples/axum/src/error.rs +++ b/examples/axum/src/error.rs @@ -1,9 +1,10 @@ use breach::HttpError; use serde::Serialize; +use utoipa::ToSchema; use uuid::Uuid; -#[derive(HttpError, Serialize)] -#[http(status = NOT_FOUND)] +#[derive(HttpError, Serialize, ToSchema)] +#[http(status = NOT_FOUND, utoipa)] #[serde(rename_all = "camelCase")] pub struct NotFoundError { id: Uuid, diff --git a/examples/axum/src/user/errors.rs b/examples/axum/src/user/errors.rs index 498a770..70484bc 100644 --- a/examples/axum/src/user/errors.rs +++ b/examples/axum/src/user/errors.rs @@ -1,9 +1,10 @@ use breach::HttpError; use serde::Serialize; +use utoipa::ToSchema; use crate::error::NotFoundError; -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] #[serde( tag = "code", rename_all = "camelCase", @@ -14,6 +15,7 @@ pub enum UserValidationError { } #[derive(HttpError, Serialize)] +#[http(utoipa)] #[serde( tag = "code", rename_all = "camelCase", @@ -34,6 +36,7 @@ impl From for CreateUserError { } #[derive(HttpError, Serialize)] +#[http(utoipa)] #[serde( tag = "code", rename_all = "camelCase", @@ -47,6 +50,7 @@ pub enum GetUserByIdError { } #[derive(HttpError, Serialize)] +#[http(utoipa)] #[serde( tag = "code", rename_all = "camelCase", @@ -67,6 +71,7 @@ impl From for UpdateUserError { } #[derive(HttpError, Serialize)] +#[http(utoipa)] #[serde( tag = "code", rename_all = "camelCase", diff --git a/examples/axum/src/user/routes.rs b/examples/axum/src/user/routes.rs index 9d4f157..a4cc344 100644 --- a/examples/axum/src/user/routes.rs +++ b/examples/axum/src/user/routes.rs @@ -27,7 +27,7 @@ impl UserRoutes { } #[derive(HttpError, Serialize)] -#[http(axum)] +#[http(axum, utoipa)] pub enum CreateUserRouteError { CreateUser(CreateUserError), } @@ -47,7 +47,8 @@ impl From for CreateUserRouteError { tags = ["User"], request_body = CreateUser, responses( - (status = CREATED, description = "The user has been created.", body = User) + (status = CREATED, description = "The user has been created.", body = User), + CreateUserRouteError, ) )] async fn create_user( @@ -60,7 +61,7 @@ async fn create_user( } #[derive(HttpError, Serialize)] -#[http(axum)] +#[http(axum, utoipa)] pub enum GetUserRouteError { GetUserById(GetUserByIdError), } @@ -80,7 +81,8 @@ impl From for GetUserRouteError { tags = ["User"], params(UserPathParams), responses( - (status = OK, description = "The user.", body = User) + (status = OK, description = "The user.", body = User), + GetUserRouteError, ) )] async fn user( @@ -93,7 +95,7 @@ async fn user( } #[derive(HttpError, Serialize)] -#[http(axum)] +#[http(axum, utoipa)] pub enum UpdateUserRouteError { GetUserById(GetUserByIdError), @@ -122,7 +124,8 @@ impl From for UpdateUserRouteError { params(UserPathParams), request_body = UpdateUser, responses( - (status = OK, description = "The user has been updated.", body = User) + (status = OK, description = "The user has been updated.", body = User), + UpdateUserRouteError, ) )] async fn update_user( @@ -138,7 +141,7 @@ async fn update_user( } #[derive(HttpError, Serialize)] -#[http(axum)] +#[http(axum, utoipa)] pub enum DeleteUserRouteError { GetUserById(GetUserByIdError), @@ -166,7 +169,8 @@ impl From for DeleteUserRouteError { tags = ["User"], params(UserPathParams), responses( - (status = NO_CONTENT, description = "The user has been deleted.") + (status = NO_CONTENT, description = "The user has been deleted."), + DeleteUserRouteError, ) )] async fn delete_user( diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 6a630af..53aa5c6 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -3,6 +3,12 @@ use breach::{HttpError, http::StatusCode}; use serde::Serialize; use serde_json::json; +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct ForbiddenError { + id: String, +} + #[derive(HttpError, Serialize)] #[http(status = NOT_FOUND)] #[serde(rename_all = "camelCase")] @@ -18,9 +24,7 @@ struct NotFoundError { )] enum GetUserByIdError { #[http(status = FORBIDDEN)] - Forbidden { - id: String, - }, + Forbidden(ForbiddenError), NotFound(NotFoundError), @@ -37,7 +41,7 @@ enum GetUserByIdError { enum UpdateUserError { GetUserById(GetUserByIdError), - #[http(status = UNPROCESSABLE_CONTENT)] + #[http(status = UNPROCESSABLE_ENTITY)] Validation, #[http(status = INTERNAL_SERVER_ERROR)] @@ -45,7 +49,9 @@ enum UpdateUserError { } fn main() { - let error = UpdateUserError::GetUserById(GetUserByIdError::Forbidden { id: "1".to_owned() }); + let error = UpdateUserError::GetUserById(GetUserByIdError::Forbidden(ForbiddenError { + id: "1".to_owned(), + })); assert_eq!(StatusCode::FORBIDDEN, error.status()); assert_eq!( json!({ diff --git a/packages/breach-macros/Cargo.toml b/packages/breach-macros/Cargo.toml index f2fd390..be123fa 100644 --- a/packages/breach-macros/Cargo.toml +++ b/packages/breach-macros/Cargo.toml @@ -20,6 +20,7 @@ serde = [] utoipa = [] [dependencies] +http.workspace = true proc-macro2 = "1.0.103" quote = "1.0.42" syn = "2.0.110" diff --git a/packages/breach-macros/src/http.rs b/packages/breach-macros/src/http.rs index 941a1da..5e779c3 100644 --- a/packages/breach-macros/src/http.rs +++ b/packages/breach-macros/src/http.rs @@ -42,17 +42,30 @@ impl<'a> ToTokens for HttpError<'a> { } }); - if let Some(attribute) = self.data.attribute() - && attribute.axum - { - tokens.append_all(quote! { - #[automatically_derived] - impl #impl_generics ::axum::response::IntoResponse for #ident #type_generics #where_clause { - fn into_response(self) -> ::axum::response::Response { - (self.status(), Json(self)).into_response() + if let Some(attribute) = self.data.attribute() { + if attribute.axum { + tokens.append_all(quote! { + #[automatically_derived] + impl #impl_generics ::axum::response::IntoResponse for #ident #type_generics #where_clause { + fn into_response(self) -> ::axum::response::Response { + (self.status(), Json(self)).into_response() + } } - } - }); + }); + } + + if attribute.utoipa { + let responses = self.data.responses(); + + tokens.append_all(quote! { + #[automatically_derived] + impl #impl_generics ::utoipa::IntoResponses for #ident #type_generics #where_clause { + fn responses() -> ::std::collections::BTreeMap> { + #responses + } + } + }); + } } } } diff --git a/packages/breach-macros/src/http/attribute.rs b/packages/breach-macros/src/http/attribute.rs index b7e0316..56c1f40 100644 --- a/packages/breach-macros/src/http/attribute.rs +++ b/packages/breach-macros/src/http/attribute.rs @@ -1,10 +1,13 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{Attribute, Error, Ident, Result, spanned::Spanned}; +use syn::{Attribute, Error, Result, spanned::Spanned}; + +use crate::status::Status; pub struct HttpErrorAttribute { - status: Option, + pub status: Option, pub axum: bool, + pub utoipa: bool, } impl<'a> HttpErrorAttribute { @@ -32,6 +35,7 @@ impl<'a> HttpErrorAttribute { pub fn parse(attribute: &'a Attribute) -> Result { let mut status = None; let mut axum = false; + let mut utoipa = false; attribute.parse_nested_meta(|meta| { if meta.path.is_ident("status") { @@ -41,21 +45,60 @@ impl<'a> HttpErrorAttribute { } else if meta.path.is_ident("axum") { axum = true; + Ok(()) + } else if meta.path.is_ident("utoipa") { + utoipa = true; + Ok(()) } else { Err(meta.error("unknown parameter")) } })?; - Ok(Self { status, axum }) + Ok(Self { + status, + axum, + utoipa, + }) } pub fn status(&self) -> TokenStream { if let Some(status) = &self.status { - if status == "UNPROCESSABLE_CONTENT" { - quote!(::breach::http::StatusCode::UNPROCESSABLE_ENTITY) - } else { - quote!(::breach::http::StatusCode::#status) + let status = status.as_ident(); + + quote!(::breach::http::StatusCode::#status) + } else { + quote!(compile_error!("missing `#[http(status = ..)]` attribute")) + } + } + + pub fn responses(&self, r#type: Option) -> TokenStream { + if let Some(status) = &self.status { + let code = status.code.as_str(); + + let content = r#type.map(|r#type| { + // TODO: Attempt to infer content type from schema? + quote! { + .content( + "application/json", + ::utoipa::openapi::content::ContentBuilder::new() + .schema(Some(<#r#type as ::utoipa::PartialSchema>::schema())) + .build() + ) + } + }); + + quote! { + ::std::collections::BTreeMap::from_iter([ + ( + #code.to_owned(), + ::utoipa::openapi::RefOr::T( + ::utoipa::openapi::response::ResponseBuilder::new() + #content + .build() + ), + ), + ]) } } else { quote!(compile_error!("missing `#[http(status = ..)]` attribute")) diff --git a/packages/breach-macros/src/http/data.rs b/packages/breach-macros/src/http/data.rs index 90b21e3..34fe477 100644 --- a/packages/breach-macros/src/http/data.rs +++ b/packages/breach-macros/src/http/data.rs @@ -36,4 +36,12 @@ impl<'a> HttpErrorData<'a> { HttpErrorData::Union(r#union) => r#union.status(), } } + + pub fn responses(&self) -> TokenStream { + match self { + HttpErrorData::Struct(r#struct) => r#struct.responses(), + HttpErrorData::Enum(r#enum) => r#enum.responses(), + HttpErrorData::Union(r#union) => r#union.responses(), + } + } } diff --git a/packages/breach-macros/src/http/enum.rs b/packages/breach-macros/src/http/enum.rs index 1ea4d61..42bb876 100644 --- a/packages/breach-macros/src/http/enum.rs +++ b/packages/breach-macros/src/http/enum.rs @@ -1,8 +1,8 @@ use proc_macro2::TokenStream; -use quote::quote; +use quote::{ToTokens, quote}; use syn::{DataEnum, DeriveInput, Error, Field, Fields, Ident, Result, Variant, spanned::Spanned}; -use crate::{http::attribute::HttpErrorAttribute, util::Either}; +use crate::http::attribute::HttpErrorAttribute; pub struct HttpErrorEnum<'a> { ident: &'a Ident, @@ -36,7 +36,27 @@ impl<'a> HttpErrorEnum<'a> { quote! { match &self { - #(#arms),* + #( #arms ),* + } + } + } + + pub fn responses(&self) -> TokenStream { + let mut responses = self + .variants + .iter() + .map(|variant| variant.responses()) + .collect::>(); + + if responses.is_empty() { + quote!(::std::collections::BTreeMap::default()) + } else if responses.len() == 1 { + responses.remove(0) + } else { + quote! { + ::breach::utoipa::merge_responses([ + #( #responses ),* + ].into_iter()) } } } @@ -46,78 +66,106 @@ pub struct HttpErrorEnumVariant<'a> { enum_ident: &'a Ident, ident: &'a Ident, fields: &'a Fields, - attribute_or_field: Either, + field: Option<&'a Field>, + attribute: Option, } impl<'a> HttpErrorEnumVariant<'a> { pub fn parse(enum_ident: &'a Ident, variant: &'a Variant) -> Result { - let attribute_or_field = - if let Some(attribute) = HttpErrorAttribute::parse_slice(&variant.attrs)? { - Either::Left(attribute) - } else { - Either::Right(match &variant.fields { - Fields::Named(fields) => { - return Err(Error::new(fields.span(), "named fields are not supported")); - } - Fields::Unnamed(fields) => { - if fields.unnamed.len() > 1 { - return Err(Error::new( - fields.unnamed.span(), - "multiple unnamed fields are not supported", - )); + let field = match &variant.fields { + Fields::Named(fields) => { + return Err(Error::new(fields.span(), "named fields are not supported")); + } + Fields::Unnamed(fields) => { + if fields.unnamed.len() > 1 { + return Err(Error::new( + fields.unnamed.span(), + "multiple unnamed fields are not supported", + )); + } + + fields.unnamed.first().and_then(|field| { + if field.attrs.iter().any(|attribute| { + if attribute.meta.path().is_ident("serde") { + let mut skip = false; + + _ = attribute.parse_nested_meta(|meta| { + if meta.path.is_ident("skip") { + skip = true; + } + + Ok(()) + }); + + skip + } else { + false } - let Some(field) = fields.unnamed.first() else { - return Err(Error::new( - fields.unnamed.span(), - "no unnamed fields are not supported", - )); - }; - - field - } - Fields::Unit => { - return Err(Error::new(variant.span(), "unit fields are not supported")); + }) { + None + } else { + Some(field) } }) - }; + } + Fields::Unit => None, + }; Ok(HttpErrorEnumVariant { enum_ident, ident: &variant.ident, fields: &variant.fields, - attribute_or_field, + field, + attribute: HttpErrorAttribute::parse_slice(&variant.attrs)?, }) } pub fn status(&self) -> TokenStream { + self.arm(if let Some(attribute) = &self.attribute { + attribute.status() + } else if self.field.is_some() { + quote!(value.status()) + } else { + quote!(compile_error!("missing `#[http(status = ..)]` attribute")) + }) + } + + pub fn responses(&self) -> TokenStream { + if let Some(attribute) = &self.attribute { + attribute.responses(self.field.as_ref().map(|field| field.ty.to_token_stream())) + } else if let Some(field) = &self.field { + let r#type = &field.ty; + + quote!(<#r#type as ::utoipa::IntoResponses>::responses()) + } else { + quote!(compile_error!("missing `#[http(status = ..)]` attribute")) + } + } + + fn arm(&self, tokens: TokenStream) -> TokenStream { let enum_ident = self.enum_ident; let ident = self.ident; - let status = match &self.attribute_or_field { - Either::Left(attribute) => attribute.status(), - Either::Right(_field) => quote!(value.status()), - }; - match self.fields { Fields::Named(_) => { quote! { - #enum_ident::#ident { .. } => #status + #enum_ident::#ident { .. } => #tokens } } Fields::Unnamed(fields) => { - let idents: Vec = if self.attribute_or_field.is_right() { - fields.unnamed.iter().map(|_| quote!(value)).collect() - } else { + let idents: Vec = if self.attribute.is_some() { fields.unnamed.iter().map(|_| quote!(_)).collect() + } else { + fields.unnamed.iter().map(|_| quote!(value)).collect() }; quote! { - #enum_ident::#ident( #(#idents),* ) => #status + #enum_ident::#ident( #(#idents),* ) => #tokens } } Fields::Unit => { quote! { - #enum_ident::#ident => #status + #enum_ident::#ident => #tokens } } } diff --git a/packages/breach-macros/src/http/struct.rs b/packages/breach-macros/src/http/struct.rs index 2a16ae9..310f1e2 100644 --- a/packages/breach-macros/src/http/struct.rs +++ b/packages/breach-macros/src/http/struct.rs @@ -1,4 +1,5 @@ use proc_macro2::TokenStream; +use quote::quote; use syn::{DataStruct, DeriveInput, Error, Result, spanned::Spanned}; use crate::http::attribute::HttpErrorAttribute; @@ -23,4 +24,8 @@ impl<'a> HttpErrorStruct { pub fn status(&self) -> TokenStream { self.attribute.status() } + + pub fn responses(&self) -> TokenStream { + self.attribute.responses(Some(quote!(Self))) + } } diff --git a/packages/breach-macros/src/http/union.rs b/packages/breach-macros/src/http/union.rs index e15030d..cde338c 100644 --- a/packages/breach-macros/src/http/union.rs +++ b/packages/breach-macros/src/http/union.rs @@ -17,4 +17,8 @@ impl HttpErrorUnion { pub fn status(&self) -> TokenStream { todo!() } + + pub fn responses(&self) -> TokenStream { + todo!() + } } diff --git a/packages/breach-macros/src/lib.rs b/packages/breach-macros/src/lib.rs index d1d47d4..4215a0a 100644 --- a/packages/breach-macros/src/lib.rs +++ b/packages/breach-macros/src/lib.rs @@ -3,7 +3,7 @@ //! Breach macros. mod http; -mod util; +mod status; use proc_macro::TokenStream; use quote::ToTokens; diff --git a/packages/breach-macros/src/status.rs b/packages/breach-macros/src/status.rs new file mode 100644 index 0000000..b9157aa --- /dev/null +++ b/packages/breach-macros/src/status.rs @@ -0,0 +1,193 @@ +use http::StatusCode; +use proc_macro2::Span; +use syn::{ + Error, Ident, LitInt, Result, + parse::{Parse, ParseStream}, +}; + +#[derive(Clone)] +pub struct Status { + pub code: StatusCode, + raw: RawStatusCode, +} + +impl Status { + fn as_text(&self) -> &'static str { + match self.code { + StatusCode::CONTINUE => "CONTINUE", + StatusCode::SWITCHING_PROTOCOLS => "SWITCHING_PROTOCOLS", + StatusCode::PROCESSING => "PROCESSING", + StatusCode::EARLY_HINTS => "EARLY_HINTS", + StatusCode::OK => "OK", + StatusCode::CREATED => "CREATED", + StatusCode::ACCEPTED => "ACCEPTED", + StatusCode::NON_AUTHORITATIVE_INFORMATION => "NON_AUTHORITATIVE_INFORMATION", + StatusCode::NO_CONTENT => "NO_CONTENT", + StatusCode::RESET_CONTENT => "RESET_CONTENT", + StatusCode::PARTIAL_CONTENT => "PARTIAL_CONTENT", + StatusCode::MULTI_STATUS => "MULTI_STATUS", + StatusCode::ALREADY_REPORTED => "ALREADY_REPORTED", + StatusCode::IM_USED => "IM_USED", + StatusCode::MULTIPLE_CHOICES => "MULTIPLE_CHOICES", + StatusCode::MOVED_PERMANENTLY => "MOVED_PERMANENTLY", + StatusCode::FOUND => "FOUND", + StatusCode::SEE_OTHER => "SEE_OTHER", + StatusCode::NOT_MODIFIED => "NOT_MODIFIED", + StatusCode::USE_PROXY => "USE_PROXY", + StatusCode::TEMPORARY_REDIRECT => "TEMPORARY_REDIRECT", + StatusCode::PERMANENT_REDIRECT => "PERMANENT_REDIRECT", + StatusCode::BAD_REQUEST => "BAD_REQUEST", + StatusCode::UNAUTHORIZED => "UNAUTHORIZED", + StatusCode::PAYMENT_REQUIRED => "PAYMENT_REQUIRED", + StatusCode::FORBIDDEN => "FORBIDDEN", + StatusCode::NOT_FOUND => "NOT_FOUND", + StatusCode::METHOD_NOT_ALLOWED => "METHOD_NOT_ALLOWED", + StatusCode::NOT_ACCEPTABLE => "NOT_ACCEPTABLE", + StatusCode::PROXY_AUTHENTICATION_REQUIRED => "PROXY_AUTHENTICATION_REQUIRED", + StatusCode::REQUEST_TIMEOUT => "REQUEST_TIMEOUT", + StatusCode::CONFLICT => "CONFLICT", + StatusCode::GONE => "GONE", + StatusCode::LENGTH_REQUIRED => "LENGTH_REQUIRED", + StatusCode::PRECONDITION_FAILED => "PRECONDITION_FAILED", + StatusCode::PAYLOAD_TOO_LARGE => "PAYLOAD_TOO_LARGE", + StatusCode::URI_TOO_LONG => "URI_TOO_LONG", + StatusCode::UNSUPPORTED_MEDIA_TYPE => "UNSUPPORTED_MEDIA_TYPE", + StatusCode::RANGE_NOT_SATISFIABLE => "RANGE_NOT_SATISFIABLE", + StatusCode::EXPECTATION_FAILED => "EXPECTATION_FAILED", + StatusCode::IM_A_TEAPOT => "IM_A_TEAPOT", + StatusCode::MISDIRECTED_REQUEST => "MISDIRECTED_REQUEST", + StatusCode::UNPROCESSABLE_ENTITY => "UNPROCESSABLE_ENTITY", + StatusCode::LOCKED => "LOCKED", + StatusCode::FAILED_DEPENDENCY => "FAILED_DEPENDENCY", + StatusCode::TOO_EARLY => "TOO_EARLY", + StatusCode::UPGRADE_REQUIRED => "UPGRADE_REQUIRED", + StatusCode::PRECONDITION_REQUIRED => "PRECONDITION_REQUIRED", + StatusCode::TOO_MANY_REQUESTS => "TOO_MANY_REQUESTS", + StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE => "REQUEST_HEADER_FIELDS_TOO_LARGE", + StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS => "UNAVAILABLE_FOR_LEGAL_REASONS", + StatusCode::INTERNAL_SERVER_ERROR => "INTERNAL_SERVER_ERROR", + StatusCode::NOT_IMPLEMENTED => "NOT_IMPLEMENTED", + StatusCode::BAD_GATEWAY => "BAD_GATEWAY", + StatusCode::SERVICE_UNAVAILABLE => "SERVICE_UNAVAILABLE", + StatusCode::GATEWAY_TIMEOUT => "GATEWAY_TIMEOUT", + StatusCode::HTTP_VERSION_NOT_SUPPORTED => "HTTP_VERSION_NOT_SUPPORTED", + StatusCode::VARIANT_ALSO_NEGOTIATES => "VARIANT_ALSO_NEGOTIATES", + StatusCode::INSUFFICIENT_STORAGE => "INSUFFICIENT_STORAGE", + StatusCode::LOOP_DETECTED => "LOOP_DETECTED", + StatusCode::NOT_EXTENDED => "NOT_EXTENDED", + StatusCode::NETWORK_AUTHENTICATION_REQUIRED => "NETWORK_AUTHENTICATION_REQUIRED", + _ => unimplemented!("unknown HTTP status code"), + } + } + + pub fn as_ident(&self) -> Ident { + match &self.raw { + RawStatusCode::Ident(ident) => ident.clone(), + RawStatusCode::Lit(_) => Ident::new(self.as_text(), Span::call_site()), + } + } +} + +impl Parse for Status { + fn parse(input: ParseStream) -> Result { + let raw: RawStatusCode = input.parse()?; + let code: StatusCode = (&raw).try_into()?; + + Ok(Self { raw, code }) + } +} + +#[derive(Clone)] +pub enum RawStatusCode { + Ident(Ident), + Lit(LitInt), +} + +impl TryFrom<&RawStatusCode> for StatusCode { + type Error = Error; + + fn try_from(value: &RawStatusCode) -> std::result::Result { + match value { + RawStatusCode::Ident(ident) => Ok(match ident.to_string().as_str() { + "CONTINUE" => StatusCode::CONTINUE, + "SWITCHING_PROTOCOLS" => StatusCode::SWITCHING_PROTOCOLS, + "PROCESSING" => StatusCode::PROCESSING, + "EARLY_HINTS" => StatusCode::EARLY_HINTS, + "OK" => StatusCode::OK, + "CREATED" => StatusCode::CREATED, + "ACCEPTED" => StatusCode::ACCEPTED, + "NON_AUTHORITATIVE_INFORMATION" => StatusCode::NON_AUTHORITATIVE_INFORMATION, + "NO_CONTENT" => StatusCode::NO_CONTENT, + "RESET_CONTENT" => StatusCode::RESET_CONTENT, + "PARTIAL_CONTENT" => StatusCode::PARTIAL_CONTENT, + "MULTI_STATUS" => StatusCode::MULTI_STATUS, + "ALREADY_REPORTED" => StatusCode::ALREADY_REPORTED, + "IM_USED" => StatusCode::IM_USED, + "MULTIPLE_CHOICES" => StatusCode::MULTIPLE_CHOICES, + "MOVED_PERMANENTLY" => StatusCode::MOVED_PERMANENTLY, + "FOUND" => StatusCode::FOUND, + "SEE_OTHER" => StatusCode::SEE_OTHER, + "NOT_MODIFIED" => StatusCode::NOT_MODIFIED, + "USE_PROXY" => StatusCode::USE_PROXY, + "TEMPORARY_REDIRECT" => StatusCode::TEMPORARY_REDIRECT, + "PERMANENT_REDIRECT" => StatusCode::PERMANENT_REDIRECT, + "BAD_REQUEST" => StatusCode::BAD_REQUEST, + "UNAUTHORIZED" => StatusCode::UNAUTHORIZED, + "PAYMENT_REQUIRED" => StatusCode::PAYMENT_REQUIRED, + "FORBIDDEN" => StatusCode::FORBIDDEN, + "NOT_FOUND" => StatusCode::NOT_FOUND, + "METHOD_NOT_ALLOWED" => StatusCode::METHOD_NOT_ALLOWED, + "NOT_ACCEPTABLE" => StatusCode::NOT_ACCEPTABLE, + "PROXY_AUTHENTICATION_REQUIRED" => StatusCode::PROXY_AUTHENTICATION_REQUIRED, + "REQUEST_TIMEOUT" => StatusCode::REQUEST_TIMEOUT, + "CONFLICT" => StatusCode::CONFLICT, + "GONE" => StatusCode::GONE, + "LENGTH_REQUIRED" => StatusCode::LENGTH_REQUIRED, + "PRECONDITION_FAILED" => StatusCode::PRECONDITION_FAILED, + "PAYLOAD_TOO_LARGE" => StatusCode::PAYLOAD_TOO_LARGE, + "URI_TOO_LONG" => StatusCode::URI_TOO_LONG, + "UNSUPPORTED_MEDIA_TYPE" => StatusCode::UNSUPPORTED_MEDIA_TYPE, + "RANGE_NOT_SATISFIABLE" => StatusCode::RANGE_NOT_SATISFIABLE, + "EXPECTATION_FAILED" => StatusCode::EXPECTATION_FAILED, + "IM_A_TEAPOT" => StatusCode::IM_A_TEAPOT, + "MISDIRECTED_REQUEST" => StatusCode::MISDIRECTED_REQUEST, + "UNPROCESSABLE_ENTITY" => StatusCode::UNPROCESSABLE_ENTITY, + "LOCKED" => StatusCode::LOCKED, + "FAILED_DEPENDENCY" => StatusCode::FAILED_DEPENDENCY, + "TOO_EARLY" => StatusCode::TOO_EARLY, + "UPGRADE_REQUIRED" => StatusCode::UPGRADE_REQUIRED, + "PRECONDITION_REQUIRED" => StatusCode::PRECONDITION_REQUIRED, + "TOO_MANY_REQUESTS" => StatusCode::TOO_MANY_REQUESTS, + "REQUEST_HEADER_FIELDS_TOO_LARGE" => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, + "UNAVAILABLE_FOR_LEGAL_REASONS" => StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS, + "INTERNAL_SERVER_ERROR" => StatusCode::INTERNAL_SERVER_ERROR, + "NOT_IMPLEMENTED" => StatusCode::NOT_IMPLEMENTED, + "BAD_GATEWAY" => StatusCode::BAD_GATEWAY, + "SERVICE_UNAVAILABLE" => StatusCode::SERVICE_UNAVAILABLE, + "GATEWAY_TIMEOUT" => StatusCode::GATEWAY_TIMEOUT, + "HTTP_VERSION_NOT_SUPPORTED" => StatusCode::HTTP_VERSION_NOT_SUPPORTED, + "VARIANT_ALSO_NEGOTIATES" => StatusCode::VARIANT_ALSO_NEGOTIATES, + "INSUFFICIENT_STORAGE" => StatusCode::INSUFFICIENT_STORAGE, + "LOOP_DETECTED" => StatusCode::LOOP_DETECTED, + "NOT_EXTENDED" => StatusCode::NOT_EXTENDED, + "NETWORK_AUTHENTICATION_REQUIRED" => StatusCode::NETWORK_AUTHENTICATION_REQUIRED, + _ => return Err(Error::new(ident.span(), "invalid HTTP status code")), + }), + RawStatusCode::Lit(lit) => StatusCode::from_u16(lit.base10_parse()?) + .map_err(|_| Error::new(lit.span(), "invalid HTTP status code")), + } + } +} + +impl Parse for RawStatusCode { + fn parse(input: ParseStream) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(Ident) { + input.parse().map(RawStatusCode::Ident) + } else if lookahead.peek(LitInt) { + input.parse().map(RawStatusCode::Lit) + } else { + Err(lookahead.error()) + } + } +} diff --git a/packages/breach-macros/src/util.rs b/packages/breach-macros/src/util.rs deleted file mode 100644 index 6fb7ed6..0000000 --- a/packages/breach-macros/src/util.rs +++ /dev/null @@ -1,21 +0,0 @@ -pub enum Either { - Left(L), - Right(R), -} - -impl Either { - #[expect(unused)] - pub fn is_left(&self) -> bool { - match &self { - Either::Left(_) => true, - Either::Right(_) => false, - } - } - - pub fn is_right(&self) -> bool { - match &self { - Either::Left(_) => false, - Either::Right(_) => true, - } - } -} diff --git a/packages/breach/Cargo.toml b/packages/breach/Cargo.toml index 28d92a6..332b2c7 100644 --- a/packages/breach/Cargo.toml +++ b/packages/breach/Cargo.toml @@ -14,10 +14,13 @@ all-features = true [features] default = ["macros"] macros = ["dep:breach-macros"] +utoipa = ["dep:itertools", "dep:utoipa"] [dependencies] breach-macros = { workspace = true, optional = true } -http = "1.4.0" +http.workspace = true +itertools = { version = "0.14.0", optional = true } +utoipa = { workspace = true, optional = true } [dev-dependencies] diff --git a/packages/breach/src/lib.rs b/packages/breach/src/lib.rs index 7f8ea10..c39fc86 100644 --- a/packages/breach/src/lib.rs +++ b/packages/breach/src/lib.rs @@ -3,6 +3,8 @@ //! Breach. mod error; +#[cfg(feature = "utoipa")] +pub mod utoipa; pub use error::*; diff --git a/packages/breach/src/utoipa.rs b/packages/breach/src/utoipa.rs new file mode 100644 index 0000000..14a8ebd --- /dev/null +++ b/packages/breach/src/utoipa.rs @@ -0,0 +1,87 @@ +//! Utoipa utilities. + +use std::collections::BTreeMap; + +use http::StatusCode; +use itertools::Itertools; +use utoipa::openapi::{ + Content, ContentBuilder, OneOfBuilder, RefOr, Response, ResponseBuilder, Schema, +}; + +/// Merge multiple [`BTreeMap>`] into a single [`BTreeMap>`]. +pub fn merge_responses( + responses: impl Iterator>>, +) -> BTreeMap> { + responses + .flatten() + .chunk_by(|(code, _)| code.clone()) + .into_iter() + .map(|(code, chunk)| { + let response = merge_response( + StatusCode::from_bytes(code.as_bytes()).expect("valid status code"), + chunk.map(|(_, response)| response), + ); + + (code, RefOr::T(response)) + }) + .collect() +} + +/// Merge multiple [`RefOr`] into a single [`Response`]. +fn merge_response(code: StatusCode, responses: impl Iterator>) -> Response { + let responses = responses.filter_map(|response| match response { + RefOr::Ref(_) => None, + RefOr::T(response) => Some(response), + }); + + let mut builder = ResponseBuilder::new(); + + if let Some(canonical_reason) = code.canonical_reason() { + builder = builder.description(canonical_reason) + } + + builder = responses + .flat_map(|response| response.content) + .chunk_by(|(content_type, _)| content_type.clone()) + .into_iter() + .fold(builder, |builder, (content_type, chunk)| { + let content = merge_content(chunk.map(|(_, content)| content)); + + // TODO: Merge content. + builder.content(content_type, content) + }); + + // TODO: Merge headers, extensions, links. + + builder.build() +} + +fn merge_content(contents: impl Iterator) -> Content { + let mut builder = ContentBuilder::new(); + let mut one_of_builder = OneOfBuilder::new(); + + for content in contents { + if content.example.is_some() { + // TODO: Error that this is unsupported. + } + + builder = builder.examples_from_iter(content.examples); + + for (name, encoding) in content.encoding { + builder = builder.encoding(name, encoding); + } + + if let Some(schema) = content.schema { + one_of_builder = one_of_builder.item(schema); + } + + // TODO: Merge extensions. + } + + let one_of = one_of_builder.build(); + if !one_of.items.is_empty() { + builder = builder.schema(Some(Schema::from(one_of))); + } + + builder.build() +}