diff --git a/.gitignore b/.gitignore index 7c7065d..4ba28c4 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,6 @@ bench/cuda/matrix_vector_mult tests/all tests/cublas *.ndb -nimsuggest.log \ No newline at end of file +nimsuggest.log +bin/ +.DS_Store \ No newline at end of file diff --git a/linalg/private/ops.nim b/linalg/private/ops.nim index 82190d2..fdabaa1 100644 --- a/linalg/private/ops.nim +++ b/linalg/private/ops.nim @@ -166,10 +166,7 @@ template maxIndexPrivate(N, v: untyped): auto = m = val (j, m) -proc maxIndex*[N: static[int]](v: Vector32[N]): tuple[i: int, val: float32] = - maxIndexPrivate(N, v) - -proc maxIndex*[N: static[int]](v: Vector64[N]): tuple[i: int, val: float64] = +proc maxIndex*[N: static[int], T](v: Vector[N, T]): tuple[i: int, val: T] = maxIndexPrivate(N, v) proc maxIndex*(v: DVector32): tuple[i: int, val: float32] = @@ -190,10 +187,7 @@ template minIndexPrivate(N, v: untyped): auto = m = val return (j, m) -proc minIndex*[N: static[int]](v: Vector32[N]): tuple[i: int, val: float32] = - minIndexPrivate(N, v) - -proc minIndex*[N: static[int]](v: Vector64[N]): tuple[i: int, val: float64] = +proc minIndex*[N: static[int], T](v: Vector[N, T]): tuple[i: int, val: T] = minIndexPrivate(N, v) proc minIndex*(v: DVector32): tuple[i: int, val: float32] = diff --git a/linalg/private/types.nim b/linalg/private/types.nim index b65553e..5af3c6f 100644 --- a/linalg/private/types.nim +++ b/linalg/private/types.nim @@ -13,35 +13,31 @@ # limitations under the License. type - Vector32*[N: static[int]] = ref array[N, float32] - Vector64*[N: static[int]] = ref array[N, float64] - Matrix32*[M, N: static[int]] = object + Vector*[N: static[int], T: SomeReal] = ref array[N, T] + Matrix*[M, N: static[int], T: SomeReal] = object order: OrderType - data: ref array[N * M, float32] - Matrix64*[M, N: static[int]] = object - order: OrderType - data: ref array[M * N, float64] - DVector32* = seq[float32] - DVector64* = seq[float64] - DMatrix32* = ref object - order: OrderType - M, N: int - data: seq[float32] - DMatrix64* = ref object + data: ref array[N * M, T] + DVector*[T: SomeReal] = seq[T] + DMatrix*[T: SomeReal] = ref object order: OrderType M, N: int - data: seq[float64] - AnyVector = Vector32 or Vector64 or DVector32 or DVector64 - AnyMatrix = Matrix32 or Matrix64 or DMatrix32 or DMatrix64 + data: seq[T] + + Vector32*[N: static[int]] = Vector[N, float32] + Vector64*[N: static[int]] = Vector[N, float64] + Matrix32*[M, N: static[int]] = Matrix[M, N, float32] + Matrix64*[M, N: static[int]] = Matrix[M, N, float64] + DVector32* = DVector[float32] + DVector64* = DVector[float64] + DMatrix32* = DMatrix[float32] + DMatrix64* = DMatrix[float64] + AnyVector = Vector32 or Vector64 or DVector32 or DVector64 or Vector + AnyMatrix = Matrix32 or Matrix64 or DMatrix32 or DMatrix64 or Matrix # Float pointers -template fp(v: Vector32): ptr float32 = cast[ptr float32](addr(v[])) - -template fp(v: Vector64): ptr float64 = cast[ptr float64](addr(v[])) - -template fp(m: Matrix32): ptr float32 = cast[ptr float32](addr(m.data[])) +template fp[N,T](v: Vector[N,T]): ptr T = cast[ptr T](addr(v[])) -template fp(m: Matrix64): ptr float64 = cast[ptr float64](addr(m.data[])) +template fp[M,N,T](m: Matrix[M,N,T]): ptr T = cast[ptr T](addr(m.data[])) template fp(v: DVector32): ptr float32 = cast[ptr float32](unsafeAddr(v[0]))