Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpp interface improvements #1789

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions enzyme/Enzyme/Clang/include_utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ Return __enzyme_fwddiff(T...);

namespace enzyme {

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"

struct nodiff{};

template<bool ReturnPrimal = false>
Expand All @@ -45,31 +48,42 @@ namespace enzyme {
template < typename T >
struct Active{
T value;
Active(T &&v) : value(v) {}
operator T&() { return value; }
};

template < typename T >
struct Duplicated{
T value;
T shadow;
Duplicated(T &&v, T&& s) : value(v), shadow(s) {}
};

template < typename T >
struct DuplicatedNoNeed{
T value;
T shadow;
DuplicatedNoNeed(T &&v, T&& s) : value(v), shadow(s) {}
};

template < typename T >
struct Const{
T value;
Const(T &&v) : value(v) {}
operator T&() { return value; }
};

// CTAD available in C++17 or later
#if __cplusplus >= 201703L
template < typename T >
Active(T) -> Active<T>;

template < typename T >
Const(T) -> Const<T>;

template < typename T >
Duplicated(T,T) -> Duplicated<T>;

template < typename T >
DuplicatedNoNeed(T,T) -> DuplicatedNoNeed<T>;
#endif

template < typename T >
struct type_info {
static constexpr bool is_active = false;
Expand Down Expand Up @@ -343,7 +357,9 @@ namespace enzyme {
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return autodiff_impl<return_type, DiffMode, function, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_args(args)...));
}
}
#pragma clang diagnostic pop

} // namespace enzyme
}]>;

def : Headers<"/enzymeroot/enzyme/type_traits", [{
Expand Down Expand Up @@ -410,13 +426,17 @@ def : Headers<"/enzymeroot/enzyme/tuple", [{
// constexpr support for std::tuple). Owning the implementation lets
// us add __host__ __device__ annotations to any part of it

#include <cstddef> // for std::size_t
#include <utility> // for std::integer_sequence

#include <enzyme/type_traits>

#define _NOEXCEPT noexcept
namespace enzyme {

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"

template <int i>
struct Index {};

Expand Down Expand Up @@ -468,10 +488,10 @@ template <typename Tuple>
struct tuple_size;

template <typename... T>
struct tuple_size<tuple<T...>> : std::integral_constant<size_t, sizeof...(T)> {};
struct tuple_size<tuple<T...>> : std::integral_constant<std::size_t, sizeof...(T)> {};

template <typename Tuple>
static constexpr size_t tuple_size_v = tuple_size<Tuple>::value;
static constexpr std::size_t tuple_size_v = tuple_size<Tuple>::value;

template <typename... T>
__attribute__((always_inline))
Expand All @@ -484,7 +504,7 @@ namespace impl {
template <typename index_seq>
struct make_tuple_from_fwd_tuple;

template <size_t... indices>
template <std::size_t... indices>
struct make_tuple_from_fwd_tuple<std::index_sequence<indices...>> {
template <typename FWD_TUPLE>
__attribute__((always_inline))
Expand All @@ -499,12 +519,12 @@ struct concat_with_fwd_tuple;
template < typename Tuple >
using iseq = std::make_index_sequence<tuple_size_v< enzyme::remove_cvref_t< Tuple > > >;

template <size_t... fwd_indices, size_t... indices>
template <std::size_t... fwd_indices, std::size_t... indices>
struct concat_with_fwd_tuple<std::index_sequence<fwd_indices...>, std::index_sequence<indices...>> {
template <typename FWD_TUPLE, typename TUPLE>
__attribute__((always_inline))
static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) {
return forward_as_tuple(get<fwd_indices>(impl::forward<FWD_TUPLE>(fwd))..., get<indices>(impl::forward<TUPLE>(t))...);
return enzyme::forward_as_tuple(get<fwd_indices>(impl::forward<FWD_TUPLE>(fwd))..., get<indices>(impl::forward<TUPLE>(t))...);
}
};

Expand All @@ -528,6 +548,8 @@ constexpr auto tuple_cat(Tuples&&... tuples) {
return impl::tuple_cat(impl::forward<Tuples>(tuples)...);
}

#pragma clang diagnostic pop

} // namespace enzyme
#undef _NOEXCEPT
}]>;
Expand Down
1 change: 1 addition & 0 deletions enzyme/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ configure_lit_site_cfg(
set(ENZYME_TEST_DEPS LLVMEnzyme-${LLVM_VERSION_MAJOR})

add_subdirectory(ActivityAnalysis)
add_subdirectory(CppInterface)
add_subdirectory(TypeAnalysis)
add_subdirectory(Enzyme)
if (${Clang_FOUND})
Expand Down
14 changes: 14 additions & 0 deletions enzyme/test/CppInterface/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.16)

project(CppInterfaceTests LANGUAGES CXX)

file(GLOB_RECURSE cpp_tests ${PROJECT_SOURCE_DIR}/*.cpp)

enable_testing()

foreach(filename ${cpp_tests})
get_filename_component(testname ${filename} NAME_WE)
add_executable(${testname} ${filename})
target_link_libraries(${testname} PUBLIC ClangEnzymeFlags)
add_test(${testname} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${testname})
endforeach(filename ${cpp_tests})
32 changes: 32 additions & 0 deletions enzyme/test/CppInterface/enzyme_types.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include <utility>
#include <iostream>

#include <enzyme/enzyme>

struct move_only_type {
move_only_type() {};
move_only_type(move_only_type && obj) {
std::cout << "calling move_only_type move ctor" << std::endl;
};
};

int main() {
double z = 3.0;
double & rz = z;
move_only_type foo{};

[[maybe_unused]] enzyme::Const<double> c1{3.0};
[[maybe_unused]] enzyme::Const<double> c2{z};
[[maybe_unused]] enzyme::Const<double&> c3{rz};
[[maybe_unused]] enzyme::Const<double&> c4{z};
[[maybe_unused]] enzyme::Const<move_only_type> c5{move_only_type{}};
[[maybe_unused]] enzyme::Const<move_only_type> c6{std::move(foo)};

// CTAD examples for C++17 and later
#if __cplusplus >= 201703L
[[maybe_unused]] enzyme::Const d1{3.0};
[[maybe_unused]] enzyme::Const d2{z};
[[maybe_unused]] enzyme::Const d3{rz};
#endif

}
19 changes: 19 additions & 0 deletions enzyme/test/CppInterface/forward_mode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <iostream>

#include "minimal_test_framework.hpp"

#include <enzyme/enzyme>

double square(double x) { return x * x; }

double dsquare(double x) {
return enzyme::autodiff<enzyme::Forward>((void*) square, enzyme::Duplicated{x, 1.0});
// return 1.0;
}

int main() {
for(double i=1; i<5; i++) {
EXPECT(dsquare(i) == 2 * i);
}
return any_tests_failed;
}
26 changes: 26 additions & 0 deletions enzyme/test/CppInterface/gh_issue_1785.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <memory>
#include <vector>
#include <enzyme/enzyme>

int main() {
auto elasticity_kernel = [](const std::vector<double> &dudxi,
const std::vector<double> &J,
const double &w)
{
auto r = dudxi;
return r;
};

std::vector<double> dudxi(4), s_dudxi(4), J(4);
double w = 1.0;

enzyme::get<0>
(enzyme::autodiff<enzyme::Forward,
enzyme::DuplicatedNoNeed<std::vector<double>>>
(+elasticity_kernel,
enzyme::Duplicated<std::vector<double> *>(&dudxi, &s_dudxi),
enzyme::Const<std::vector<double> *>{&J},
enzyme::Const<double*>{&w}));

return 0;
}
9 changes: 9 additions & 0 deletions enzyme/test/CppInterface/minimal_test_framework.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

bool any_tests_failed = false;

#define EXPECT(boolean) \
if (!(boolean)) { \
std::cout << "Test failure on " << __FILE__ << ":" << __LINE__ << ", EXPECT(" << (#boolean) << ")" << std::endl; \
any_tests_failed = true; \
}
Loading