Skip to content

Commit

Permalink
Merge branch 'dev_thread_local' into inferface_eager_boxing
Browse files Browse the repository at this point in the history
  • Loading branch information
lixinqi committed Aug 4, 2021
2 parents 1713683 + b4f2e2a commit 3ba9155
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 0 deletions.
22 changes: 22 additions & 0 deletions oneflow/core/common/cached_caller.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ limitations under the License.
#ifndef ONEFLOW_CORE_COMMON_CACHED_CALLER_H_
#define ONEFLOW_CORE_COMMON_CACHED_CALLER_H_

#include <list>
#include <tuple>
#include <thread>
#include "oneflow/core/common/function_traits.h"
#include "oneflow/core/common/hash_eq_trait_ptr.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/tuple_hash.h"

namespace oneflow {

Expand Down Expand Up @@ -77,6 +82,23 @@ std::function<Ret(const Arg&)> WithResultCached(F f) {
};
}

template<typename T, typename = void>
struct ThreadLocalStruct;

template<typename T, typename... Args>
struct ThreadLocalStruct<T (*)(Args...)> final {
template<T (*func)(Args...)>
static T Call(Args... args) {
using KeyT = std::tuple<typename std::decay<Args>::type...>;
static thread_local std::unordered_map<KeyT, T> map;
const auto& key = KeyT(args...);
auto iter = map.find(key);
if (iter == map.end()) { iter = map.emplace(key, func(args...)).first; }
return iter->second;
}
};
#define THREAD_LOCAL_CACHED(fn_ptr) &ThreadLocalStruct<decltype(fn_ptr)>::template Call<fn_ptr>

} // namespace oneflow

#endif // ONEFLOW_CORE_COMMON_CACHED_CALLER_H_
59 changes: 59 additions & 0 deletions oneflow/core/common/cached_caller_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/cached_caller.h"
#include "oneflow/core/common/util.h"

namespace oneflow {
namespace test {

Maybe<int> Inc(int x) { return x + 1; }

Maybe<int> IncByConstRef(const int& x) { return x + 1; }

TEST(ThreadLocal, scalar) {
auto* CachedInc = THREAD_LOCAL_CACHED(&Inc);

int x = CHECK_JUST(CachedInc(0));
ASSERT_EQ(x, 1);
}

TEST(ThreadLocal, const_ref) {
auto* CachedIncByConstRef = THREAD_LOCAL_CACHED(&IncByConstRef);

int x = CHECK_JUST(CachedIncByConstRef(0));
ASSERT_EQ(x, 1);
}

namespace {

struct Foo {
static Maybe<Foo> New(int x) { return std::shared_ptr<Foo>(new Foo{x}); }

int x;
};

} // namespace

TEST(ThreadLocal, _class) {
auto* CachedFooNew = THREAD_LOCAL_CACHED(&Foo::New);
const auto& foo = CHECK_JUST(CachedFooNew(10));
const auto& bar = CHECK_JUST(CachedFooNew(10));
ASSERT_EQ(foo->x, 10);
ASSERT_TRUE(foo == bar);
}

} // namespace test
} // namespace oneflow
16 changes: 16 additions & 0 deletions oneflow/core/common/thread_local_cache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_
117 changes: 117 additions & 0 deletions oneflow/core/common/tuple_hash.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_TUPLE_HASH_H_
#define ONEFLOW_CORE_COMMON_TUPLE_HASH_H_

#include <tuple>
#include <functional>

namespace std {

template<typename T>
struct hash<std::tuple<T>> final {
size_t operator()(const std::tuple<T>& val) const { return std::hash<T>()(std::get<0>(val)); }
};

template<typename T0, typename T1>
struct hash<std::tuple<T0, T1>> final {
size_t operator()(const std::tuple<T0, T1>& val) const {
return std::hash<T0>()(std::get<0>(val)) ^ std::hash<T1>()(std::get<1>(val));
}
};

template<typename T0, typename T1, typename T2>
struct hash<std::tuple<T0, T1, T2>> final {
size_t operator()(const std::tuple<T0, T1, T2>& val) const {
return std::hash<T0>()(std::get<0>(val)) ^ std::hash<T1>()(std::get<1>(val))
^ std::hash<T2>()(std::get<2>(val));
}
};

template<typename T0, typename T1, typename T2, typename T3>
struct hash<std::tuple<T0, T1, T2, T3>> final {
size_t operator()(const std::tuple<T0, T1, T2, T3>& val) const {
return std::hash<T0>()(std::get<0>(val)) ^ std::hash<T1>()(std::get<1>(val))
^ std::hash<T2>()(std::get<2>(val)) ^ std::hash<T3>()(std::get<3>(val));
}
};

template<typename T0, typename T1, typename T2, typename T3, typename T4>
struct hash<std::tuple<T0, T1, T2, T3, T4>> final {
size_t operator()(const std::tuple<T0, T1, T2, T3, T4>& val) const {
return std::hash<T0>()(std::get<0>(val)) ^ std::hash<T1>()(std::get<1>(val))
^ std::hash<T2>()(std::get<2>(val)) ^ std::hash<T3>()(std::get<3>(val))
^ std::hash<T4>()(std::get<4>(val));
}
};

template<typename T0, typename T1, typename T2, typename T3, typename T4, typename T5>
struct hash<std::tuple<T0, T1, T2, T3, T4, T5>> final {
size_t operator()(const std::tuple<T0, T1, T2, T3, T4, T5>& val) const {
return std::hash<T0>()(std::get<0>(val)) ^ std::hash<T1>()(std::get<1>(val))
^ std::hash<T2>()(std::get<2>(val)) ^ std::hash<T3>()(std::get<3>(val))
^ std::hash<T4>()(std::get<4>(val)) ^ std::hash<T5>()(std::get<5>(val));
}
};

template<typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename T6>
struct hash<std::tuple<T0, T1, T2, T3, T4, T5, T6>> final {
size_t operator()(const std::tuple<T0, T1, T2, T3, T4, T5, T6>& val) const {
return std::hash<T0>()(std::get<0>(val)) ^ std::hash<T1>()(std::get<1>(val))
^ std::hash<T2>()(std::get<2>(val)) ^ std::hash<T3>()(std::get<3>(val))
^ std::hash<T4>()(std::get<4>(val)) ^ std::hash<T5>()(std::get<5>(val))
^ std::hash<T6>()(std::get<6>(val));
}
};

template<typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename T6,
typename T7>
struct hash<std::tuple<T0, T1, T2, T3, T4, T5, T6, T7>> final {
size_t operator()(const std::tuple<T0, T1, T2, T3, T4, T5, T6, T7>& val) const {
return std::hash<T0>()(std::get<0>(val)) ^ std::hash<T1>()(std::get<1>(val))
^ std::hash<T2>()(std::get<2>(val)) ^ std::hash<T3>()(std::get<3>(val))
^ std::hash<T4>()(std::get<4>(val)) ^ std::hash<T5>()(std::get<5>(val))
^ std::hash<T6>()(std::get<6>(val)) ^ std::hash<T7>()(std::get<7>(val));
}
};

template<typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename T6,
typename T7, typename T8>
struct hash<std::tuple<T0, T1, T2, T3, T4, T5, T6, T7, T8>> final {
size_t operator()(const std::tuple<T0, T1, T2, T3, T4, T5, T6, T7, T8>& val) const {
return std::hash<T0>()(std::get<0>(val)) ^ std::hash<T1>()(std::get<1>(val))
^ std::hash<T2>()(std::get<2>(val)) ^ std::hash<T3>()(std::get<3>(val))
^ std::hash<T4>()(std::get<4>(val)) ^ std::hash<T5>()(std::get<5>(val))
^ std::hash<T6>()(std::get<6>(val)) ^ std::hash<T7>()(std::get<7>(val))
^ std::hash<T8>()(std::get<8>(val));
}
};

template<typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename T6,
typename T7, typename T8, typename T9>
struct hash<std::tuple<T0, T1, T2, T3, T4, T5, T6, T7, T8, T9>> final {
size_t operator()(const std::tuple<T0, T1, T2, T3, T4, T5, T6, T7, T8, T9>& val) const {
return std::hash<T0>()(std::get<0>(val)) ^ std::hash<T1>()(std::get<1>(val))
^ std::hash<T2>()(std::get<2>(val)) ^ std::hash<T3>()(std::get<3>(val))
^ std::hash<T4>()(std::get<4>(val)) ^ std::hash<T5>()(std::get<5>(val))
^ std::hash<T6>()(std::get<6>(val)) ^ std::hash<T7>()(std::get<7>(val))
^ std::hash<T8>()(std::get<8>(val)) ^ std::hash<T9>()(std::get<9>(val));
}
};

} // namespace std

#endif // ONEFLOW_CORE_COMMON_TUPLE_HASH_H_

0 comments on commit 3ba9155

Please sign in to comment.