Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Random Generator] Part1: Dev random generator #5360

Merged
merged 15 commits into from
Jul 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 38 additions & 0 deletions oneflow/api/python/framework/random_generator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/random_generator.h"

namespace py = pybind11;

namespace oneflow {

ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<one::Generator, std::shared_ptr<one::Generator>>(m, "Generator")
.def("manual_seed", &one::Generator::set_current_seed)
.def("initial_seed", &one::Generator::current_seed);

m.def("manual_seed", [](uint64_t seed) { return one::ManualSeed(seed); });
m.def("create_generator",
[](const std::string& device) { return one::Generator::New(device).GetPtrOrThrow(); });
m.def("create_generator", [](const std::string& device, uint64_t seed) {
return one::Generator::New(device, seed).GetPtrOrThrow();
});
}

} // namespace oneflow
100 changes: 100 additions & 0 deletions oneflow/core/framework/random_generator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/random_generator.h"

namespace oneflow {
namespace one {

uint64_t getNonDeterministicRandom() {
std::random_device rd;
// limit to 53 bits to ensure unique representation in double
auto s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
return s;
}

/*static*/ Maybe<Generator> Generator::New(const std::string& device, uint64_t seed) {
std::shared_ptr<Generator> generator(new Generator);
JUST(generator->Init(device, seed));
return generator;
}

Maybe<void> Generator::Init(const std::string& device, uint64_t seed) {
if (device == "cpu") {
gen_impl_ = std::make_shared<DeviceGeneratorImpl<DeviceType::kCPU>>(seed);
}
#ifdef WITH_CUDA
else if (device == "cuda") {
gen_impl_ = std::make_shared<DeviceGeneratorImpl<DeviceType::kGPU>>(seed);
}
#endif // WITH_CUDA
else if (device == "auto") {
gen_impl_ = std::make_shared<AutoGeneratorImpl>(seed);
} else {
UNIMPLEMENTED_THEN_RETURN() << " device unimplemented, device name: " << device;
}
return Maybe<void>::Ok();
}

uint64_t Generator::seed() {
uint64_t seed = getNonDeterministicRandom();
set_current_seed(seed);
return seed;
}

void ManualSeed(uint64_t seed) {
#ifdef WITH_CUDA
const auto& cuda_gen = GetDefaultDeviceGenerator<DeviceType::kGPU>();
cuda_gen->set_current_seed(seed);
#endif // WITH_CUDA
const auto& cpu_gen = GetDefaultDeviceGenerator<DeviceType::kCPU>();
cpu_gen->set_current_seed(seed);
const auto& auto_gen = GetDefaultAutoGenerator();
auto_gen->set_current_seed(seed);
}

std::shared_ptr<AutoGeneratorImpl> CreateAutoGenerator(uint64_t seed) {
return std::make_shared<AutoGeneratorImpl>(seed);
}

template<DeviceType device_type>
std::shared_ptr<DeviceGeneratorImpl<device_type>> CreateDeviceGenerator(uint64_t seed) {
return std::make_shared<DeviceGeneratorImpl<device_type>>(seed);
}

const std::shared_ptr<AutoGeneratorImpl>& GetDefaultAutoGenerator() {
static auto generator = CreateAutoGenerator(getNonDeterministicRandom());
return generator;
}

template<DeviceType device_type>
const std::shared_ptr<DeviceGeneratorImpl<device_type>>& GetDefaultDeviceGenerator() {
static auto generator = CreateDeviceGenerator<device_type>(getNonDeterministicRandom());
return generator;
}

template<DeviceType device_type>
Maybe<DeviceGeneratorImpl<device_type>> TryGetDeviceGenerator(
const std::shared_ptr<GeneratorImpl>& generator) {
if (auto auto_gen = std::dynamic_pointer_cast<AutoGeneratorImpl>(generator)) {
return auto_gen->template GetDeviceGenerator<device_type>();
}
auto device_gen = std::dynamic_pointer_cast<DeviceGeneratorImpl<device_type>>(generator);
CHECK_NOTNULL_OR_RETURN(device_gen);
return device_gen;
}

} // namespace one
} // namespace oneflow
72 changes: 72 additions & 0 deletions oneflow/core/framework/random_generator.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/framework/random_generator.h"

namespace oneflow {
namespace one {

namespace {

int GetThreadNum(const cudaDeviceProp& prop) {
switch (prop.major) {
case 3: // Kepler
return 2 * 192;
case 5: // Maxwell
return 2 * 128;
case 6: // Pascal
if ((prop.minor == 1) || (prop.minor == 2)) {
return 2 * 128;
} else {
return 2 * 64;
}
case 7: // Volta and Turing
return 2 * 64;
default: return 2 * 64;
}
}

__global__ void SetupKernel(uint64_t seed, curandState* state) {
const int id = blockIdx.x * blockDim.x + threadIdx.x;
size_t local_seed = (static_cast<size_t>(seed) + 0x9e3779b9U + (static_cast<size_t>(id) << 6U)
+ (static_cast<size_t>(id) >> 2U));
curand_init(local_seed, 0, 0, &state[id]);
}

} // namespace

void DeviceGeneratorImpl<DeviceType::kGPU>::CudaRandInit(uint64_t seed) {
SetupKernel<<<block_num_, thread_num_>>>(seed, curand_states_);
}

DeviceGeneratorImpl<DeviceType::kGPU>::DeviceGeneratorImpl(uint64_t seed)
: GeneratorImpl(seed, "cuda") {
cudaDeviceProp prop;
OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, 0));
block_num_ = prop.multiProcessorCount;
thread_num_ = GetThreadNum(prop);
OF_CUDA_CHECK(cudaMalloc(&curand_states_, block_num_ * thread_num_ * sizeof(curandState)));
CudaRandInit(seed);
}

DeviceGeneratorImpl<DeviceType::kGPU>::~DeviceGeneratorImpl() {
OF_CUDA_CHECK(cudaFree(curand_states_));
}

template class DeviceGeneratorImpl<DeviceType::kGPU>;

} // namespace one
} // namespace oneflow
170 changes: 170 additions & 0 deletions oneflow/core/framework/random_generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_
#define ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_

#include "oneflow/core/common/data_type.h"
#include "oneflow/core/device/device_context.h"
#include <fcntl.h>
#include <unistd.h>
#include <unordered_map>

#ifdef WITH_CUDA
#include <curand.h>
#include <curand_kernel.h>
#endif // WITH_CUDA

namespace oneflow {
namespace one {

class GeneratorImpl {
public:
GeneratorImpl() = default;
VertexC marked this conversation as resolved.
Show resolved Hide resolved
GeneratorImpl(const uint64_t& seed, const std::string& device_type)
: seed_(seed), device_type_(device_type) {}

virtual ~GeneratorImpl() = default;

virtual void set_current_seed(uint64_t seed) = 0;
uint64_t current_seed() const { return seed_; }

const std::string& device_type() const { return device_type_; }

protected:
uint64_t seed_;
std::string device_type_;
};

template<DeviceType device_type>
class DeviceGeneratorImpl;

template<>
class DeviceGeneratorImpl<DeviceType::kCPU> : public GeneratorImpl {
public:
DeviceGeneratorImpl(uint64_t seed) : GeneratorImpl(seed, "cpu"), mt19937_generator_(seed) {}

virtual ~DeviceGeneratorImpl() = default;

VertexC marked this conversation as resolved.
Show resolved Hide resolved
void set_current_seed(uint64_t seed) override {
seed_ = seed;
mt19937_generator_.seed(seed_);
}

std::mt19937& generator() { return mt19937_generator_; }

public:
std::mt19937 mt19937_generator_;
};

#ifdef WITH_CUDA
template<>
class DeviceGeneratorImpl<DeviceType::kGPU> : public GeneratorImpl {
public:
DeviceGeneratorImpl(uint64_t seed);

virtual ~DeviceGeneratorImpl();

const int32_t& block_num() const { return block_num_; }
const int32_t& thread_num() const { return thread_num_; }

curandState* curand_states() const { return curand_states_; }

void set_current_seed(uint64_t seed) override {
seed_ = seed;
CudaRandInit(seed_);
}

private:
void CudaRandInit(uint64_t seed);

int32_t block_num_;
int32_t thread_num_;
curandState* curand_states_;
};
#endif // WITH_CUDA

class AutoGeneratorImpl : public GeneratorImpl {
public:
AutoGeneratorImpl(uint64_t seed) : GeneratorImpl(seed, "auto") {}

void set_current_seed(uint64_t seed) override {
seed_ = seed;
for (const auto& it : generators_) { it.second->set_current_seed(seed); }
}

template<DeviceType device_type>
Maybe<DeviceGeneratorImpl<device_type>> GetDeviceGenerator() {
CHECK_OR_RETURN(device_type != DeviceType::kInvalidDevice);
auto it = generators_.find(device_type);
if (it == generators_.end()) {
it = generators_
.emplace(device_type, std::make_shared<DeviceGeneratorImpl<device_type>>(seed_))
.first;
}
return std::dynamic_pointer_cast<DeviceGeneratorImpl<device_type>>(it->second);
}
VertexC marked this conversation as resolved.
Show resolved Hide resolved

private:
std::unordered_map<DeviceType, std::shared_ptr<GeneratorImpl>, std::hash<int>> generators_;
};

class Generator final {
public:
// The default seed is selected to be a large number
// with good distribution of 0s and 1s in bit representation
static constexpr uint64_t default_rng_seed_val = 67280421310721;

public:
Generator() = default;

Maybe<void> Init(const std::string& device, uint64_t seed);

static Maybe<Generator> New(const std::string& device) {
return New(device, default_rng_seed_val);
}
static Maybe<Generator> New(const std::string& device, uint64_t seed);

void set_current_seed(uint64_t seed) { gen_impl_->set_current_seed(seed); }

uint64_t current_seed() const { return gen_impl_->current_seed(); }

// Reset current seed by the default seed, and returns it.
uint64_t seed();

private:
std::shared_ptr<GeneratorImpl> gen_impl_;
};

void ManualSeed(uint64_t seed);

const std::shared_ptr<AutoGeneratorImpl>& GetDefaultAutoGenerator();

std::shared_ptr<AutoGeneratorImpl> CreateAutoGenerator(uint64_t seed);

template<DeviceType device_type>
std::shared_ptr<DeviceGeneratorImpl<device_type>> CreateDeviceGenerator(uint64_t seed);

template<DeviceType device_type>
const std::shared_ptr<DeviceGeneratorImpl<device_type>>& GetDefaultDeviceGenerator();

template<DeviceType device_type>
Maybe<DeviceGeneratorImpl<device_type>> TryGetDeviceGenerator(
const std::shared_ptr<GeneratorImpl>& generator);

} // namespace one
} // namespace oneflow

#endif // ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_