Skip to content

Commit

Permalink
[DRAFT] Support Add
Browse files Browse the repository at this point in the history
.

Signed-off-by: YongHyun An <yonghyunz.an@samsung.com>
  • Loading branch information
YongHyun An committed Jan 26, 2024
1 parent d8d8837 commit ccb639d
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 1 deletion.
87 changes: 87 additions & 0 deletions compute/cker/include/cker/train/operation/BinaryArithmetic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __NNFW_CKER_TRAIN_OPERATION_BINARYARITHMETIC_H__
#define __NNFW_CKER_TRAIN_OPERATION_BINARYARITHMETIC_H__

#include "cker/Shape.h"
#include "cker/eigen/Utils.h"
#include "cker/operation/BroadcastTo.h"

namespace nnfw
{
namespace cker
{
namespace train
{
enum class ArithmeticType
{
kAdd,
kSub,
kMul,
kDiv,
};

template <typename T>
void BinaryArithmeticGrad(const Shape &lhs_shape, const T *lhs_data, const Shape &rhs_shape,
const T *rhs_data, const Shape &incoming_shape, const T *incoming_data,
const Shape &lhs_grad_shape, T *lhs_grad_data,
const Shape &rhs_grad_shape, T *rhs_grad_data,
ArithmeticType arithmetic_type)
{
switch (arithmetic_type)
{
case ArithmeticType::kAdd:
{
BroadcastTo(incoming_shape, const_cast<T *>(incoming_data), lhs_grad_shape, lhs_grad_data);
BroadcastTo(incoming_shape, const_cast<T *>(incoming_data), rhs_grad_shape, rhs_grad_data);
}
break;

case ArithmeticType::kSub:
{
BroadcastTo(incoming_shape, const_cast<T *>(incoming_data), lhs_grad_shape, lhs_grad_data);

auto const in_map = MapAsMatrixWithLastDimAsRows(incoming_data, incoming_shape);
auto rhs_grad_map = MapAsMatrixWithLastDimAsRows(rhs_grad_data, rhs_grad_shape);
rhs_grad_map = -in_map;
}
break;

case ArithmeticType::kMul:
{
auto const in_map = MapAsMatrixWithLastDimAsRows(incoming_data, incoming_shape);
auto const lhs_map = MapAsMatrixWithLastDimAsRows(lhs_data, lhs_shape);
auto const rhs_map = MapAsMatrixWithLastDimAsRows(rhs_data, rhs_shape);
auto lhs_grad_map = MapAsMatrixWithLastDimAsRows(lhs_grad_data, lhs_grad_shape);
auto rhs_grad_map = MapAsMatrixWithLastDimAsRows(rhs_grad_data, rhs_grad_shape);

lhs_grad_map = rhs_map.cwiseProduct(in_map);
rhs_grad_map = lhs_map.cwiseProduct(in_map);
}
break;

case ArithmeticType::kDiv:
default:
throw std::runtime_error{"Unsupported Binary Arithmetic Operation"};
}
}

} // namespace train
} // namespace cker
} // namespace nnfw

#endif // __NNFW_CKER_TRAIN_OPERATION_BINARYARITHMETIC_H__
2 changes: 1 addition & 1 deletion runtime/onert/backend/cpu/ops/BinaryArithmeticLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class BinaryArithmeticLayer : public ::onert::exec::IFunction

void run() override;

private:
protected:
const IPortableTensor *_lhs;
const IPortableTensor *_rhs;
IPortableTensor *_output;
Expand Down
27 changes: 27 additions & 0 deletions runtime/onert/backend/train/KernelGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "KernelGenerator.h"

#include "ops/BinaryArithmeticLayer.h"
#include "ops/ConvolutionLayer.h"
#include "ops/DepthwiseConvolutionLayer.h"
#include "ops/ElementwiseActivationLayer.h"
Expand Down Expand Up @@ -120,6 +121,32 @@ KernelGenerator::KernelGenerator(const ir::train::TrainableGraph &tgraph,
// DO NOTHING
}

void KernelGenerator::visit(const ir::train::operation::BinaryArithmetic &node)
{
using ir::train::operation::BinaryArithmetic;

const auto output_index{node.getOutputs().at(0)};
const auto lhs_index{node.getInputs().at(BinaryArithmetic::Input::LHS)};
const auto rhs_index{node.getInputs().at(BinaryArithmetic::Input::RHS)};

const auto arithmetic_type = node.param().arithmetic_type;
const auto activation = node.param().activation;

auto output_tensor = _tensor_reg->getPortableTensor(output_index);
auto lhs_tensor = _tensor_reg->getPortableTensor(lhs_index);
auto rhs_tensor = _tensor_reg->getPortableTensor(rhs_index);

auto back_prop_output_tensor = _tensor_reg->getBackPropTensor(output_index);
auto back_prop_lhs_tensor = _tensor_reg->getBackPropTensor(lhs_index);
auto back_prop_rhs_tensor = _tensor_reg->getBackPropTensor(rhs_index);

auto fn = std::make_unique<ops::BinaryArithmeticLayer>();
fn->configure(lhs_tensor, rhs_tensor, output_tensor, back_prop_lhs_tensor, back_prop_rhs_tensor,
back_prop_output_tensor, activation,
static_cast<train::ops::ArithmeticType>(arithmetic_type));
_return_fn = std::move(fn);
}

void KernelGenerator::visit(const ir::train::operation::Conv2D &node)
{
using ir::train::operation::Conv2D;
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/backend/train/KernelGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class KernelGenerator : public backend::train::KernelGeneratorBase

std::unique_ptr<exec::train::TrainableFnSequence> generate(ir::OperationIndex op_ind) override;

void visit(const ir::train::operation::BinaryArithmetic &) override;
void visit(const ir::train::operation::Conv2D &) override;
void visit(const ir::train::operation::DepthwiseConv2D &) override;
void visit(const ir::train::operation::ElementwiseActivation &) override;
Expand Down
93 changes: 93 additions & 0 deletions runtime/onert/backend/train/ops/BinaryArithmeticLayer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "BinaryArithmeticLayer.h"

#include "OperationUtils.h"

#include <cker/Shape.h>
#include <cker/train/operation/BinaryArithmetic.h>
#include <cker/operation/BinaryArithmeticOps.h>
#include <cker/train/operation/BinaryArithmetic.h>
#include <cker/train/operation/ReLU.h>

namespace onert
{
namespace backend
{
namespace train
{
namespace ops
{

BinaryArithmeticLayer::BinaryArithmeticLayer()
: cpu::ops::BinaryArithmeticLayer(), _back_prop_lhs{nullptr}, _back_prop_rhs{nullptr},
_back_prop_output{nullptr}
{
// DO NOTHING
}

void BinaryArithmeticLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs,
IPortableTensor *output, IPortableTensor *back_prop_lhs,
IPortableTensor *back_prop_rhs,
const IPortableTensor *back_prop_output,
const ir::Activation activation,
const ArithmeticType arithmetic_type)
{
if (arithmetic_type != ArithmeticType::kAdd && arithmetic_type != ArithmeticType::kSub &&
arithmetic_type != ArithmeticType::kMul)
{
throw std::runtime_error{"Unsupported binary operation"};
}

cpu::ops::BinaryArithmeticLayer::configure(
lhs, rhs, output, activation, static_cast<cpu::ops::ArithmeticType>(arithmetic_type));

_back_prop_lhs = back_prop_lhs;
_back_prop_rhs = back_prop_rhs;
_back_prop_output = back_prop_output;
_arithmetic_type = arithmetic_type;
_activation = activation;
}

void BinaryArithmeticLayer::forward(bool) { cpu::ops::BinaryArithmeticLayer::run(); }

void BinaryArithmeticLayer::backward()
{
// Calculate gradient for activation
const IPortableTensor *backprop_act;
try
{
backprop_act =
backpropActivation(_activation, _output, _back_prop_output, _act_back_prop_output.get());
}
catch (const std::exception &e)
{
throw std::runtime_error{"train BinaryArithmeticLayer: " + std::string(e.what())};
}
assert(backprop_act != nullptr);

nnfw::cker::train::BinaryArithmeticGrad(
getShape(_lhs), getBuffer<float>(_lhs), getShape(_rhs), getBuffer<float>(_rhs),
getShape(backprop_act), getBuffer<float>(backprop_act), getShape(_back_prop_lhs),
getBuffer<float>(_back_prop_lhs), getShape(_back_prop_rhs), getBuffer<float>(_back_prop_rhs),
static_cast<nnfw::cker::train::ArithmeticType>(_arithmetic_type));
}

} // namespace ops
} // namespace train
} // namespace backend
} // namespace onert
72 changes: 72 additions & 0 deletions runtime/onert/backend/train/ops/BinaryArithmeticLayer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_BACKEND_TRAIN_OPS_BINARYARITHMETICLAYER_H__
#define __ONERT_BACKEND_TRAIN_OPS_BINARYARITHMETICLAYER_H__

#include <ops/BinaryArithmeticLayer.h>
#include <backend/IPortableTensor.h>

#include "../Tensor.h"
#include <exec/train/ITrainableFunction.h>

namespace onert
{
namespace backend
{
namespace train
{
namespace ops
{

enum class ArithmeticType
{
kAdd,
kSub,
kMul,
kDiv,
};

class BinaryArithmeticLayer : public ::onert::exec::train::ITrainableFunction,
public cpu::ops::BinaryArithmeticLayer
{
public:
BinaryArithmeticLayer();

public:
void configure(const IPortableTensor *lhs, const IPortableTensor *rhs, IPortableTensor *output,
IPortableTensor *back_prop_lhs, IPortableTensor *back_prop_rhs,
const IPortableTensor *back_prop_output, const ir::Activation activation,
const ArithmeticType arithmetic_type);
void forward(bool training) override;
void backward() override;

private:
IPortableTensor *_back_prop_lhs;
IPortableTensor *_back_prop_rhs;
const IPortableTensor *_back_prop_output;

ArithmeticType _arithmetic_type;
ir::Activation _activation;
std::unique_ptr<BackPropTensor> _act_back_prop_output;
};

} // namespace ops
} // namespace train
} // namespace backend
} // namespace onert

#endif // __ONERT_BACKEND_TRAIN_OPS_BINARYARITHMETICLAYER_H__

0 comments on commit ccb639d

Please sign in to comment.