Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff cleanups #138627

Merged
merged 7 commits into from
Mar 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 95 additions & 58 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
@@ -25,6 +25,16 @@ mod llvm_enzyme {

use crate::errors;

pub(crate) fn outer_normal_attr(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about putting it into the rustc_ast crate, but I couldn't find any other users and given jana's attribute rework I don't think it makes sense to add extra functions there, especially given someone is looking into moving autodiff over to the new attribute handling infra.

kind: &P<rustc_ast::NormalAttr>,
id: rustc_ast::AttrId,
span: Span,
) -> rustc_ast::Attribute {
let style = rustc_ast::AttrStyle::Outer;
let kind = rustc_ast::AttrKind::Normal(kind.clone());
rustc_ast::Attribute { kind, id, style, span }
}

// If we have a default `()` return type or explicitley `()` return type,
// then we often can skip doing some work.
fn has_ret(ty: &FnRetTy) -> bool {
@@ -223,20 +233,8 @@ mod llvm_enzyme {
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
.count() as u32;
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
let new_decl_span = d_sig.span;
let d_body = gen_enzyme_body(
ecx,
&x,
n_active,
&sig,
&d_sig,
primal,
&new_args,
span,
sig_span,
new_decl_span,
idents,
errored,
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
);
let d_ident = first_ident(&meta_item_vec[0]);

@@ -268,36 +266,39 @@ mod llvm_enzyme {
};
let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let attr: ast::Attribute = ast::Attribute {
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
id: new_id,
style: ast::AttrStyle::Outer,
span,
};
let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let inline_never: ast::Attribute = ast::Attribute {
kind: ast::AttrKind::Normal(inline_never_attr),
id: new_id,
style: ast::AttrStyle::Outer,
span,
};
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);

// We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
match (attr, item) {
(ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
let a = &a.item.path;
let b = &b.item.path;
a.segments.len() == b.segments.len()
&& a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
}
_ => false,
}
}

// Don't add it multiple times:
let orig_annotatable: Annotatable = match item {
Annotatable::Item(ref mut iitem) => {
if !iitem.attrs.iter().any(|a| a.id == attr.id) {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
iitem.attrs.push(attr);
}
if !iitem.attrs.iter().any(|a| a.id == inline_never.id) {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
iitem.attrs.push(inline_never.clone());
}
Annotatable::Item(iitem.clone())
}
Annotatable::AssocItem(ref mut assoc_item, i @ Impl) => {
if !assoc_item.attrs.iter().any(|a| a.id == attr.id) {
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
assoc_item.attrs.push(attr);
}
if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) {
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
assoc_item.attrs.push(inline_never.clone());
}
Annotatable::AssocItem(assoc_item.clone(), i)
@@ -312,13 +313,7 @@ mod llvm_enzyme {
delim: rustc_ast::token::Delimiter::Parenthesis,
tokens: ts,
});
let d_attr: ast::Attribute = ast::Attribute {
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
id: new_id,
style: ast::AttrStyle::Outer,
span,
};

let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
let d_annotatable = if is_impl {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
let d_fn = P(ast::AssocItem {
@@ -359,30 +354,27 @@ mod llvm_enzyme {
ty
}

/// We only want this function to type-check, since we will replace the body
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
/// so instead we build something that should pass. We also add a inline_asm
/// line, as one more barrier for rustc to prevent inlining of this function.
/// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
/// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
/// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
/// this function (which should never happen, since it is only a placeholder).
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
/// from optimizing any arguments away.
fn gen_enzyme_body(
// Will generate a body of the type:
// ```
// {
// unsafe {
// asm!("NOP");
// }
// ::core::hint::black_box(primal(args));
// ::core::hint::black_box((args, ret));
// <This part remains to be done by following function>
// }
// ```
fn init_body_helper(
ecx: &ExtCtxt<'_>,
x: &AutoDiffAttrs,
n_active: u32,
sig: &ast::FnSig,
d_sig: &ast::FnSig,
span: Span,
primal: Ident,
new_names: &[String],
span: Span,
sig_span: Span,
new_decl_span: Span,
idents: Vec<Ident>,
idents: &[Ident],
errored: bool,
) -> P<ast::Block> {
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
let noop = ast::InlineAsm {
asm_macro: ast::AsmMacro::Asm,
@@ -431,6 +423,51 @@ mod llvm_enzyme {
}
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));

(body, primal_call, black_box_primal_call, blackbox_call_expr)
}

/// We only want this function to type-check, since we will replace the body
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
/// so instead we manually build something that should pass the type checker.
/// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
/// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
/// bug would ever try to accidentially differentiate this placeholder function body.
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
/// from optimizing any arguments away.
fn gen_enzyme_body(
ecx: &ExtCtxt<'_>,
x: &AutoDiffAttrs,
n_active: u32,
sig: &ast::FnSig,
d_sig: &ast::FnSig,
primal: Ident,
new_names: &[String],
span: Span,
sig_span: Span,
idents: Vec<Ident>,
errored: bool,
) -> P<ast::Block> {
let new_decl_span = d_sig.span;

// Just adding some default inline-asm and black_box usages to prevent early inlining
// and optimizations which alter the function signature.
//
// The bb_primal_call is the black_box call of the primal function. We keep it around,
// since it has the convenient property of returning the type of the primal function,
// Remember, we only care to match types here.
// No matter which return we pick, we always wrap it into a std::hint::black_box call,
// to prevent rustc from propagating it into the caller.
let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
ecx,
span,
primal,
new_names,
sig_span,
new_decl_span,
&idents,
errored,
);

if !has_ret(&d_sig.decl.output) {
// there is no return type that we have to match, () works fine.
return body;
@@ -442,7 +479,7 @@ mod llvm_enzyme {

if primal_ret && n_active == 0 && x.mode.is_rev() {
// We only have the primal ret.
body.stmts.push(ecx.stmt_expr(black_box_primal_call));
body.stmts.push(ecx.stmt_expr(bb_primal_call));
return body;
}

@@ -534,11 +571,11 @@ mod llvm_enzyme {
return body;
}
[arg] => {
ret = ecx.expr_call(new_decl_span, blackbox_call_expr, thin_vec![arg.clone()]);
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
}
args => {
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
ret = ecx.expr_call(new_decl_span, blackbox_call_expr, thin_vec![ret_tuple]);
ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
}
}
assert!(has_ret(&d_sig.decl.output));
@@ -551,7 +588,7 @@ mod llvm_enzyme {
ecx: &ExtCtxt<'_>,
span: Span,
primal: Ident,
idents: Vec<Ident>,
idents: &[Ident],
) -> P<ast::Expr> {
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
if has_self {
Loading
Loading