Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Random Generator] Part1: Dev random generator #5360

Merged
merged 15 commits into from
Jul 4, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 34 additions & 0 deletions oneflow/api/python/framework/random_generator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/random_generator.h"

namespace py = pybind11;

namespace oneflow {

ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("manual_seed", [](uint64_t seed) { return one::manual_seed(seed); });
VertexC marked this conversation as resolved.
Show resolved Hide resolved
py::class_<one::Generator, std::shared_ptr<one::Generator>>(m, "Generator")
.def(py::init<std::string>())
VertexC marked this conversation as resolved.
Show resolved Hide resolved
.def(py::init<std::string, const int64_t>())
VertexC marked this conversation as resolved.
Show resolved Hide resolved
.def("manual_seed", &one::Generator::set_seed)
.def("initial_seed", &one::Generator::get_seed);
}

} // namespace oneflow
64 changes: 64 additions & 0 deletions oneflow/core/framework/random_generator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/random_generator.h"

namespace oneflow {
namespace one {

/**
* [Copy from pytorch]
* Gets a non deterministic random number number from either the
* /dev/urandom or the current time. For CUDA, gets random from
* std::random_device and adds a transformation on it.
*
* FIXME: The behavior in this function is from legacy code
* (THRandom_seed/THCRandom_seed) and is probably not the right thing to do,
* even though our tests pass. Figure out if tests get perturbed
* - when the same algorithm is used for all backends. Note that the current
* behavior is different for CPU, CUDA and Windows CPU.
* - when using C++11 std objects, such as std::random_device
* - when constructing a 64 bit seed properly, rather than static casting
* a 32 bit number to 64 bit.
*/
uint64_t getNonDeterministicRandom() {
std::random_device rd;
// limit to 53 bits to ensure unique representation in double
auto s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
return s;
}

void manual_seed(uint64_t seed) {
#ifdef WITH_CUDA
const auto& cuda_gen = GetDefaultDeviceGenerator<DeviceType::kGPU>();
cuda_gen->set_seed(seed);
#endif
const auto& cpu_gen = GetDefaultDeviceGenerator<DeviceType::kCPU>();
cpu_gen->set_seed(seed);
const auto& auto_gen = GetDefaultAutoGenerator();
auto_gen->set_seed(seed);
}

const std::shared_ptr<AutoGeneratorImpl> CreateAutoGenerator(uint64_t seed) {
return std::make_shared<AutoGeneratorImpl>(seed);
}

const std::shared_ptr<AutoGeneratorImpl>& GetDefaultAutoGenerator() {
static auto generator = CreateAutoGenerator(getNonDeterministicRandom());
return generator;
}

} // namespace one
} // namespace oneflow
74 changes: 74 additions & 0 deletions oneflow/core/framework/random_generator.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/framework/random_generator.h"

namespace oneflow {
namespace one {

namespace {

int GetThreadNum(const cudaDeviceProp& prop) {
switch (prop.major) {
case 3: // Kepler
return 2 * 192;
case 5: // Maxwell
return 2 * 128;
case 6: // Pascal
if ((prop.minor == 1) || (prop.minor == 2)) {
return 2 * 128;
} else {
return 2 * 64;
}
case 7: // Volta and Turing
return 2 * 64;
default: return 2 * 64;
}
}

__global__ void SetupKernel(uint64_t seed, curandState* state) {
const int id = blockIdx.x * blockDim.x + threadIdx.x;
size_t local_seed = (static_cast<size_t>(seed) + 0x9e3779b9U + (static_cast<size_t>(id) << 6U)
+ (static_cast<size_t>(id) >> 2U));
curand_init(local_seed, 0, 0, &state[id]);
}

} // namespace

void DeviceGeneratorImpl<DeviceType::kGPU>::CudaRandInit(uint64_t seed) {
SetupKernel<<<block_num_, thread_num_>>>(seed, curand_states_);
}

DeviceGeneratorImpl<DeviceType::kGPU>::DeviceGeneratorImpl(uint64_t seed) {
seed_ = seed;
device_type_ = DeviceType::kGPU;
cudaDeviceProp prop;
// FIXME: will this cause issue by always using cuda:0's property?
OF_CUDA_CHECK(cudaGetDeviceProperties(&prop, 0));
block_num_ = prop.multiProcessorCount;
thread_num_ = GetThreadNum(prop);
OF_CUDA_CHECK(cudaMalloc(&curand_states_, block_num_ * thread_num_ * sizeof(curandState)));
CudaRandInit(seed);
}

DeviceGeneratorImpl<DeviceType::kGPU>::~DeviceGeneratorImpl() {
OF_CUDA_CHECK(cudaFree(curand_states_));
}

template class DeviceGeneratorImpl<DeviceType::kGPU>;

} // namespace one
} // namespace oneflow
191 changes: 191 additions & 0 deletions oneflow/core/framework/random_generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_
#define ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_

#include "oneflow/core/common/data_type.h"
#include "oneflow/core/device/device_context.h"
#include <fcntl.h>
#include <unistd.h>

#ifdef WITH_CUDA
#include <curand.h>
#include <curand_kernel.h>
#endif

namespace oneflow {
namespace one {

uint64_t getNonDeterministicRandom();

void manual_seed(uint64_t seed);

// The default seed is selected to be a large number
// with good distribution of 0s and 1s in bit representation
constexpr uint64_t default_rng_seed_val = 67280421310721;

class GeneratorImpl {
public:
GeneratorImpl() = default;
VertexC marked this conversation as resolved.
Show resolved Hide resolved
virtual ~GeneratorImpl() = default;

virtual void set_seed(const uint64_t seed) = 0;
VertexC marked this conversation as resolved.
Show resolved Hide resolved
virtual uint64_t get_seed() const { return seed_; }
virtual const std::string& device_type() const { return device_type_; }
VertexC marked this conversation as resolved.
Show resolved Hide resolved

protected:
std::string device_type_;
uint64_t seed_;
};

template<DeviceType device_type>
class DeviceGeneratorImpl;

class AutoGeneratorImpl : public GeneratorImpl {
public:
AutoGeneratorImpl(uint64_t seed) {
VertexC marked this conversation as resolved.
Show resolved Hide resolved
seed_ = seed;
device_type_ = "auto";
}

void set_seed(const uint64_t seed) override {
seed_ = seed;
for (const auto& it : generators_) { it.second->set_seed(seed); }
}

VertexC marked this conversation as resolved.
Show resolved Hide resolved
template<DeviceType device_type>
Maybe<DeviceGeneratorImpl<device_type>> GetDeviceGenerator() {
CHECK_OR_RETURN(device_type != DeviceType::kInvalidDevice);
auto it = generators_.find(device_type);
if (it == generators_.end()) {
it = generators_
.emplace(device_type, std::make_shared<DeviceGeneratorImpl<device_type>>(seed_))
.first;
}
return std::dynamic_pointer_cast<DeviceGeneratorImpl<device_type>>(it->second);
}

private:
std::map<DeviceType, std::shared_ptr<GeneratorImpl>> generators_;
};

template<>
class DeviceGeneratorImpl<DeviceType::kCPU> : public GeneratorImpl {
public:
DeviceGeneratorImpl(uint64_t seed) : mt19937_generator_(seed) {
VertexC marked this conversation as resolved.
Show resolved Hide resolved
seed_ = seed;
device_type_ = DeviceType::kCPU;
}
virtual ~DeviceGeneratorImpl() = default;

VertexC marked this conversation as resolved.
Show resolved Hide resolved
void set_seed(const uint64_t seed) override {
seed_ = seed;
mt19937_generator_.seed(seed_);
}

std::mt19937& generator() { return mt19937_generator_; }

public:
std::mt19937 mt19937_generator_;
};

#ifdef WITH_CUDA
template<>
class DeviceGeneratorImpl<DeviceType::kGPU> : public GeneratorImpl {
public:
DeviceGeneratorImpl(uint64_t seed);
virtual ~DeviceGeneratorImpl();

const int32_t& block_num() const { return block_num_; }
const int32_t& thread_num() const { return thread_num_; }
curandState* curand_states() const { return curand_states_; }
void CudaRandInit(uint64_t seed);

void set_seed(const uint64_t seed) override {
seed_ = seed;
CudaRandInit(seed_);
}

protected:
curandState* curand_states_;
int32_t block_num_;
int32_t thread_num_;
};
#endif
VertexC marked this conversation as resolved.
Show resolved Hide resolved

class Generator final {
public:
explicit Generator(std::string device, uint64_t seed) { init(device, seed); }
VertexC marked this conversation as resolved.
Show resolved Hide resolved
explicit Generator(std::string device) { init(device, default_rng_seed_val); }

void init(std::string device, uint64_t seed) {
VertexC marked this conversation as resolved.
Show resolved Hide resolved
if (device == "cpu") {
gen_impl_ = std::make_shared<DeviceGeneratorImpl<DeviceType::kCPU>>(seed);
} else if (device == "cuda") {
gen_impl_ = std::make_shared<DeviceGeneratorImpl<DeviceType::kGPU>>(seed);
} else if (device == "auto") {
gen_impl_ = std::make_shared<AutoGeneratorImpl>(seed);
} else {
UNIMPLEMENTED() << " device unimplemented, device name: " << device;
VertexC marked this conversation as resolved.
Show resolved Hide resolved
}
}
VertexC marked this conversation as resolved.
Show resolved Hide resolved

void set_seed(const uint64_t seed) { gen_impl_->set_seed(seed); }

uint64_t get_seed() const { return gen_impl_->get_seed(); }
VertexC marked this conversation as resolved.
Show resolved Hide resolved

uint64_t seed() {
uint64_t seed = getNonDeterministicRandom();
set_seed(seed);
return seed;
}

private:
std::shared_ptr<GeneratorImpl> gen_impl_;
};

const std::shared_ptr<AutoGeneratorImpl> CreateAutoGenerator(uint64_t seed);
VertexC marked this conversation as resolved.
Show resolved Hide resolved

const std::shared_ptr<AutoGeneratorImpl>& GetDefaultAutoGenerator();

template<DeviceType device_type>
const std::shared_ptr<DeviceGeneratorImpl<device_type>> CreateDeviceGenerator(uint64_t seed) {
VertexC marked this conversation as resolved.
Show resolved Hide resolved
return std::make_shared<DeviceGeneratorImpl<device_type>>(seed);
}

template<DeviceType device_type>
const std::shared_ptr<DeviceGeneratorImpl<device_type>>& GetDefaultDeviceGenerator() {
static auto generator = CreateDeviceGenerator<device_type>(getNonDeterministicRandom());
return generator;
}

template<DeviceType device_type>
const Maybe<DeviceGeneratorImpl<device_type>> TryGetDeviceGenerator(
VertexC marked this conversation as resolved.
Show resolved Hide resolved
const std::shared_ptr<GeneratorImpl>& generator) {
if (generator->device_type() == "auto") {
const auto auto_gen = std::dynamic_pointer_cast<AutoGeneratorImpl>(generator);
CHECK_NOTNULL_OR_RETURN(auto_gen);
return auto_gen->template GetDeviceGenerator<device_type>();
}
const auto device_gen = std::dynamic_pointer_cast<DeviceGeneratorImpl<device_type>>(generator);
CHECK_NOTNULL_OR_RETURN(device_gen);
return device_gen;
}

} // namespace one
} // namespace oneflow

#endif // ONEFLOW_CORE_FRAMEWORK_RANDOM_GENERATOR_H_