-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Marcel Ullrich
committed
Jan 23, 2023
1 parent
10f0cf2
commit 04c5f38
Showing
1 changed file
with
121 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
// RUN: rm -f %t.ll ; \ | ||
// RUN: %thorin -d matrix -d affine -d direct -d clos -d math -d autodiff -o - --output-ll %t.ll %s | ||
// RUN: clang %S/lib.c %t.ll -o %t -Wno-override-module | ||
// RUN: %t 2 3 | FileCheck %s | ||
|
||
.import core; | ||
.import mem; | ||
.import matrix; | ||
.import autodiff; | ||
|
||
.let _32 = 4294967296; | ||
.let I32 = .Idx _32; | ||
.let _f64_p = 52; | ||
.let _f64_e = 11; | ||
.let _f64 = (_f64_p, _f64_e); | ||
.let F64 = %math.F _f64; | ||
.let k = 2; | ||
.let l = 4; | ||
.let n = 3; | ||
.let MT1 = (2, (k,l), F64); | ||
.let MT2 = (2, (l,n), F64); | ||
.let MTO = (2, (k,n), F64); | ||
|
||
.con print_int_matrix [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (⊤:.Nat,⊤:.Nat), I32), return : .Cn [%mem.M]]; | ||
.con print_double_matrix [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (⊤:.Nat,⊤:.Nat), F64), return : .Cn [%mem.M]]; | ||
|
||
.con print_int_matrix_wrap [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (k,l), I32), return : .Cn [%mem.M]] = { | ||
.let m2 = %core.bitcast (%matrix.Mat (2,(⊤:.Nat,⊤:.Nat),I32),%matrix.Mat (2,(k,l),I32)) m; | ||
print_int_matrix(mem, k, l, m2, return) | ||
}; | ||
|
||
.con print_double_matrix_wrap [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (k,l), F64), return : .Cn [%mem.M]] = { | ||
.let m2 = %core.bitcast (%matrix.Mat (2,(⊤:.Nat,⊤:.Nat),F64),%matrix.Mat (2,(k,l),F64)) m; | ||
print_double_matrix(mem, k, l, m2, return) | ||
}; | ||
|
||
// TODO: dependent types need memoization in autodiff_type_fun | ||
// .con f [[mem: %mem.M, | ||
// [k: .Nat, l: .Nat, n: .Nat, | ||
// m1: %matrix.Mat (2, (k,l), F64), | ||
// m2: %matrix.Mat (2, (l,n), F64)]], | ||
// return : .Cn [%mem.M, %matrix.Mat (2, (k,n), F64)]] = { | ||
// .let (mem1, mP) = %matrix.prod (k,l,n, _f64) (mem, m1, m2); | ||
// return (mem1, mP) | ||
// }; | ||
|
||
.con f2 [[mem: %mem.M, | ||
[m1: %matrix.Mat (2, (k,l), F64), | ||
m2: %matrix.Mat (2, (l,n), F64)]], | ||
return : .Cn [%mem.M, %matrix.Mat (2, (k,n), F64)]] = { | ||
.let (mem1, mP) = %matrix.prod (k,l,n, _f64) (mem, m1, m2); | ||
return (mem1, mP) | ||
}; | ||
|
||
.con .extern main [mem : %mem.M, argc : I32, argv : %mem.Ptr (%mem.Ptr (.Idx 256, 0:.Nat), 0:.Nat), return : .Cn [%mem.M, I32]] = { | ||
.con return_cont [mem:%mem.M] = return (mem, 0:I32); | ||
|
||
.let c = 3.0:F64; | ||
.let d = 5.0:F64; | ||
.let (mem2,m1) = %matrix.constMat MT1 (mem,c); | ||
.let (mem3,m2) = %matrix.constMat MT2 (mem2,d); | ||
.let (mem4,m1_2) = %matrix.insert MT1 (mem3,m1, (0:(.Idx k),2:(.Idx l)), 4.0:F64); | ||
.let (mem5,m2_2) = %matrix.insert MT2 (mem4,m2, (1:(.Idx l),2:(.Idx n)), 6.0:F64); | ||
|
||
.con print_cont [mem:%mem.M, m:%matrix.Mat (2, (2,3), F64)] = { | ||
print_double_matrix_wrap (mem, 2, 3, m, return_cont) | ||
}; | ||
|
||
// f (mem5, 2, 4, 3, m1_2, m2_2, print_cont) | ||
|
||
.let f_diff = %autodiff.ad | ||
(.Cn [[mem: %mem.M, | ||
[ | ||
%matrix.Mat (2, (k,l), F64), | ||
%matrix.Mat (2, (l,n), F64)]], | ||
.Cn [%mem.M, %matrix.Mat (2, (k,n), F64)] | ||
]) f2; | ||
|
||
.con call_cont [ | ||
[mem:%mem.M, m:%matrix.Mat (2, (k,n), F64)], | ||
pb: .Cn [ | ||
[%mem.M, %matrix.Mat (2, (k, n), %math.F (52, 11))], | ||
.Cn [%mem.M, | ||
[%matrix.Mat (2, (k, l), F64), %matrix.Mat (2, (l, n), F64)] | ||
] | ||
] | ||
] = { | ||
.let (mem1,ms) = %matrix.constMat MTO (mem,0.0:F64); | ||
.let (mem2,ms_2) = %matrix.insert MTO (mem1,ms, (0:(.Idx k),2:(.Idx n)), 1.0:F64); | ||
// .let mem2 = mem; | ||
|
||
// print_double_matrix_wrap (mem2, k, n, m, return_cont) | ||
.con pb_cont [mem:%mem.M, | ||
[m1d: %matrix.Mat (2, (k, l), F64), m2d:%matrix.Mat (2, (l, n), F64)] | ||
] = { | ||
|
||
.con print_m2d [mem:%mem.M] = { | ||
print_double_matrix_wrap (mem, l, n, m2d, return_cont) | ||
}; | ||
.con print_m1d [mem:%mem.M] = { | ||
print_double_matrix_wrap (mem, k, l, m1d, print_m2d) | ||
}; | ||
print_double_matrix_wrap (mem, k, n, m, print_m1d) | ||
|
||
// print_double_matrix_wrap (mem, k, n, m , .cn [mem: %mem.M] = { | ||
// print_double_matrix_wrap (mem, k, l, m1d, .cn [mem: %mem.M] = { | ||
// print_double_matrix_wrap (mem, l, n, m2d, return_cont) | ||
// }) | ||
// print_double_matrix_wrap (mem, l, n, m2d, return_cont) | ||
// }) | ||
}; | ||
|
||
pb ((mem2, ms_2), pb_cont) | ||
}; | ||
|
||
f_diff ((mem5, (m1_2, m2_2)), call_cont) | ||
|
||
}; | ||
|
||
// CHECK: 65.00, 65.00, 68.00, | ||
// CHECK: 60.00, 60.00, 63.00, |