Skip to content

Commit

Permalink
macros: clean up protocol argument extraction a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Oct 31, 2021
1 parent bfe7086 commit 6a3e1e7
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 132 deletions.
37 changes: 6 additions & 31 deletions pyo3-macros-backend/src/params.rs
Expand Up @@ -4,7 +4,7 @@ use crate::{
attributes::FromPyWithAttribute,
method::{FnArg, FnSpec},
pyfunction::Argument,
utils::unwrap_ty_group,
utils::{remove_lifetime, replace_self, unwrap_ty_group},
};
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};
Expand Down Expand Up @@ -267,7 +267,11 @@ fn impl_arg_param(
};

return if let syn::Type::Reference(tref) = unwrap_ty_group(arg.optional.unwrap_or(ty)) {
let (tref, mut_) = preprocess_tref(tref, self_);
let mut tref = remove_lifetime(tref);
if let Some(cls) = self_ {
replace_self(&mut tref.elem, cls);
}
let mut_ = tref.mutability;
let (target_ty, borrow_tmp) = if arg.optional.is_some() {
// Get Option<&T> from Option<PyRef<T>>
(
Expand Down Expand Up @@ -295,33 +299,4 @@ fn impl_arg_param(
let #arg_name = #arg_value_or_default;
})
};

/// Replace `Self`, remove lifetime and get mutability from the type
fn preprocess_tref(
tref: &syn::TypeReference,
self_: Option<&syn::Type>,
) -> (syn::TypeReference, Option<syn::token::Mut>) {
let mut tref = tref.to_owned();
if let Some(syn::Type::Path(tpath)) = self_ {
replace_self(&mut tref, &tpath.path);
}
tref.lifetime = None;
let mut_ = tref.mutability;
(tref, mut_)
}

/// Replace `Self` with the exact type name since it is used out of the impl block
fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) {
match &mut *tref.elem {
syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path),
syn::Type::Path(tpath) => {
if let Some(ident) = tpath.path.get_ident() {
if ident == "Self" {
tpath.path = self_path.to_owned();
}
}
}
_ => {}
}
}
}
186 changes: 85 additions & 101 deletions pyo3-macros-backend/src/pymethod.rs
Expand Up @@ -4,7 +4,9 @@ use std::borrow::Cow;

use crate::attributes::NameAttribute;
use crate::method::{CallingConvention, ExtractErrorMode};
use crate::utils::{ensure_not_async_fn, unwrap_ty_group, PythonDoc};
use crate::utils::{
ensure_not_async_fn, remove_lifetime, replace_self, unwrap_ty_group, PythonDoc,
};
use crate::{deprecations::Deprecations, utils};
use crate::{
method::{FnArg, FnSpec, FnType, SelfType},
Expand Down Expand Up @@ -424,7 +426,7 @@ const __HASH__: SlotDef = SlotDef::new("Py_tp_hash", "hashfunc")
));
const __RICHCMP__: SlotDef = SlotDef::new("Py_tp_richcompare", "richcmpfunc")
.extract_error_mode(ExtractErrorMode::NotImplemented)
.arguments(&[Ty::ObjectOrNotImplemented, Ty::CompareOp]);
.arguments(&[Ty::Object, Ty::CompareOp]);
const __GET__: SlotDef =
SlotDef::new("Py_tp_descr_get", "descrgetfunc").arguments(&[Ty::Object, Ty::Object]);
const __ITER__: SlotDef = SlotDef::new("Py_tp_iter", "getiterfunc");
Expand Down Expand Up @@ -452,55 +454,55 @@ const __FLOAT__: SlotDef = SlotDef::new("Py_nb_float", "unaryfunc");
const __BOOL__: SlotDef = SlotDef::new("Py_nb_bool", "inquiry").ret_ty(Ty::Int);

const __IADD__: SlotDef = SlotDef::new("Py_nb_inplace_add", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __ISUB__: SlotDef = SlotDef::new("Py_nb_inplace_subtract", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __IMUL__: SlotDef = SlotDef::new("Py_nb_inplace_multiply", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __IMATMUL__: SlotDef = SlotDef::new("Py_nb_inplace_matrix_multiply", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __ITRUEDIV__: SlotDef = SlotDef::new("Py_nb_inplace_true_divide", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __IFLOORDIV__: SlotDef = SlotDef::new("Py_nb_inplace_floor_divide", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __IMOD__: SlotDef = SlotDef::new("Py_nb_inplace_remainder", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __IPOW__: SlotDef = SlotDef::new("Py_nb_inplace_power", "ternaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object, Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __ILSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_lshift", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __IRSHIFT__: SlotDef = SlotDef::new("Py_nb_inplace_rshift", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __IAND__: SlotDef = SlotDef::new("Py_nb_inplace_and", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __IXOR__: SlotDef = SlotDef::new("Py_nb_inplace_xor", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();
const __IOR__: SlotDef = SlotDef::new("Py_nb_inplace_or", "binaryfunc")
.arguments(&[Ty::ObjectOrNotImplemented])
.arguments(&[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.return_self();

Expand Down Expand Up @@ -548,7 +550,6 @@ fn pyproto(method_name: &str) -> Option<&'static SlotDef> {
#[derive(Clone, Copy)]
enum Ty {
Object,
ObjectOrNotImplemented,
NonNullObject,
CompareOp,
Int,
Expand All @@ -560,7 +561,7 @@ enum Ty {
impl Ty {
fn ffi_type(self) -> TokenStream {
match self {
Ty::Object | Ty::ObjectOrNotImplemented => quote! { *mut ::pyo3::ffi::PyObject },
Ty::Object => quote! { *mut ::pyo3::ffi::PyObject },
Ty::NonNullObject => quote! { ::std::ptr::NonNull<::pyo3::ffi::PyObject> },
Ty::Int | Ty::CompareOp => quote! { ::std::os::raw::c_int },
Ty::PyHashT => quote! { ::pyo3::ffi::Py_hash_t },
Expand All @@ -574,95 +575,82 @@ impl Ty {
cls: &syn::Type,
py: &syn::Ident,
ident: &syn::Ident,
target: &syn::Type,
arg: &FnArg,
extract_error_mode: ExtractErrorMode,
) -> TokenStream {
match self {
Ty::Object => {
let extract = extract_from_any(cls, target, ident);
quote! {
let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident);
#extract
}
}
Ty::ObjectOrNotImplemented => {
let extract = if let syn::Type::Reference(tref) = unwrap_ty_group(target) {
let (tref, mut_) = preprocess_tref(tref, cls);
let extract = handle_error(
extract_error_mode,
py,
quote! {
let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = match #ident.extract() {
::std::result::Result::Ok(#ident) => #ident,
::std::result::Result::Err(_) => return ::pyo3::callback::convert(#py, #py.NotImplemented()),
};
let #ident = &#mut_ *#ident;
}
} else {
quote! {
let #ident = match #ident.extract() {
::std::result::Result::Ok(#ident) => #ident,
::std::result::Result::Err(_) => return ::pyo3::callback::convert(#py, #py.NotImplemented()),
};
}
};
quote! {
let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident);
#extract
}
#py.from_borrowed_ptr::<::pyo3::PyAny>(#ident).extract()
},
);
extract_object(cls, arg.ty, ident, extract)
}
Ty::NonNullObject => {
let extract = extract_from_any(cls, target, ident);
let extract = handle_error(
extract_error_mode,
py,
quote! {
#py.from_borrowed_ptr::<::pyo3::PyAny>(#ident.as_ptr()).extract()
},
);
extract_object(cls, arg.ty, ident, extract)
}
Ty::CompareOp => {
let extract = handle_error(
extract_error_mode,
py,
quote! {
::pyo3::class::basic::CompareOp::from_raw(#ident)
.ok_or_else(|| ::pyo3::exceptions::PyValueError::new_err("invalid comparison operator"))
},
);
quote! {
let #ident: &::pyo3::PyAny = #py.from_borrowed_ptr(#ident.as_ptr());
#extract
let #ident = #extract;
}
}
Ty::CompareOp => quote! {
let #ident = ::pyo3::class::basic::CompareOp::from_raw(#ident)
.ok_or_else(|| ::pyo3::exceptions::PyValueError::new_err("invalid comparison operator"))?;
},
Ty::Int | Ty::PyHashT | Ty::PySsizeT | Ty::Void => todo!(),
}
}
}

fn extract_from_any(self_: &syn::Type, target: &syn::Type, ident: &syn::Ident) -> TokenStream {
return if let syn::Type::Reference(tref) = unwrap_ty_group(target) {
let (tref, mut_) = preprocess_tref(tref, self_);
fn handle_error(
extract_error_mode: ExtractErrorMode,
py: &syn::Ident,
extract: TokenStream,
) -> TokenStream {
match extract_error_mode {
ExtractErrorMode::Raise => quote! { #extract? },
ExtractErrorMode::NotImplemented => quote! {
match #extract {
::std::result::Result::Ok(value) => value,
::std::result::Result::Err(_) => { return ::pyo3::callback::convert(#py, #py.NotImplemented()); },
}
},
}
}

fn extract_object(
cls: &syn::Type,
target: &syn::Type,
ident: &syn::Ident,
extract: TokenStream,
) -> TokenStream {
if let syn::Type::Reference(tref) = unwrap_ty_group(target) {
let mut tref = remove_lifetime(tref);
replace_self(&mut tref.elem, cls);
let mut_ = tref.mutability;
quote! {
let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = #ident.extract()?;
let #mut_ #ident: <#tref as ::pyo3::derive_utils::ExtractExt<'_>>::Target = #extract;
let #ident = &#mut_ *#ident;
}
} else {
quote! {
let #ident = #ident.extract()?;
}
};
}

/// Replace `Self`, remove lifetime and get mutability from the type
fn preprocess_tref(
tref: &syn::TypeReference,
self_: &syn::Type,
) -> (syn::TypeReference, Option<syn::token::Mut>) {
let mut tref = tref.to_owned();
if let syn::Type::Path(tpath) = self_ {
replace_self(&mut tref, &tpath.path);
}
tref.lifetime = None;
let mut_ = tref.mutability;
(tref, mut_)
}

/// Replace `Self` with the exact type name since it is used out of the impl block
fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) {
match &mut *tref.elem {
syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path),
syn::Type::Path(tpath) => {
if let Some(ident) = tpath.path.get_ident() {
if ident == "Self" {
tpath.path = self_path.to_owned();
}
}
let #ident = #extract;
}
_ => {}
}
}

Expand Down Expand Up @@ -800,7 +788,8 @@ fn generate_method_body(
) -> Result<TokenStream> {
let self_conversion = spec.tp.self_conversion(Some(cls), extract_error_mode);
let rust_name = spec.name;
let (arg_idents, conversions) = extract_proto_arguments(cls, py, &spec.args, arguments)?;
let (arg_idents, conversions) =
extract_proto_arguments(cls, py, &spec.args, arguments, extract_error_mode)?;
let call = quote! { ::pyo3::callback::convert(#py, #cls::#rust_name(_slf, #(#arg_idents),*)) };
let body = if let Some(return_mode) = return_mode {
return_mode.return_call_output(py, call)
Expand Down Expand Up @@ -883,7 +872,7 @@ const __DELITEM__: SlotFragmentDef = SlotFragmentDef::new("__delitem__", &[Ty::O

macro_rules! binary_num_slot_fragment_def {
($ident:ident, $name:literal) => {
const $ident: SlotFragmentDef = SlotFragmentDef::new($name, &[Ty::ObjectOrNotImplemented])
const $ident: SlotFragmentDef = SlotFragmentDef::new($name, &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
};
Expand Down Expand Up @@ -916,18 +905,12 @@ binary_num_slot_fragment_def!(__RXOR__, "__rxor__");
binary_num_slot_fragment_def!(__OR__, "__or__");
binary_num_slot_fragment_def!(__ROR__, "__ror__");

const __POW__: SlotFragmentDef = SlotFragmentDef::new(
"__pow__",
&[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented],
)
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __RPOW__: SlotFragmentDef = SlotFragmentDef::new(
"__rpow__",
&[Ty::ObjectOrNotImplemented, Ty::ObjectOrNotImplemented],
)
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __POW__: SlotFragmentDef = SlotFragmentDef::new("__pow__", &[Ty::Object, Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __RPOW__: SlotFragmentDef = SlotFragmentDef::new("__rpow__", &[Ty::Object, Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);

fn pyproto_fragment(method_name: &str) -> Option<&'static SlotFragmentDef> {
match method_name {
Expand Down Expand Up @@ -974,6 +957,7 @@ fn extract_proto_arguments(
py: &syn::Ident,
method_args: &[FnArg],
proto_args: &[Ty],
extract_error_mode: ExtractErrorMode,
) -> Result<(Vec<Ident>, TokenStream)> {
let mut arg_idents = Vec::with_capacity(method_args.len());
let mut non_python_args = 0;
Expand All @@ -987,7 +971,7 @@ fn extract_proto_arguments(
let ident = syn::Ident::new(&format!("arg{}", non_python_args), Span::call_site());
let conversions = proto_args.get(non_python_args)
.ok_or_else(|| err_spanned!(arg.ty.span() => format!("Expected at most {} non-python arguments", proto_args.len())))?
.extract(cls, py, &ident, arg.ty);
.extract(cls, py, &ident, arg, extract_error_mode);
non_python_args += 1;
args_conversions.push(conversions);
arg_idents.push(ident);
Expand Down

0 comments on commit 6a3e1e7

Please sign in to comment.