diff --git a/packages/cw-orch-fns-derive/src/execute_fns.rs b/packages/cw-orch-fns-derive/src/execute_fns.rs index 6a181af13..ab9e1213b 100644 --- a/packages/cw-orch-fns-derive/src/execute_fns.rs +++ b/packages/cw-orch-fns-derive/src/execute_fns.rs @@ -1,15 +1,4 @@ -extern crate proc_macro; -use crate::helpers::{ - impl_into_deprecation, process_fn_name, process_sorting, to_generic_arguments, - LexiographicMatching, -}; -use convert_case::{Case, Casing}; -use proc_macro::TokenStream; -use proc_macro2::Span; -use quote::{format_ident, quote}; -use syn::{parse_quote, visit_mut::VisitMut, DeriveInput, Fields, Ident}; - -fn payable(v: &syn::Variant) -> bool { +pub fn payable(v: &syn::Variant) -> bool { for attr in &v.attrs { if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "payable" { return true; @@ -17,169 +6,3 @@ fn payable(v: &syn::Variant) -> bool { } false } - -pub fn execute_fns_derive(input: DeriveInput) -> TokenStream { - let name = &input.ident; - let bname = Ident::new(&format!("{name}Fns"), name.span()); - - let generics = input.generics.clone(); - let (_impl_generics, _ty_generics, where_clause) = generics.split_for_impl().clone(); - let type_generics = to_generic_arguments(&generics); - - let is_attributes_sorted = process_sorting(&input.attrs); - - let syn::Data::Enum(syn::DataEnum { variants, .. }) = input.data else { - unimplemented!(); - }; - - let variant_fns = variants.into_iter().map( |mut variant|{ - let variant_name = variant.ident.clone(); - - // We rename the variant if it has a fn_name attribute associated with it - let mut variant_func_name = - format_ident!("{}", process_fn_name(&variant).to_case(Case::Snake)); - variant_func_name.set_span(variant_name.span()); - - let is_payable = payable(&variant); - - let variant_doc: syn::Attribute = { - let doc = format!("Automatically generated wrapper around {}::{} variant", name, variant_name); - parse_quote!( - #[doc=#doc] - ) - }; - - let (maybe_coins_attr, passed_coins) = if is_payable { - (quote!(coins: &[::cosmwasm_std::Coin]),quote!(Some(coins))) - } else { - (quote!(),quote!(None)) - }; - - match &mut variant.fields { - Fields::Unnamed(variant_fields) => { - - let mut variant_idents = variant_fields.unnamed.clone(); - // remove any attributes for use in fn arguments - variant_idents.iter_mut().for_each(|f: &mut syn::Field| f.attrs = vec![]); - - - // We need to figure out a parameter name for all fields associated to their types - // They will be numbered from 0 to n-1 - let variant_ident_content_names = variant_idents - .iter() - .enumerate() - .map(|(i, _)| Ident::new(&format!("arg{}", i), Span::call_site())); - - let variant_attr = variant_idents.clone().into_iter() - .enumerate() - .map(|(i, mut id)| { - id.ident = Some(Ident::new(&format!("arg{}", i), Span::call_site())); - id - }); - - quote!( - #variant_doc - #[allow(clippy::too_many_arguments)] - fn #variant_func_name(&self, #(#variant_attr,)* #maybe_coins_attr) -> Result<::cw_orch::core::environment::TxResponse, ::cw_orch::core::CwEnvError> { - let msg = #name::#variant_name ( - #(#variant_ident_content_names,)* - ); - >::execute(self, &msg.into(),#passed_coins) - } - ) - }, - Fields::Unit => { - - quote!( - #variant_doc - fn #variant_func_name(&self, #maybe_coins_attr) -> Result<::cw_orch::core::environment::TxResponse, ::cw_orch::core::CwEnvError> { - let msg = #name::#variant_name; - >::execute(self, &msg.into(),#passed_coins) - } - ) - } - Fields::Named(variant_fields) => { - if is_attributes_sorted{ - // sort fields on field name - LexiographicMatching::default().visit_fields_named_mut(variant_fields); - } - - // parse these fields as arguments to function - let mut variant_idents = variant_fields.named.clone(); - - // remove any attributes for use in fn arguments - variant_idents.iter_mut().for_each(|f: &mut syn::Field| f.attrs = vec![]); - - let variant_ident_content_names = variant_idents.iter().map(|f|f.ident.clone().unwrap()); - - let variant_attr = variant_idents.iter(); - quote!( - #variant_doc - #[allow(clippy::too_many_arguments)] - fn #variant_func_name(&self, #(#variant_attr,)* #maybe_coins_attr) -> Result<::cw_orch::core::environment::TxResponse, ::cw_orch::core::CwEnvError> { - let msg = #name::#variant_name { - #(#variant_ident_content_names,)* - }; - >::execute(self, &msg.into(),#passed_coins) - } - ) - } - } - }); - let necessary_trait_where = quote!(#name<#type_generics>: Into); - let combined_trait_where_clause = where_clause - .map(|w| { - quote!( - #w #necessary_trait_where - ) - }) - .unwrap_or(quote!( - where - #necessary_trait_where - )); - - let impl_into_depr = impl_into_deprecation(&input.attrs); - let derived_trait = quote!( - #[cfg(not(target_arch = "wasm32"))] - #impl_into_depr - /// Automatically derived trait that allows you to call the variants of the message directly without the need to construct the struct yourself. - pub trait #bname: ::cw_orch::core::contract::interface_traits::CwOrchExecute #combined_trait_where_clause { - #(#variant_fns)* - } - - #[cfg(target_arch = "wasm32")] - /// Automatically derived trait that allows you to call the variants of the message directly without the need to construct the struct yourself. - pub trait #bname{ - - } - ); - - // We need to merge the where clauses (rust doesn't support 2 wheres) - // If there is no where clause, we simply add the necessary where - let necessary_where = quote!(SupportedContract: ::cw_orch::core::contract::interface_traits::CwOrchExecute, #necessary_trait_where); - let combined_where_clause = where_clause - .map(|w| { - quote!( - #w #necessary_where - ) - }) - .unwrap_or(quote!( - where - #necessary_where - )); - - let derived_trait_impl = quote!( - #[automatically_derived] - impl #bname for SupportedContract - #combined_where_clause {} - ); - - let expand = quote!( - #derived_trait - - #[cfg(not(target_arch = "wasm32"))] - #derived_trait_impl - ); - - expand.into() -} diff --git a/packages/cw-orch-fns-derive/src/fns_derive.rs b/packages/cw-orch-fns-derive/src/fns_derive.rs new file mode 100644 index 000000000..4d4bf4999 --- /dev/null +++ b/packages/cw-orch-fns-derive/src/fns_derive.rs @@ -0,0 +1,215 @@ +extern crate proc_macro; +use crate::{ + execute_fns::payable, + helpers::{ + impl_into_deprecation, process_fn_name, process_sorting, LexiographicMatching, MsgType, + }, + query_fns::parse_query_type, +}; +use convert_case::{Case, Casing}; +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::{format_ident, quote}; +use syn::{parse_quote, visit_mut::VisitMut, Fields, Generics, Ident, ItemEnum, WhereClause}; + +pub fn fns_derive(msg_type: MsgType, input: ItemEnum) -> TokenStream { + let name = &input.ident; + + let (trait_name, func_name, trait_msg_type, generic_msg_type, chain_trait) = match msg_type { + MsgType::Execute => ( + quote!(CwOrchExecute), + quote!(execute), + quote!(ExecuteMsg), + quote!(CwOrchExecuteMsgType), + quote!(::cw_orch::core::environment::TxHandler), + ), + MsgType::Query => ( + quote!(CwOrchQuery), + quote!(query), + quote!(QueryMsg), + quote!(CwOrchQueryMsgType), + quote!( + ::cw_orch::core::environment::QueryHandler + + ::cw_orch::core::environment::ChainState + ), + ), + }; + + let variant_fns = input.variants.into_iter().map( |mut variant|{ + let variant_name = variant.ident.clone(); + + // We rename the variant if it has a fn_name attribute associated with it + let mut variant_func_name = + format_ident!("{}", process_fn_name(&variant).to_case(Case::Snake)); + variant_func_name.set_span(variant_name.span()); + + + let variant_doc: syn::Attribute = { + let doc = format!("Automatically generated wrapper around {}::{} variant", name, variant_name); + parse_quote!( + #[doc=#doc] + ) + }; + + // TODO + // Execute Specific + let (maybe_coins_attr,passed_coins) = match msg_type{ + MsgType::Execute => { + let is_payable = payable(&variant); + if is_payable { + (quote!(coins: &[::cosmwasm_std::Coin]),quote!(Some(coins))) + } else { + (quote!(),quote!(None)) + } + } + MsgType::Query => { + (quote!(), quote!()) + } + }; + + + let response = match msg_type{ + MsgType::Execute => quote!(::cw_orch::core::environment::TxResponse), + MsgType::Query => parse_query_type(&variant) + }; + + match &mut variant.fields { + Fields::Unnamed(variant_fields) => { + let mut variant_idents = variant_fields.unnamed.clone(); + + // remove any attributes for use in fn arguments + variant_idents.iter_mut().for_each(|f| f.attrs = vec![]); + + // We need to figure out a parameter name for all fields associated to their types + // They will be numbered from 0 to n-1 + let variant_ident_content_names = variant_idents + .iter() + .enumerate() + .map(|(i, _)| Ident::new(&format!("arg{}", i), Span::call_site())); + + let variant_attr = variant_idents.clone().into_iter() + .enumerate() + .map(|(i, mut id)| { + id.ident = Some(Ident::new(&format!("arg{}", i), Span::call_site())); + id + }); + + quote!( + #variant_doc + #[allow(clippy::too_many_arguments)] + fn #variant_func_name(&self, #(#variant_attr,)* #maybe_coins_attr) -> Result<#response, ::cw_orch::core::CwEnvError> { + let msg = #name::#variant_name ( + #(#variant_ident_content_names,)* + ); + >::#func_name(self, &msg.into(),#passed_coins) + } + ) + }, + Fields::Unit => { + + quote!( + #variant_doc + fn #variant_func_name(&self, #maybe_coins_attr) -> Result<#response, ::cw_orch::core::CwEnvError> { + let msg = #name::#variant_name; + >::#func_name(self, &msg.into(),#passed_coins) + } + ) + } + Fields::Named(variant_fields) => { + let is_attributes_sorted = process_sorting(&input.attrs); + if is_attributes_sorted{ + // sort fields on field name + LexiographicMatching::default().visit_fields_named_mut(variant_fields); + } + + // remove attributes from fields + variant_fields.named.iter_mut().for_each(|f| f.attrs = vec![]); + + // Parse these fields as arguments to function + let variant_fields = variant_fields.named.clone(); + let variant_idents = variant_fields.iter().map(|f|f.ident.clone().unwrap()); + + let variant_attr = variant_fields.iter(); + quote!( + #variant_doc + #[allow(clippy::too_many_arguments)] + fn #variant_func_name(&self, #(#variant_attr,)* #maybe_coins_attr) -> Result<#response, ::cw_orch::core::CwEnvError> { + let msg = #name::#variant_name { + #(#variant_idents,)* + }; + >::#func_name(self, &msg.into(),#passed_coins) + } + ) + } + } + }); + + // Generics for the Trait + let mut cw_orch_generics: Generics = parse_quote!(); + cw_orch_generics + .params + .extend(input.generics.params.clone()); + + // Where clause for the Trait + let mut combined_trait_where_clause = { + let (_, ty_generics, where_clause) = input.generics.split_for_impl().clone(); + + // Adding a where clause for the derive message type to implement into the contract message type + let mut clause: WhereClause = + parse_quote!(where #name #ty_generics: Into<#generic_msg_type>); + + // Adding eventual where clauses that were present on the original QueryMsg + if let Some(w) = where_clause { + clause.predicates.extend(w.predicates.clone()); + } + clause + }; + + let bname = Ident::new(&format!("{name}Fns"), name.span()); + let trait_condition = quote!(::cw_orch::core::contract::interface_traits::#trait_name); + + let impl_into_depr = impl_into_deprecation(&input.attrs); + let derived_trait = quote!( + #[cfg(not(target_arch = "wasm32"))] + #impl_into_depr + /// Automatically derived trait that allows you to call the variants of the message directly without the need to construct the struct yourself. + pub trait #bname #cw_orch_generics : #trait_condition #combined_trait_where_clause { + #(#variant_fns)* + } + + #[cfg(target_arch = "wasm32")] + /// Automatically derived trait that allows you to call the variants of the message directly without the need to construct the struct yourself. + pub trait #bname{ + + } + ); + + // Generating the generics for the blanket implementation + let mut supported_contract_generics = cw_orch_generics.clone(); + supported_contract_generics + .params + .push(parse_quote!(SupportedContract)); + + // Generating the where clause for the blanket implementation + combined_trait_where_clause + .predicates + .push(parse_quote!(SupportedContract: #trait_condition)); + + let (support_contract_impl, _, _) = supported_contract_generics.split_for_impl(); + let (_, cw_orch_generics, _) = cw_orch_generics.split_for_impl(); + + let derived_trait_blanket_impl = quote!( + #[automatically_derived] + impl #support_contract_impl #bname #cw_orch_generics for SupportedContract + #combined_trait_where_clause {} + ); + + let expand = quote!( + #derived_trait + + #[cfg(not(target_arch = "wasm32"))] + #derived_trait_blanket_impl + ); + + expand.into() +} diff --git a/packages/cw-orch-fns-derive/src/helpers.rs b/packages/cw-orch-fns-derive/src/helpers.rs index 20d4a5641..931dc3895 100644 --- a/packages/cw-orch-fns-derive/src/helpers.rs +++ b/packages/cw-orch-fns-derive/src/helpers.rs @@ -2,10 +2,14 @@ use proc_macro2::TokenStream; use quote::quote; use std::cmp::Ordering; use syn::{ - parse_quote, punctuated::Punctuated, token::Comma, Attribute, Field, FieldsNamed, - GenericArgument, GenericParam, Generics, Lit, Meta, NestedMeta, + punctuated::Punctuated, token::Comma, Attribute, Field, FieldsNamed, Lit, Meta, NestedMeta, }; +pub enum MsgType { + Execute, + Query, +} + pub(crate) fn process_fn_name(v: &syn::Variant) -> String { for attr in &v.attrs { if let Ok(Meta::List(list)) = attr.parse_meta() { @@ -21,21 +25,6 @@ pub(crate) fn process_fn_name(v: &syn::Variant) -> String { v.ident.to_string() } -pub fn to_generic_arguments(generics: &Generics) -> Punctuated { - generics.params.iter().map(to_generic_argument).collect() -} - -pub fn to_generic_argument(p: &GenericParam) -> GenericArgument { - match p { - GenericParam::Type(t) => { - let ident = &t.ident; - GenericArgument::Type(parse_quote!(#ident)) - } - GenericParam::Lifetime(l) => GenericArgument::Lifetime(l.lifetime.clone()), - GenericParam::Const(c) => GenericArgument::Const(parse_quote!(#c)), - } -} - pub(crate) fn process_sorting(attrs: &Vec) -> bool { // If the disable_fields_sorting attribute is enabled, we return false, no sorting should be done for attr in attrs { diff --git a/packages/cw-orch-fns-derive/src/lib.rs b/packages/cw-orch-fns-derive/src/lib.rs index b9de0ee00..6a4f14ce8 100644 --- a/packages/cw-orch-fns-derive/src/lib.rs +++ b/packages/cw-orch-fns-derive/src/lib.rs @@ -1,13 +1,15 @@ #![recursion_limit = "128"] mod execute_fns; +mod fns_derive; mod helpers; mod query_fns; extern crate proc_macro; +use helpers::MsgType; use proc_macro::TokenStream; -use syn::{parse_macro_input, DeriveInput, ItemEnum}; +use syn::{parse_macro_input, ItemEnum}; #[proc_macro_derive( ExecuteFns, @@ -15,8 +17,8 @@ use syn::{parse_macro_input, DeriveInput, ItemEnum}; )] pub fn cw_orch_execute(input: TokenStream) -> TokenStream { // We only parse and return the modified code if the flag is activated - let ast = parse_macro_input!(input as DeriveInput); - execute_fns::execute_fns_derive(ast) + let ast = parse_macro_input!(input as ItemEnum); + fns_derive::fns_derive(MsgType::Execute, ast) } #[proc_macro_derive( @@ -25,5 +27,5 @@ pub fn cw_orch_execute(input: TokenStream) -> TokenStream { )] pub fn cw_orch_query(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as ItemEnum); - query_fns::query_fns_derive(ast) + fns_derive::fns_derive(MsgType::Query, ast) } diff --git a/packages/cw-orch-fns-derive/src/query_fns.rs b/packages/cw-orch-fns-derive/src/query_fns.rs index 8fa2c91d7..43602a71f 100644 --- a/packages/cw-orch-fns-derive/src/query_fns.rs +++ b/packages/cw-orch-fns-derive/src/query_fns.rs @@ -1,18 +1,9 @@ -extern crate proc_macro; -use crate::helpers::{ - impl_into_deprecation, process_fn_name, process_sorting, to_generic_arguments, - LexiographicMatching, -}; -use convert_case::{Case, Casing}; -use proc_macro::TokenStream; -use proc_macro2::Span; -use quote::{format_ident, quote}; -use syn::{visit_mut::VisitMut, Fields, Ident, ItemEnum, Type}; +use quote::quote; const RETURNS: &str = "returns"; /// Extract the query -> response mapping out of an enum variant. -fn parse_query_type(v: &syn::Variant) -> Type { +pub fn parse_query_type(v: &syn::Variant) -> proc_macro2::TokenStream { let response_ty: syn::Type = v .attrs .iter() @@ -20,159 +11,5 @@ fn parse_query_type(v: &syn::Variant) -> Type { .unwrap_or_else(|| panic!("missing return type for query: {}", v.ident)) .parse_args() .unwrap_or_else(|_| panic!("return for {} must be a type", v.ident)); - response_ty -} - -pub fn query_fns_derive(input: ItemEnum) -> TokenStream { - let name = &input.ident; - let bname = Ident::new(&format!("{name}Fns"), name.span()); - - let generics = input.generics.clone(); - let (_impl_generics, _ty_generics, where_clause) = generics.split_for_impl().clone(); - - let type_generics = to_generic_arguments(&generics); - - let is_attributes_sorted = process_sorting(&input.attrs); - - let variants = input.variants; - - let variant_fns = variants.into_iter().map( |mut variant|{ - let variant_name = variant.ident.clone(); - let response = parse_query_type(&variant); - let mut variant_func_name = - format_ident!("{}", process_fn_name(&variant).to_case(Case::Snake)); - variant_func_name.set_span(variant_name.span()); - - let variant_doc: syn::Attribute = { - let doc = format!("Automatically generated wrapper around {}::{} variant", name, variant_name); - syn::parse_quote!( - #[doc=#doc] - ) - }; - - match &mut variant.fields { - Fields::Unnamed(variant_fields) => { - let mut variant_idents = variant_fields.unnamed.clone(); - - // remove attributes from fields - variant_idents.iter_mut().for_each(|f| f.attrs = vec![]); - - // Parse these fields as arguments to function - - // We need to figure out a parameter name for all fields associated to their types - // They will be numbered from 0 to n-1 - let variant_ident_content_names = variant_idents - .iter() - .enumerate() - .map(|(i, _)| Ident::new(&format!("arg{}", i), Span::call_site())); - - let variant_attr = variant_idents.clone().into_iter() - .enumerate() - .map(|(i, mut id)| { - id.ident = Some(Ident::new(&format!("arg{}", i), Span::call_site())); - id - }); - - quote!( - #variant_doc - #[allow(clippy::too_many_arguments)] - fn #variant_func_name(&self, #(#variant_attr,)*) -> ::core::result::Result<#response, ::cw_orch::core::CwEnvError> { - let msg = #name::#variant_name (#(#variant_ident_content_names,)*); - >::query(self, &msg.into()) - } - ) - } - Fields::Unit => { - quote!( - #variant_doc - fn #variant_func_name(&self) -> ::core::result::Result<#response, ::cw_orch::core::CwEnvError> { - let msg = #name::#variant_name; - >::query(self, &msg.into()) - } - ) - }, - Fields::Named(variant_fields) => { - if is_attributes_sorted{ - // sort fields on field name - LexiographicMatching::default().visit_fields_named_mut(variant_fields); - } - - // remove attributes from fields - variant_fields.named.iter_mut().for_each(|f| f.attrs = vec![]); - - // Parse these fields as arguments to function - let variant_fields = variant_fields.named.clone(); - let variant_idents = variant_fields.iter().map(|f|f.ident.clone().unwrap()); - - let variant_attr = variant_fields.iter(); - quote!( - #variant_doc - #[allow(clippy::too_many_arguments)] - fn #variant_func_name(&self, #(#variant_attr,)*) -> ::core::result::Result<#response, ::cw_orch::core::CwEnvError> { - let msg = #name::#variant_name { - #(#variant_idents,)* - }; - >::query(self, &msg.into()) - } - ) - } - } - }); - - let necessary_trait_where = quote!(#name<#type_generics>: Into); - let combined_trait_where_clause = where_clause - .map(|w| { - quote!( - #w #necessary_trait_where - ) - }) - .unwrap_or(quote!( - where - #necessary_trait_where - )); - - let impl_into_depr = impl_into_deprecation(&input.attrs); - let derived_trait = quote!( - #[cfg(not(target_arch = "wasm32"))] - #impl_into_depr - /// Automatically derived trait that allows you to call the variants of the message directly without the need to construct the struct yourself. - pub trait #bname: ::cw_orch::core::contract::interface_traits::CwOrchQuery #combined_trait_where_clause { - #(#variant_fns)* - } - - #[cfg(target_arch = "wasm32")] - /// Automatically derived trait that allows you to call the variants of the message directly without the need to construct the struct yourself. - pub trait #bname{ - - } - ); - - // We need to merge the where clauses (rust doesn't support 2 wheres) - // If there is no where clause, we simply add the necessary where - let necessary_where = quote!(SupportedContract: ::cw_orch::core::contract::interface_traits::CwOrchQuery, #necessary_trait_where); - let combined_where_clause = where_clause - .map(|w| { - quote!( - #w #necessary_where - ) - }) - .unwrap_or(quote!( - where - #necessary_where - )); - - let derived_trait_impl = quote!( - #[automatically_derived] - impl #bname for SupportedContract - #combined_where_clause {} - ); - - let expand = quote!( - #derived_trait - - #[cfg(not(target_arch = "wasm32"))] - #derived_trait_impl - ); - - expand.into() + quote!(#response_ty) }