Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pygpu/blas.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ cdef extern from "gpuarray/buffer_blas.h":
cb_conj_trans

cdef extern from "gpuarray/blas.h":
int GpuArray_rdot(_GpuArray *X, _GpuArray *Y, _GpuArray *Z, int nocopy)
int GpuArray_rgemv(cb_transpose transA, double alpha, _GpuArray *A,
_GpuArray *X, double beta, _GpuArray *Y, int nocopy)
int GpuArray_rgemm(cb_transpose transA, cb_transpose transB,
Expand All @@ -18,6 +19,13 @@ cdef extern from "gpuarray/blas.h":
int GpuArray_rger(double alpha, _GpuArray *X, _GpuArray *Y, _GpuArray *A,
int nocopy)

cdef api int pygpu_blas_rdot(GpuArray X, GpuArray Y, GpuArray Z, bint nocopy) except -1:
cdef int err
err = GpuArray_rdot(&X.ga, &Y.ga, &Z.ga, nocopy)
if err != GA_NO_ERROR:
raise GpuArrayException(GpuArray_error(&X.ga, err), err)
return 0

cdef api int pygpu_blas_rgemv(cb_transpose transA, double alpha, GpuArray A,
GpuArray X, double beta, GpuArray Y,
bint nocopy) except -1:
Expand Down Expand Up @@ -45,6 +53,16 @@ cdef api int pygpu_blas_rger(double alpha, GpuArray X, GpuArray Y, GpuArray A,
return 0


def dot(GpuArray X, GpuArray Y, GpuArray Z=None, overwrite_z=False):
if Z is None:
Z = pygpu_empty(0, NULL, X.typecode, GA_ANY_ORDER, X.context, None)
overwrite_z = True

if not overwrite_z:
Z = pygpu_copy(Z, GA_ANY_ORDER)
pygpu_blas_rdot(X, Y, Z, 0)
return Z

def gemv(double alpha, GpuArray A, GpuArray X, double beta=0.0,
GpuArray Y=None, trans_a=False, overwrite_y=False):
cdef cb_transpose transA
Expand Down
112 changes: 62 additions & 50 deletions pygpu/tests/test_blas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy
from itertools import product
import numpy
from nose.plugins.skip import SkipTest

from .support import (guard_devsup, gen_gpuarray, context)
Expand All @@ -14,25 +15,48 @@

import pygpu.blas as gblas

def test_dot():
bools = [True, False]
for N, dtype, offseted_i, sliced in product(
[1, 256, 1337], ['float32', 'float64'], bools, bools):
yield dot, N, dtype, offseted_i, sliced, True, False
for overwrite, init_z in product(bools, bools):
yield dot, 666, 'float32', False, False, overwrite, init_z

@guard_devsup
def dot(N, dtype, offseted_i, sliced, overwrite, init_z):
cX, gX = gen_gpuarray((N,), dtype, offseted_inner=offseted_i,
sliced=sliced, ctx=context)
cY, gY = gen_gpuarray((N,), dtype, offseted_inner=offseted_i,
sliced=sliced, ctx=context)
if init_z:
_, gZ = gen_gpuarray((), dtype, offseted_inner=offseted_i,
sliced=sliced, ctx=context)
else:
_, gZ = None, None

if dtype == 'float32':
cr = fblas.sdot(cX, cY)
else:
cr = fblas.ddot(cX, cY)
gr = gblas.dot(gX, gY, gZ, overwrite_z=overwrite)
numpy.testing.assert_allclose(cr, numpy.asarray(gr), rtol=1e-6)


def test_gemv():
for shape in [(100, 128), (128, 50)]:
for order in ['f', 'c']:
for trans in [False, True]:
for offseted_i in [True, False]:
for sliced in [1, 2, -1, -2]:
yield gemv, shape, 'float32', order, trans, \
offseted_i, sliced, True, False
for overwrite in [True, False]:
for init_y in [True, False]:
yield gemv, (4, 3), 'float32', 'f', False, False, 1, \
overwrite, init_y
bools = [False, True]
for shape, order, trans, offseted_i, sliced in product(
[(100, 128), (128, 50)], 'fc', bools, bools, [1, 2, -1, -2]):
yield gemv, shape, 'float32', order, trans, \
offseted_i, sliced, True, False
for overwrite, init_y in product(bools, bools):
yield gemv, (4, 3), 'float32', 'f', False, False, 1, \
overwrite, init_y
yield gemv, (32, 32), 'float64', 'f', False, False, 1, True, False
for alpha in [0, 1, -1, 0.6]:
for beta in [0, 1, -1, 0.6]:
for overwite in [True, False]:
yield gemv, (32, 32), 'float32', 'f', False, False, 1, \
overwrite, True, alpha, beta

for alpha, beta, overwrite in product(
[0, 1, -1, 0.6], [0, 1, -1, 0.6], bools):
yield gemv, (32, 32), 'float32', 'f', False, False, 1, \
overwrite, True, alpha, beta

@guard_devsup
def gemv(shp, dtype, order, trans, offseted_i, sliced,
Expand Down Expand Up @@ -65,28 +89,22 @@ def gemv(shp, dtype, order, trans, offseted_i, sliced,


def test_gemm():
for m, n, k in [(48, 15, 32), (15, 32, 48)]:
for order in [('f', 'f', 'f'), ('c', 'c', 'c'),
('f', 'f', 'c'), ('f', 'c', 'f'),
('f', 'c', 'c'), ('c', 'f', 'f'),
('c', 'f', 'c'), ('c', 'c', 'f')]:
for trans in [(False, False), (True, True),
(False, True), (True, False)]:
for offseted_o in [False, True]:
yield gemm, m, n, k, 'float32', order, trans, \
offseted_o, 1, False, False
for sliced in [1, 2, -1, -2]:
for overwrite in [True, False]:
for init_res in [True, False]:
yield gemm, 4, 3, 2, 'float32', ('f', 'f', 'f'), \
(False, False), False, sliced, overwrite, init_res
bools = [False, True]
for (m, n, k), order, trans, offseted_o in product(
[(48, 15, 32), (15, 32, 48)], list(product(*['fc']*3)),
list(product(bools, bools)), bools):
yield gemm, m, n, k, 'float32', order, trans, \
offseted_o, 1, False, False
for sliced, overwrite, init_res in product(
[1, 2, -1, -2], bools, bools):
yield gemm, 4, 3, 2, 'float32', ('f', 'f', 'f'), \
(False, False), False, sliced, overwrite, init_res
yield gemm, 32, 32, 32, 'float64', ('f', 'f', 'f'), (False, False), \
False, 1, False, False
for alpha in [0, 1, -1, 0.6]:
for beta in [0, 1, -1, 0.6]:
for overwrite in [True, False]:
yield gemm, 32, 23, 32, 'float32', ('f', 'f', 'f'), \
(False, False), False, 1, overwrite, True, alpha, beta
for alpha, beta, overwrite in product(
[0, 1, -1, 0.6], [0, 1, -1, 0.6], bools):
yield gemm, 32, 23, 32, 'float32', ('f', 'f', 'f'), \
(False, False), False, 1, overwrite, True, alpha, beta

@guard_devsup
def gemm(m, n, k, dtype, order, trans, offseted_o, sliced, overwrite,
Expand Down Expand Up @@ -124,19 +142,13 @@ def gemm(m, n, k, dtype, order, trans, offseted_o, sliced, overwrite,


def test_ger():
for m, n in [(4, 5)]:
for order in ['f', 'c']:
for sliced_x in [1, 2, -2, -1]:
for sliced_y in [1, 2, -2, -1]:
yield ger, m, n, 'float32', order, sliced_x, sliced_y, \
False

bools = [False, True]
for (m,n), order, sliced_x, sliced_y in product(
[(4,5)], 'fc', [1, 2, -2, -1], [1, 2, -2, -1]):
yield ger, m, n, 'float32', order, sliced_x, sliced_y, False
yield ger, 4, 5, 'float64', 'f', 1, 1, False

for init_res in [True, False]:
for overwrite in [True, False]:
yield ger, 4, 5, 'float32', 'f', 1, 1, init_res, overwrite

for init_res, overwrite in product(bools, bools):
yield ger, 4, 5, 'float32', 'f', 1, 1, init_res, overwrite

def ger(m, n, dtype, order, sliced_x, sliced_y, init_res, overwrite=False):
cX, gX = gen_gpuarray((m,), dtype, order, sliced=sliced_x, ctx=context)
Expand Down
Empty file modified setup.py
100644 → 100755
Empty file.
6 changes: 6 additions & 0 deletions src/gpuarray/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
extern "C" {
#endif

// only for vector-vector dot
GPUARRAY_PUBLIC int GpuArray_rdot( GpuArray *X, GpuArray *Y,
GpuArray *Z, int nocopy);
#define GpuArray_hdot GpuArray_rdot
#define GpuArray_sdot GpuArray_rdot
#define GpuArray_ddot GpuArray_rdot
GPUARRAY_PUBLIC int GpuArray_rgemv(cb_transpose transA, double alpha,
GpuArray *A, GpuArray *X, double beta,
GpuArray *Y, int nocopy);
Expand Down
18 changes: 18 additions & 0 deletions src/gpuarray/buffer_blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ GPUARRAY_PUBLIC void gpublas_teardown(gpucontext *ctx);

GPUARRAY_PUBLIC const char *gpublas_error(gpucontext *ctx);

GPUARRAY_PUBLIC int gpublas_hdot(
size_t N,
gpudata *X, size_t offX, size_t incX,
gpudata *Y, size_t offY, size_t incY,
gpudata *Z, size_t offZ);

GPUARRAY_PUBLIC int gpublas_sdot(
size_t N,
gpudata *X, size_t offX, size_t incX,
gpudata *Y, size_t offY, size_t incY,
gpudata *Z, size_t offZ);

GPUARRAY_PUBLIC int gpublas_ddot(
size_t N,
gpudata *X, size_t offX, size_t incX,
gpudata *Y, size_t offY, size_t incY,
gpudata *Z, size_t offZ);

GPUARRAY_PUBLIC int gpublas_hgemv(
cb_order order, cb_transpose transA, size_t M, size_t N, float alpha,
gpudata *A, size_t offA, size_t lda, gpudata *X, size_t offX, int incX,
Expand Down
88 changes: 86 additions & 2 deletions src/gpuarray_array_blas.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,91 @@
#include "gpuarray/util.h"
#include "gpuarray/error.h"

int GpuArray_rdot( GpuArray *X, GpuArray *Y,
GpuArray *Z, int nocopy) {
GpuArray *Xp = X;
GpuArray copyX;
GpuArray *Yp = Y;
GpuArray copyY;
GpuArray *Zp = Z;
size_t n;
void *ctx;
size_t elsize;
int err;

if (X->typecode != GA_HALF &&
X->typecode != GA_FLOAT &&
X->typecode != GA_DOUBLE)
return GA_INVALID_ERROR;

if (X->nd != 1 || Y->nd != 1 || Z->nd != 0 ||
X->typecode != Y->typecode || X->typecode != Z->typecode)
return GA_VALUE_ERROR;
n = X->dimensions[0];
if (!(X->flags & GA_ALIGNED) || !(Y->flags & GA_ALIGNED) ||
!(Z->flags & GA_ALIGNED))
return GA_UNALIGNED_ERROR;
if (X->dimensions[0] != Y->dimensions[0])
return GA_VALUE_ERROR;

elsize = gpuarray_get_elsize(X->typecode);
if (X->strides[0] < 0) {
if (nocopy)
return GA_COPY_ERROR;
else {
err = GpuArray_copy(&copyX, X, GA_ANY_ORDER);
if (err != GA_NO_ERROR)
goto cleanup;
Xp = &copyX;
}
}
if (Y->strides[0] < 0) {
if (nocopy)
return GA_COPY_ERROR;
else {
err = GpuArray_copy(&copyY, Y, GA_ANY_ORDER);
if (err != GA_NO_ERROR)
goto cleanup;
Yp = &copyY;
}
}

ctx = gpudata_context(Xp->data);
err = gpublas_setup(ctx);
if (err != GA_NO_ERROR)
goto cleanup;

switch (Xp->typecode) {
case GA_HALF:
err = gpublas_hdot(
n,
Xp->data, Xp->offset / elsize, Xp->strides[0] / elsize,
Yp->data, Yp->offset / elsize, Yp->strides[0] / elsize,
Zp->data, Zp->offset / elsize);
break;
case GA_FLOAT:
err = gpublas_sdot(
n,
Xp->data, Xp->offset / elsize, Xp->strides[0] / elsize,
Yp->data, Yp->offset / elsize, Yp->strides[0] / elsize,
Zp->data, Zp->offset / elsize);
break;
case GA_DOUBLE:
err = gpublas_ddot(
n,
Xp->data, Xp->offset / elsize, Xp->strides[0] / elsize,
Yp->data, Yp->offset / elsize, Yp->strides[0] / elsize,
Zp->data, Zp->offset / elsize);
break;
}
cleanup:
if (Xp == &copyX)
GpuArray_clear(&copyX);
if (Yp == &copyY)
GpuArray_clear(&copyY);
return err;
}

int GpuArray_rgemv(cb_transpose transA, double alpha, GpuArray *A,
GpuArray *X, double beta, GpuArray *Y, int nocopy) {
GpuArray *Ap = A;
Expand All @@ -24,8 +109,7 @@ int GpuArray_rgemv(cb_transpose transA, double alpha, GpuArray *A,
return GA_INVALID_ERROR;

if (A->nd != 2 || X->nd != 1 || Y->nd != 1 ||
A->typecode != A->typecode || X->typecode != A->typecode ||
Y->typecode != A->typecode)
X->typecode != A->typecode || Y->typecode != A->typecode)
return GA_VALUE_ERROR;

if (!(A->flags & GA_ALIGNED) || !(X->flags & GA_ALIGNED) ||
Expand Down
Loading