Skip to content

Commit

Permalink
matrix diff test
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Ullrich committed Jan 23, 2023
1 parent 10f0cf2 commit 04c5f38
Showing 1 changed file with 121 additions and 0 deletions.
121 changes: 121 additions & 0 deletions lit/autodiff/matrix/prod.thorin
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d matrix -d affine -d direct -d clos -d math -d autodiff -o - --output-ll %t.ll %s
// RUN: clang %S/lib.c %t.ll -o %t -Wno-override-module
// RUN: %t 2 3 | FileCheck %s

.import core;
.import mem;
.import matrix;
.import autodiff;

.let _32 = 4294967296;
.let I32 = .Idx _32;
.let _f64_p = 52;
.let _f64_e = 11;
.let _f64 = (_f64_p, _f64_e);
.let F64 = %math.F _f64;
.let k = 2;
.let l = 4;
.let n = 3;
.let MT1 = (2, (k,l), F64);
.let MT2 = (2, (l,n), F64);
.let MTO = (2, (k,n), F64);

.con print_int_matrix [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (⊤:.Nat,⊤:.Nat), I32), return : .Cn [%mem.M]];
.con print_double_matrix [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (⊤:.Nat,⊤:.Nat), F64), return : .Cn [%mem.M]];

.con print_int_matrix_wrap [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (k,l), I32), return : .Cn [%mem.M]] = {
.let m2 = %core.bitcast (%matrix.Mat (2,(⊤:.Nat,⊤:.Nat),I32),%matrix.Mat (2,(k,l),I32)) m;
print_int_matrix(mem, k, l, m2, return)
};

.con print_double_matrix_wrap [mem: %mem.M, k: .Nat, l: .Nat, m: %matrix.Mat (2, (k,l), F64), return : .Cn [%mem.M]] = {
.let m2 = %core.bitcast (%matrix.Mat (2,(⊤:.Nat,⊤:.Nat),F64),%matrix.Mat (2,(k,l),F64)) m;
print_double_matrix(mem, k, l, m2, return)
};

// TODO: dependent types need memoization in autodiff_type_fun
// .con f [[mem: %mem.M,
// [k: .Nat, l: .Nat, n: .Nat,
// m1: %matrix.Mat (2, (k,l), F64),
// m2: %matrix.Mat (2, (l,n), F64)]],
// return : .Cn [%mem.M, %matrix.Mat (2, (k,n), F64)]] = {
// .let (mem1, mP) = %matrix.prod (k,l,n, _f64) (mem, m1, m2);
// return (mem1, mP)
// };

.con f2 [[mem: %mem.M,
[m1: %matrix.Mat (2, (k,l), F64),
m2: %matrix.Mat (2, (l,n), F64)]],
return : .Cn [%mem.M, %matrix.Mat (2, (k,n), F64)]] = {
.let (mem1, mP) = %matrix.prod (k,l,n, _f64) (mem, m1, m2);
return (mem1, mP)
};

.con .extern main [mem : %mem.M, argc : I32, argv : %mem.Ptr (%mem.Ptr (.Idx 256, 0:.Nat), 0:.Nat), return : .Cn [%mem.M, I32]] = {
.con return_cont [mem:%mem.M] = return (mem, 0:I32);

.let c = 3.0:F64;
.let d = 5.0:F64;
.let (mem2,m1) = %matrix.constMat MT1 (mem,c);
.let (mem3,m2) = %matrix.constMat MT2 (mem2,d);
.let (mem4,m1_2) = %matrix.insert MT1 (mem3,m1, (0:(.Idx k),2:(.Idx l)), 4.0:F64);
.let (mem5,m2_2) = %matrix.insert MT2 (mem4,m2, (1:(.Idx l),2:(.Idx n)), 6.0:F64);

.con print_cont [mem:%mem.M, m:%matrix.Mat (2, (2,3), F64)] = {
print_double_matrix_wrap (mem, 2, 3, m, return_cont)
};

// f (mem5, 2, 4, 3, m1_2, m2_2, print_cont)

.let f_diff = %autodiff.ad
(.Cn [[mem: %mem.M,
[
%matrix.Mat (2, (k,l), F64),
%matrix.Mat (2, (l,n), F64)]],
.Cn [%mem.M, %matrix.Mat (2, (k,n), F64)]
]) f2;

.con call_cont [
[mem:%mem.M, m:%matrix.Mat (2, (k,n), F64)],
pb: .Cn [
[%mem.M, %matrix.Mat (2, (k, n), %math.F (52, 11))],
.Cn [%mem.M,
[%matrix.Mat (2, (k, l), F64), %matrix.Mat (2, (l, n), F64)]
]
]
] = {
.let (mem1,ms) = %matrix.constMat MTO (mem,0.0:F64);
.let (mem2,ms_2) = %matrix.insert MTO (mem1,ms, (0:(.Idx k),2:(.Idx n)), 1.0:F64);
// .let mem2 = mem;

// print_double_matrix_wrap (mem2, k, n, m, return_cont)
.con pb_cont [mem:%mem.M,
[m1d: %matrix.Mat (2, (k, l), F64), m2d:%matrix.Mat (2, (l, n), F64)]
] = {

.con print_m2d [mem:%mem.M] = {
print_double_matrix_wrap (mem, l, n, m2d, return_cont)
};
.con print_m1d [mem:%mem.M] = {
print_double_matrix_wrap (mem, k, l, m1d, print_m2d)
};
print_double_matrix_wrap (mem, k, n, m, print_m1d)

// print_double_matrix_wrap (mem, k, n, m , .cn [mem: %mem.M] = {
// print_double_matrix_wrap (mem, k, l, m1d, .cn [mem: %mem.M] = {
// print_double_matrix_wrap (mem, l, n, m2d, return_cont)
// })
// print_double_matrix_wrap (mem, l, n, m2d, return_cont)
// })
};

pb ((mem2, ms_2), pb_cont)
};

f_diff ((mem5, (m1_2, m2_2)), call_cont)

};

// CHECK: 65.00, 65.00, 68.00,
// CHECK: 60.00, 60.00, 63.00,

0 comments on commit 04c5f38

Please sign in to comment.