/
madd.go
43 lines (39 loc) · 1.32 KB
/
madd.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
package cuda
import (
"github.com/mumax/3/data"
"github.com/mumax/3/util"
)
// multiply: dst[i] = a[i] * b[i]
func Mul(dst, a, b *data.Slice) {
N := dst.Len()
nComp := dst.NComp()
util.Assert(a.Len() == N && a.NComp() == nComp && b.Len() == N && b.NComp() == nComp)
cfg := make1DConf(N)
for c := 0; c < nComp; c++ {
k_mul_async(dst.DevPtr(c), a.DevPtr(c), b.DevPtr(c), N, cfg)
}
}
// multiply-add: dst[i] = src1[i] * factor1 + src2[i] * factor2
func Madd2(dst, src1, src2 *data.Slice, factor1, factor2 float32) {
N := dst.Len()
nComp := dst.NComp()
util.Assert(src1.Len() == N && src2.Len() == N)
util.Assert(src1.NComp() == nComp && src2.NComp() == nComp)
cfg := make1DConf(N)
for c := 0; c < nComp; c++ {
k_madd2_async(dst.DevPtr(c), src1.DevPtr(c), factor1,
src2.DevPtr(c), factor2, N, cfg)
}
}
// multiply-add: dst[i] = src1[i] * factor1 + src2[i] * factor2 + src3 * factor3
func Madd3(dst, src1, src2, src3 *data.Slice, factor1, factor2, factor3 float32) {
N := dst.Len()
nComp := dst.NComp()
util.Assert(src1.Len() == N && src2.Len() == N && src3.Len() == N)
util.Assert(src1.NComp() == nComp && src2.NComp() == nComp && src3.NComp() == nComp)
cfg := make1DConf(N)
for c := 0; c < nComp; c++ {
k_madd3_async(dst.DevPtr(c), src1.DevPtr(c), factor1,
src2.DevPtr(c), factor2, src3.DevPtr(c), factor3, N, cfg)
}
}