Skip to content

Commit

Permalink
fix sneaky use-after-free
Browse files Browse the repository at this point in the history
  • Loading branch information
ericniebler committed May 26, 2024
1 parent 5f9b064 commit 61f3d08
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
2 changes: 1 addition & 1 deletion include/stdexec/__detail/__schedule_from.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ namespace stdexec {
using receiver_concept = receiver_t;

void set_value() noexcept {
__state_->__data_.visit(__make_visitor_fn(__state_), __state_->__data_);
__state_->__data_.visit(__schfr::__make_visitor_fn(__state_), __state_->__data_);
}

template <class _Error>
Expand Down
34 changes: 20 additions & 14 deletions include/stdexec/__detail/__variant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,19 @@ namespace stdexec {
STDEXEC_ATTRIBUTE((host, device))
void
__destroy() noexcept {
if (__index != __variant_npos) {
auto __local_index = std::exchange(__index, __variant_npos);
if (__variant_npos != __local_index) {
#if STDEXEC_NVHPC()
// Unknown nvc++ name lookup bug
((_Idx == __index ? get<_Idx>()._Ts::~_Ts() : void(0)), ...);
((_Idx == __local_index ? reinterpret_cast<const __at<_Idx> *>(__storage)->_Ts::~_Ts()
: void(0)),
...);
#else
// casting the destructor expression to void is necessary for MSVC in
// /permissive- mode.
((_Idx == __index ? void(get<_Idx>().~_Ts()) : void(0)), ...);
((_Idx == __local_index ? void(reinterpret_cast<const __at<_Idx> *>(__storage)->~_Ts())
: void(0)),
...);
#endif
}
}
Expand Down Expand Up @@ -94,54 +99,55 @@ namespace stdexec {
template <class _Ty, class... _As>
STDEXEC_ATTRIBUTE((host, device))
_Ty &
emplace(_As &&...as) //
emplace(_As &&...__as) //
noexcept(__nothrow_constructible_from<_Ty, _As...>) {
constexpr std::size_t __new_index = stdexec::__index_of<_Ty, _Ts...>();
static_assert(__new_index != __variant_npos, "Type not in variant");

__destroy();
::new (__storage) _Ty{static_cast<_As &&>(as)...};
::new (__storage) _Ty{static_cast<_As &&>(__as)...};
__index = __new_index;
return *reinterpret_cast<_Ty *>(__storage);
}

template <std::size_t _Ny, class... _As>
STDEXEC_ATTRIBUTE((host, device))
__at<_Ny> &
emplace(_As &&...as) //
emplace(_As &&...__as) //
noexcept(__nothrow_constructible_from<__at<_Ny>, _As...>) {
static_assert(_Ny < sizeof...(_Ts), "variant index is too large");

__destroy();
::new (__storage) __at<_Ny>{static_cast<_As &&>(as)...};
::new (__storage) __at<_Ny>{static_cast<_As &&>(__as)...};
__index = _Ny;
return *reinterpret_cast<__at<_Ny> *>(__storage);
}

template <class _Fn, class... _As>
STDEXEC_ATTRIBUTE((host, device))
auto
emplace_from(_Fn &&fn, _As &&...as) //
emplace_from(_Fn &&__fn, _As &&...__as) //
noexcept(__nothrow_callable<_Fn, _As...>) -> __call_result_t<_Fn, _As...> & {
using __result_t = __call_result_t<_Fn, _As...>;
constexpr std::size_t __new_index = stdexec::__index_of<__result_t, _Ts...>();
static_assert(__new_index != __variant_npos, "Type not in variant");

__destroy();
::new (__storage) __result_t(static_cast<_Fn &&>(fn)(static_cast<_As &&>(as)...));
::new (__storage) __result_t(static_cast<_Fn &&>(__fn)(static_cast<_As &&>(__as)...));
__index = __new_index;
return *reinterpret_cast<__result_t *>(__storage);
}

template <class _Fn, class _Self, class... _As>
STDEXEC_ATTRIBUTE((host, device))
static void
visit(_Fn &&fn, _Self &&self, _As &&...as) //
visit(_Fn &&__fn, _Self &&__self, _As &&...__as) //
noexcept((__nothrow_callable<_Fn, _As..., __copy_cvref_t<_Self, _Ts>> && ...)) {
STDEXEC_ASSERT(self.__index != __variant_npos);
((_Idx == self.__index ? static_cast<_Fn &&>(fn)(
static_cast<_As &&>(as)..., static_cast<_Self &&>(self).template get<_Idx>())
: void()),
STDEXEC_ASSERT(__self.__index != __variant_npos);
auto __index = __self.__index; // make it local so we don't access it after it's deleted.
((_Idx == __index ? static_cast<_Fn &&>(__fn)(
static_cast<_As &&>(__as)..., static_cast<_Self &&>(__self).template get<_Idx>())
: void()),
...);
}

Expand Down

0 comments on commit 61f3d08

Please sign in to comment.