Skip to content

Commit

Permalink
implement a few TODO'd tests for GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
Whiteknight committed Aug 21, 2010
1 parent 89381bf commit 6d20be9
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 15 deletions.
5 changes: 2 additions & 3 deletions src/pmc/complexmatrix2d.pmc
Expand Up @@ -1321,15 +1321,14 @@ Calculates the matrix equation:
*/

METHOD gemm(PMC *alpha, PMC * A, PMC *B, PMC *beta, PMC *C) {
PMC * const c_out = VTABLE_clone(INTERP, C);
FLOATVAL alpha_r, alpha_i, beta_r, beta_i;
A = convert_to_ComplexMatrix2D(interp, A);
B = convert_to_ComplexMatrix2D(interp, B);
C = convert_to_ComplexMatrix2D(interp, C);
get_complex_value_from_pmc(interp, alpha, &alpha_r, &alpha_i);
get_complex_value_from_pmc(interp, beta, &beta_r, &beta_i);
call_gemm(INTERP, alpha_r, alpha_i, A, B, beta_r, beta_i, c_out);
RETURN(PMC* c_out);
call_gemm(INTERP, alpha_r, alpha_i, A, B, beta_r, beta_i, C);
RETURN(PMC* C);
}

/*
Expand Down
5 changes: 2 additions & 3 deletions src/pmc/nummatrix2d.pmc
Expand Up @@ -1221,12 +1221,11 @@ Calculates the matrix equation:
*/

METHOD gemm(FLOATVAL alpha, PMC * A, PMC *B, FLOATVAL beta, PMC *C) {
PMC * const c_out = VTABLE_clone(INTERP, C);
A = convert_to_NumMatrix2D(interp, A);
B = convert_to_NumMatrix2D(interp, B);
C = convert_to_NumMatrix2D(interp, C);
call_gemm(INTERP, alpha, A, B, beta, c_out);
RETURN(PMC* c_out);
call_gemm(INTERP, alpha, A, B, beta, C);
RETURN(PMC* C);
}

/*
Expand Down
64 changes: 55 additions & 9 deletions t/testlib/methods/gemm.nqp
Expand Up @@ -70,16 +70,62 @@ class Pla::Methods::Gemm is Pla::MatrixTestBase {
});
}

method test_METHOD_gemm_AUTOCONVERT_A_NumMatrix2D() { todo("Write this!"); }
method test_METHOD_gemm_AUTOCONVERT_B_NumMatrix2D() { todo("Write this!"); }
method test_METHOD_gemm_AUTOCONVERT_C_NumMatrix2D() { todo("Write this!"); }
# Tests that for the current type, when we call GEMM the values and
# results are converted to this type
method __test_gemm_autoconvert($Af, $Bf, $Cf) {
my $m := self.factory.defaultmatrix2x2();
my $A := $Af.fancymatrix2x2();
my $B := $Bf.fancymatrix2x2();
my $C := $Cf.fancymatrix2x2();
my $alpha := self.factory.fancyvalue(0);
my $beta := self.factory.fancyvalue(0);
my $D := $m.gemm($alpha, $A, $B, $beta, $C);
my $type_D := pir::typeof__SP($D);
my $type_m := pir::typeof__SP($m);
assert_equal($type_D, $type_m,
"not the right type. Found " ~ $type_D ~ " expected " ~ $type_m);
}

method test_autoconvert_A_NumMatrix2D() {
my $factory := Pla::MatrixFactory::NumMatrix2D.new();
self.__test_gemm_autoconvert($factory, self.factory, self.factory);
}

method test_autoconvert_B_NumMatrix2D() {
my $factory := Pla::MatrixFactory::NumMatrix2D.new();
self.__test_gemm_autoconvert(self.factory, $factory, self.factory);
}

method test_autoconvert_C_NumMatrix2D() {
my $factory := Pla::MatrixFactory::NumMatrix2D.new();
self.__test_gemm_autoconvert(self.factory, self.factory, $factory);
}

method test_METHOD_gemm_AUTOCONVERT_A_ComplexMatrix2D() { todo("Write this!"); }
method test_METHOD_gemm_AUTOCONVERT_B_ComplexMatrix2D() { todo("Write this!"); }
method test_METHOD_gemm_AUTOCONVERT_C_ComplexMatrix2D() { todo("Write this!"); }
method test_autoconvert_A_ComplexMatrix2D() {
my $factory := Pla::MatrixFactory::ComplexMatrix2D.new();
self.__test_gemm_autoconvert($factory, self.factory, self.factory);
}

method test_autoconvert_B_ComplexMatrix2D() {
my $factory := Pla::MatrixFactory::ComplexMatrix2D.new();
self.__test_gemm_autoconvert(self.factory, $factory, self.factory);
}

method test_METHOD_gemm_AUTOCONVERT_A_PMCMatrix2D() { todo("Write this!"); }
method test_METHOD_gemm_AUTOCONVERT_B_PMCMatrix2D() { todo("Write this!"); }
method test_METHOD_gemm_AUTOCONVERT_C_PMCMatrix2D() { todo("Write this!"); }
method test_autoconvert_C_ComplexMatrix2D() {
my $factory := Pla::MatrixFactory::ComplexMatrix2D.new();
self.__test_gemm_autoconvert(self.factory, self.factory, $factory);
}

method test_autoconvert_A_PMCMatrix2D() {
my $factory := Pla::MatrixFactory::PMCMatrix2D.new();
self.__test_gemm_autoconvert($factory, self.factory, self.factory);
}
method test_autoconvert_B_PMCMatrix2D() {
my $factory := Pla::MatrixFactory::PMCMatrix2D.new();
self.__test_gemm_autoconvert(self.factory, $factory, self.factory);
}
method test_autoconvert_C_PMCMatrix2D() {
my $factory := Pla::MatrixFactory::PMCMatrix2D.new();
self.__test_gemm_autoconvert(self.factory, self.factory, $factory);
}
}

0 comments on commit 6d20be9

Please sign in to comment.