Skip to content

Commit

Permalink
improve error for invalid #[classmethod] receivers
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Nov 24, 2023
1 parent aba3a35 commit 5ac56b8
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 177 deletions.
71 changes: 36 additions & 35 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ pub enum FnType {
Setter(SelfType),
Fn(SelfType),
FnNew,
FnNewClass,
FnClass,
FnNewClass(Span),
FnClass(Span),
FnStatic,
FnModule(Span),
ClassAttribute,
Expand All @@ -91,8 +91,8 @@ impl FnType {
FnType::Getter(_)
| FnType::Setter(_)
| FnType::Fn(_)
| FnType::FnClass
| FnType::FnNewClass
| FnType::FnClass(_)
| FnType::FnNewClass(_)
| FnType::FnModule(_) => true,
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => false,
}
Expand All @@ -111,10 +111,12 @@ impl FnType {
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => {
quote!()
}
FnType::FnClass | FnType::FnNewClass => {
quote! {
FnType::FnClass(span) | FnType::FnNewClass(span) => {
let py = syn::Ident::new("py", Span::call_site());
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
quote_spanned! { *span =>
#[allow(clippy::useless_conversion)]
::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(py, _slf as *mut _pyo3::ffi::PyTypeObject)),
::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(#py, #slf.cast())),
}
}
FnType::FnModule(span) => {
Expand Down Expand Up @@ -306,7 +308,7 @@ impl<'a> FnSpec<'a> {
FunctionSignature::from_arguments(arguments)?
};

let convention = if matches!(fn_type, FnType::FnNew | FnType::FnNewClass) {
let convention = if matches!(fn_type, FnType::FnNew | FnType::FnNewClass(_)) {
CallingConvention::TpNew
} else {
CallingConvention::from_signature(&signature)
Expand Down Expand Up @@ -355,36 +357,40 @@ impl<'a> FnSpec<'a> {
.map(|stripped| syn::Ident::new(stripped, name.span()))
};

let mut set_name_to_new = || {
if let Some(name) = &python_name {
bail_spanned!(name.span() => "`name` not allowed with `#[new]`");
}
*python_name = Some(syn::Ident::new("__new__", Span::call_site()));
Ok(())
};

let fn_type = match method_attributes.as_mut_slice() {
[] => FnType::Fn(parse_receiver(
"static method needs #[staticmethod] attribute",
)?),
[MethodTypeAttribute::StaticMethod(_)] => FnType::FnStatic,
[MethodTypeAttribute::ClassAttribute(_)] => FnType::ClassAttribute,
[MethodTypeAttribute::New(_)]
| [MethodTypeAttribute::New(_), MethodTypeAttribute::ClassMethod(_)]
| [MethodTypeAttribute::ClassMethod(_), MethodTypeAttribute::New(_)] => {
if let Some(name) = &python_name {
bail_spanned!(name.span() => "`name` not allowed with `#[new]`");
}
*python_name = Some(syn::Ident::new("__new__", Span::call_site()));
if matches!(method_attributes.as_slice(), [MethodTypeAttribute::New(_)]) {
FnType::FnNew
} else {
FnType::FnNewClass
}
[MethodTypeAttribute::New(_)] => {
set_name_to_new()?;
FnType::FnNew
}
[MethodTypeAttribute::New(_), MethodTypeAttribute::ClassMethod(span)]
| [MethodTypeAttribute::ClassMethod(span), MethodTypeAttribute::New(_)] => {
set_name_to_new()?;
FnType::FnNewClass(*span)
}
[MethodTypeAttribute::ClassMethod(_)] => {
// Add a helpful hint if the classmethod doesn't look like a classmethod
match sig.inputs.first() {
let span = match sig.inputs.first() {
// Don't actually bother checking the type of the first argument, the compiler
// will error on incorrect type.
Some(syn::FnArg::Typed(_)) => {}
Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
sig.inputs.span() => "Expected `cls: &PyType` as the first argument to `#[classmethod]`"
sig.paren_token.span.join() => "Expected `&PyType` or `Py<PyType>` as the first argument to `#[classmethod]`"
),
}
FnType::FnClass
};
FnType::FnClass(span)
}
[MethodTypeAttribute::Getter(_, name)] => {
if let Some(name) = name.take() {
Expand Down Expand Up @@ -516,17 +522,12 @@ impl<'a> FnSpec<'a> {
}
CallingConvention::TpNew => {
let (arg_convert, args) = impl_arg_params(self, cls, false)?;
let call = match &self.tp {
FnType::FnNew => quote! { #rust_name(#(#args),*) },
FnType::FnNewClass => {
quote! { #rust_name(_pyo3::types::PyType::from_type_ptr(py, subtype), #(#args),*) }
}
x => panic!("Only `FnNew` or `FnNewClass` may use the `TpNew` calling convention. Got: {:?}", x),
};
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
let call = quote! { #rust_name(#self_arg #(#args),*) };
quote! {
unsafe fn #ident(
py: _pyo3::Python<'_>,
subtype: *mut _pyo3::ffi::PyTypeObject,
_slf: *mut _pyo3::ffi::PyTypeObject,
_args: *mut _pyo3::ffi::PyObject,
_kwargs: *mut _pyo3::ffi::PyObject
) -> _pyo3::PyResult<*mut _pyo3::ffi::PyObject> {
Expand All @@ -535,7 +536,7 @@ impl<'a> FnSpec<'a> {
#arg_convert
let result = #call;
let initializer: _pyo3::PyClassInitializer::<#cls> = result.convert(py)?;
let cell = initializer.create_cell_from_subtype(py, subtype)?;
let cell = initializer.create_cell_from_subtype(py, _slf)?;
::std::result::Result::Ok(cell as *mut _pyo3::ffi::PyObject)
}
}
Expand Down Expand Up @@ -634,7 +635,7 @@ impl<'a> FnSpec<'a> {
FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => return None,
FnType::Fn(_) => Some("self"),
FnType::FnModule(_) => Some("module"),
FnType::FnClass | FnType::FnNewClass => Some("cls"),
FnType::FnClass(_) | FnType::FnNewClass(_) => Some("cls"),
FnType::FnStatic | FnType::FnNew => None,
};

Expand Down
2 changes: 1 addition & 1 deletion pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ pub fn impl_wrap_pyfunction(
let span = match func.sig.inputs.first() {
Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
func.span() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
),
};
method::FnType::FnModule(span)
Expand Down
6 changes: 3 additions & 3 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ pub fn gen_py_method(
&spec.get_doc(meth_attrs),
None,
)?),
(_, FnType::FnClass) => GeneratedPyMethod::Method(impl_py_method_def(
(_, FnType::FnClass(_)) => GeneratedPyMethod::Method(impl_py_method_def(
cls,
spec,
&spec.get_doc(meth_attrs),
Expand All @@ -237,7 +237,7 @@ pub fn gen_py_method(
Some(quote!(_pyo3::ffi::METH_STATIC)),
)?),
// special prototypes
(_, FnType::FnNew) | (_, FnType::FnNewClass) => {
(_, FnType::FnNew) | (_, FnType::FnNewClass(_)) => {
GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?)
}

Expand Down Expand Up @@ -311,7 +311,7 @@ pub fn impl_py_method_def(
let add_flags = flags.map(|flags| quote!(.flags(#flags)));
let methoddef_type = match spec.tp {
FnType::FnStatic => quote!(Static),
FnType::FnClass => quote!(Class),
FnType::FnClass(_) => quote!(Class),
_ => quote!(Method),
};
let methoddef = spec.get_methoddef(quote! { #cls::#wrapper_ident }, doc);
Expand Down
2 changes: 1 addition & 1 deletion tests/test_compile_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
fn test_compile_errors() {
let t = trybuild::TestCases::new();

t.compile_fail("tests/ui/invalid_need_module_arg_position.rs");
t.compile_fail("tests/ui/invalid_property_args.rs");
t.compile_fail("tests/ui/invalid_proto_pymethods.rs");
t.compile_fail("tests/ui/invalid_pyclass_args.rs");
Expand All @@ -14,6 +13,7 @@ fn test_compile_errors() {
t.compile_fail("tests/ui/invalid_pyfunction_signatures.rs");
#[cfg(any(not(Py_LIMITED_API), Py_3_11))]
t.compile_fail("tests/ui/invalid_pymethods_buffer.rs");
t.compile_fail("tests/ui/invalid_pymethods_duplicates.rs");
t.compile_fail("tests/ui/invalid_pymethod_names.rs");
t.compile_fail("tests/ui/invalid_pymodule_args.rs");
t.compile_fail("tests/ui/reject_generics.rs");
Expand Down
12 changes: 0 additions & 12 deletions tests/ui/invalid_need_module_arg_position.rs

This file was deleted.

14 changes: 0 additions & 14 deletions tests/ui/invalid_need_module_arg_position.stderr

This file was deleted.

8 changes: 8 additions & 0 deletions tests/ui/invalid_pyfunctions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,12 @@ fn destructured_argument((a, b): (i32, i32)) {}
#[pyfunction]
fn function_with_required_after_option(_opt: Option<i32>, _x: i32) {}

#[pyfunction(pass_module)]
fn pass_module_but_no_arguments<'py>() {}

#[pyfunction(pass_module)]
fn first_argument_not_module<'py>(string: &str, module: &'py PyModule) -> PyResult<&'py str> {
module.name()
}

fn main() {}
21 changes: 21 additions & 0 deletions tests/ui/invalid_pyfunctions.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,24 @@ error: required arguments after an `Option<_>` argument are ambiguous
|
16 | fn function_with_required_after_option(_opt: Option<i32>, _x: i32) {}
| ^^^

error: expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`
--> tests/ui/invalid_pyfunctions.rs:19:37
|
19 | fn pass_module_but_no_arguments<'py>() {}
| ^^

error[E0277]: the trait bound `&str: From<&pyo3::prelude::PyModule>` is not satisfied
--> tests/ui/invalid_pyfunctions.rs:22:43
|
22 | fn first_argument_not_module<'py>(string: &str, module: &'py PyModule) -> PyResult<&'py str> {
| ^ the trait `From<&pyo3::prelude::PyModule>` is not implemented for `&str`
|
= help: the following other types implement trait `From<T>`:
<String as From<char>>
<String as From<Box<str>>>
<String as From<Cow<'a, str>>>
<String as From<&str>>
<String as From<&mut str>>
<String as From<&String>>
= note: required for `&pyo3::prelude::PyModule` to implement `Into<&str>`
42 changes: 16 additions & 26 deletions tests/ui/invalid_pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ impl MyClass {
fn classmethod_with_receiver(&self) {}
}

#[pymethods]
impl MyClass {
#[classmethod]
fn classmethod_missing_argument() -> Self {
Self {}
}
}

#[pymethods]
impl MyClass {
#[classmethod]
fn classmethod_wrong_first_argument(_x: i32) -> Self {
Self {}
}
}

#[pymethods]
impl MyClass {
#[getter(x)]
Expand Down Expand Up @@ -172,32 +188,6 @@ impl MyClass {
fn method_self_by_value(self) {}
}

struct TwoNew {}

#[pymethods]
impl TwoNew {
#[new]
fn new_1() -> Self {
Self {}
}

#[new]
fn new_2() -> Self {
Self {}
}
}

struct DuplicateMethod {}

#[pymethods]
impl DuplicateMethod {
#[pyo3(name = "func")]
fn func_a(&self) {}

#[pyo3(name = "func")]
fn func_b(&self) {}
}

macro_rules! macro_invocation {
() => {};
}
Expand Down
Loading

0 comments on commit 5ac56b8

Please sign in to comment.