Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 committed Sep 24, 2020
1 parent 4b9e79d commit b970cdf
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 44 deletions.
7 changes: 4 additions & 3 deletions src/include/migraphx/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ void find_matches(program& p, instruction_ref ins, Ms&&... ms)
const
#endif
int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
bool match = false;
bool match = false;
each_args(
[&](auto&& m) {
if(match)
Expand All @@ -233,7 +233,7 @@ void find_matches(program& p, instruction_ref ins, Ms&&... ms)
return;
if(trace > 0)
{
if (trace > 1)
if(trace > 1)
p.debug_print();
std::cout << "Matched by " << get_type_name(m) << std::endl;
p.debug_print(ins);
Expand All @@ -244,7 +244,8 @@ void find_matches(program& p, instruction_ref ins, Ms&&... ms)
if(invalid != p.end())
{
auto index = std::distance(p.begin(), invalid);
MIGRAPHX_THROW(get_type_name(m) + " matcher produces invalid program at instruction " +
MIGRAPHX_THROW(get_type_name(m) +
" matcher produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
#endif
Expand Down
64 changes: 36 additions & 28 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,20 +357,21 @@ struct find_concat_reshape_op
{
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](match::has_attribute("reshape")(match::arg(0)(pointwise(), match::used_once()))));
return match::name("concat")(match::any_of[match::inputs()](
match::has_attribute("reshape")(match::arg(0)(pointwise(), match::used_once()))));
}

void apply(program& p, const match::matcher_result& r) const
{
for(auto rshp:r.result->inputs())
for(auto rshp : r.result->inputs())
{
if (not rshp->get_operator().attributes().contains("reshape"))
if(not rshp->get_operator().attributes().contains("reshape"))
continue;
auto pw = rshp->inputs().front();
if (not pw->get_operator().attributes().contains("pointwise"))
if(not pw->get_operator().attributes().contains("pointwise"))
continue;
std::vector<instruction_ref> args;
for(auto arg:pw->inputs())
for(auto arg : pw->inputs())
{
args.push_back(p.insert_instruction(std::next(arg), rshp->get_operator(), arg));
}
Expand Down Expand Up @@ -413,33 +414,34 @@ std::vector<instruction_ref> get_splits(instruction_ref ins)
return result;
}

static std::pair<instruction_ref, instruction_ref> order(instruction_ref end, instruction_ref a, instruction_ref b)
static std::pair<instruction_ref, instruction_ref>
order(instruction_ref end, instruction_ref a, instruction_ref b)
{
auto a1 = a;
auto b1 = b;
for(;;)
{
a1++;
b1++;
if (a1 == b)
if(a1 == b)
return std::make_pair(a, b);
else if (b1 == a)
else if(b1 == a)
return std::make_pair(b, a);
else if (a1 == end)
else if(a1 == end)
return std::make_pair(b, a);
else if (b1 == end)
else if(b1 == end)
return std::make_pair(a, b);

}
}

static instruction_ref find_last(instruction_ref end, std::vector<instruction_ref> inss)
{
if (inss.empty())
if(inss.empty())
return end;
return std::accumulate(inss.begin(), inss.end(), inss.front(), [&](instruction_ref a, instruction_ref b) {
return order(end, a, b).second;
});
return std::accumulate(
inss.begin(), inss.end(), inss.front(), [&](instruction_ref a, instruction_ref b) {
return order(end, a, b).second;
});
}

struct find_splits
Expand All @@ -456,8 +458,7 @@ struct find_splits
std::unordered_map<instruction_ref, instruction_ref> split_arg;
};

static split_group
get_split_groups(const std::vector<instruction_ref>& splits)
static split_group get_split_groups(const std::vector<instruction_ref>& splits)
{
split_group result;
for(auto out : splits.front()->outputs())
Expand Down Expand Up @@ -503,7 +504,9 @@ struct find_splits
continue;

assert(std::all_of(group.begin(), group.end(), [](auto i) {
return std::any_of(i->inputs().begin(), i->inputs().end(), [](auto ii) { return ii->name() == "slice"; });
return std::any_of(i->inputs().begin(), i->inputs().end(), [](auto ii) {
return ii->name() == "slice";
});
}));

// Make sure there is no duplicates
Expand All @@ -520,23 +523,27 @@ struct find_splits
assert(not std::none_of(start->inputs().begin(), start->inputs().end(), [](auto i) {
return i->name() == "slice";
}) && "one argument must be a split");
auto split_idx = 0;
auto data_idx = 1;
auto split_idx = 0;
auto data_idx = 1;
if(start->inputs().back() == sg.split_arg.at(start))
{
split_idx = 1;
data_idx = 0;
}

std::vector<std::size_t> data_indices;
std::transform(group.begin(), group.end(), std::back_inserter(data_indices), [&](auto i) {
if(i->inputs().back() == sg.split_arg.at(i))
return 0;
else
return 1;
});
std::transform(
group.begin(), group.end(), std::back_inserter(data_indices), [&](auto i) {
if(i->inputs().back() == sg.split_arg.at(i))
return 0;
else
return 1;
});
// If arguments are flipped then make sure the op is commutative
if (not std::all_of(data_indices.begin(), data_indices.end(), [&](auto i) { return i == data_idx; }) and not op.attributes().contains("commutative"))
if(not std::all_of(data_indices.begin(),
data_indices.end(),
[&](auto i) { return i == data_idx; }) and
not op.attributes().contains("commutative"))
continue;

std::vector<instruction_ref> data_args;
Expand All @@ -553,7 +560,8 @@ struct find_splits
auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match
auto last_arg = find_last(p.end(), data_args);
auto concat = p.insert_instruction(std::next(last_arg), op::concat{concat_axis}, data_args);
auto concat =
p.insert_instruction(std::next(last_arg), op::concat{concat_axis}, data_args);
auto final_ins = order(p.end(), concat, ins).second;

std::vector<instruction_ref> args;
Expand Down
26 changes: 13 additions & 13 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,20 +969,20 @@ TEST_CASE(simplify_split_sub_relu_flipped_args)

migraphx::program p2;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p2.add_parameter("input", s);
auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p2.add_literal(1);
auto oneb = p2.add_instruction(b, one);
auto two = p2.add_literal(2);
auto twob = p2.add_instruction(b, two);
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p2.add_parameter("input", s);
auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p2.add_literal(1);
auto oneb = p2.add_instruction(b, one);
auto two = p2.add_literal(2);
auto twob = p2.add_instruction(b, two);
auto oneneg = p2.add_instruction(migraphx::op::neg{}, oneb);
auto sum1 = p2.add_instruction(migraphx::op::add{}, x, oneneg);
auto relu1 = p2.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p2.add_instruction(migraphx::op::sub{}, twob, y);
auto relu2 = p2.add_instruction(migraphx::op::relu{}, sum2);
auto add = p2.add_instruction(migraphx::op::add{}, relu1, relu2);
auto sum1 = p2.add_instruction(migraphx::op::add{}, x, oneneg);
auto relu1 = p2.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p2.add_instruction(migraphx::op::sub{}, twob, y);
auto relu2 = p2.add_instruction(migraphx::op::relu{}, sum2);
auto add = p2.add_instruction(migraphx::op::add{}, relu1, relu2);
p2.add_instruction(pass_op{}, add);
}

Expand Down

0 comments on commit b970cdf

Please sign in to comment.