Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions aten/src/ATen/miopen/Descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,11 @@ struct TORCH_CUDA_CPP_API ConvolutionDescriptor
}
};

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_CUDA_CPP_API DropoutDescriptor
: public Descriptor<
miopenDropoutDescriptor,
&miopenCreateDropoutDescriptor,
&miopenDestroyDropoutDescriptor> {
struct DropoutDescriptor
: public Descriptor<miopenDropoutDescriptor,
&miopenCreateDropoutDescriptor,
&miopenDestroyDropoutDescriptor>
{
void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
Expand All @@ -143,9 +142,14 @@ struct TORCH_CUDA_CPP_API RNNDescriptor
&miopenDestroyRNNDescriptor>
{
void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode,
miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
}

void setWithDropout(DropoutDescriptor& dropout_desc, int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction,
miopenRNNMode_t rnn_mode, miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
MIOPEN_CHECK(miopenSetRNNDescriptor_V2(mut_desc(), hidden_size, num_layers, dropout_desc.mut_desc(), input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
}
};

union Constant
Expand Down
91 changes: 85 additions & 6 deletions aten/src/ATen/native/miopen/RNN_miopen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ namespace at::native {

#include <ATen/TensorUtils.h>

#include <c10/hip/HIPCachingAllocator.h>

#include <rocrand/rocrand_xorwow.h>

#include <functional>
#include <iterator>
#include <sstream>
Expand All @@ -66,12 +70,35 @@ namespace at::native {
#include <stdint.h>
#include <unordered_map>

namespace at { namespace native {
namespace at::native {

namespace {

struct DropoutState {
DropoutState(size_t size) : size(size), data(NULL) {
data = c10::hip::HIPCachingAllocator::raw_alloc(size);
}
DropoutState(const DropoutState&) = delete;
DropoutState(DropoutState&&) = default;
DropoutState& operator=(DropoutState&&) = default;
~DropoutState() {
if (data) {
c10::hip::HIPCachingAllocator::raw_delete(data);
}
}

size_t size;
void* data;
};

} // anonymous

//RNNDescriptor.
struct RNNDescriptorParams {
int64_t hidden_size;
int64_t num_layers;
double dropout_rate;
uint64_t dropout_seed;
miopenRNNDirectionMode_t direction;
miopenRNNMode_t rnn_mode;
miopenDataType_t datatype;
Expand Down Expand Up @@ -114,6 +141,12 @@ struct RNNDescriptorParams {
}
}

void set_dropout(double dropout_rate, uint64_t dropout_seed = 0) {
this->dropout_rate = dropout_rate;
// TODO: Implement seed setting for RNN dropout
this->dropout_seed = dropout_seed;
}

void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, miopenDataType_t datatype, miopenRNNBiasMode_t bias_mode) {
this->set_mode(mode);
this->hidden_size = hidden_size;
Expand All @@ -128,12 +161,18 @@ struct RNNDescriptorParams {
rnn_desc.set(hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
return rnn_desc;
}

RNNDescriptor descriptorWithDropout(DropoutDescriptor& dropout_desc) const {
RNNDescriptor rnn_desc;
rnn_desc.setWithDropout(dropout_desc, hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algo, datatype);
return rnn_desc;
}
};

//TensorDescriptor list.
std::vector<TensorDescriptor> rnn_descriptor_sequence(const Tensor& tensor, IntArrayRef batch_sizes) {
std::vector<TensorDescriptor> descriptors(batch_sizes.size());
size_t i =0;
size_t i = 0;

auto batch_tensor_size = tensor.sizes().vec();
for (auto batch_size : batch_sizes) {
Expand Down Expand Up @@ -204,6 +243,8 @@ struct RNNParams {

struct RNNDescriptors {
RNNDescriptor rnn_desc;
static thread_local DropoutDescriptor dropout_desc;
static thread_local std::unique_ptr<DropoutState> dropout_states;
std::vector<TensorDescriptor> x_descs;
std::vector<TensorDescriptor> y_descs;
TensorDescriptor hx_desc;
Expand All @@ -212,7 +253,39 @@ struct RNNDescriptors {
TensorDescriptor cy_desc;

RNNDescriptors(const RNNParams& fn, miopenHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) {
rnn_desc = fn.rnn.descriptor();
if (fn.rnn.dropout_rate == 0.0) {
rnn_desc = fn.rnn.descriptor();
} else {
if (!dropout_states) {
size_t states_size_in_bytes = 0;
MIOPEN_CHECK(miopenDropoutGetStatesSize(handle, &states_size_in_bytes));
size_t states_size = states_size_in_bytes / sizeof(rocrand_state_xorwow);

dropout_states = std::make_unique<DropoutState>(states_size * sizeof(rocrand_state_xorwow));

dropout_desc.set(handle,
fn.rnn.dropout_rate,
dropout_states->data,
dropout_states->size,
fn.rnn.dropout_seed,
false,
false,
miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW);
} else {
dropout_desc.restore(handle,
fn.rnn.dropout_rate,
dropout_states->data,
dropout_states->size,
fn.rnn.dropout_seed,
// use_mask flag must be true in order to continue from a saved RNG state
true,
false,
miopenRNGType_t::MIOPEN_RNG_PSEUDO_XORWOW);
}

rnn_desc = fn.rnn.descriptorWithDropout(dropout_desc);
}

x_descs = fn.tensors.descriptors(x);
y_descs = fn.tensors.descriptors(y);
hx_desc.set(hx, 5);
Expand All @@ -239,6 +312,11 @@ struct RNNDescriptors {
}
};

// We need to store both the dropout descriptor and state thread locally to avoid multithreading issues
thread_local DropoutDescriptor RNNDescriptors::dropout_desc {};
// Each state is 0.75 MB so there is no problem in caching all of them for each thread
thread_local std::unique_ptr<DropoutState> RNNDescriptors::dropout_states { nullptr };

Tensor permute_wei_for_miopen(Tensor wei, int64_t mode)
{
if (mode < 2)
Expand Down Expand Up @@ -492,7 +570,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
auto handle = getMiopenHandle();
miopenRNNAlgo_t algo = miopenRNNdefault;
fn.rnn.set_algo(algo);

fn.rnn.set_dropout(fn_dropout);
RNNDescriptors descs(fn, handle, x, y, hx, cx);

FilterDescriptor w_desc;
Expand Down Expand Up @@ -551,7 +629,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
}

return std::make_tuple(output, hy, cy, reserve, weight_buf);

}

std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
Expand Down Expand Up @@ -626,6 +703,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(

miopenRNNAlgo_t algo = miopenRNNdefault;
fn.rnn.set_algo(algo);
fn.rnn.set_dropout(fn_dropout);
RNNDescriptors descs(fn, handle, x, y, hx, cx);

FilterDescriptor w_desc;
Expand Down Expand Up @@ -720,6 +798,7 @@ std::vector<Tensor> miopen_rnn_backward_weight(

miopenRNNAlgo_t algo = miopenRNNdefault;
fn.rnn.set_algo(algo);
fn.rnn.set_dropout(fn_dropout);
RNNDescriptors descs(fn, handle, x, y, hx, cx);

FilterDescriptor w_desc;
Expand Down Expand Up @@ -909,6 +988,6 @@ REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen)
REGISTER_CUDA_DISPATCH(lstm_packed_miopen_stub, &lstm_packed_miopen)

} // anonymous namespace
}} //namespace native.
} // namespace at::native

#endif