diff --git a/example-postgres/src/main.rs b/example-postgres/src/main.rs index 70bbe5f..4d7b72e 100644 --- a/example-postgres/src/main.rs +++ b/example-postgres/src/main.rs @@ -1,7 +1,7 @@ // #![feature(trace_macros)] use chrono::{NaiveDateTime, Utc}; use ormx::{Insert, Table, Delete}; -use sqlx::PgPool; +use sqlx::{PgConnection, PgPool}; // trace_macros!(true); @@ -18,6 +18,7 @@ async fn main() -> anyhow::Result<()> { .init()?; let db = PgPool::connect(&dotenv::var("DATABASE_URL")?).await?; + let mut conn = db.acquire().await?; log::info!("insert a new row into the database"); let mut new = InsertUser { @@ -28,7 +29,7 @@ async fn main() -> anyhow::Result<()> { disabled: None, role: Role::User, } - .insert(&mut *db.acquire().await?) + .insert(&mut conn) .await?; log::info!("update a single field"); @@ -61,15 +62,41 @@ async fn main() -> anyhow::Result<()> { log::info!("delete the user from the database"); new.delete(&db).await?; + log::info!("inserting 3 dummy users with ids: 1, 2 & 3"); + insert_dummy_user(&mut conn, 2).await?; + insert_dummy_user(&mut conn, 3).await?; + insert_dummy_user(&mut conn, 4).await?; + + log::info!("getting many users by any user id (using 'get_any' getter)"); + let users = User::get_by_any_user_id(&mut conn, &[2, 4]).await?; + dbg!(&users); + assert_eq!(users.len(), 2); + + log::info!("empty user table"); + sqlx::query!("DELETE FROM users").execute(&db).await?; Ok(()) } +async fn insert_dummy_user(conn: &mut PgConnection, id: i32) -> Result { + InsertUser { + user_id: id, + first_name: "Dummy".to_owned(), + last_name: "Dummy".to_owned(), + email: format!("dummy{}@mail.com", id), + disabled: None, + role: Role::User, + } + .insert(conn) + .await +} + #[derive(Debug, ormx::Table)] #[ormx(table = "users", id = user_id, insertable, deletable)] struct User { // map this field to the column "id" #[ormx(column = "id")] #[ormx(get_one = get_by_user_id)] + #[ormx(get_by_any)] user_id: i32, first_name: String, last_name: String, diff --git a/ormx-macros/src/attrs.rs b/ormx-macros/src/attrs.rs index fc99434..964b3e7 100644 --- a/ormx-macros/src/attrs.rs +++ b/ormx-macros/src/attrs.rs @@ -33,6 +33,9 @@ pub enum TableFieldAttr { GetOptional(Getter), // get_many [= ]? [()]? GetMany(Getter), + // get_by_any [= ]? [()]? + #[cfg(feature = "postgres")] + GetByAny(Getter), // set [= ]? Set(Option), // by_ref @@ -105,14 +108,16 @@ pub fn parse_attrs(attrs: &[Attribute]) -> Result> { macro_rules! impl_parse { // entry point ($i:ident { - $( $s:literal => $v:ident( $($t:tt)* ) ),* + $( $(#[cfg($cfg_attr: meta)])? $s:literal => $v:ident( $($t:tt)* ) ),* }) => { impl syn::parse::Parse for $i { #[allow(clippy::redundant_closure_call)] fn parse(input: syn::parse::ParseStream) -> syn::Result { let ident = input.parse::()?; match &*ident.to_string() { - $( $s => (impl_parse!($($t)*))(input).map(Self::$v), )* + $( + $(#[cfg($cfg_attr)])? + $s => (impl_parse!($($t)*))(input).map(Self::$v), )* _ => Err(input.error("unknown attribute")) } } @@ -152,6 +157,8 @@ impl_parse!(TableFieldAttr { "get_one" => GetOne(Getter), "get_optional" => GetOptional(Getter), "get_many" => GetMany(Getter), + #[cfg(feature = "postgres")] + "get_by_any" => GetByAny(Getter), "set" => Set((= Ident)?), "custom_type" => CustomType(), "default" => Default(), diff --git a/ormx-macros/src/backend/postgres/get_by_any.rs b/ormx-macros/src/backend/postgres/get_by_any.rs new file mode 100644 index 0000000..85537af --- /dev/null +++ b/ormx-macros/src/backend/postgres/get_by_any.rs @@ -0,0 +1,41 @@ +use crate::{ + backend::{common, postgres::PgBindings}, + table::Table, +}; + +use super::PgBackend; +use proc_macro2::TokenStream; +use quote::quote; + +pub(crate) fn impl_get_by_any_getter(table: &Table) -> TokenStream { + let column_list = table.select_column_list(); + let vis = &table.vis; + let mut getters = TokenStream::new(); + + for field in table.fields.iter() { + if let Some(getter) = &field.get_by_any { + let sql = format!( + "SELECT {} FROM {} WHERE {} = ANY({})", + column_list, + table.table, + field.column(), + PgBindings::default().next().unwrap() + ); + + let func = getter.ident_or(&field, &format!("get_by_any_{}", field.field)); + let arg = getter.arg_ty.clone().unwrap_or_else(|| { + let ty = &field.ty; + syn::parse2(quote!(&[#ty])).unwrap() + }); + + getters.extend(common::get_many(vis, &func, &arg, &sql)); + } + } + + let table_ident = &table.ident; + quote! { + impl #table_ident { + #getters + } + } +} diff --git a/ormx-macros/src/backend/postgres/mod.rs b/ormx-macros/src/backend/postgres/mod.rs index 2f6d556..ad443be 100644 --- a/ormx-macros/src/backend/postgres/mod.rs +++ b/ormx-macros/src/backend/postgres/mod.rs @@ -4,6 +4,11 @@ use proc_macro2::TokenStream; use crate::{backend::Backend, table::Table}; +use self::get_by_any::impl_get_by_any_getter; + +use super::common; + +mod get_by_any; mod insert; #[derive(Clone)] @@ -31,6 +36,12 @@ impl Backend for PgBackend { fn impl_insert(table: &Table) -> TokenStream { insert::impl_insert(table) } + + fn impl_getters(table: &Table) -> TokenStream { + let mut getters = common::getters::(table); + getters.extend(impl_get_by_any_getter(table)); + getters + } } #[derive(Default)] diff --git a/ormx-macros/src/lib.rs b/ormx-macros/src/lib.rs index 19852b4..b98ffd4 100644 --- a/ormx-macros/src/lib.rs +++ b/ormx-macros/src/lib.rs @@ -54,7 +54,7 @@ mod utils; /// /// # Accessors: Getters /// ormx will generate accessor functions for fields annotated with `#[ormx(get_one)]`, -/// `#[ormx(get_optional)]` and `#[ormx(get_many)]`. +/// `#[ormx(get_optional)]`, `#[ormx(get_many)]` and `#[ormx(get_by_any)]`. /// These functions can be used to query a row by the value of the annotated field. /// /// The generated function will have these signature: @@ -67,6 +67,9 @@ mod utils; /// **`#[ormx(get_many)]`**: /// `{pub} async fn get_by_{field_name}(&{field_type}) -> Result>` /// +/// **`#[ormx(get_by_any)]`**: +/// `{pub} async fn get_by_any_{field_name}(&[{field_type}]) -> Result>` +/// /// By default, the function will be named `get_by_{field_name)`, though this can be changed by /// supplying a custom name: `#[ormx(get_one = by_id)]`. /// By default, the function will take a reference to the type of the annotated field as an argument, diff --git a/ormx-macros/src/table/mod.rs b/ormx-macros/src/table/mod.rs index cc64475..161923d 100644 --- a/ormx-macros/src/table/mod.rs +++ b/ormx-macros/src/table/mod.rs @@ -33,6 +33,8 @@ pub struct TableField { pub get_one: Option, pub get_optional: Option, pub get_many: Option, + #[cfg(feature = "postgres")] + pub get_by_any: Option, pub set: Option, pub by_ref: bool, pub insert_attrs: Vec, @@ -104,16 +106,19 @@ impl TableField { impl Getter { pub fn or_fallback(&self, field: &TableField) -> (Ident, Type) { - let ident = self - .func - .clone() - .unwrap_or_else(|| Ident::new(&format!("by_{}", field.field), Span::call_site())); + let ident = self.ident_or(field, &format!("by_{}", field.field)); let arg = self.arg_ty.clone().unwrap_or_else(|| { let ty = &field.ty; syn::parse2(quote!(&#ty)).unwrap() }); (ident, arg) } + + pub fn ident_or(&self, field: &TableField, ident: &str) -> Ident { + self.func + .clone() + .unwrap_or_else(|| Ident::new(ident, Span::call_site())) + } } pub fn derive(input: DeriveInput) -> Result { diff --git a/ormx-macros/src/table/parse.rs b/ormx-macros/src/table/parse.rs index 4842b28..23b7f53 100644 --- a/ormx-macros/src/table/parse.rs +++ b/ormx-macros/src/table/parse.rs @@ -40,6 +40,9 @@ impl TryFrom<&syn::Field> for TableField { ); let mut insert_attrs = vec![]; + #[cfg(feature = "postgres")] + none!(get_by_any); + for attr in parse_attrs::(&value.attrs)? { match attr { TableFieldAttr::Column(c) => set_once(&mut column, c)?, @@ -47,6 +50,8 @@ impl TryFrom<&syn::Field> for TableField { TableFieldAttr::GetOne(g) => set_once(&mut get_one, g)?, TableFieldAttr::GetOptional(g) => set_once(&mut get_optional, g)?, TableFieldAttr::GetMany(g) => set_once(&mut get_many, g)?, + #[cfg(feature = "postgres")] + TableFieldAttr::GetByAny(g) => set_once(&mut get_by_any, g)?, TableFieldAttr::Set(s) => { let default = || Ident::new(&format!("set_{}", ident), Span::call_site()); set_once(&mut set, s.unwrap_or_else(default))? @@ -66,6 +71,8 @@ impl TryFrom<&syn::Field> for TableField { get_one, get_optional, get_many, + #[cfg(feature = "postgres")] + get_by_any, set, by_ref: by_ref.unwrap_or(false), insert_attrs,