Skip to content

Commit

Permalink
handle exeception given invalid device
Browse files Browse the repository at this point in the history
  • Loading branch information
VertexC committed Jul 2, 2021
1 parent bf58b61 commit 04a9797
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 27 deletions.
7 changes: 5 additions & 2 deletions oneflow/api/python/framework/random_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ namespace oneflow {

ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("manual_seed", [](uint64_t seed) { return one::manual_seed(seed); });
m.def("create_generator",
[](const std::string& device) { return one::Generator::New(device).GetPtrOrThrow(); });
m.def("create_generator", [](const std::string& device, uint64_t seed) {
return one::Generator::New(device, seed).GetPtrOrThrow();
});
py::class_<one::Generator, std::shared_ptr<one::Generator>>(m, "Generator")
.def(py::init<std::string>())
.def(py::init<std::string, const int64_t>())
.def("manual_seed", &one::Generator::set_seed)
.def("initial_seed", &one::Generator::get_seed);
}
Expand Down
24 changes: 24 additions & 0 deletions oneflow/core/framework/random_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,29 @@ const std::shared_ptr<AutoGeneratorImpl>& GetDefaultAutoGenerator() {
return generator;
}

template<DeviceType device_type>
std::shared_ptr<DeviceGeneratorImpl<device_type>> CreateDeviceGenerator(uint64_t seed) {
return std::make_shared<DeviceGeneratorImpl<device_type>>(seed);
}

template<DeviceType device_type>
const std::shared_ptr<DeviceGeneratorImpl<device_type>>& GetDefaultDeviceGenerator() {
static auto generator = CreateDeviceGenerator<device_type>(getNonDeterministicRandom());
return generator;
}

template<DeviceType device_type>
const Maybe<DeviceGeneratorImpl<device_type>> TryGetDeviceGenerator(
const std::shared_ptr<GeneratorImpl>& generator) {
if (generator->device_type() == "auto") {
const auto auto_gen = std::dynamic_pointer_cast<AutoGeneratorImpl>(generator);
CHECK_NOTNULL_OR_RETURN(auto_gen);
return auto_gen->template GetDeviceGenerator<device_type>();
}
const auto device_gen = std::dynamic_pointer_cast<DeviceGeneratorImpl<device_type>>(generator);
CHECK_NOTNULL_OR_RETURN(device_gen);
return device_gen;
}

} // namespace one
} // namespace oneflow
1 change: 0 additions & 1 deletion oneflow/core/framework/random_generator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ void DeviceGeneratorImpl<DeviceType::kGPU>::CudaRandInit(uint64_t seed) {
DeviceGeneratorImpl<DeviceType::kGPU>::DeviceGeneratorImpl(uint64_t seed)
: GeneratorImpl(seed, "cuda") {
cudaDeviceProp prop;
// FIXME: will this cause issue by always using cuda:0's property?
OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, 0));
block_num_ = prop.multiProcessorCount;
thread_num_ = GetThreadNum(prop);
Expand Down
39 changes: 17 additions & 22 deletions oneflow/core/framework/random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,18 @@ class DeviceGeneratorImpl<DeviceType::kGPU> : public GeneratorImpl {

class Generator final {
public:
explicit Generator(std::string device, uint64_t seed) { init(device, seed); }
explicit Generator(std::string device) { init(device, default_rng_seed_val); }

void init(std::string device, uint64_t seed) {
Generator() = default;
static Maybe<Generator> New(const std::string& device) {
std::shared_ptr<Generator> generator(new Generator);
JUST(generator->Init(device, default_rng_seed_val));
return generator;
}
static Maybe<Generator> New(const std::string& device, uint64_t seed) {
std::shared_ptr<Generator> generator(new Generator);
JUST(generator->Init(device, seed));
return generator;
}
Maybe<void> Init(const std::string& device, uint64_t seed) {
if (device == "cpu") {
gen_impl_ = std::make_shared<DeviceGeneratorImpl<DeviceType::kCPU>>(seed);
}
Expand All @@ -139,8 +147,9 @@ class Generator final {
else if (device == "auto") {
gen_impl_ = std::make_shared<AutoGeneratorImpl>(seed);
} else {
UNIMPLEMENTED() << " device unimplemented, device name: " << device;
UNIMPLEMENTED_THEN_RETURN() << " device unimplemented, device name: " << device;
}
return Maybe<void>::Ok();
}

void set_seed(uint64_t seed) { gen_impl_->set_seed(seed); }
Expand All @@ -162,28 +171,14 @@ std::shared_ptr<AutoGeneratorImpl> CreateAutoGenerator(uint64_t seed);
const std::shared_ptr<AutoGeneratorImpl>& GetDefaultAutoGenerator();

template<DeviceType device_type>
std::shared_ptr<DeviceGeneratorImpl<device_type>> CreateDeviceGenerator(uint64_t seed) {
return std::make_shared<DeviceGeneratorImpl<device_type>>(seed);
}
std::shared_ptr<DeviceGeneratorImpl<device_type>> CreateDeviceGenerator(uint64_t seed);

template<DeviceType device_type>
const std::shared_ptr<DeviceGeneratorImpl<device_type>>& GetDefaultDeviceGenerator() {
static auto generator = CreateDeviceGenerator<device_type>(getNonDeterministicRandom());
return generator;
}
const std::shared_ptr<DeviceGeneratorImpl<device_type>>& GetDefaultDeviceGenerator();

template<DeviceType device_type>
const Maybe<DeviceGeneratorImpl<device_type>> TryGetDeviceGenerator(
const std::shared_ptr<GeneratorImpl>& generator) {
if (generator->device_type() == "auto") {
const auto auto_gen = std::dynamic_pointer_cast<AutoGeneratorImpl>(generator);
CHECK_NOTNULL_OR_RETURN(auto_gen);
return auto_gen->template GetDeviceGenerator<device_type>();
}
const auto device_gen = std::dynamic_pointer_cast<DeviceGeneratorImpl<device_type>>(generator);
CHECK_NOTNULL_OR_RETURN(device_gen);
return device_gen;
}
const std::shared_ptr<GeneratorImpl>& generator);

} // namespace one
} // namespace oneflow
Expand Down
4 changes: 2 additions & 2 deletions oneflow/python/framework/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def MakeGenerator(device=None, seed=None):
if device is None:
device = "auto"
if seed is None:
return oneflow._oneflow_internal.Generator(device)
return oneflow._oneflow_internal.create_generator(device)
else:
return oneflow._oneflow_internal.Generator(device, seed)
return oneflow._oneflow_internal.create_generator(device, seed)


@oneflow_export("manual_seed")
Expand Down
3 changes: 3 additions & 0 deletions oneflow/python/test/generator/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def test_different_devices(test_case):
cpu_gen = flow.Generator(device="cpu")
test_case.assertTrue(auto_gen.initial_seed(), cuda_gen.initial_seed())
test_case.assertTrue(auto_gen.initial_seed(), cpu_gen.initial_seed())
with test_case.assertRaises(Exception) as context:
flow.Generator(device="invalid")
test_case.assertTrue("unimplemented" in str(context.exception))

def test_global_manual_seed(test_case):
flow.manual_seed(10)
Expand Down

0 comments on commit 04a9797

Please sign in to comment.