Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions lazy_tensor_core/lazy_tensor_core/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <thread>
#include <vector>

#include "ATen/core/functional.h"
#include "lazy_tensor_core/csrc/aten_ltc_bridge.h"
#include "lazy_tensor_core/csrc/compiler/backend_impl_interface.h"
#include "lazy_tensor_core/csrc/device.h"
Expand All @@ -32,6 +33,8 @@
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/jit/python/pybind.h"
#include "torch/csrc/utils/cuda_lazy_init.h"
#include "lazy_tensor_core/csrc/ts_backend/ops/add.h"
#include "torch/torch.h"

namespace torch_lazy_tensors {
namespace {
Expand Down Expand Up @@ -413,6 +416,75 @@ void InitLtcModuleBindings(py::module m) {
};
return GetTensorsDump(tensors, coverter);
});
m.def("_dynamic_size",
[](at::Tensor& self) {
LazyTensor self_lazy_tensor = bridge::GetLtcTensor(self);
return bridge::AtenFromLtcTensor(
self_lazy_tensor.CreateFrom(ir::MakeNode<ir::ops::DynamicSize>(
self_lazy_tensor.GetIrValue())));
});
m.def("_dynamic_size2",
[](at::Tensor& self) {
LazyTensor self_lazy_tensor = bridge::GetLtcTensor(self);
return ir::MakeNode<ir::ops::DynamicSize>(self_lazy_tensor.GetIrValue());
});
m.def("_dynamic_expand2",
[](at::Tensor& self, std::shared_ptr<ir::Node> val) {
LazyTensor self_lazy_tensor = bridge::GetLtcTensor(self);
return bridge::AtenFromLtcTensor(
self_lazy_tensor.CreateFrom(ir::MakeNode<ir::ops::DynamicExpand>(
self_lazy_tensor.GetIrValue(),val)));
});
m.def("_add_dim",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i didn't anticipate that add_dim was dynamic. the 'dim' constant is the dynamic part, i guess? but how c ome there isn't any special handling of the 'dim' argument in this API?

[](at::Tensor& self, at::Tensor& other) {
LazyTensor self_lazy_tensor = bridge::GetLtcTensor(self);
LazyTensor other_lazy_tensor =
bridge::GetOrCreateLtcTensor(other, self_lazy_tensor.GetDevice());
return bridge::AtenFromLtcTensor(
self_lazy_tensor.CreateFrom(ir::MakeNode<ir::ops::AddDim>(
self_lazy_tensor.GetIrValue(), other_lazy_tensor.GetIrValue())));
});
m.def("_dynamic_expand",
[](at::Tensor& self, at::Tensor& other) {
LazyTensor self_lazy_tensor = bridge::GetLtcTensor(self);
LazyTensor other_lazy_tensor =
bridge::GetOrCreateLtcTensor(other, self_lazy_tensor.GetDevice());
return bridge::AtenFromLtcTensor(
self_lazy_tensor.CreateFrom(ir::MakeNode<ir::ops::DynamicExpand>(
self_lazy_tensor.GetIrValue(), other_lazy_tensor.GetIrValue())));
});
m.def("_dynamic_view",
[](std::vector<at::Tensor>& self_and_dims) {
auto self_lazy_tensor = bridge::GetLtcTensor(self_and_dims[0]);
auto ir_values = c10::fmap(self_and_dims, [&self_lazy_tensor](const at::Tensor& t) {
return bridge::GetOrCreateLtcTensor(t, self_lazy_tensor.GetDevice()).GetIrValue();
});
return bridge::AtenFromLtcTensor(
self_lazy_tensor.CreateFrom(ir::MakeNode<ir::ops::DynamicView>(ir_values)));
});
m.def("_dynamic_linear",
//TODO: figure out how to do optional bias
[](at::Tensor& self, at::Tensor& weight, at::Tensor& bias) {
LazyTensor self_lazy_tensor = bridge::GetLtcTensor(self);
LazyTensor weight_lazy_tensor =
bridge::GetOrCreateLtcTensor(weight, self_lazy_tensor.GetDevice());
LazyTensor bias_lazy_tensor =
bridge::GetOrCreateLtcTensor(bias, self_lazy_tensor.GetDevice());
return bridge::AtenFromLtcTensor(
self_lazy_tensor.CreateFrom(ir::MakeNode<ir::ops::DynamicLinear>(
self_lazy_tensor.GetIrValue(), weight_lazy_tensor.GetIrValue(), bias_lazy_tensor.GetIrValue())));
});
m.def("_dynamic_getitem",
[](at::Tensor& self, int index) {
LazyTensor self_lazy_tensor = bridge::GetLtcTensor(self);
auto at_index_ten = torch::tensor({index}, c10::TensorOptions(c10::kLong));
LazyTensor other_lazy_tensor =
bridge::GetOrCreateLtcTensor(at_index_ten, self_lazy_tensor.GetDevice());
return bridge::AtenFromLtcTensor(
self_lazy_tensor.CreateFrom(ir::MakeNode<ir::ops::DynamicGetItem>(
self_lazy_tensor.GetIrValue(), other_lazy_tensor.GetIrValue())));
});
// IrValueFromScalar
m.def("_get_ltc_tensors_text",
[](const std::vector<at::Tensor>& tensors) -> std::string {
auto coverter = [](lazy_tensors::Span<const ir::Node* const> nodes) {
Expand Down Expand Up @@ -491,6 +563,7 @@ void InitLtcModuleBindings(py::module m) {
});

py::class_<ir::Value, std::shared_ptr<ir::Value>>(m, "IrValue");
py::class_<ir::Node, std::shared_ptr<ir::Node>>(m, "IrNode");
m.def("_ltc_create_token",
[](const std::string& device) { return CreateToken(device); });
m.def("_ltc_all_reduce_inplace", [](const std::string& reduce_type,
Expand Down
110 changes: 110 additions & 0 deletions lazy_tensor_core/lazy_tensor_core/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@
#include "lazy_tensors/str_join.h"
#include "torch/csrc/autograd/variable.h"


#include "torch/torch.h"
#include "torch/csrc/jit/runtime/custom_operator.h"

int get_current_level() {
static const auto PRINT_VLOG = std::getenv("PRINT_VLOG");
if (PRINT_VLOG) {
return std::atoi(PRINT_VLOG);
}
return 3;
}

std::ostream& get_ostream(int level) {
static std::stringstream dummy{};

static const auto cur_level = get_current_level();
if (level >= cur_level) {
return std::cerr;
}
return dummy;
}

namespace torch_lazy_tensors {
namespace {

Expand Down Expand Up @@ -1618,6 +1640,11 @@ std::shared_ptr<LazyTensor::Async> LazyTensor::SyncTensorsGraphInternal(
&coll.indices);

PostOrderData po_data = RunPostOrder(*tensors, coll.indices);

for (auto n: po_data.post_order) {
LTC_VLOG(5) << "node = " << *n << std::endl;
}

coll.hash = lazy_tensors::util::HashCombine(
coll.hash, lazy_tensors::util::Hash(po_data.parameter_sequence));
LTC_VLOG(4) << "Parameter sequence graph hash "
Expand Down Expand Up @@ -1660,3 +1687,86 @@ lazy_tensors::uint64 LazyTensor::GetRunningSeed(const Device& device) {
}

} // namespace torch_lazy_tensors


// void DynamicSize(torch::jit::Stack* stack) {
// at::Tensor t = torch::jit::pop(stack).toTensor();
// torch::jit::push(stack, t.sizes());
// }

const torch::jit::RegisterOperators DynamicSizeOp({
torch::jit::Operator(
"aten::dynamic_size(Tensor a) -> Tensor",
[](const torch::jit::Node*) -> torch::jit::Operation {
return [](torch::jit::Stack* stack) {
auto t = torch::jit::pop(stack).toTensor();
auto sz_ten = torch::tensor(t.sizes(), c10::TensorOptions(c10::kLong));
std::cerr << sz_ten;
torch::jit::push(stack, sz_ten);
};
},
c10::AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
"prim::list_to_tensor(int[] a) -> Tensor",
[](const torch::jit::Node*) -> torch::jit::Operation {
return [](torch::jit::Stack* stack) {
auto sz_vec = torch::jit::pop(stack).toIntVector();
auto sz_ten = torch::tensor(sz_vec, c10::TensorOptions(c10::kLong));
std::cerr << sz_ten;
torch::jit::push(stack, sz_ten);
};
},
c10::AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
"prim::dim_to_tensor(int a) -> Tensor",
[](const torch::jit::Node*) -> torch::jit::Operation {
return [](torch::jit::Stack* stack) {
auto dim = torch::jit::pop(stack).toInt();
auto sz_ten = torch::tensor({dim}, c10::TensorOptions(c10::kLong));
std::cerr << sz_ten;
torch::jit::push(stack, sz_ten);
};
},
c10::AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
"prim::dim_to_tensor(...) -> int[]",
[](const torch::jit::Node*) -> torch::jit::Operation {
return [](torch::jit::Stack* stack) {
auto dim = torch::jit::pop(stack).toInt();
auto sz_ten = torch::tensor({dim}, c10::TensorOptions(c10::kLong));
std::cerr << sz_ten;
torch::jit::push(stack, sz_ten);
};
},
c10::AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
"prim::tensor_to_list(Tensor a) -> int[]",
[](const torch::jit::Node*) -> torch::jit::Operation {
return [](torch::jit::Stack* stack) {
auto t = torch::jit::pop(stack).toTensor();
auto n = t.numel();
std::vector<int64_t> r;
auto t_data = t.data<int64_t>();
for (auto i : c10::irange(n)) {
r.push_back(t_data[i]);
}
torch::jit::push(stack, r);
};
},
c10::AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
"prim::dim_tensors_to_list(...) -> int[]",
[](const torch::jit::Node* n) -> torch::jit::Operation {
auto num_inputs = n->inputs().size();
return [num_inputs](torch::jit::Stack* stack) {
std::vector<int64_t> dims;
auto ivals = torch::jit::last(stack, num_inputs);
for (auto iv : ivals) {
dims.push_back(iv.toTensor().item<int64_t>());
}
torch::jit::drop(stack, num_inputs);
torch::jit::push(stack, dims);
};
},
c10::AliasAnalysisKind::FROM_SCHEMA),
});
7 changes: 7 additions & 0 deletions lazy_tensor_core/lazy_tensor_core/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ void LTCTensorImpl::shallow_copy_from(
}

at::IntArrayRef LTCTensorImpl::sizes() const {
// get data directly from a tensor if it was materialized
// this would be used if the next op is a fallback
// and this tensor is an input to the op
auto opt_ten = tensor_.CurrentTensorData();
if (opt_ten) {
return opt_ten->sizes();
}
const_cast<LTCTensorImpl*>(this)->SetupSizeProperties();
return c10::TensorImpl::sizes();
}
Expand Down
73 changes: 73 additions & 0 deletions lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/ops/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,79 @@ NodePtr Add::Clone(OpList operands) const {
return MakeNode<Add>(operands.at(0), operands.at(1));
}

DynamicView::DynamicView(lazy_tensors::Span<const ir::Value> values)
: Node(ir::OpKind(c10::Symbol::prim("_dynamic_view")), values,
/*num_outputs=*/1, /*hash_seed=*/0x5a2d296e9) {}

NodePtr DynamicView::Clone(OpList operands) const {
return MakeNode<DynamicView>(operands);
}

AddDim::AddDim(const Value& lhs, const Value& rhs)
: Node(ir::OpKind(c10::Symbol::prim("_add_dim")), {lhs, rhs}, lazy_tensors::Shape{},
/*num_outputs=*/1, /*hash_seed=*/0x5a2d296e9) {}

NodePtr AddDim::Clone(OpList operands) const {
return MakeNode<AddDim>(operands.at(0), operands.at(1));
}

MulDim::MulDim(const Value& lhs, const Value& rhs)
: Node(ir::OpKind(c10::Symbol::prim("_mul_dim")), {lhs, rhs}, lazy_tensors::Shape{},
/*num_outputs=*/1, /*hash_seed=*/0x5a2d296e9) {}

NodePtr MulDim::Clone(OpList operands) const {
return MakeNode<MulDim>(operands.at(0), operands.at(1));
}

DynamicExpand::DynamicExpand(const Value& lhs, const Value& rhs)
: Node(ir::OpKind(c10::Symbol::prim("_dynamic_expand")), {lhs, rhs}, lhs.shape(),
/*num_outputs=*/1, /*hash_seed=*/0x5a2d296e9) {}

NodePtr DynamicExpand::Clone(OpList operands) const {
return MakeNode<DynamicExpand>(operands.at(0), operands.at(1));
}

DynamicExpand2::DynamicExpand2(const Value& lhs, const Value& rhs)
: Node(ir::OpKind(c10::Symbol::prim("_dynamic_expand2")), {lhs, rhs}, lhs.shape(),
/*num_outputs=*/1, /*hash_seed=*/0x5a2d296e9) {}

NodePtr DynamicExpand2::Clone(OpList operands) const {
return MakeNode<DynamicExpand2>(operands.at(0), operands.at(1));
}

DynamicSize::DynamicSize(const Value& lhs)
: Node(ir::OpKind(c10::Symbol::prim("_dynamic_size")), {lhs}, lazy_tensors::Shape{},
/*num_outputs=*/1, /*hash_seed=*/0x5a2d296e9) {}

NodePtr DynamicSize::Clone(OpList operands) const {
return MakeNode<DynamicSize>(operands.at(0));
}

DynamicSize2::DynamicSize2(const Value& lhs)
: Node(ir::OpKind(c10::Symbol::prim("_dynamic_size2")), {lhs}, lazy_tensors::Shape{},
/*num_outputs=*/1, /*hash_seed=*/0x5a2d296e9) {}

NodePtr DynamicSize2::Clone(OpList operands) const {
return MakeNode<DynamicSize2>(operands.at(0));
}

// TODO: figure out how to do optional in LTC IR
DynamicLinear::DynamicLinear(const Value& input, const Value& weight, const Value& bias)
: Node(ir::OpKind(c10::Symbol::prim("_dynamic_linear")), {input, weight, bias}, input.shape(),
/*num_outputs=*/1, /*hash_seed=*/0x5a2d296e9) {}

NodePtr DynamicLinear::Clone(OpList operands) const {
return MakeNode<DynamicLinear>(operands.at(0), operands.at(1), operands.at(2));
}

DynamicGetItem::DynamicGetItem(const Value& lhs, const Value& rhs)
: Node(ir::OpKind(c10::Symbol::prim("_dynamic_getitem")), {lhs, rhs}, lazy_tensors::Shape{},
/*num_outputs=*/1, /*hash_seed=*/0x5a2d296e9) {}

NodePtr DynamicGetItem::Clone(OpList operands) const {
return MakeNode<DynamicGetItem>(operands.at(0), operands.at(1));
}

} // namespace ops
} // namespace ir
} // namespace torch_lazy_tensors
64 changes: 64 additions & 0 deletions lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/ops/add.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,70 @@ class Add : public Node {
NodePtr Clone(OpList operands) const override;
};

class AddDim : public Node {
public:
AddDim(const Value& lhs, const Value& rhs);

NodePtr Clone(OpList operands) const override;
};

class MulDim : public Node {
public:
MulDim(const Value& lhs, const Value& rhs);

NodePtr Clone(OpList operands) const override;
};

class DynamicSize : public Node {
public:
DynamicSize(const Value& lhs);

NodePtr Clone(OpList operands) const override;
};

class DynamicSize2 : public Node {
public:
DynamicSize2(const Value& lhs);

NodePtr Clone(OpList operands) const override;
};

class DynamicExpand : public Node {
public:
DynamicExpand(const Value& lhs, const Value& sz);

NodePtr Clone(OpList operands) const override;
};

class DynamicExpand2 : public Node {
public:
DynamicExpand2(const Value& lhs, const Value& sz);

NodePtr Clone(OpList operands) const override;
};

class DynamicLinear : public Node {
public:
DynamicLinear(const Value& input, const Value& weight, const Value& bias);

NodePtr Clone(OpList operands) const override;
};

class DynamicGetItem : public Node {
public:
DynamicGetItem(const Value& lhs, const Value& rhs);

NodePtr Clone(OpList operands) const override;
};

class DynamicView : public Node {
public:
DynamicView(lazy_tensors::Span<const ir::Value> values);

NodePtr Clone(OpList operands) const override;

};

} // namespace ops
} // namespace ir
} // namespace torch_lazy_tensors
Loading