Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
VertexC committed Jul 5, 2021
1 parent 60c1bd2 commit f0481a9
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions oneflow/core/framework/random_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ const std::shared_ptr<DeviceGeneratorImpl<device_type>>& GetDefaultDeviceGenerat

template<DeviceType device_type>
Maybe<DeviceGeneratorImpl<device_type>> TryGetDeviceGenerator(
const std::shared_ptr<GeneratorImpl>& generator) {
if (auto auto_gen = std::dynamic_pointer_cast<AutoGeneratorImpl>(generator)) {
const std::shared_ptr<GeneratorImpl>& gen_impl) {
if (auto auto_gen = std::dynamic_pointer_cast<AutoGeneratorImpl>(gen_impl)) {
return auto_gen->template GetDeviceGenerator<device_type>();
}
auto device_gen = std::dynamic_pointer_cast<DeviceGeneratorImpl<device_type>>(generator);
auto device_gen = std::dynamic_pointer_cast<DeviceGeneratorImpl<device_type>>(gen_impl);
CHECK_NOTNULL_OR_RETURN(device_gen);
return device_gen;
}
Expand All @@ -113,5 +113,8 @@ template Maybe<DeviceGeneratorImpl<DeviceType::kCPU>> TryGetDeviceGenerator(
template Maybe<DeviceGeneratorImpl<DeviceType::kGPU>> TryGetDeviceGenerator(
const std::shared_ptr<Generator>& generator);

template class DeviceGeneratorImpl<DeviceType::kCPU>;
template class DeviceGeneratorImpl<DeviceType::kGPU>;

} // namespace one
} // namespace oneflow

0 comments on commit f0481a9

Please sign in to comment.