/
pla_blas.c
102 lines (97 loc) · 1.95 KB
/
pla_blas.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include "pla.h"
/* Wrapper to call the dgemm function from BLAS with PMC arguments. Assumes
A, B, and C are all NumMatrix2D. */
void
call_dgemm(FLOATVAL alpha,
INTVAL flags_a, FLOATVAL * A, INTVAL rows_a, INTVAL cols_a,
INTVAL flags_b, FLOATVAL * B, INTVAL cols_b,
FLOATVAL beta, FLOATVAL * C)
{
const INTVAL M = rows_a;
const INTVAL N = cols_b;
const INTVAL K = cols_a;
#ifdef PLA_HAVE_CBLAS
cblas_dgemm(CblasRowMajor,
IS_TRANSPOSED_BLAS(flags_a),
IS_TRANSPOSED_BLAS(flags_b),
M,
N,
K,
alpha,
A,
M,
B,
N,
beta,
C,
M
);
#else
dgemm_(
IS_TRANSPOSED_BLAS(flags_a),
IS_TRANSPOSED_BLAS(flags_b),
&M,
&N,
&K,
&alpha,
A,
&M,
B,
&N,
&beta,
C,
&M
);
#endif
}
/* Wrapper to call the zdgemm function from BLAS with PMC arguments. Assumes
A, B, and C are all ComplexMatrix2D. */
void
call_zgemm(FLOATVAL alpha_r, FLOATVAL alpha_i,
INTVAL flags_a, FLOATVAL * A, INTVAL rows_a, INTVAL cols_a,
INTVAL flags_b, FLOATVAL * B, INTVAL cols_b,
FLOATVAL beta_r, FLOATVAL beta_i, FLOATVAL * C)
{
const INTVAL M = rows_a;
const INTVAL N = cols_b;
const INTVAL K = cols_a;
FLOATVAL alpha_p[2];
FLOATVAL beta_p[2];
alpha_p[0] = alpha_r;
alpha_p[1] = alpha_i;
beta_p[0] = beta_r;
beta_p[1] = beta_i;
#ifdef PLA_HAVE_CBLAS
cblas_zgemm(CblasRowMajor,
IS_TRANSPOSED_BLAS(flags_a),
IS_TRANSPOSED_BLAS(flags_b),
M,
N,
K,
alpha_p,
A,
M,
B,
N,
beta_p,
C,
M
);
#else
zgemm_(
IS_TRANSPOSED_BLAS(flags_a),
IS_TRANSPOSED_BLAS(flags_b),
&M,
&N,
&K,
alpha_p,
A,
&M,
B,
&N,
beta_p,
C,
&M
);
#endif
}