Skip to content

Commit

Permalink
Allow inline struct, enum, fn and mod inside of declarative modules (#…
Browse files Browse the repository at this point in the history
…3902)

* Inline struct, enum, fn and mod inside of declarative modules

* remove news fragment

---------

Co-authored-by: David Hewitt <mail@davidhewitt.dev>
  • Loading branch information
Tpt and davidhewitt committed Mar 6, 2024
1 parent b08ee4b commit fe84fed
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 32 deletions.
161 changes: 135 additions & 26 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,10 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
for item in &mut *items {
match item {
Item::Use(item_use) => {
let mut is_pyo3 = false;
item_use.attrs.retain(|attr| {
let found = attr.path().is_ident("pymodule_export");
is_pyo3 |= found;
!found
});
if is_pyo3 {
let cfg_attrs = item_use
.attrs
.iter()
.filter(|attr| attr.path().is_ident("cfg"))
.cloned()
.collect::<Vec<_>>();
let is_pymodule_export =
find_and_remove_attribute(&mut item_use.attrs, "pymodule_export");
if is_pymodule_export {
let cfg_attrs = get_cfg_attributes(&item_use.attrs);
extract_use_items(
&item_use.tree,
&cfg_attrs,
Expand All @@ -137,23 +128,116 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
}
}
Item::Fn(item_fn) => {
let mut is_module_init = false;
item_fn.attrs.retain(|attr| {
let found = attr.path().is_ident("pymodule_init");
is_module_init |= found;
!found
});
if is_module_init {
ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one pymodule_init may be specified");
let ident = &item_fn.sig.ident;
ensure_spanned!(
!has_attribute(&item_fn.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
let is_pymodule_init =
find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init");
let ident = &item_fn.sig.ident;
if is_pymodule_init {
ensure_spanned!(
!has_attribute(&item_fn.attrs, "pyfunction"),
item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`"
);
ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
pymodule_init = Some(quote! { #ident(module)?; });
} else {
bail_spanned!(item.span() => "only 'use' statements and and pymodule_init functions are allowed in #[pymodule]")
} else if has_attribute(&item_fn.attrs, "pyfunction") {
module_items.push(ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
}
}
item => {
bail_spanned!(item.span() => "only 'use' statements and and pymodule_init functions are allowed in #[pymodule]")
Item::Struct(item_struct) => {
ensure_spanned!(
!has_attribute(&item_struct.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_struct.attrs, "pyclass") {
module_items.push(item_struct.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
}
}
Item::Enum(item_enum) => {
ensure_spanned!(
!has_attribute(&item_enum.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_enum.attrs, "pyclass") {
module_items.push(item_enum.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
}
}
Item::Mod(item_mod) => {
ensure_spanned!(
!has_attribute(&item_mod.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_mod.attrs, "pymodule") {
module_items.push(item_mod.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
}
}
Item::ForeignMod(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Trait(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Const(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Static(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Macro(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::ExternCrate(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Impl(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::TraitAlias(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Type(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Union(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
_ => (),
}
}

Expand Down Expand Up @@ -355,6 +439,31 @@ fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs
Ok(pyfn_args)
}

fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
attrs
.iter()
.filter(|attr| attr.path().is_ident("cfg"))
.cloned()
.collect()
}

fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bool {
let mut found = false;
attrs.retain(|attr| {
if attr.path().is_ident(ident) {
found = true;
false
} else {
true
}
});
found
}

fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
attrs.iter().any(|attr| attr.path().is_ident(ident))
}

enum PyModulePyO3Option {
Crate(CrateAttribute),
Name(NameAttribute),
Expand Down
36 changes: 32 additions & 4 deletions tests/test_declarative_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ struct ValueClass {
#[pymethods]
impl ValueClass {
#[new]
fn new(value: usize) -> ValueClass {
ValueClass { value }
fn new(value: usize) -> Self {
Self { value }
}
}

Expand Down Expand Up @@ -48,6 +48,33 @@ mod declarative_module {
#[pymodule_export]
use super::{declarative_module2, double, MyError, ValueClass as Value};

#[pymodule]
mod inner {
use super::*;

#[pyfunction]
fn triple(x: usize) -> usize {
x * 3
}

#[pyclass]
struct Struct;

#[pymethods]
impl Struct {
#[new]
fn new() -> Self {
Self
}
}

#[pyclass]
enum Enum {
A,
B,
}
}

#[pymodule_init]
fn init(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("double2", m.getattr("double")?)
Expand All @@ -65,7 +92,6 @@ mod declarative_submodule {
use super::{double, double_value};
}

/// A module written using declarative syntax.
#[pymodule]
#[pyo3(name = "declarative_module_renamed")]
mod declarative_module2 {
Expand All @@ -84,7 +110,7 @@ fn test_declarative_module() {
);

py_assert!(py, m, "m.double(2) == 4");
py_assert!(py, m, "m.double2(3) == 6");
py_assert!(py, m, "m.inner.triple(3) == 9");
py_assert!(py, m, "m.declarative_submodule.double(4) == 8");
py_assert!(
py,
Expand All @@ -97,5 +123,7 @@ fn test_declarative_module() {
py_assert!(py, m, "not hasattr(m, 'LocatedClass')");
#[cfg(not(Py_LIMITED_API))]
py_assert!(py, m, "hasattr(m, 'LocatedClass')");
py_assert!(py, m, "isinstance(m.inner.Struct(), m.inner.Struct)");
py_assert!(py, m, "isinstance(m.inner.Enum.A, m.inner.Enum)");
})
}
2 changes: 1 addition & 1 deletion tests/ui/invalid_pymodule_trait.stderr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error: only 'use' statements and and pymodule_init functions are allowed in #[pymodule]
error: `#[pymodule_export]` may only be used on `use` statements
--> tests/ui/invalid_pymodule_trait.rs:5:5
|
5 | #[pymodule_export]
Expand Down
2 changes: 1 addition & 1 deletion tests/ui/invalid_pymodule_two_pymodule_init.stderr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error: only one pymodule_init may be specified
error: only one `#[pymodule_init]` may be specified
--> tests/ui/invalid_pymodule_two_pymodule_init.rs:11:5
|
11 | fn init2(m: &PyModule) -> PyResult<()> {
Expand Down

0 comments on commit fe84fed

Please sign in to comment.