diff --git a/docs/source/oneflow.rst b/docs/source/oneflow.rst index 2ec638d9454..4061b0f0ceb 100644 --- a/docs/source/oneflow.rst +++ b/docs/source/oneflow.rst @@ -51,7 +51,8 @@ oneflow as_strided, div, dot, - eq, + eq, + einsum, equal, expand, eye, diff --git a/oneflow/core/common/util.h b/oneflow/core/common/util.h index ffcb4ea6030..a087050fb42 100644 --- a/oneflow/core/common/util.h +++ b/oneflow/core/common/util.h @@ -166,7 +166,8 @@ inline uint32_t NewRandomSeed() { #define DIM_SEQ \ OF_PP_MAKE_TUPLE_SEQ(1) \ OF_PP_MAKE_TUPLE_SEQ(2) \ - OF_PP_MAKE_TUPLE_SEQ(3) OF_PP_MAKE_TUPLE_SEQ(4) OF_PP_MAKE_TUPLE_SEQ(5) OF_PP_MAKE_TUPLE_SEQ(6) + OF_PP_MAKE_TUPLE_SEQ(3) \ + OF_PP_MAKE_TUPLE_SEQ(4) OF_PP_MAKE_TUPLE_SEQ(5) OF_PP_MAKE_TUPLE_SEQ(6) OF_PP_MAKE_TUPLE_SEQ(7) #define BOOL_SEQ (true)(false) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 6c43aa25a48..9af8257b168 100755 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2002,3 +2002,7 @@ - name: "cumprod_grad" signature: "Tensor (Tensor input, Tensor y, Tensor x, Int64 dim) => CumprodGrad" bind_python: False + +- name: "einsum" + signature: "Tensor (String equation, TensorTuple operands) => EinSum" + bind_python: True diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index fb5d502cb1e..5eb972e4478 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -30,6 +30,9 @@ limitations under the License. #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/functional/tensor_processor.h" +#include +#include + namespace oneflow { namespace one { namespace functional { @@ -2158,6 +2161,576 @@ class CumProdGradFunctor : public CumGradBaseFunctor { } }; +// NOTE(Liang Depeng): The implementation of sumproduct_pair are mostly taken from pytorch. +// For more details pls refer to: +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Linear.cpp#L65 + +// sumproduct_pair computes `(left*right).sum(sumdims)` by means of permutation and +// batch matrix multiplication +// its main purpose is to provide a pairwise reduction for einsum +static Maybe sumproduct_pair(const std::shared_ptr& left_, + const std::shared_ptr& right_, + const std::vector& sum_dims_, bool keepdim) { + // assumes that tensors have been pre-unsqueezed (so that all dimensions match - after + // broadcasting) but makes no other assumptions on the order of dimensions + CHECK_OR_RETURN(left_->ndim() == right_->ndim()) << "number of dimensions must match"; + if (sum_dims_.size() == 0) return functional::Mul(left_, right_); + int64_t dim = left_->ndim(); + + constexpr size_t dim_bitset_size = 64; + CHECK_OR_RETURN(dim <= (int64_t)dim_bitset_size) + << "only tensors with up to " << dim_bitset_size << " dims are supported"; + std::bitset sum_dims; + for (int i = 0; i < sum_dims_.size(); ++i) { + size_t d = sum_dims_[i]; + CHECK_OR_RETURN(!sum_dims[d]) << "dim " << d << " appears multiple times in the list of dims"; + sum_dims[d] = true; + } + + // dimensions that will be part of the output (i.e. not summed over) in three vectors + // dims in lro appear in left, right and output, similarly lo: left and output, ro: right and + // output also the sizes are kept track of for reshaping + std::vector lro, lo, ro; + int32_t lro_size = 1, lo_size = 1, ro_size = 1, sum_size = 1; + std::shared_ptr left = left_; + std::shared_ptr right = right_; + for (int i = 0; i < dim; ++i) { + auto sl = left->shape()->At(i) > 1; + auto sr = right->shape()->At(i) > 1; + if (sum_dims[i]) { // first dimensions that will be summed over after multiplication + if (sl && sr) { // dimensions nontrivially in both left and right must be of the same size + CHECK_OR_RETURN(left->shape()->At(i) == right->shape()->At(i)) + << "non-broadcast dimensions must match"; + sum_size *= left->shape()->At(i); + } else if (sl) { // if it is only in one of left and right, we can sum right away + left = JUST(functional::ReduceSum(left, {i}, true)); + } else if (sr) { + right = JUST(functional::ReduceSum(right, {i}, true)); + } + } else if (sl && sr) { // now deal with dimensions dimensions that will be in the output + // dimensions nontrivially in both left and right must be of the same size + CHECK_OR_RETURN(left->shape()->At(i) == right->shape()->At(i)) + << "non-broadcast dimensions must match"; + lro.push_back(i); + lro_size *= left->shape()->At(i); + } else if (sl) { // keep track of dimensions appearing only once + lo.push_back(i); + lo_size *= left->shape()->At(i); + } else { + ro.push_back(i); + ro_size *= right->shape()->At(i); + } + } + + // we now work with the following permutations / shapes. + // the pipeline is permute inputs -> reshape inputs -> batch matrix mul -> reshape(view) output -> + // permute output output: "lro, lo, 1-for-summed-dims, ro" with orgiginal shape dimensions left: + // "lro, lo, summed" permuted with lpermutation and the three flattened right: "lro, summed, ro" + // permuted with rpermutation and the three flattened then the permuted output is a view of + // bmm(left, right) finally, opermutation reverts the permutation to the original order of + // dimensions + std::vector out_size; + for (auto& d : lro) out_size.push_back(left->shape()->At(d)); + for (auto& d : lo) out_size.push_back(left->shape()->At(d)); + for (auto& d : sum_dims_) { + out_size.push_back(1); + (void)(d); + }; // avoid warining about not using d + for (auto& d : ro) out_size.push_back(right->shape()->At(d)); + + std::vector lpermutation(lro); + lpermutation.insert(lpermutation.end(), lo.begin(), lo.end()); + lpermutation.insert(lpermutation.end(), sum_dims_.begin(), sum_dims_.end()); + lpermutation.insert(lpermutation.end(), ro.begin(), ro.end()); + + std::vector rpermutation(lro); + rpermutation.insert(rpermutation.end(), sum_dims_.begin(), sum_dims_.end()); + rpermutation.insert(rpermutation.end(), ro.begin(), ro.end()); + rpermutation.insert(rpermutation.end(), lo.begin(), lo.end()); + + std::vector opermutation(lro.size() + lo.size() + sum_dims_.size() + ro.size(), -1); + { + int32_t i = 0; + + for (auto it = lro.cbegin(); it != lro.cend(); i++, it++) { opermutation[*it] = i; } + for (auto it = lo.cbegin(); it != lo.cend(); i++, it++) { opermutation[*it] = i; } + for (auto it = sum_dims_.cbegin(); it != sum_dims_.cend(); i++, it++) { opermutation[*it] = i; } + for (auto it = ro.cbegin(); it != ro.cend(); i++, it++) { opermutation[*it] = i; } + } + + // now we can execute the operations above + left = JUST(functional::Permute(left, lpermutation)); + DimVector lsv(3); + lsv[0] = lro_size; + lsv[1] = lo_size; + lsv[2] = sum_size; + const Shape ls(lsv); + + left = JUST(functional::Reshape(left, ls)); + + right = JUST(functional::Permute(right, rpermutation)); + DimVector rsv(3); + rsv[0] = lro_size; + rsv[1] = sum_size; + rsv[2] = ro_size; + const Shape rs(rsv); + right = JUST(functional::Reshape(right, rs)); + + std::shared_ptr result = + JUST(functional::BatchMatMul(left, right, false, false, 1.0)); + DimVector osv(out_size.size()); + for (int i = 0; i < out_size.size(); ++i) { osv[i] = out_size[i]; } + const Shape os(osv); + // TODO(Liang Depeng): change reshape to veiw + result = JUST(functional::Reshape(result, os)); + result = JUST(functional::Permute(result, opermutation)); + + // finally squeeze summed dimensions if desired + if (!keepdim) { + auto sizes = result->shape()->dim_vec(); + for (int i = dim - 1; i >= 0; i--) { + if (sum_dims[i]) { sizes.erase(sizes.begin() + i); } + } + // TODO(Liang Depeng): change reshape to veiw + const Shape s(sizes); + result = JUST(functional::Reshape(result, s)); + } + return result; +} + +namespace { + +bool einsum_check_label(unsigned char label) { return std::isalpha(label); } + +uint8_t einsum_label_to_index(unsigned char label) { + constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1; + return std::isupper(label) ? label - 'A' : NUM_OF_LETTERS + (label - 'a'); +} + +unsigned char einsum_index_to_label(uint8_t index) { + constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1; + return index < NUM_OF_LETTERS ? index + 'A' : index - NUM_OF_LETTERS + 'a'; +} + +} // namespace + +// NOTE(Liang Depeng): The implementation of EinSumFunctor are mostly taken from pytorch. +// For more details pls refer to: +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Linear.cpp#L190 + +// There are roughly three parts to compute einsum: +// 1. Parse equation to extract the labels for each input operand and output +// 2. Unsqueeze missing dimensions from input operands and permute to align them +// 3. Compute result by multiplying input operands and summing contraction +// dimensions We do the last part by reducing to batch matmul. +class EinSumFunctor { + public: + EinSumFunctor() {} + Maybe operator()(const std::string& equation, const one::TensorTuple& operands) const { + CHECK_OR_RETURN(operands.size() > 0) << "einsum(): must provide at least one input tensor."; + // NOTE(Liang Depeng): In order to better understand what einsum is doing, + // the following comments will give a detailed explaination of + // how the operands of equation "ik,jkl,il->ij" (bilinear) + // are transformed during the computation. + // Assume that the size of each operands "ik", "jkl" and "il" are + // [2, 3], [4, 3, 5], [2, 5] respectively. + + // Code used to identify ELLIPSIS ("...") + constexpr uint8_t ELLIPSIS = 52; + + // Find arrow (->) to split equation into lhs (input equations) and rhs (output equation) + const auto arrow_pos = equation.find("->"); + const auto lhs = equation.substr(0, arrow_pos); + + const auto num_ops = operands.size(); + + // Convert each input equations into indexes in range [0, 52] and store + // them in op_labels for each operand along with ELLIPSIS if present. + std::vector> op_labels(num_ops); + // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". + // After running the following for loop, `op_labels` contains 3 vectors. + // The contents of each vectors are: + // op_labels[0]: [34('i'-'a'+26), 36('k'-'a'+26)] + // op_labels[1]: [35('j'-'a'+26), 36('k'-'a'+26), 37('l'-'a'+26)] + // op_labels[2]: [34('i'-'a'+26), 37('l'-'a'+26)] + bool found_ell = false; + std::size_t curr_op = 0; + for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) { + const unsigned char label = lhs[i]; + switch (label) { + case ' ': + // Ignore spaces + break; + + case '.': + // process ellipsis + CHECK_OR_RETURN( + // Only one ellipsis per operand can be given + !found_ell) + << "einsum(): found \'.\' for operand " << curr_op + << " for which an ellipsis was already found"; + CHECK_OR_RETURN( + // Ensure it's a valid ellipsis + i + 2 < lhs.length() && lhs[++i] == '.' && lhs[++i] == '.') + << "einsum(): found \'.\' for operand " << curr_op + << " that is not part of any ellipsis"; + op_labels[curr_op].push_back(ELLIPSIS); + found_ell = true; + break; + + case ',': + // Move onto next operand + ++curr_op; + CHECK_OR_RETURN(curr_op < num_ops) + << "einsum(): fewer operands were provided than specified in the equation"; + found_ell = false; + break; + + default: + // Parse label + CHECK_OR_RETURN(einsum_check_label(label)) + << "einsum(): invalid subscript given at index " << i + << " in the equation string, subscripts must be in [a-zA-Z]"; + op_labels[curr_op].push_back(einsum_label_to_index(label)); + } + } + + CHECK_OR_RETURN(curr_op == num_ops - 1) + << "einsum(): more operands were provided than specified in the equation"; + + // Labels must be within [a-zA-Z]. + constexpr uint8_t TOTAL_LABELS = 52; + std::vector label_count(TOTAL_LABELS, 0); + + // The maximum number of dimensions covered by any ellipsis, needed when + // unsqueezing missing dimensions from operands to permute and broadcast + int32_t ell_num_dim = 0; + // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". + // After running the following for loop, + // the none zero indexes of `label_count` are: + // op_labels[34] = 2 + // op_labels[35] = 1 + // op_labels[36] = 2 + // op_labels[37] = 2 + // `ell_num_dim` equals to 0 because no ellipsis in equation + + // Compute label frequency and number of dimensions covered by ellipsis + // We do this after parsing labels to make it more readable and simpler + // to compute the number of dimensions covered by ellipsis. + for (auto i = 0; i < num_ops; i++) { + const auto operand = operands[i]; + const auto labels = op_labels[i]; + const int ndims = operand->ndim(); + int32_t nlabels = static_cast(labels.size()); + bool has_ellipsis = false; + + for (const auto& label : labels) { + if (label == ELLIPSIS) { + --nlabels; + has_ellipsis = true; + ell_num_dim = std::max(ell_num_dim, ndims - nlabels); + } else { + ++label_count[label]; + } + } + if (has_ellipsis) { + CHECK_OR_RETURN(nlabels <= ndims) + << "einsum() the number of subscripts in the equation (" << nlabels + << ") is more than the number of dimensions (" << ndims << ") for operand " << i; + } else { + CHECK_OR_RETURN(nlabels == ndims) + << "einsum(): the number of subscripts in the equation (" << nlabels + << ") does not match the number of dimensions (" << ndims << ") for operand " << i + << " and no ellipsis was given"; + } + } + + // We want to align the dimensions of every input tensor to have + // shape out_dims + sum_dims. For this, we create a mapping of label + // to index into the permuted shape. + std::vector label_perm_index(TOTAL_LABELS, -1); + + // Current index in the permuted shape + int32_t perm_index = 0; + + // Start index of ellipsis dimensions in the permuted shape + int32_t ell_index = 0; + found_ell = false; + + // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". + // After running the following if-else code block, + // the none -1 indexes of `label_perm_index` are: + // label_perm_index[34] = 0 + // label_perm_index[35] = 1 + // `perm_index` equals to 2 + // `ell_index` equals to 0 because no ellipsis in equation + // `found_ell` equals to false because no ellipsis in equation + if (arrow_pos == std::string::npos) { + // Implicit output is ellipsis (...) + labels seen only once + perm_index = ell_num_dim; + found_ell = true; + for (auto label = 0; label < TOTAL_LABELS; label++) { + if (label_count[label] == 1) { label_perm_index[label] = perm_index++; } + } + } else { + // Parse explicit output + const auto rhs = equation.substr(arrow_pos + 2); + for (auto i = decltype(rhs.length()){0}; i < rhs.length(); ++i) { + const unsigned char label = rhs[i]; + switch (label) { + case ' ': + // Ignore spaces + break; + + case '.': + // process ellipsis + CHECK_OR_RETURN( + // There can only be one ellipsis in the output + !found_ell) + << "einsum(): found \'.\' for output but an ellipsis (...) was already found"; + CHECK_OR_RETURN( + // Ensure ellipsis is correct + i + 2 < rhs.length() && rhs[++i] == '.' && rhs[++i] == '.') + "einsum(): found \'.\' for output that is not part of any ellipsis (...)"; + ell_index = perm_index; + perm_index += ell_num_dim; + found_ell = true; + break; + + default: + CHECK_OR_RETURN(einsum_check_label(label)) + << "einsum(): invalid subscript given at index " << lhs.size() + 2 + i + << " in the equation string, subscripts must be in [a-zA-Z]"; + const auto index = einsum_label_to_index(label); + CHECK_OR_RETURN( + // Ensure label appeared at least once for some input operand + // and at most once for the output + label_count[index] > 0 && label_perm_index[index] == -1) + << "einsum(): output subscript " << label + << (label_perm_index[index] > -1 + ? " appears more than once in the output" + : " does not appear in the equation for any input operand"); + label_perm_index[index] = perm_index++; + } + } + } + + // Save output size before adding contraction dims (dims to sum out) + const int32_t out_size = perm_index; + + // If ellipsis is not part of the output, add to contraction dimensions + if (!found_ell) { + ell_index = perm_index; + perm_index += ell_num_dim; + } + + // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". + // After running the following foor loop, + // the none -1 indexes of `label_perm_index` are: + // label_perm_index[34] = 0 ('i') + // label_perm_index[35] = 1 ('j') + // label_perm_index[36] = 2 ('k') + // label_perm_index[37] = 3 ('l') + // `out_size` equals to 2 + // `perm_index` equals to 4 + + // Add contraction labels (labels not present in output) + for (auto label = 0; label < TOTAL_LABELS; label++) { + if (label_count[label] > 0 && label_perm_index[label] == -1) { + label_perm_index[label] = perm_index++; + } + } + + // Here we unsqueeze missing dimensions to make all operands have the same + // number of dimensions. We take diagonals for repeated labels within the + // same operand. Finally we permute the operands to align dimensions as + // per the perm_out_index we computed above. + TensorTuple permuted_operands; + for (auto i = 0; i < num_ops; i++) { + std::vector perm_shape(perm_index, -1); + std::vector label_dim(TOTAL_LABELS, -1); + std::shared_ptr operand = operands[i]; + const auto labels = op_labels[i]; + const auto original_sizes = operand->shape()->dim_vec(); + + int32_t j = 0; + for (const auto& label : labels) { + if (label == ELLIPSIS) { + // Add missing dimensions covered by the ellipsis + const auto num_missing_dim = ell_num_dim - (original_sizes.size() - labels.size() + 1); + for (auto k = 0; k < num_missing_dim; k++) { + operand = JUST(functional::Unsqueeze(operand, j)); + } + for (auto k = 0; k < ell_num_dim; k++) { perm_shape[ell_index + k] = j++; } + } else if (label_dim[label] != -1) { + // Repeated label, take diagonal + const auto dim = label_dim[label]; + CHECK_OR_RETURN(operand->dim(j) == operand->dim(dim)) + << "einsum() subscript " << einsum_index_to_label(label) + << " is repeated for operand " << i << " but the sizes don't match, " + << operand->dim(j) << " != " << operand->dim(dim); + + operand = JUST(functional::Diagonal(operand, 0, dim, j)); + operand = JUST(functional::MovedimInt(operand, -1, dim)); + } else { + // Lookup output index for label + label_dim[label] = j; + perm_shape[label_perm_index[label]] = j++; + } + } + + // Add dimensions for missing labels + for (int32_t& index : perm_shape) { + if (index == -1) { + operand = JUST(functional::Unsqueeze(operand, -1)); + index = j++; + } + } + permuted_operands.emplace_back(JUST(functional::Permute(operand, perm_shape))); + + // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". + // What is going on within this foor loop? + // For operand "ik" size = [2, 3]: + // `perm_shape` equals to [0, 2, 1, 3] + // first unsqueeze "ik" to 4 dim, from [2, 3] to [2, 3, 1, 1] + // then permute with `perm_shape`, from [2, 3, 1, 1] to [2, 1, 3, 1] + // + // For operand "jkl" size = [4, 3, 5]: + // `perm_shape` equals to [3, 0, 1, 2] + // first unsqueeze "jkl" to 4 dim, from [4, 3, 5] to [4, 3, 5, 1] + // then permute with `perm_shape`, from [4, 3, 5, 1] to [1, 4, 3, 5] + // + // For operand "il" size = [2, 5]: + // `perm_shape` equals to [0, 2, 3, 1] + // first unsqueeze "ik" to 4 dim, from [2, 5] to [2, 5, 1, 1] + // then permute with `perm_shape`, from [2, 5, 1, 1] to [2, 1, 1, 5] + } + + // Check if operands broadcast and keep track of last operand with + // dimension size != 1 for optimizing reductions + std::vector dim_last_op(perm_index, 0); + bool has_zero_size_dim = false; + // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". + // After running the following foor loop, + // The contents of `dim_last_op` are: + // dim_last_op[0] = 2 + // dim_last_op[1] = 1 + // dim_last_op[2] = 1 + // dim_last_op[3] = 2 + // `has_zero_size_dim` equals to false + for (auto dim = 0; dim < perm_index; dim++) { + auto broadcast_size = permuted_operands[0]->dim(dim); + for (auto i = 1; i < num_ops; i++) { + const auto dim_size = permuted_operands[i]->dim(dim); + if (broadcast_size != dim_size && broadcast_size != 1 && dim_size != 1) { + std::ostringstream msg; + msg << "einsum(): operands do not broadcast with remapped shapes [original->remapped]:"; + for (auto j = 0; j < num_ops; j++) { + msg << " " << operands[j]->shape()->DebugStr() << "->" + << permuted_operands[j]->shape()->DebugStr(); + } + CHECK_OR_RETURN(false) << msg.str(); + } + if (dim_size != 1) { + broadcast_size = dim_size; + dim_last_op[dim] = i; + } + } + has_zero_size_dim |= broadcast_size == 0; + } + + // Compute result + std::shared_ptr result = permuted_operands[0]; + + // Fast path for when an operand has zero sized dim + if (has_zero_size_dim) { + DimVector out_shape(out_size); + for (auto i = 0; i < out_size; i++) { + out_shape[i] = permuted_operands[dim_last_op[i]]->dim(i); + } + + const Shape shape(out_shape); + return functional::Constant(shape, Scalar(0), *permuted_operands[0]->dtype(), NullOpt); + } + + // Sum out or squeeze dimensions that are size 1 for all later operands + int dim = out_size; + for (int i = dim; i < perm_index; ++i, ++dim) { + if (dim_last_op[i] == 0) { + if (result->dim(dim) == 1) { + std::vector dims = {dim--}; + result = JUST(functional::Squeeze(result, dims)); + } else { + result = JUST(functional::ReduceSum(result, {dim--}, false)); + } + } + } + + for (auto i = 1; i < num_ops; i++) { + auto operand = permuted_operands[i]; + std::vector sum_dims; + + // Sum out or squeeze dimensions that are size 1 for all later operands + dim = out_size; + for (int j = dim; j < perm_index; ++j, ++dim) { + if (dim_last_op[j] < i) { + std::vector dims = {dim--}; + operand = JUST(functional::Squeeze(operand, dims)); + } else if (dim_last_op[j] == i) { + if (result->dim(dim) == 1) { + operand = JUST(functional::ReduceSum(operand, {dim}, false)); + std::vector dims = {dim--}; + result = JUST(functional::Squeeze(result, dims)); + } else { + sum_dims.push_back(dim); + } + } + } + + // Multiply tensors and sum out dimensions in sum_dims + if (sum_dims.empty()) { + result = JUST(functional::Mul(result, operand)); + } else if (sum_dims.size() == result->shape()->NumAxes()) { + auto flatten_result = JUST(functional::Flatten(result, 0, -1)); + auto flatten_operand = JUST(functional::Flatten(operand, 0, -1)); + result = JUST(functional::Dot(flatten_result, flatten_operand)); + } else { + result = JUST(sumproduct_pair(result, operand, sum_dims, false)); + } + + // NOTE(Liang Depeng): Continue explaining the equation "ik,jkl,il->ij". + // What is going on within this foor loop? + // For iter i = 1: + // result = permuted_operands[0], size = [2, 1, 3, 1] + // operand = permuted_operands[1], size = [1, 4, 3, 5] + // sum_dims = [2, ] + // what happened in `sumproduct_pair` ? + // result [2, 1, 3, 1] will be permuted to [2, 3, 1, 1] then + // reshaped to [1, 2, 3] + // operand [1, 4, 3, 5] will be permuted to [3, 4, 5, 1] then + // reshape to [1, 3, 4 * 5] + // perform batch_matmul(result, operand) => [1, 2, 4 * 5] + // then reshape to [2, 1, 4, 5] then permute to + // [2, 4, 1, 5], at last reshape to [2, 4, 5] + // + // For iter i = 2: + // result, size = [2, 4, 5] + // operand = permuted_operands[2], size = [2, 1, 1, 5] + // squeeze operand from [2, 1, 1, 5] to [2, 1, 5] + // sum_dims = [2,] + // what happened in `sumproduct_pair` ? + // result [2, 4, 5] will be permuted to [2, 4, 5] then + // reshaped to [2, 4, 5] + // operand [2, 1, 5] will be permuted to [2, 5, 1] then + // reshape to [2, 5, 1] + // perform batch_matmul(result, operand)=>[2, 4, 1] + // then reshape to [2, 4, 1] then permute to [2, 4, 1] + // at last reshape to [2, 4] + } + return result; + } +}; + } // namespace impl using namespace impl; @@ -2245,6 +2818,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("CumsumGrad"); m.add_functor("Cumprod"); m.add_functor("CumprodGrad"); + m.add_functor("EinSum"); }; } // namespace functional diff --git a/oneflow/core/ndarray/xpu_broadcast_ndarray.h b/oneflow/core/ndarray/xpu_broadcast_ndarray.h index 5f98465e510..435d8f1205f 100644 --- a/oneflow/core/ndarray/xpu_broadcast_ndarray.h +++ b/oneflow/core/ndarray/xpu_broadcast_ndarray.h @@ -61,6 +61,7 @@ SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(2); SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(3); SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(4); SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(5); +SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(6); #undef SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL #undef IMPLACE_SET_SRC_COORD diff --git a/oneflow/core/ndarray/xpu_shape.h b/oneflow/core/ndarray/xpu_shape.h index 863b335ef98..f422288d85e 100644 --- a/oneflow/core/ndarray/xpu_shape.h +++ b/oneflow/core/ndarray/xpu_shape.h @@ -111,6 +111,7 @@ SPECIALIZE_XPU_SHAPE_UTIL(1); SPECIALIZE_XPU_SHAPE_UTIL(2); SPECIALIZE_XPU_SHAPE_UTIL(3); SPECIALIZE_XPU_SHAPE_UTIL(4); +SPECIALIZE_XPU_SHAPE_UTIL(5); #undef SPECIALIZE_XPU_SHAPE_UTIL #undef EXTRACT_COORD #undef COORD_MUL_STRIDE diff --git a/oneflow/core/ndarray/xpu_util.h b/oneflow/core/ndarray/xpu_util.h index 3519dd72c75..667379fb59b 100644 --- a/oneflow/core/ndarray/xpu_util.h +++ b/oneflow/core/ndarray/xpu_util.h @@ -61,6 +61,7 @@ namespace oneflow { #define GET_SEQ_3 GET_SEQ_2 OF_PP_MAKE_TUPLE_SEQ(3) #define GET_SEQ_4 GET_SEQ_3 OF_PP_MAKE_TUPLE_SEQ(4) #define GET_SEQ_5 GET_SEQ_4 OF_PP_MAKE_TUPLE_SEQ(5) +#define GET_SEQ_6 GET_SEQ_5 OF_PP_MAKE_TUPLE_SEQ(6) } // namespace oneflow diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 09039cd0c5c..3a6bbe6cff4 100755 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -315,6 +315,7 @@ def atexit_hook(hook): adaptive_avg_pool2d, adaptive_avg_pool3d, ) +from oneflow.nn.modules.einsum import einsum_op as einsum from oneflow.nn.modules.is_tensor import is_tensor_op as is_tensor from oneflow.nn.modules.arange import arange_op as arange from oneflow.nn.modules.linspace import linspace_op as linspace diff --git a/python/oneflow/framework/docstr/__init__.py b/python/oneflow/framework/docstr/__init__.py index d7f680f61d5..db4bc3593f7 100644 --- a/python/oneflow/framework/docstr/__init__.py +++ b/python/oneflow/framework/docstr/__init__.py @@ -61,4 +61,5 @@ from .sort import * from .is_floating_point import * from .where import * +from .einsum import * from .oneflow import * diff --git a/python/oneflow/framework/docstr/einsum.py b/python/oneflow/framework/docstr/einsum.py new file mode 100644 index 00000000000..80e0c99c70b --- /dev/null +++ b/python/oneflow/framework/docstr/einsum.py @@ -0,0 +1,122 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow +from oneflow.framework.docstr.utils import add_docstr + +add_docstr( + oneflow.einsum, + """ + einsum(equation, *operands) -> oneflow.Tensor + + Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation + based on the Einstein summation convention. + + Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them + in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of + this format are described below, but the general idea is to label every dimension of the input :attr:`operands` + with some subscript and define which subscripts are part of the output. The output is then computed by summing + the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the + output. For example, matrix multiplication can be computed using einsum as `flow.einsum("ij,jk->ik", A, B)`. + Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why). + + Equation: + + The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of + the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a + comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript + must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is + repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand + must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that + appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order. + The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based + on the subscripts, and then summing out the dimensions whose subscripts are not part of the output. + + Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation + followed by the subscripts for the output. For instance, the following equation computes the transpose of a + matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and + at most once for the output. + + Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis. + Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts, + e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth + dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the + 'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not + explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions), + before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements + batch matrix multiplication `'...ij,...jk'`. + + A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis, + arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands. + + .. note:: + + ``flow.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions + covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output. + + .. note:: + + This function does not optimize the given expression, so a different formula for the same computation may + run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) + can optimize the formula for you. + + Args: + equation (String): The subscripts for the Einstein summation. + *operands (oneflow.Tensor): The tensors to compute the Einstein summation of. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + + # trace + >>> flow.einsum('ii', flow.arange(4*4).reshape(4,4).to(flow.float32)) + tensor(30., dtype=oneflow.float32) + + # diagonal + >>> flow.einsum('ii->i', flow.arange(4*4).reshape(4,4).to(flow.float32)) + tensor([ 0., 5., 10., 15.], dtype=oneflow.float32) + + # outer product + >>> x = flow.arange(5).to(flow.float32) + >>> y = flow.arange(4).to(flow.float32) + >>> flow.einsum('i,j->ij', x, y) + tensor([[ 0., 0., 0., 0.], + [ 0., 1., 2., 3.], + [ 0., 2., 4., 6.], + [ 0., 3., 6., 9.], + [ 0., 4., 8., 12.]], dtype=oneflow.float32) + + # batch matrix multiplication + >>> As = flow.arange(3*2*5).reshape(3,2,5).to(flow.float32) + >>> Bs = flow.arange(3*5*4).reshape(3,5,4).to(flow.float32) + >>> flow.einsum('bij,bjk->bik', As, Bs).shape + oneflow.Size([3, 2, 4]) + + # batch permute + >>> A = flow.randn(2, 3, 4, 5) + >>> flow.einsum('...ij->...ji', A).shape + oneflow.Size([2, 3, 5, 4]) + + # bilinear + >>> A = flow.randn(3,5,4) + >>> l = flow.randn(2,5) + >>> r = flow.randn(2,4) + >>> flow.einsum('bn,anm,bm->ba', l, A, r).shape + oneflow.Size([2, 3]) + + """, +) diff --git a/python/oneflow/nn/modules/einsum.py b/python/oneflow/nn/modules/einsum.py new file mode 100644 index 00000000000..ac4b1f4e781 --- /dev/null +++ b/python/oneflow/nn/modules/einsum.py @@ -0,0 +1,26 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow + + +def einsum_op(equation, *operands): + return flow._C.einsum(equation, operands) + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/python/oneflow/test/expensive/test_einsum.py b/python/oneflow/test/expensive/test_einsum.py new file mode 100644 index 00000000000..716f2a378b2 --- /dev/null +++ b/python/oneflow/test/expensive/test_einsum.py @@ -0,0 +1,625 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import oneflow as flow + +import oneflow.unittest +from oneflow.test_utils.automated_test_util import * + + +@flow.unittest.skip_unless_1n1d() +class TestEinsum(flow.unittest.TestCase): + @autotest(n=20, check_graph=True) + def test_einsum_matrix_transpose(test_case): + device = random_device() + x = random_tensor(ndim=2, dim0=random(1, 6), dim1=random(1, 6),).to(device) + z = torch.einsum("ij->ji", x) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_eltwise_multiply(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) + y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) + z = torch.einsum("ij,ij->ij", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_get_diagonal(test_case): + device = random_device() + dim = random(1, 6) + x = random_tensor(ndim=2, dim0=dim, dim1=dim,).to(device) + z = torch.einsum("ii->i", x) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_batch_permute(test_case): + device = random_device() + x = random_tensor( + ndim=5, + dim0=random(1, 6), + dim1=random(1, 6), + dim2=random(1, 6), + dim3=random(1, 6), + dim4=random(1, 6), + ).to(device) + z = torch.einsum("...ij->...ji", x) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_reduce_sum(test_case): + device = random_device() + x = random_tensor(ndim=2, dim0=random(1, 6), dim1=random(1, 6),).to(device) + z = torch.einsum("ij->", x) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_matrix_column_sum(test_case): + device = random_device() + x = random_tensor(ndim=2, dim0=random(1, 6), dim1=random(1, 6),).to(device) + z = torch.einsum("ij->j", x) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_matrix_vector_multiply(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) + y = random_tensor(ndim=1, dim0=dim1,).to(device) + # NOTE(Liang Depeng): the same as 'ik,k->i' + z = torch.einsum("ik,k", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_matmul(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + dim2 = random(1, 6) + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) + y = random_tensor(ndim=2, dim0=dim1, dim1=dim2,).to(device) + # NOTE(Liang Depeng): the same as 'ik,kj->ij' + z = torch.einsum("ik,kj", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_vector_inner_product(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor(ndim=1, dim0=dim0,).to(device) + y = random_tensor(ndim=1, dim0=dim0,).to(device) + # NOTE(Liang Depeng): the same as 'i,i->' + z = torch.einsum("i,i", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_eltwise_mul_then_reduce_sum(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) + y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) + # NOTE(Liang Depeng): the same as 'ij,ij->' + z = torch.einsum("ij,ij", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_vector_outer_product(test_case): + device = random_device() + x = random_tensor(ndim=1, dim0=random(1, 6),).to(device) + y = random_tensor(ndim=1, dim0=random(1, 6),).to(device) + # NOTE(Liang Depeng): the same as 'i,j->ij' + z = torch.einsum("i,j", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_batch_matmul(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) + z = torch.einsum("ijk,ikl->ijl", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_tensor_contraction(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=random(1, 6), dim1=dim0, dim2=dim1, dim3=random(1, 6), + ).to(device) + y = random_tensor( + ndim=5, + dim0=random(1, 6), + dim1=random(1, 6), + dim2=dim0, + dim3=random(1, 6), + dim4=dim1, + ).to(device) + z = torch.einsum("pqrs,tuqvr->pstuv", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_bilinear_transformation(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + dim2 = random(1, 6) + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) + y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim1, dim2=dim2,).to(device) + w = random_tensor(ndim=2, dim0=dim0, dim1=dim2,).to(device) + z = torch.einsum("ik,jkl,il->ij", x, y, w) + return z + + @autotest(n=20, auto_backward=False, check_graph=True) + def test_einsum_0_size_tensor(test_case): + device = random_device() + x = random_tensor(ndim=3, dim0=random(1, 6), dim1=0, dim2=random(1, 6),).to( + device + ) + z = torch.einsum("ijk", x) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_tensor_contraction2(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor( + ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=random(1, 6), + ).to(device) + y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 6),).to(device) + z = torch.einsum("b n h w, n d -> b d h w", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_eltwise_mul_sum_row(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) + y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) + z = torch.einsum("n d, n d -> n", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_matmul2(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor(ndim=2, dim0=random(1, 6), dim1=dim0,).to(device) + y = random_tensor(ndim=2, dim0=random(1, 6), dim1=dim0,).to(device) + z = torch.einsum("i d, j d -> i j", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_attention(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + dim2 = random(1, 6) + x = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2, + ).to(device) + y = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2, + ).to(device) + z = torch.einsum("b h i d, b h j d -> b h i j", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_batch_matmul2(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + dim2 = random(1, 6) + x = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2 + ).to(device) + y = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=random(1, 6) + ).to(device) + z = torch.einsum("b h i j, b h j d -> b h i d", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_batch_matrix_vector_multiply(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + dim2 = random(1, 6) + x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=dim2,).to(device) + y = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2, + ).to(device) + z = torch.einsum("b i d, b i j d -> b i j", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_batch_matmul3(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6), dim3=dim1, + ).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device) + z = torch.einsum("b x i d, b j d -> b x i j", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_batch_matmul4(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6), dim3=dim1, + ).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) + z = torch.einsum("b x i j, b j d -> b x i d", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase1(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) + z = torch.einsum("hij, ijc->ihc", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase2(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) + z = torch.einsum("rac,rab->rbc", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase3(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) + z = torch.einsum("ra,rab->rb", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase4(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device) + y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device) + z = torch.einsum("qhc,khc->qkh", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase5(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor(ndim=2, dim0=random(1, 6), dim1=dim0,).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to( + device + ) + z = torch.einsum("nm, mrc->nrc", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase6(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1,).to(device) + z = torch.einsum("abc,adc->bdc", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase7(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=random(1, 6), dim1=dim0, dim2=dim1, dim3=random(1, 6), + ).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6),).to(device) + z = torch.einsum("dceb,cef->dbf", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase8(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to( + device + ) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to( + device + ) + z = torch.einsum("acb,ade->dceb", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase9(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to( + device + ) + y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 6),).to(device) + z = torch.einsum("qkc,ch->hqk", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase10(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + dim2 = random(1, 6) + x = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2, + ).to(device) + y = random_tensor( + ndim=4, dim0=dim0, dim1=dim2, dim2=dim1, dim3=random(1, 6) + ).to(device) + z = torch.einsum("bhqk,bkhc->bqhc", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_alphaflod_usecase11(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to( + device + ) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=random(1, 6),).to( + device + ) + z = torch.einsum("bqa,ahc->bqhc", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_ellipsis_usecase1(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to( + device + ) + y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to( + device + ) + z = torch.einsum("...lc, ...c -> ...l", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_ellipsis_usecase2(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1,).to(device) + y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim0, dim2=dim1).to(device) + z = torch.einsum("...lc, ...lc -> ...l", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_ellipsis_usecase3(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0,).to( + device + ) + y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to( + device + ) + z = torch.einsum("...id,...jd->...ij", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_ellipsis_usecase4(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=dim1 + ).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 6)).to(device) + z = torch.einsum("...klm,kmn->...kln", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_ellipsis_usecase5(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor( + ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=random(1, 6) + ).to(device) + y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to( + device + ) + z = torch.einsum("...ikl, ...jk -> ...ijl", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_ellipsis_usecase6(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to( + device + ) + y = random_tensor(ndim=3, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0).to( + device + ) + z = torch.einsum("...l,...l->...", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_ellipsis_usecase7(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + dim2 = random(1, 6) + x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=dim2).to(device) + y = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=random(1, 6) + ).to(device) + z = torch.einsum("ijk,ijk...->ij...", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_other_usecase1(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + dim2 = random(1, 6) + x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1).to(device) + y = random_tensor(ndim=3, dim0=random(1, 6), dim1=dim1, dim2=dim2).to(device) + w = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim2).to(device) + z = torch.einsum("bxi,oij,byj->boxy", x, y, w) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_other_usecase2(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6) + ).to(device) + y = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6) + ).to(device) + z = torch.einsum("ijac,ijkp->ijakcp", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_other_usecase3(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=dim0, dim1=random(1, 6), dim2=dim1, dim3=random(1, 6) + ).to(device) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 6), dim2=dim1).to(device) + z = torch.einsum("cdij,cbi->cdbj", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_fastfold_usecase1(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + dim2 = random(1, 6) + x = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2 + ).to(device) + y = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=dim2 + ).to(device) + z = torch.einsum("bsid,bsjd->bijd", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_fastfold_usecase2(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6) + ).to(device) + y = random_tensor( + ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 6), dim3=random(1, 6) + ).to(device) + z = torch.einsum("bsid,bsje->bijde", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_openfold_usecase1(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor( + ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=random(1, 6) + ).to(device) + y = random_tensor( + ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=random(1, 6) + ).to(device) + z = torch.einsum("...bac,...dae->...bdce", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_openfold_usecase2(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=dim1 + ).to(device) + y = random_tensor( + ndim=4, dim0=random(1, 6), dim1=dim0, dim2=random(1, 6), dim3=dim1 + ).to(device) + z = torch.einsum("...abc,...adc->...bdc", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_openfold_usecase3(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=dim1 + ).to(device) + y = random_tensor( + ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim0, dim3=dim1 + ).to(device) + z = torch.einsum("...qhd,...khd->...hqk", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_openfold_usecase4(test_case): + device = random_device() + dim0 = random(1, 6) + dim1 = random(1, 6) + x = random_tensor( + ndim=4, dim0=random(1, 6), dim1=dim0, dim2=dim1, dim3=random(1, 6) + ).to(device) + y = random_tensor( + ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=dim1, dim3=dim0 + ).to(device) + z = torch.einsum("...vhf,...qhv->...qhf", x, y) + return z + + @autotest(n=20, check_graph=True) + def test_einsum_openfold_usecase5(test_case): + device = random_device() + dim0 = random(1, 6) + x = random_tensor( + ndim=4, dim0=random(1, 6), dim1=random(1, 6), dim2=random(1, 6), dim3=dim0 + ).to(device) + y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 6)).to(device) + z = torch.einsum("...ij,jk->ik", x, y) + return z + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_consistent_einsum.py b/python/oneflow/test/modules/test_consistent_einsum.py new file mode 100644 index 00000000000..9e37986a699 --- /dev/null +++ b/python/oneflow/test/modules/test_consistent_einsum.py @@ -0,0 +1,635 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest + +import numpy as np + +import oneflow as flow +import oneflow.unittest +from oneflow.test_utils.automated_test_util import * + + +@autotest(n=2, check_graph=False) +def _test_einsum_matrix_transpose(test_case, placement, sbp): + x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8) + g_x = x.to_global(placement=placement, sbp=sbp) + z = torch.einsum("ij->ji", g_x) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_eltwise_multiply(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) + y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("ij,ij->ij", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_get_diagonal(test_case, placement, sbp): + dim = random(1, 3) * 8 + x = random_tensor(ndim=2, dim0=dim, dim1=dim,) + g_x = x.to_global(placement=placement, sbp=sbp) + z = torch.einsum("ii->i", g_x) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_batch_permute(test_case, placement, sbp): + x = random_tensor( + ndim=5, + dim0=random(1, 3) * 8, + dim1=random(1, 3) * 8, + dim2=random(1, 3) * 8, + dim3=random(1, 3) * 8, + dim4=random(1, 3) * 8, + ) + g_x = x.to_global(placement=placement, sbp=sbp) + z = torch.einsum("...ij->...ji", g_x) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_reduce_sum(test_case, placement, sbp): + x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + z = torch.einsum("ij->", g_x) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_matrix_column_sum(test_case, placement, sbp): + x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + z = torch.einsum("ij->j", g_x) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_matrix_vector_multiply(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) + y = random_tensor(ndim=1, dim0=dim1,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + # NOTE(Liang Depeng): the same as 'ik,k->i' + z = torch.einsum("ik,k", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_matmul(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + dim2 = random(1, 3) * 8 + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) + y = random_tensor(ndim=2, dim0=dim1, dim1=dim2,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + # NOTE(Liang Depeng): the same as 'ik,kj->ij' + z = torch.einsum("ik,kj", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_vector_inner_product(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + x = random_tensor(ndim=1, dim0=dim0,) + y = random_tensor(ndim=1, dim0=dim0,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + # NOTE(Liang Depeng): the same as 'i,i->' + z = torch.einsum("i,i", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_eltwise_mul_then_reduce_sum(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) + y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + # NOTE(Liang Depeng): the same as 'ij,ij->' + z = torch.einsum("ij,ij", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_vector_outer_product(test_case, placement, sbp): + x = random_tensor(ndim=1, dim0=random(1, 3) * 8,) + y = random_tensor(ndim=1, dim0=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + # NOTE(Liang Depeng): the same as 'i,j->ij' + z = torch.einsum("i,j", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_batch_matmul(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("ijk,ikl->ijl", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_tensor_contraction(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor( + ndim=4, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1, dim3=random(1, 3) * 8, + ) + y = random_tensor( + ndim=5, + dim0=random(1, 3) * 8, + dim1=random(1, 3) * 8, + dim2=dim0, + dim3=random(1, 3) * 8, + dim4=dim1, + ) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("pqrs,tuqvr->pstuv", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_bilinear_transformation(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + dim2 = random(1, 3) * 8 + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) + y = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim1, dim2=dim2,) + w = random_tensor(ndim=2, dim0=dim0, dim1=dim2,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + g_w = w.to_global(placement=placement, sbp=sbp) + z = torch.einsum("ik,jkl,il->ij", g_x, g_y, g_w) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_tensor_contraction2(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + x = random_tensor( + ndim=4, + dim0=random(1, 3) * 8, + dim1=dim0, + dim2=random(1, 3) * 8, + dim3=random(1, 3) * 8, + ) + y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("b n h w, n d -> b d h w", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_eltwise_mul_sum_row(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) + y = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("n d, n d -> n", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_matmul2(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=dim0,) + y = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=dim0,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("i d, j d -> i j", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_attention(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + dim2 = random(1, 3) * 8 + x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,) + y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("b h i d, b h j d -> b h i j", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_batch_matmul2(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + dim2 = random(1, 3) * 8 + x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2) + y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=dim2, dim3=random(1, 3) * 8) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("b h i j, b h j d -> b h i d", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_batch_matrix_vector_multiply(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + dim2 = random(1, 3) * 8 + x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=dim2,) + y = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("b i d, b i j d -> b i j", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_batch_matmul3(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor( + ndim=4, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8, dim3=dim1, + ) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("b x i d, b j d -> b x i j", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_batch_matmul4(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor( + ndim=4, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8, dim3=dim1, + ) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("b x i j, b j d -> b x i d", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase1(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1,) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("hij, ijc->ihc", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase2(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("rac,rab->rbc", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase3(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor(ndim=2, dim0=dim0, dim1=dim1,) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("ra,rab->rb", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase4(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1,) + y = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("qhc,khc->qkh", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase5(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + x = random_tensor(ndim=2, dim0=random(1, 3) * 8, dim1=dim0,) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("nm, mrc->nrc", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase6(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=dim1,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("abc,adc->bdc", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase7(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + x = random_tensor( + ndim=4, dim0=random(1, 3) * 8, dim1=dim0, dim2=dim1, dim3=random(1, 3) * 8, + ) + y = random_tensor(ndim=3, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("dceb,cef->dbf", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase8(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + x = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("acb,ade->dceb", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase9(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8, dim2=dim0,) + y = random_tensor(ndim=2, dim0=dim0, dim1=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("qkc,ch->hqk", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase10(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + dim1 = random(1, 3) * 8 + dim2 = random(1, 3) * 8 + x = random_tensor(ndim=4, dim0=dim0, dim1=dim1, dim2=random(1, 3) * 8, dim3=dim2,) + y = random_tensor(ndim=4, dim0=dim0, dim1=dim2, dim2=dim1, dim3=random(1, 3) * 8) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("bhqk,bkhc->bqhc", g_x, g_y) + return z + + +@autotest(n=2, check_graph=False) +def _test_einsum_alphaflod_usecase11(test_case, placement, sbp): + dim0 = random(1, 3) * 8 + x = random_tensor(ndim=3, dim0=random(1, 3) * 8, dim1=random(1, 3) * 8, dim2=dim0,) + y = random_tensor(ndim=3, dim0=dim0, dim1=random(1, 3) * 8, dim2=random(1, 3) * 8,) + g_x = x.to_global(placement=placement, sbp=sbp) + g_y = y.to_global(placement=placement, sbp=sbp) + z = torch.einsum("bqa,ahc->bqhc", g_x, g_y) + return z + + +class TestEinsumConsistent(flow.unittest.TestCase): + @globaltest + def test_einsum_matrix_transpose(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_matrix_transpose(test_case, placement, sbp) + + @globaltest + def test_einsum_eltwise_multiply(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_eltwise_multiply(test_case, placement, sbp) + + @globaltest + def test_einsum_get_diagonal(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_get_diagonal(test_case, placement, sbp) + + @globaltest + def test_einsum_batch_permute(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=5): + _test_einsum_batch_permute(test_case, placement, sbp) + + @globaltest + def test_einsum_reduce_sum(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_reduce_sum(test_case, placement, sbp) + + @globaltest + def test_einsum_matrix_column_sum(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_matrix_column_sum(test_case, placement, sbp) + + @globaltest + def test_einsum_matrix_vector_multiply(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=1): + _test_einsum_matrix_vector_multiply(test_case, placement, sbp) + + @globaltest + def test_einsum_matmul(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_matmul(test_case, placement, sbp) + + @globaltest + def test_einsum_vector_inner_product(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=1): + _test_einsum_vector_inner_product(test_case, placement, sbp) + + @globaltest + def test_einsum_eltwise_mul_then_reduce_sum(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_eltwise_mul_then_reduce_sum(test_case, placement, sbp) + + @globaltest + def test_einsum_vector_outer_product(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=1): + _test_einsum_vector_outer_product(test_case, placement, sbp) + + @globaltest + def test_einsum_batch_matmul(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_batch_matmul(test_case, placement, sbp) + + @globaltest + def test_einsum_tensor_contraction(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=4): + _test_einsum_tensor_contraction(test_case, placement, sbp) + + @globaltest + def test_einsum_bilinear_transformation(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_bilinear_transformation(test_case, placement, sbp) + + @globaltest + def test_einsum_tensor_contraction2(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_tensor_contraction2(test_case, placement, sbp) + + @globaltest + def test_einsum_eltwise_mul_sum_row(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_eltwise_mul_sum_row(test_case, placement, sbp) + + @globaltest + def test_einsum_matmul2(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_matmul2(test_case, placement, sbp) + + @globaltest + def test_einsum_attention(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=4): + _test_einsum_attention(test_case, placement, sbp) + + @globaltest + def test_einsum_batch_matmul2(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=4): + _test_einsum_batch_matmul2(test_case, placement, sbp) + + @globaltest + def test_einsum_batch_matrix_vector_multiply(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_batch_matrix_vector_multiply(test_case, placement, sbp) + + @globaltest + def test_einsum_batch_matmul3(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_batch_matmul3(test_case, placement, sbp) + + @globaltest + def test_einsum_batch_matmul4(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_batch_matmul4(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase1(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_alphaflod_usecase1(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase2(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_alphaflod_usecase2(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase3(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_alphaflod_usecase3(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase4(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_alphaflod_usecase4(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase5(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_alphaflod_usecase5(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase6(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_alphaflod_usecase6(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase7(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_alphaflod_usecase7(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase8(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_alphaflod_usecase8(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase9(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2): + _test_einsum_alphaflod_usecase9(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase10(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=4): + _test_einsum_alphaflod_usecase10(test_case, placement, sbp) + + @globaltest + def test_einsum_alphaflod_usecase11(test_case): + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=3): + _test_einsum_alphaflod_usecase11(test_case, placement, sbp) + + +if __name__ == "__main__": + unittest.main()