Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Switch operation #935

Merged
merged 6 commits into from Apr 12, 2019
Merged
Changes from 5 commits
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

@@ -19,5 +19,6 @@
#include <phylanx/plugins/keras_support/categorical_crossentropy_operation.hpp>
#include <phylanx/plugins/keras_support/softplus_operation.hpp>
#include <phylanx/plugins/keras_support/softsign_operation.hpp>
#include <phylanx/plugins/keras_support/switch_operation.hpp>

#endif
@@ -0,0 +1,76 @@
// Copyright (c) 2019 Bita Hasheminezhad
// Copyright (c) 2019 Hartmut Kaiser
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#if !defined(PHYLANX_PRIMITIVES_KERAS_SUPPORT_SWITCH)
#define PHYLANX_PRIMITIVES_KERAS_SUPPORT_SWITCH

#include <phylanx/config.hpp>
#include <phylanx/execution_tree/primitives/base_primitive.hpp>
#include <phylanx/execution_tree/primitives/primitive_component_base.hpp>

#include <hpx/lcos/future.hpp>

#include <array>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace phylanx { namespace execution_tree { namespace primitives
{
class switch_operation
: public primitive_component_base
, public std::enable_shared_from_this<switch_operation>
{
protected:
hpx::future<primitive_argument_type> eval(
primitive_arguments_type const& operands,
primitive_arguments_type const& args,
eval_context ctx) const override;

public:
static match_pattern_type const match_data;

switch_operation() = default;

switch_operation(primitive_arguments_type&& operand,
std::string const& name, std::string const& codename);

private:
bool validate_shapes(std::size_t const& ndims_cond,
std::size_t const& ndims_then,
std::array<std::size_t, PHYLANX_MAX_DIMENSIONS>&& dims_cond,
std::array<std::size_t, PHYLANX_MAX_DIMENSIONS>&& dims_then) const;

primitive_argument_type switch0d(ir::node_data<std::uint8_t>&& cond,
ir::node_data<double>&& then_expr,
ir::node_data<double>&& else_expr) const;
primitive_argument_type switch1d(ir::node_data<std::uint8_t>&& cond,
ir::node_data<double>&& then_expr,
ir::node_data<double>&& else_expr) const;
primitive_argument_type switch2d(ir::node_data<std::uint8_t>&& cond,
ir::node_data<double>&& then_expr,
ir::node_data<double>&& else_expr) const;

#if defined(PHYLANX_HAVE_BLAZE_TENSOR)
primitive_argument_type switch3d(ir::node_data<std::uint8_t>&& cond,
ir::node_data<double>&& then_expr,
ir::node_data<double>&& else_expr) const;
#endif
};

inline primitive create_switch_operation(hpx::id_type const& locality,
primitive_arguments_type&& operands,
std::string const& name = "", std::string const& codename = "")
{
return create_primitive_component(
locality, "switch", std::move(operands), name, codename);
}
}}}

#endif
@@ -154,8 +154,8 @@ namespace phylanx { namespace execution_tree
return;
}

// tensors with just one row can be broadcast into any other
// tensor with the same number of columns
// tensors with just one page and row can be broadcast into any
// other tensor with the same number of columns
if (rhs.dimension(0) == 1 && rhs.dimension(1) == 1 &&
rhs.dimension(2) == columns)
{
@@ -174,8 +174,8 @@ namespace phylanx { namespace execution_tree
return;
}

// tensors with just one column can be broadcast into any other
// tensor with the same number of rows
// tensors with just one page and column can be broadcast into
// any other tensor with the same number of rows
if (rhs.dimension(0) == 1 && rhs.dimension(1) == rows &&
rhs.dimension(2) == 1)
{
@@ -194,6 +194,26 @@ namespace phylanx { namespace execution_tree
return;
}

// tensors with just one row and column can be broadcast into
// any other tensor with the same number of pages
if (rhs.dimension(0) == pages && rhs.dimension(1) == 1 &&
rhs.dimension(2) == 1)
{
result.resize(pages, rows, columns);
auto t = rhs.tensor();

auto row = blaze::row(blaze::rowslice(t, 0), 0);
for (std::size_t i = 0; i != rows; ++i)
{
auto slice = blaze::rowslice(result, i);
for (std::size_t j = 0; j != columns; ++j)
{
blaze::row(slice, j) = row;
}
}
return;
}

// tensors with just one page can be broadcast into any other
// tensor with the same number of columns/rows
if (rhs.dimension(0) == 1 && rhs.dimension(1) == rows &&
@@ -210,6 +230,38 @@ namespace phylanx { namespace execution_tree
return;
}

// tensors with just one row can be broadcast into any other
// tensor with the same number of pages/columns
if (rhs.dimension(0) == pages && rhs.dimension(1) == 1 &&
rhs.dimension(2) == columns)
{
result.resize(pages, rows, columns);
auto t = rhs.tensor();

auto rhs_rowslice = blaze::rowslice(t, 0);
for (std::size_t i = 0; i != rows; ++i)
{
blaze::rowslice(result, i) = rhs_rowslice;
}
return;
}

// tensors with just one column can be broadcast into any other
// tensor with the same number of pages/rows
if (rhs.dimension(0) == pages && rhs.dimension(1) == rows &&
rhs.dimension(2) == 1)
{
result.resize(pages, rows, columns);
auto t = rhs.tensor();

auto rhs_columnslice = blaze::columnslice(t, 0);
for (std::size_t j = 0; j != columns; ++j)
{
blaze::columnslice(result, j) = rhs_columnslice;
}
return;
}

if (rhs.dimension(0) != pages || rhs.dimension(1) != rows ||
rhs.dimension(2) != columns)
{
@@ -39,3 +39,5 @@ PHYLANX_REGISTER_PLUGIN_FACTORY(softplus_operation_plugin,
phylanx::execution_tree::primitives::softplus_operation::match_data);
PHYLANX_REGISTER_PLUGIN_FACTORY(softsign_operation_plugin,
phylanx::execution_tree::primitives::softsign_operation::match_data);
PHYLANX_REGISTER_PLUGIN_FACTORY(switch_operation_plugin,
phylanx::execution_tree::primitives::switch_operation::match_data);
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.