From 04c5f385f898c8d506eb936e4c61f124f04276d7 Mon Sep 17 00:00:00 2001 From: Marcel Ullrich Date: Mon, 23 Jan 2023 09:32:52 +0100 Subject: [PATCH] matrix diff test --- lit/autodiff/matrix/prod.thorin | 121 ++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 lit/autodiff/matrix/prod.thorin diff --git a/lit/autodiff/matrix/prod.thorin b/lit/autodiff/matrix/prod.thorin new file mode 100644 index 0000000000..9dccc5106c --- /dev/null +++ b/lit/autodiff/matrix/prod.thorin @@ -0,0 +1,121 @@ +// RUN: rm -f %t.ll ; \ +// RUN: %thorin -d matrix -d affine -d direct -d clos -d math -d autodiff -o - --output-ll %t.ll %s +// RUN: clang %S/lib.c %t.ll -o %t -Wno-override-module +// RUN: %t 2 3 | FileCheck %s + +.import core; +.import mem; +.import matrix; +.import autodiff; + +.let _32 = 4294967296; +.let I32 = .Idx _32; +.let _f64_p = 52; +.let _f64_e = 11; +.let _f64 = (_f64_p, _f64_e); +.let F64 = %math.F _f64; +.let k = 2; +.let l = 4; +.let n = 3; +.let MT1 = (2, (k,l), F64); +.let MT2 = (2, (l,n), F64); +.let MTO = (2, (k,n), F64); + +.con print_int_matrix [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (⊤:.Nat,⊤:.Nat), I32), return : .Cn [%mem.M]]; +.con print_double_matrix [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (⊤:.Nat,⊤:.Nat), F64), return : .Cn [%mem.M]]; + +.con print_int_matrix_wrap [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (k,l), I32), return : .Cn [%mem.M]] = { + .let m2 = %core.bitcast (%matrix.Mat (2,(⊤:.Nat,⊤:.Nat),I32),%matrix.Mat (2,(k,l),I32)) m; + print_int_matrix(mem, k, l, m2, return) +}; + +.con print_double_matrix_wrap [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (k,l), F64), return : .Cn [%mem.M]] = { + .let m2 = %core.bitcast (%matrix.Mat (2,(⊤:.Nat,⊤:.Nat),F64),%matrix.Mat (2,(k,l),F64)) m; + print_double_matrix(mem, k, l, m2, return) +}; + +// TODO: dependent types need memoization in autodiff_type_fun +// .con f [[mem: %mem.M, +// [k: .Nat, l: .Nat, n: .Nat, +// m1: %matrix.Mat (2, (k,l), F64), +// m2: %matrix.Mat (2, (l,n), F64)]], +// return : .Cn [%mem.M, %matrix.Mat (2, (k,n), F64)]] = { +// .let (mem1, mP) = %matrix.prod (k,l,n, _f64) (mem, m1, m2); +// return (mem1, mP) +// }; + +.con f2 [[mem: %mem.M, + [m1: %matrix.Mat (2, (k,l), F64), + m2: %matrix.Mat (2, (l,n), F64)]], + return : .Cn [%mem.M, %matrix.Mat (2, (k,n), F64)]] = { + .let (mem1, mP) = %matrix.prod (k,l,n, _f64) (mem, m1, m2); + return (mem1, mP) +}; + +.con .extern main [mem : %mem.M, argc : I32, argv : %mem.Ptr (%mem.Ptr (.Idx 256, 0:.Nat), 0:.Nat), return : .Cn [%mem.M, I32]] = { + .con return_cont [mem:%mem.M] = return (mem, 0:I32); + + .let c = 3.0:F64; + .let d = 5.0:F64; + .let (mem2,m1) = %matrix.constMat MT1 (mem,c); + .let (mem3,m2) = %matrix.constMat MT2 (mem2,d); + .let (mem4,m1_2) = %matrix.insert MT1 (mem3,m1, (0:(.Idx k),2:(.Idx l)), 4.0:F64); + .let (mem5,m2_2) = %matrix.insert MT2 (mem4,m2, (1:(.Idx l),2:(.Idx n)), 6.0:F64); + + .con print_cont [mem:%mem.M, m:%matrix.Mat (2, (2,3), F64)] = { + print_double_matrix_wrap (mem, 2, 3, m, return_cont) + }; + + // f (mem5, 2, 4, 3, m1_2, m2_2, print_cont) + + .let f_diff = %autodiff.ad + (.Cn [[mem: %mem.M, + [ + %matrix.Mat (2, (k,l), F64), + %matrix.Mat (2, (l,n), F64)]], + .Cn [%mem.M, %matrix.Mat (2, (k,n), F64)] + ]) f2; + + .con call_cont [ + [mem:%mem.M, m:%matrix.Mat (2, (k,n), F64)], + pb: .Cn [ + [%mem.M, %matrix.Mat (2, (k, n), %math.F (52, 11))], + .Cn [%mem.M, + [%matrix.Mat (2, (k, l), F64), %matrix.Mat (2, (l, n), F64)] + ] + ] + ] = { + .let (mem1,ms) = %matrix.constMat MTO (mem,0.0:F64); + .let (mem2,ms_2) = %matrix.insert MTO (mem1,ms, (0:(.Idx k),2:(.Idx n)), 1.0:F64); + // .let mem2 = mem; + + // print_double_matrix_wrap (mem2, k, n, m, return_cont) + .con pb_cont [mem:%mem.M, + [m1d: %matrix.Mat (2, (k, l), F64), m2d:%matrix.Mat (2, (l, n), F64)] + ] = { + + .con print_m2d [mem:%mem.M] = { + print_double_matrix_wrap (mem, l, n, m2d, return_cont) + }; + .con print_m1d [mem:%mem.M] = { + print_double_matrix_wrap (mem, k, l, m1d, print_m2d) + }; + print_double_matrix_wrap (mem, k, n, m, print_m1d) + + // print_double_matrix_wrap (mem, k, n, m , .cn [mem: %mem.M] = { + // print_double_matrix_wrap (mem, k, l, m1d, .cn [mem: %mem.M] = { + // print_double_matrix_wrap (mem, l, n, m2d, return_cont) + // }) + // print_double_matrix_wrap (mem, l, n, m2d, return_cont) + // }) + }; + + pb ((mem2, ms_2), pb_cont) + }; + + f_diff ((mem5, (m1_2, m2_2)), call_cont) + +}; + +// CHECK: 65.00, 65.00, 68.00, +// CHECK: 60.00, 60.00, 63.00,