Skip to content

Commit

Permalink
Add optional fields and cleanup optional functions
Browse files Browse the repository at this point in the history
  • Loading branch information
OpenByteDev committed Jul 29, 2023
1 parent b561556 commit 897f79e
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 59 deletions.
131 changes: 85 additions & 46 deletions dlopen2-derive/src/wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::common::{get_fields, get_non_marker_attrs, has_marker_attr, symbol_name};
use quote::quote;
use syn::{self, BareFnArg, DeriveInput, Field, Type, TypePtr, Visibility};
use syn::{self, BareFnArg, DeriveInput, Field, GenericArgument, Type, TypePtr, Visibility};

const ALLOW_NULL: &str = "dlopen2_allow_null";
const TRAIT_NAME: &str = "WrapperApi";
Expand Down Expand Up @@ -66,13 +66,22 @@ fn field_to_tokens(field: &Field) -> proc_macro2::TokenStream {
}
}
Type::Path(ref path) => {
assert!(path.qself.is_none());
let path = &path.path;
assert!(path.leading_colon.is_none());
let segments = &path.segments;
let segment = segments.first().unwrap();
assert!(segment.ident.to_string() == "Option", "Only bare functions, optional bare functions, references and pointers are allowed in structures implementing WrapperApi trait");
optional_field(field)
let segments_string: Vec<String> = path
.segments
.iter()
.map(|segment| segment.ident.to_string())
.collect();
let segments_str: Vec<&str> = segments_string
.iter()
.map(|segment| segment.as_str())
.collect();
match (path.leading_colon.is_some(), segments_str.as_slice()) {
(_, ["core" | "std", "option", "Option"]) | (false, ["option", "Option"]) | (false, ["Option"]) => {
optional_field(field)
}
_ => panic!("Only bare functions, optional bare functions, references and pointers are allowed in structures implementing WrapperApi trait")
}
}
_ => {
// dbg!();
Expand Down Expand Up @@ -132,7 +141,10 @@ fn optional_field(field: &Field) -> proc_macro2::TokenStream {
}

fn field_to_wrapper(field: &Field) -> Option<proc_macro2::TokenStream> {
let ident = field.ident.as_ref().expect("field must have ident");
let ident = field
.ident
.as_ref()
.expect("Fields must have idents (tuple structs are not supported)");
let attrs = get_non_marker_attrs(field);

match field.ty {
Expand All @@ -148,7 +160,7 @@ fn field_to_wrapper(field: &Field) -> Option<proc_macro2::TokenStream> {
.map(|a| fun_arg_to_tokens(a, &ident.to_string()));
let arg_names = fun.inputs.iter().map(|a| match a.name {
::std::option::Option::Some((ref arg_name, _)) => arg_name,
::std::option::Option::None => panic!("This should never happen"),
::std::option::Option::None => unreachable!(),
});
Some(quote! {
#(#attrs)*
Expand Down Expand Up @@ -191,49 +203,76 @@ fn field_to_wrapper(field: &Field) -> Option<proc_macro2::TokenStream> {
Type::Path(ref path) => {
let path = &path.path;
let segments = &path.segments;
let segment = segments.first().unwrap();
let segment = segments
.iter()
.filter(|segment| segment.ident == "Option")
.next()
.unwrap();
let args = &segment.arguments;
match args {
syn::PathArguments::AngleBracketed(args) => {
let args_inner = &args.args;
let token = quote!(# args_inner);
// panic!("{}", token);
let fun = syn::parse::<syn::TypeBareFn>(token.into()).unwrap();

if fun.variadic.is_some() {
return None;
} else {
let output = &fun.output;
let output = match output {
syn::ReturnType::Default => quote!(-> Option<()>),
syn::ReturnType::Type(_, ty) => quote!( -> Option<#ty>),
};
let unsafety = &fun.unsafety;
let arg_iter = fun
.inputs
.iter()
.map(|a| fun_arg_to_tokens(a, &ident.to_string()));
let arg_names = fun.inputs.iter().map(|a| match a.name {
::std::option::Option::Some((ref arg_name, _)) => arg_name,
::std::option::Option::None => panic!("This should never happen"),
});
let has_ident = quote::format_ident!("has_{}", ident);
return Some(quote! {
#(#attrs)*
pub #unsafety fn #ident (&self, #(#arg_iter),* ) #output {
self.#ident.map(|f| (f)(#(#arg_names),*))
}
#(#attrs)*
pub fn #has_ident (&self) -> bool {
self.#ident.is_some()
syn::PathArguments::AngleBracketed(args) => match args.args.first().unwrap() {
GenericArgument::Type(Type::BareFn(fun)) => {
if fun.variadic.is_some() {
None
} else {
let output = &fun.output;
let output = match output {
syn::ReturnType::Default => quote!(-> Option<()>),
syn::ReturnType::Type(_, ty) => quote!( -> Option<#ty>),
};
let unsafety = &fun.unsafety;
let arg_iter = fun
.inputs
.iter()
.map(|a| fun_arg_to_tokens(a, &ident.to_string()));
let arg_names = fun.inputs.iter().map(|a| match a.name {
::std::option::Option::Some((ref arg_name, _)) => arg_name,
::std::option::Option::None => unreachable!(),
});
let has_ident = quote::format_ident!("has_{}", ident);
Some(quote! {
#(#attrs)*
pub #unsafety fn #ident (&self, #(#arg_iter),* ) #output {
self.#ident.map(|f| (f)(#(#arg_names),*))
}
#(#attrs)*
pub fn #has_ident (&self) -> bool {
self.#ident.is_some()
}
})
}
}
GenericArgument::Type(Type::Reference(ref_ty)) => {
let ty = &ref_ty.elem;
match ref_ty.mutability {
Some(_token) => {
let mut_ident = &format!("{}", ident);
let method_name = syn::Ident::new(mut_ident, ident.span());
Some(quote! {
#(#attrs)*
pub fn #method_name (&mut self) -> ::core::option::Option<&mut #ty> {
if let Some(&mut ref mut val) = self.#ident {
Some(val)
} else {
None
}
}
})
}
});
None => Some(quote! {
#(#attrs)*
pub fn #ident (&self) -> ::core::option::Option<& #ty> {
self.#ident
}
}),
}
}
}
_ => panic!("Unknown optional type, this should not happen!"),
_ => panic!("Unsupported field type"),
},
_ => panic!("Unknown optional type!"),
}
}
_ => panic!("Unknown field type, this should not happen!"),
_ => panic!("Unsupported field type"),
}
}

Expand Down
19 changes: 12 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,30 @@ I hope that this library will help you to quickly get what you need and avoid er
```no_run
use dlopen2::wrapper::{Container, WrapperApi};
use std::os::raw::c_int;
#[derive(WrapperApi)]
struct Api<'a> {
example_rust_fun: fn(arg: i32) -> u32,
example_c_fun: unsafe extern "C" fn(),
// A function may not exist in the library.
example_c_option_fun: Option<unsafe extern "C" fn() -> c_int>,
example_reference: &'a mut i32,
// A function or field may not always exist in the library.
example_c_fun_option: Option<unsafe extern "C" fn()>,
example_reference_option: Option<&'a mut i32>,
}
fn main(){
fn main() {
let mut cont: Container<Api> =
unsafe { Container::load("libexample.so") }.expect("Could not open library or load symbols");
cont.example_rust_fun(5);
unsafe{cont.example_c_fun()};
// option function returns Option<fn_return_type>, it's Option<c_int> here.
unsafe{cont.example_c_option_fun().map(|i| i == 0)};
unsafe { cont.example_c_fun() };
*cont.example_reference_mut() = 5;
// Optional functions return Some(result) if the function is present or None if absent.
unsafe { cont.example_c_fun_option() };
// Optional fields are Some(value) if present and None if absent.
if let Some(example_reference) = cont.example_reference_option() {
*example_reference = 5;
}
}
```
Expand Down
19 changes: 13 additions & 6 deletions tests/wrapper_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ struct Api<'a> {
rust_i32_mut: &'a mut i32,
#[dlopen2_name = "rust_i32_mut"]
rust_i32_ptr: *const i32,
#[dlopen2_name = "rust_i32"]
rust_i32_optional: Option<&'a i32>,
rust_i32_not_found: Option<&'a i32>,
c_int: &'a c_int,
c_struct: &'a SomeData,
rust_str: &'a &'static str,
Expand All @@ -41,25 +44,29 @@ fn open_play_close_wrapper_api() {
let mut cont: Container<Api> =
unsafe { Container::load(lib_path) }.expect("Could not open library or load symbols");

cont.rust_fun_print_something(); //should not crash
cont.rust_fun_print_something(); // should not crash
assert_eq!(cont.rust_fun_add_one(5), 6);
unsafe { cont.c_fun_print_something_else() }; //should not crash
unsafe { cont.c_fun_print_something_else() }; // should not crash
unsafe { cont.c_fun_print_something_else_optional() };
assert!(cont.has_c_fun_print_something_else_optional());
assert_eq!(unsafe { cont.c_fun_add_two(2) }, Some(4));
assert!(!cont.has_c_fun_add_two_not_found());
assert_eq!(unsafe { cont.c_fun_add_two_not_found(2) }, None);
assert_eq!(43, *cont.rust_i32());
assert_eq!(42, *cont.rust_i32_mut_mut());
*cont.rust_i32_mut_mut() = 55; //should not crash
*cont.rust_i32_mut_mut() = 55; // should not crash
assert_eq!(55, unsafe { *cont.rust_i32_ptr() });
//the same with C
assert_eq!(cont.rust_i32_optional(), Some(&43));
assert_eq!(cont.rust_i32_not_found(), None);

// the same with C
assert_eq!(45, *cont.c_int());
//now static c struct

// now static c struct
assert_eq!(1, cont.c_struct().first);
assert_eq!(2, cont.c_struct().second);
//let's play with strings

// let's play with strings
assert_eq!("Hello!", *cont.rust_str());
let converted = cont.c_const_str().to_str().unwrap();
assert_eq!(converted, "Hi!");
Expand Down

0 comments on commit 897f79e

Please sign in to comment.