Skip to content

Commit

Permalink
Implemented add for transposed matrices, fixed some accesses to trans…
Browse files Browse the repository at this point in the history
…posed matrices, implemented multiply for transposed matrices.
  • Loading branch information
Markus Mayr committed Nov 6, 2009
1 parent 650ccbb commit 67a86d5
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 40 deletions.
174 changes: 139 additions & 35 deletions src/pmc/nummatrix2d.pmc
Expand Up @@ -98,7 +98,7 @@ pmclass NumMatrix2D dynpmc auto_attrs {
if (x >= x_size || y >= y_size)
Parrot_ex_throw_from_c_args(INTERP, NULL, EXCEPTION_OUT_OF_BOUNDS,
"NumMatrix2d: indices out of bounds");
return ITEM_XY_ROWMAJOR(attrs->storage, x_size, y_size, x, y);
return ITEM_XY(attrs->storage, attrs->flags, x_size, y_size, x, y);
}

VTABLE INTVAL get_integer_keyed(PMC * key) {
Expand Down Expand Up @@ -182,7 +182,7 @@ pmclass NumMatrix2D dynpmc auto_attrs {
x_size = attrs->x;
y_size = attrs->y;
}
ITEM_XY_ROWMAJOR(attrs->storage, x_size, y_size, x, y) = value;
ITEM_XY(attrs->storage, attrs->flags, x_size, y_size, x, y) = value;
}

VTABLE void set_integer_keyed(PMC * key, INTVAL value) {
Expand Down Expand Up @@ -235,7 +235,7 @@ pmclass NumMatrix2D dynpmc auto_attrs {
}

MULTI PMC *add(NumMatrix2D *value, PMC *dest) {
int i = 0;
int i = 0, j = 0;
INTVAL x_size, y_size;
Parrot_NumMatrix2D_attributes * const selfattr
= (Parrot_NumMatrix2D_attributes *) PARROT_NUMMATRIX2D(SELF);
Expand All @@ -246,27 +246,55 @@ pmclass NumMatrix2D dynpmc auto_attrs {
x_size = selfattr->x;
y_size = selfattr->y;

if (IS_TRANSPOSED(selfattr->flags) || IS_TRANSPOSED(valattr->flags)) {
Parrot_ex_throw_from_c_args(INTERP, NULL, EXCEPTION_UNIMPLEMENTED,
"NumMatrix2D: Transposed matrices not supported yet in add.");
}

if (x_size != valattr->x || y_size != valattr->y) {
/* XXX: Throw a better exception. */
Parrot_ex_throw_from_c_args(INTERP, NULL, EXCEPTION_OUT_OF_BOUNDS,
"NumMatrix2D: Matrix dimensions must match in add.");
}

dest = VTABLE_clone(INTERP, value);
destattr = (Parrot_NumMatrix2D_attributes *) PARROT_NUMMATRIX2D(dest);
if ((IS_TRANSPOSED(selfattr->flags) && ! IS_TRANSPOSED(valattr->flags))
|| (IS_TRANSPOSED(valattr->flags) && ! IS_TRANSPOSED(selfattr->flags))) {
FLOATVAL *sstor = selfattr->storage,
*vstor = valattr->storage;
FLOATVAL *dstor = NULL;

dest = pmc_new(interp, VTABLE_type(interp, pmc));
resize_matrix(interp, dest, x_size - 1, y_size - 1);
destattr = (Parrot_NumMatrix2D_attributes *) PARROT_NUMMATRIX2D(dest);
dstor = destattr->storage;

if (IS_TRANSPOSED(selfattr->flags)) {
for (i = 0; i < x_size; ++i) {
for (j = 0; j < y_size; ++j) {
ITEM_XY_ROWMAJOR(dstor, x_size, y_size, i, j) =
ITEM_XY_COLMAJOR(sstor, x_size, y_size, i, j)
+ ITEM_XY_ROWMAJOR(vstor, x_size, y_size, i, j);
}
}
}
else {
for (i = 0; i < x_size; ++i) {
for (j = 0; j < y_size; ++j) {
ITEM_XY_ROWMAJOR(dstor, x_size, y_size, i, j) =
ITEM_XY_ROWMAJOR(sstor, x_size, y_size, i, j)
+ ITEM_XY_COLMAJOR(vstor, x_size, y_size, i, j);
}
}
}
}
else {
dest = VTABLE_clone(INTERP, value);
destattr = (Parrot_NumMatrix2D_attributes *) PARROT_NUMMATRIX2D(dest);

cblas_daxpy(x_size*y_size, 1, selfattr->storage, 1, destattr->storage, 1);
cblas_daxpy(x_size*y_size, 1, selfattr->storage, 1, destattr->storage, 1);
}

return dest;
}

MULTI PMC *multiply(NumMatrix2D *value, PMC *dest) {
INTVAL x_size = 0, y_size = 0;
INTVAL x_size = 0, y_size = 0, sflags = 0, vflags = 0;

Parrot_NumMatrix2D_attributes * const selfattr =
(Parrot_NumMatrix2D_attributes *) PARROT_NUMMATRIX2D(SELF);
Parrot_NumMatrix2D_attributes * const valattr =
Expand All @@ -275,6 +303,8 @@ pmclass NumMatrix2D dynpmc auto_attrs {

x_size = selfattr->x;
y_size = valattr->y;
sflags = selfattr->flags;
vflags = valattr->flags;

if (selfattr->y != valattr->x) {
Parrot_ex_throw_from_c_args(INTERP, NULL, EXCEPTION_OUT_OF_BOUNDS,
Expand All @@ -285,22 +315,98 @@ pmclass NumMatrix2D dynpmc auto_attrs {
resize_matrix(INTERP, dest, x_size - 1, y_size - 1);
destattr = (Parrot_NumMatrix2D_attributes *) PARROT_NUMMATRIX2D(dest);

cblas_dgemm(CblasRowMajor,
IS_TRANSPOSED_BLAS(selfattr->flags),
IS_TRANSPOSED_BLAS(valattr->flags),
x_size,
selfattr->y,
y_size,
1.0,
selfattr->storage,
(IS_TRANSPOSED(selfattr->flags) ? y_size : x_size ),
valattr->storage,
(IS_TRANSPOSED(valattr->flags) ? valattr->y : y_size ),
1.0,
destattr->storage,
x_size
);
if (IS_TINY(sflags) || (IS_GENERAL(sflags) && IS_GENERAL(vflags))) {
cblas_dgemm(CblasRowMajor,
IS_TRANSPOSED_BLAS(selfattr->flags),
IS_TRANSPOSED_BLAS(valattr->flags),
x_size,
selfattr->y,
y_size,
1.,
selfattr->storage,
x_size,
valattr->storage,
y_size,
0.,
destattr->storage,
x_size
);
}
else if (IS_SYMMETRIC(sflags)) {
cblas_dsymm(
CblasRowMajor,
(IS_TRANSPOSED(vflags) ? CblasRight : CblasLeft),
CblasUpper,
x_size,
y_size,
1.,
selfattr->storage,
x_size,
valattr->storage,
valattr->x,
1.,
destattr->storage,
x_size
);

if (IS_TRANSPOSED(vflags)) {
/* TODO: Transpose matrix */
}
}
else if (IS_SYMMETRIC(vflags)) {
cblas_dsymm(
CblasRowMajor,
(IS_TRANSPOSED(sflags) ? CblasLeft : CblasRight),
CblasUpper,
x_size,
y_size,
1.,
valattr->storage,
valattr->x,
selfattr->storage,
selfattr->x,
1.,
destattr->storage,
x_size
);

if (IS_TRANSPOSED(sflags)) {
/* TODO: Transpose matrix */
}
}
/* else if (IS_TRIANGLE(sflags)) {
cblas_dtrmm(
CblasRowMajor,
CblasLeft,
(IS_LTRIANGLE(sflags) ? CblasLower : CblasUpper),
x_size,
y_size,
1.,
selfattr->storage,
selfattr->x,
valattr->storage,
valattr->x,
0.,
destattr->storage,
x_size
);

}
else if (IS_TRIANGLE(vflags)) {
cblas_dtrmm(
CblasRowMajor,
CblasRight,
(IS_LTRIANGLE(vflags) ? CblasLower : CblasUpper),

}
*/
else {
Parrot_ex_throw_from_c_args(INTERP, NULL, EXCEPTION_UNIMPLEMENTED,
"parrot-linear-algebra: Method multiply not implemented for "
"this combination of flags.");
}

destattr->flags = sflags & vflags;
return dest;
}

Expand Down Expand Up @@ -353,12 +459,12 @@ pmclass NumMatrix2D dynpmc auto_attrs {
if (other->vtable->base_type == SELF->vtable->base_type) {
Parrot_NumMatrix2D_attributes * const self_attrs = PARROT_NUMMATRIX2D(SELF);
Parrot_NumMatrix2D_attributes * const other_attrs = PARROT_NUMMATRIX2D(other);
const INTVAL self_is_transposed = IS_TRANSPOSED(self_attrs->flags);
const INTVAL self_x = self_attrs->x;
const INTVAL self_y = self_attrs->y;
const INTVAL other_is_transposed = IS_TRANSPOSED(other_attrs->flags);
const INTVAL self_flags = self_attrs->flags;
const INTVAL other_x = other_attrs->x;
const INTVAL other_y = other_attrs->y;
const INTVAL other_flags = other_attrs->flags;
FLOATVAL * const self_s = self_attrs->storage;
FLOATVAL * const other_s = other_attrs->storage;
INTVAL x, y;
Expand All @@ -368,12 +474,10 @@ pmclass NumMatrix2D dynpmc auto_attrs {

for (y = 0; y < self_y; y++) {
for (x = 0; x < self_x; x++) {
const FLOATVAL self_value = self_is_transposed ?
ITEM_XY_COLMAJOR(self_s, self_x, self_y, x, y) :
ITEM_XY_ROWMAJOR(self_s, self_x, self_y, x, y);
const FLOATVAL other_value = other_is_transposed ?
ITEM_XY_COLMAJOR(other_s, other_x, other_y, x, y) :
ITEM_XY_ROWMAJOR(other_s, other_x, other_y, x, y);
const FLOATVAL self_value =
ITEM_XY(self_s, self_flags, self_x, self_y, x, y);
const FLOATVAL other_value =
ITEM_XY(other_s, other_flags, other_x, other_y, x, y);
if (self_value != other_value)
return 0;
}
Expand Down
36 changes: 31 additions & 5 deletions src/pmc/pla_matrix_types.h
Expand Up @@ -16,17 +16,43 @@ do { \
#define INDEX_XY_COLMAJOR(x_max, y_max, x, y) \
(((x_max) * (y)) + (x))

#define INDEX_XY(flags, x_max, y_max, x, y) \
(((IS_TRANSPOSED(flags)) ? INDEX_XY_COLMAJOR(x_max, y_max, x, y) : \
INDEX_XY_ROWMAJOR(x_max, y_max, x, y)))

#define ITEM_XY_ROWMAJOR(s, x_max, y_max, x, y) \
(s)[((y_max) * (x)) + (y)]
(s)[INDEX_XY_ROWMAJOR(x_max, y_max, x, x)]

#define ITEM_XY_COLMAJOR(s, x_max, y_max, x, y) \
(s)[((x_max) * (y)) + (x)]
(s)[INDEX_XY_COLMAJOR(x_max, y_max, x, y)]

#define ITEM_XY(s, flags, x_max, y_max, x, y) \
(s)[INDEX_XY(flags, x_max, y_max, x, y)]

#define INDEX_MIN(a, b) (((a) <= (b))?(a):(b))
#define INDEX_MAX(a, b) (((a) >= (b))?(a):(b))

#define FLAG_TRANSPOSED 1

#define IS_TRANSPOSED(flags) (((flags) & (FLAG_TRANSPOSED)) != 0)
#define FLAG_TRANSPOSED 1
#define FLAG_SYMMETRIC 2
#define FLAG_HERMITIAN 4
#define FLAG_UTRIANGLE 8
#define FLAG_LTRIANGLE 16
#define FLAG_TRIANGLE (FLAG_UTRIANGLE) | (FLAG_LTRIANGLE)
#define FLAG_TRIDIAGONAL 32
#define FLAG_TINY 64
#define FLAG_DIAGONAL (FLAG_SYMMETRIC) | (FLAG_HERMITIAN) | (FLAG_LTRIANGLE) |\
(FLAG_UTRIANGLE) | (FLAG_TRIDIAGONAL)

#define IS_GENERAL(flags) ((! (flags)))
#define IS_TINY(flags) (((flags) & (FLAG_TINY)))
#define IS_SYMMETRIC(flags) (((flags) & (FLAG_SYMMETRIC)))
#define IS_HERMITIAN(flags) (((flags) & (FLAG_HERMITIAN)))
#define IS_UTRIANGLE(flags) (((flags) & (FLAG_UTRIANGLE)))
#define IS_LTRIANGLE(flags) (((flags) & (FLAG_LTRIANGLE)))
#define IS_TRIANGLE(flags) (((flags) & (FLAG_TRIANGLE)))
#define IS_DIAGONAL(flags) ((((flags) & (FLAG_DIAGONAL)) == FLAG_DIAGONAL))
#define IS_TRIDIAGONAL(flags) (((flags) & (FLAG_TRIDIAGONAL)))
#define IS_TRANSPOSED(flags) (((flags) & (FLAG_TRANSPOSED)))
#define IS_TRANSPOSED_BLAS(flags) (IS_TRANSPOSED(flags) ? CblasTrans : CblasNoTrans)

#endif

0 comments on commit 67a86d5

Please sign in to comment.