Skip to content

Commit

Permalink
Register operators, remove unused operator
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiweiYan-96 committed Jan 5, 2024
1 parent fbdbb18 commit fdbd33b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 192 deletions.
138 changes: 22 additions & 116 deletions aten/src/ATen/native/mkldnn/xpu/Blas.cpp
Expand Up @@ -3,7 +3,8 @@
#include "BlasImpl.h"

namespace at {
namespace native::xpu {
namespace native {
namespace xpu {

using namespace impl;

Expand Down Expand Up @@ -362,120 +363,6 @@ Tensor& addmv_out(
return out;
}

Tensor& addmv_(
Tensor& self,
const Tensor& mat,
const Tensor& vec,
at::Scalar beta,
at::Scalar alpha) {
Tensor self_v;
TORCH_CHECK(
(mat.dim() == 2 && vec.dim() == 1 && self.dim() <= 1),
"vector + matrix @ vector expected, got ",
self.dim(),
", ",
mat.dim(),
", ",
vec.dim());
if (self.dim() == 1 && self.size(0) != 1) {
TORCH_CHECK(
(mat.size(1) == vec.size(0) && mat.size(0) == self.size(0)),
"size mismatch, get ",
self.size(0),
", ",
mat.size(0),
"x",
mat.size(1),
",",
vec.size(0));
self_v = self.view({self.size(0), 1});
} else {
TORCH_CHECK(
(mat.size(1) == vec.size(0)),
"size mismatch, get ",
mat.size(0),
"x",
mat.size(1),
",",
vec.size(0));
self_v = self;
}

Tensor vec_v = vec.view({vec.size(0), 1});
self_v.addmm_(mat, vec_v, beta, alpha);
return self;
}

Tensor tensordot(
const Tensor& input1,
const Tensor& input2,
IntArrayRef dims1,
IntArrayRef dims2) {
TORCH_CHECK(
dims1.size() == dims2.size(),
"both dimension lists should have same length");
int64_t csize = 1; // total size of the contracted dimensions
Tensor t1 = input1;
Tensor t2 = input2;
for (const auto i : c10::irange(dims1.size())) {
int s1 = input1.size(dims1[i]);
int s2 = input2.size(dims2[i]);
if (s2 == 1) { // broadcasted dimensions can be summed right away
t1 = t1.sum(dims1[i], true);
} else if (s1 == 1) {
t2 = t2.sum(dims2[i], true);
} else {
TORCH_CHECK(
s1 == s2,
"contracted dimensions need to match, but first has size ",
s1,
" in dim ",
dims1[i],
" and second has size ",
s2,
" in dim ",
dims2[i]);
csize *= s1;
}
}
auto cdims1 = at::dim_list_to_bitset(dims1, input1.dim());
auto cdims2 = at::dim_list_to_bitset(dims2, input2.dim());
std::vector<int64_t> p1, p2,
rsizes; // p1, p2: input permutations, rsizes: sizes of the result
p1.reserve(input1.dim());
p2.reserve(input2.dim());
rsizes.reserve(input1.dim() + input2.dim() - (int64_t)dims1.size());
int64_t size1 = 1; // number of non-contracted elements in input1
int64_t size2 = 1; // number of non-contracted elements in input2

// fill the permutations and compute sizes
for (const auto i : c10::irange(input1.dim())) {
if (!cdims1[i]) {
p1.emplace_back(i);
size1 *= t1.size(i);
rsizes.emplace_back(t1.size(i));
}
}
for (const auto x : dims1) {
p1.emplace_back(x);
}
for (const auto x : dims2) {
p2.emplace_back(x);
}
for (const auto i : c10::irange(input2.dim())) {
if (!cdims2[i]) {
p2.emplace_back(i);
size2 *= t2.size(i);
rsizes.emplace_back(t2.size(i));
}
}
// permut and reshape for matrix multiplication
t1 = t1.permute(p1).reshape({size1, csize});
t2 = t2.permute(p2).reshape({csize, size2});
// multiply and reshape to target size
return at::mm(t1, t2).reshape(rsizes);
}

Tensor& tensordot_out(
const Tensor& input1,
const Tensor& input2,
Expand Down Expand Up @@ -512,5 +399,24 @@ Tensor& tensordot_out(
return result;
}

} // namespace native::xpu

} // namespace xpu

TORCH_LIBRARY_IMPL(aten, XPU, m){
m.impl("admm.out", TORCH_FN(addmm_out));
m.impl("_addmm_activation.out", TORCH_FN(_addmm_activation_out));
m.impl("mm.out", TORCH_FN(mm_out));
m.impl("mm", TORCH_FN(mm));
m.impl("baddbmm.out", TORCH_FN(baddbmm_out));
m.impl("baddbmm_", TORCH_FN(baddbmm_));
m.impl("baddbmm", TORCH_FN(baddbmm));
m.impl("addbmm.out", TORCH_FN(addbmm_out));
m.impl("addbmm_", TORCH_FN(addbmm_));
m.impl("addbmm", TORCH_FN(addbmm));
m.impl("bmm.out", TORCH_FN(bmm_out));
m.impl("bmm", TORCH_FN(bmm));
m.impl("addmv.out", TORCH_FN(addmv_out));
m.impl("tensordot.out", TORCH_FN(tensordot_out));
}
} // namespace native
} // namespace at
22 changes: 0 additions & 22 deletions aten/src/ATen/native/mkldnn/xpu/Linear.cpp

This file was deleted.

54 changes: 0 additions & 54 deletions aten/src/ATen/native/mkldnn/xpu/Linear.h

This file was deleted.

0 comments on commit fdbd33b

Please sign in to comment.