Skip to content

Commit

Permalink
[DRAFT] Support Add
Browse files Browse the repository at this point in the history
.

Signed-off-by: YongHyun An <yonghyunz.an@samsung.com>
  • Loading branch information
YongHyun An committed Jan 11, 2024
1 parent f165eaf commit 12ccdba
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 0 deletions.
25 changes: 25 additions & 0 deletions runtime/onert/backend/train/KernelGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,31 @@ KernelGenerator::KernelGenerator(const ir::train::TrainableGraph &tgraph,
// DO NOTHING
}

void KernelGenerator::Visit(const ir::train::operation::BinaryArithmetic &node)
{
using ir::train::operation::BinaryArithmetic;

const auto output_index{node.getOutputs().at(0)};
const auto lhs_index{node.getInputs().at(Reduce::Input::LHS)};
const auto rhs_index{node.getInputs().at(Reduce::Input::RHS)};

const auto arithmetic_type = node.param().arithmetic_type;
const auto activation = node.param().activation;

auto output_tensor = _tensor_reg->getPortableTensor(output_index);
auto lhs_tensor = _tensor_reg->getPortableTensor(lhs_index);
auto rhs_tensor = _tensor_reg->getPortableTensor(rhs_index);

auto back_prop_output_tensor = _tensor_reg->getBackPropTensor(output_index);
auto back_prop_lhs_tensor = _tensor_reg->getBackPropTensor(lhs_index);
auto back_prop_rhs_tensor = _tensor_reg->getBackPropTensor(rhs_index);

auto fn = std::make_unique<ops::BinaryArithmeticLayer>();
fn->configure(lhs_tensor, rhs_tensor, output_tensor, back_prop_lhs_tensor, back_prop_rhs_tensor,
back_prop_output_tensor, arithmetic_type, activation);
_return_fn = std::move(fn);
}

void KernelGenerator::visit(const ir::train::operation::Conv2D &node)
{
using ir::train::operation::Conv2D;
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/backend/train/KernelGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class KernelGenerator : public backend::train::KernelGeneratorBase

std::unique_ptr<exec::train::TrainableFnSequence> generate(ir::OperationIndex op_ind) override;

void visit(const ir::train::operation::BinaryArithmetic &) override;
void visit(const ir::train::operation::Conv2D &) override;
void visit(const ir::train::operation::ElementwiseActivation &) override;
void visit(const ir::train::operation::FullyConnected &) override;
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/include/ir/train/Operations.Include.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__
#define __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__

#include "ir/train/operation/BinaryArithmetic.h"
#include "ir/train/operation/Conv2D.h"
#include "ir/train/operation/DepthwiseConv2D.h"
#include "ir/train/operation/ElementwiseActivation.h"
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/include/ir/train/Operations.lst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#error Define OP before including this file
#endif

OP(BinaryArithmetic)
OP(Conv2D)
OP(DepthwiseConv2D)
OP(ElementwiseActivation)
Expand Down
51 changes: 51 additions & 0 deletions runtime/onert/core/include/ir/train/operation/BinaryArithmetic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_IR_TRAIN_OPERATION_BinaryArithmetic_H__
#define __ONERT_IR_TRAIN_OPERATION_BinaryArithmetic_H__

#include "ir/operation/BinaryArithmetic.h"
#include "ir/train/ITrainableOperation.h"

namespace onert
{
namespace ir
{
namespace train
{
namespace operation
{

class BinaryArithmetic : public ir::operation::BinaryArithmetic, public ITrainableOperation
{
private:
using OperationType = ir::operation::BinaryArithmetic;

public:
BinaryArithmetic(const OperationType &operation);

public:
std::unique_ptr<ITrainableOperation> clone() const override;
void accept(OperationVisitor &v) const override;
void accept(TrainableOperationVisitor &v) const override;
};

} // namespace operation
} // namespace train
} // namespace ir
} // namespace onert

#endif // __ONERT_IR_TRAIN_OPERATION_BinaryArithmetic_H__
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ TrainableOperationConverter::TrainableOperationConverter(
UNUSED_RELEASE(_training_info);
}

void TrainableOperationConverter::visit(const ir::operation::BinaryArithemtic &node)
{
_return_op = std::make_unique<ir::train::operation::BinaryArithemtic>(node);
}

void TrainableOperationConverter::visit(const ir::operation::Conv2D &node)
{
_return_op = std::make_unique<ir::train::operation::Conv2D>(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class TrainableOperationConverter : public UntrainableOperationConverter
using UntrainableOperationConverter::operator();

private:
void visit(const ir::operation::BinaryArithemic &) override;
void visit(const ir::operation::Conv2D &) override;
void visit(const ir::operation::DepthwiseConv2D &) override;
void visit(const ir::operation::ElementwiseActivation &) override;
Expand Down
49 changes: 49 additions & 0 deletions runtime/onert/core/src/ir/train/operation/BinaryArithemtic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ir/train/operation/BinaryArithemtic.h"

#include "ir/OperationVisitor.h"
#include "ir/train/TrainableOperationVisitor.h"

namespace onert
{
namespace ir
{
namespace train
{
namespace operation
{

std::unique_ptr<ITrainableOperation> BinaryArithemtic::clone() const
{
return std::make_unique<BinaryArithemtic>(*this);
}

void BinaryArithemtic::accept(OperationVisitor &v) const { v.visit(*this); }

void BinaryArithemtic::accept(TrainableOperationVisitor &v) const { v.visit(*this); }

BinaryArithemtic::BinaryArithemtic(const OperationType &operation)
: OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
{
// DO NOTHING
}

} // namespace operation
} // namespace train
} // namespace ir
} // namespace onert

0 comments on commit 12ccdba

Please sign in to comment.