Skip to content

Commit

Permalink
WIP Proof of concept
Browse files Browse the repository at this point in the history
  • Loading branch information
Ailing Zhang committed Jul 11, 2021
1 parent 8c4e781 commit 83f647e
Show file tree
Hide file tree
Showing 16 changed files with 344 additions and 11 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/core/LegacyTypeDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ struct TORCH_API AutoNonVariableTypeMode {
c10::impl::ExcludeDispatchKeyGuard autograd_guard_;
};

struct TORCH_API AutoDispatchBelowFunc2{
AutoDispatchBelowFunc2() :
dispatch_key_guard_(c10::autograd_dispatch_keyset_with_Func2) {
}
// disable dispatch keys >= Func2
c10::impl::ExcludeDispatchKeyGuard dispatch_key_guard_;
};

/* Note [AutoDispatchBelowADInplaceOrView]
* AutoDispatchBelowADInplaceOrView is equivalent to AutoNonVariableTypeMode
* before we split inplace & view ops out of VariableType kernel.
Expand Down
79 changes: 79 additions & 0 deletions aten/src/ATen/core/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,42 @@ const Tensor& Tensor::_base() const {
return impl::GetVariableHooks()->base(*this);
}

bool Tensor::is_up_to_date() const {
if (alias_) {
return generation_ == alias_->generation();
}
return true;
}

void Tensor::add_update(const at::Tensor& updated_val, std::vector<at::ViewMeta> metas) {
alias_->add_update(updated_val, metas);
}

void Tensor::sync_() {
if (is_up_to_date()) {
return;
}
// Apply all updates on alias_
alias_->SyncUpdateOperations();
// Reapply views to Get the viewed tensor from updated base in alias_
auto t = alias_->base();
for (auto& view_meta: view_metas_) {
switch (view_meta.view_type) {
case ViewMeta::Type::kReshape:
t = t.view_copy(view_meta.size);
break;
case ViewMeta::Type::kNoOp:
break;
default:
TORCH_CHECK(false, "Other types are not supported yet.");
}
}
// Note this goes back to dispatcher but set_ is simply redispatch
// at Func2. (fallback kernel materializes tensors before redispatch)
this->set_(t);
generation_ = alias_->generation();
}

const std::string& Tensor::name() const {
return impl::GetVariableHooks()->name(*this);
}
Expand All @@ -115,8 +151,51 @@ void Tensor::remove_hook(unsigned pos) const {
impl::GetVariableHooks()->remove_hook(*this, pos);
}

bool Tensor::is_alias_of(const at::Tensor& other) const {
// If self and other are the same
if (unsafeGetTensorImpl() == other.unsafeGetTensorImpl()) return true;
// For tensors without storage, check alias_ information
if (has_view_meta()) {
return alias_->base().unsafeGetTensorImpl() == other.unsafeGetTensorImpl();
}
return impl_->storage().is_alias_of(other.storage());
}

unsigned Tensor::_register_hook(std::function<Tensor(const Tensor&)> hook) const {
return impl::GetVariableHooks()->_register_hook(*this, std::move(hook));
}

const at::Tensor Alias::base() const {
return base_;
}

void Alias::add_update(const at::Tensor& updated_val, std::vector<at::ViewMeta> metas) {
updates_.push_back({updated_val, metas});
generation_++;
}

void Alias::apply_update(const Update& update) {
// TODO: Should handle more kinds of view ops. Only do kReshape now.
at::Tensor t = update.new_val;
for(int i = update.view_metas.size()-1; i >= 0; --i) {
switch (update.view_metas[i].view_type) {
case ViewMeta::Type::kReshape:
t = t.view_copy(update.view_metas[i].source_size);
break;
case ViewMeta::Type::kNoOp:
break;
default:
TORCH_CHECK(false, "Other types are not supported yet.");
}
}
base_.set_(t);
}

void Alias::SyncUpdateOperations() {
for (auto& update_data: updates_) {
apply_update(update_data);
}
updates_.clear();
}

} // namespace at
39 changes: 39 additions & 0 deletions aten/src/ATen/core/VariableFallbackKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,45 @@ TORCH_LIBRARY_IMPL(_, AutogradMLC, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}

namespace {
void func2Fallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
const auto arguments_begin = stack->size() - num_arguments;
auto arguments = torch::jit::last(stack, num_arguments);
for (int64_t idx = 0; idx < num_arguments; ++idx) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
at::Tensor t = ivalue.toTensor();
if (t.has_view_meta() && !t.is_up_to_date()) {
t.sync_();
}
auto materialized_ivalue = c10::IValue(t);
(*stack)[arguments_begin + idx] = std::move(materialized_ivalue);
} else if (ivalue.isTensorList()) {
std::vector<at::Tensor> tensors = ivalue.toTensorList().vec();
for (auto& t: tensors) {
if (t.has_view_meta()) {
t.sync_();
}
}
auto materialized_ivalue= c10::IValue(c10::List<at::Tensor>(tensors));
(*stack)[arguments_begin + idx] = std::move(materialized_ivalue);
}
}
{
at::AutoDispatchBelowFunc2 guard;
// redispatchBoxed with specified dispatchKeySet cannot prevent composite kernels
// called inside from going back up dispatcher. We still need the RAII guard here.
op.redispatchBoxed(dispatchKeySet & c10::after_ADInplaceOrView_keyset, stack);
}
}
}

TORCH_LIBRARY_IMPL(_, Func2, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&func2Fallback>());
}

// see Note [ADInplaceOrView key]
TORCH_LIBRARY_IMPL(_, ADInplaceOrView, m) {
m.fallback(torch::CppFunction::makeFallthrough());
Expand Down
64 changes: 64 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/core/DimVector.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/native/Copy.h>
#include <ATen/native/cpu/CatKernel.h>
#include <ATen/native/Resize.h>
Expand All @@ -21,6 +22,7 @@
#include <c10/util/irange.h>
#include <c10/util/Optional.h>
#include <c10/util/SmallVector.h>
#include <torch/library.h>

#include <algorithm>
#include <cstdint>
Expand Down Expand Up @@ -2165,6 +2167,10 @@ Tensor view(const Tensor& self, IntArrayRef size) {
return alias_with_sizes_and_strides(self, inferred_size, stride_value);
}

Tensor view_copy(const Tensor& self, IntArrayRef size) {
return self.view(size).clone();
}

Tensor alias(const Tensor& self) {
return alias_with_sizes_and_strides(self, self.sizes(), self.strides());
}
Expand Down Expand Up @@ -2425,3 +2431,61 @@ std::vector<Tensor> unflatten_dense_tensors(const Tensor& flat, TensorList tenso
}

}} // at::native

namespace at{
namespace Func {
// The following should be codegened for every **view** op.
at::Tensor view(const at::Tensor& self, at::IntArrayRef size) {
at::Tensor out;
{
at::AutoDispatchBelowFunc2 guard;
out = at::native::view_copy(self, size);
}
ViewMeta view_meta = ViewMeta(ViewMeta::Type::kReshape, size.vec(), self.sizes().vec());
// if self is already a view, copy its ViewMeta vector and push the current one.
if (self.has_view_meta()) {
auto metas = self.view_metas();
metas.push_back(view_meta);
out.set_view_meta(metas, self.get_alias());
} else {
std::shared_ptr<Alias> alias = std::make_shared<Alias>(const_cast<Tensor&>(self));
ViewMeta base_view_info(ViewMeta::Type::kNoOp, self.sizes().vec(), self.sizes().vec());
const_cast<Tensor&>(self).set_view_meta(std::move(base_view_info), alias);
out.set_view_meta(view_meta, alias);
}
return out;
}

// The following should be codegened for every **inplace** op.
at::Tensor& add_(at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {
{
at::AutoDispatchBelowFunc2 guard;
self.add_(other, alpha);
}
// if self is view, add this update to its alias update queue.
if (self.has_view_meta()) {
// TODO: if view_metas are the same, just replace the tensor.
self.add_update(self.clone(), self.view_metas());
}
return self;
}

// set_ is used to set self with a materialized tensor, so it
// need to skip materializing self tensor to break the loop.
at::Tensor& set_(at::Tensor& self, const at::Tensor& other) {
{
at::AutoDispatchBelowFunc2 guard;
self.set_(other);
}
return self;
}
} // namespace Func
} // namespace at

namespace {
TORCH_LIBRARY_IMPL(aten, Func2, m) {
m.impl("view", TORCH_FN(&at::Func::view));
m.impl("add_.Tensor", TORCH_FN(&at::Func::add_));
m.impl("set_.source_Tensor", TORCH_FN(&at::Func::set_));
}
}
7 changes: 7 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5464,6 +5464,13 @@
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: view
MkldnnCPU: mkldnn_view

- func: view_copy(Tensor self, int[] size) -> Tensor
variants: method
device_check: NoCheck
device_guard: False
dispatch:
CPU: view_copy

# Warning: If you want to change the name or overload name of this
# operator, you might also want to change the `isBlockListedSchema`
# function in `torch/csrc/jit/frontend/schema_catching.cpp`.
Expand Down
79 changes: 76 additions & 3 deletions aten/src/ATen/templates/TensorBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct Generator;
struct Type;
class DeprecatedTypeProperties;
class Tensor;
class Alias;
} // namespace at
namespace at {
namespace indexing {
Expand Down Expand Up @@ -68,6 +69,28 @@ inline bool variable_excluded_from_dispatch() {
}
}

struct ViewMeta {
enum class Type {
kReshape,
kNoOp,
kInvalid,
};

ViewMeta() = default;
ViewMeta(Type view_type, std::vector<int64_t> size, std::vector<int64_t> source_size):
view_type(view_type),
size(std::move(size)),
source_size(std::move(source_size)) {}
bool operator==(const ViewMeta& ref) const {
return view_type == ref.view_type && size == ref.size && source_size == ref.source_size;
}

Type view_type = Type::kInvalid;
std::vector<int64_t> size;
std::vector<int64_t> source_size;
};


// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which
// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr.
//
Expand Down Expand Up @@ -247,10 +270,16 @@ class TORCH_API Tensor {

Tensor& operator=(const Tensor& x) & {
impl_ = x.impl_;
generation_ = x.generation_;
//view_metas_ = x.view_metas_;
//alias_ = x.alias_;
return *this;
}
Tensor& operator=(Tensor&& x) & {
impl_ = std::move(x.impl_);
generation_ = x.generation_;
//view_metas_ = std::move(x.view_metas_);
//alias_ = std::move(x.alias_);
return *this;
}

Expand Down Expand Up @@ -363,9 +392,8 @@ class TORCH_API Tensor {
const Storage& storage() const {
return impl_->storage();
}
bool is_alias_of(const at::Tensor& other) const{
return impl_->storage().is_alias_of(other.storage());
}
bool is_alias_of(const at::Tensor& other) const;

Tensor toType(ScalarType t) const;
Tensor toBackend(Backend b) const;

Expand Down Expand Up @@ -922,6 +950,29 @@ class TORCH_API Tensor {
/// `Variable` is not a view, throw a `std::runtime_error`.
const Tensor& _base() const;

// We already have a Tensor method called alias...
const std::shared_ptr<Alias>& get_alias() const { return alias_; }
void set_view_meta(at::ViewMeta meta, std::shared_ptr<Alias> alias) {
view_metas_.push_back(std::move(meta));
alias_ = std::move(alias);
}
void set_view_meta(std::vector<at::ViewMeta> meta, std::shared_ptr<Alias> alias) {
view_metas_ = std::move(meta);
alias_ = std::move(alias);
}


bool has_view_meta() const {
auto res = !view_metas_.empty() || alias_;
return res;
}
std::vector<ViewMeta> view_metas() const {
return view_metas_;
}
bool is_up_to_date() const ;
void sync_();
void add_update(const at::Tensor& updated_val, std::vector<at::ViewMeta> metas);

// Miscellaneous
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -932,8 +983,30 @@ class TORCH_API Tensor {

void enforce_invariants();
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
size_t generation_ = 0;
std::vector<ViewMeta> view_metas_;
std::shared_ptr<Alias> alias_;
};

class Alias {
public:
struct Update {
at::Tensor new_val;
std::vector<ViewMeta> view_metas;
};
explicit Alias(at::Tensor& base) : base_(base) {}
const at::Tensor base() const;
size_t generation() const { return generation_; }
void add_update(const at::Tensor& updated_val, std::vector<at::ViewMeta> metas);
void apply_update(const Update& update);
void SyncUpdateOperations();
private:
at::Tensor base_;
std::vector<Update> updates_;
size_t generation_ = 0;
};


// For "multiple ... operators specified" warnings, closing brace of class
// declaration must be included between pragma push & pop
#ifdef _MSC_VER
Expand Down
3 changes: 3 additions & 0 deletions c10/core/DispatchKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ const char* toString(DispatchKey t) {
case DispatchKey::Named:
return "Named";

case DispatchKey::Func2:
return "Func2";

case DispatchKey::Tracer:
return "Tracer";

Expand Down
Loading

0 comments on commit 83f647e

Please sign in to comment.