Skip to content

Commit

Permalink
minimal rewrite phase
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Ullrich committed Dec 2, 2022
1 parent b5dbb13 commit 277d500
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 44 deletions.
88 changes: 44 additions & 44 deletions dialects/matrix/matrix.thorin
Expand Up @@ -367,50 +367,50 @@
// ///
// // TODO: check code for 1-matrix edge case
// // TODO: would this automatically be handled by read(transpose) ?
.lam .extern internal_mapRed_matrix_transpose
![[k: .Nat, l: .Nat], T:*] ->
(.Cn[
[%mem.M,%matrix.Mat (2,(k, l),T)],
.Cn[%mem.M,%matrix.Mat (2,(l, k),T)]
])
= {
.con transpose_comb [[mem:%mem.M, acc:T, [a:T]], ret:.Cn[%mem.M,T]] = {
// TODO: or use generalized addition function
// ignore acc
.let new_acc = a;
ret (mem, new_acc)
};
.con inner_matrix_transpose
![
[
mem:%mem.M,
M:%matrix.Mat (2,(k, l),T),
],
ret: .Cn[%mem.M,%matrix.Mat (2,(l, k),T)]
]
= {
// TODO: use generalized zero
.let zero = (⊥:T);
ret (
%matrix.mapReduce
(2, (l, k), T,
1,
2,
T,
(k,l)
)
(
mem,
zero,
transpose_comb,
(
((1,0), M)
)
)
)
};
inner_matrix_transpose
};
// .lam .extern internal_mapRed_matrix_transpose
// ![[k: .Nat, l: .Nat], T:*] ->
// (.Cn[
// [%mem.M,%matrix.Mat (2,(k, l),T)],
// .Cn[%mem.M,%matrix.Mat (2,(l, k),T)]
// ])
// = {
// .con transpose_comb [[mem:%mem.M, acc:T, [a:T]], ret:.Cn[%mem.M,T]] = {
// // TODO: or use generalized addition function
// // ignore acc
// .let new_acc = a;
// ret (mem, new_acc)
// };
// .con inner_matrix_transpose
// ![
// [
// mem:%mem.M,
// M:%matrix.Mat (2,(k, l),T),
// ],
// ret: .Cn[%mem.M,%matrix.Mat (2,(l, k),T)]
// ]
// = {
// // TODO: use generalized zero
// .let zero = (⊥:T);
// ret (
// %matrix.mapReduce
// (2, (l, k), T,
// 1,
// 2,
// T,
// (k,l)
// )
// (
// mem,
// zero,
// transpose_comb,
// (
// ((1,0), M)
// )
// )
// )
// };
// inner_matrix_transpose
// };
// ///
// /// ### sum
// ///
Expand Down
1 change: 1 addition & 0 deletions dialects/matrix/passes/lower_matrix_lowlevel.cpp
Expand Up @@ -71,6 +71,7 @@ const Def* arrTyOfMatrixTy(const Def* Mat) {

const Def* LowerMatrixLowLevel::rewrite_structural(const Def* def) {
auto& world = def->world();
return Rewriter::rewrite_structural(def); // continue recursive rewriting with everything else

assert(!match<matrix::mapReduce>(def) && "mapReduce should have been lowered to for loops by now");
assert(!match<matrix::shape>(def) && "high level operations should have been lowered to for loops by now");
Expand Down

0 comments on commit 277d500

Please sign in to comment.