diff --git a/src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs b/src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs index b66c00aa..3865dba9 100644 --- a/src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs +++ b/src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs @@ -7,6 +7,35 @@ open Microsoft.FSharp.Quotations open GraphBLAS.FSharp.Backend.Predefined module DenseVector = + let containsNonZero<'a when 'a: struct> (clContext: ClContext) (workGroupSize: int) = + + let containsNonZero = + <@ fun (ndRange: Range1D) length (vector: ClArray<'a option>) (result: ClCell) -> + + let gid = ndRange.GlobalID0 + + if gid < length then + match vector.[gid] with + | Some _ -> result.Value <- true + | _ -> () @> + + let kernel = clContext.Compile containsNonZero + + fun (processor: MailboxProcessor<_>) (vector: ClArray<'a option>) -> + + let result = clContext.CreateClCell false + + let ndRange = + Range1D.CreateValid(vector.Length, workGroupSize) + + let kernel = kernel.GetKernel() + + processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange vector.Length vector result)) + + processor.Post(Msg.CreateRunMsg<_, _>(kernel)) + + result + let elementWise<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct> (clContext: ClContext) (opAdd: Expr<'a option -> 'b option -> 'c option>) diff --git a/tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj b/tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj index 3e5dca50..43f08bb3 100644 --- a/tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj +++ b/tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj @@ -37,6 +37,7 @@ + diff --git a/tests/GraphBLAS-sharp.Tests/Program.fs b/tests/GraphBLAS-sharp.Tests/Program.fs index 429360b5..db9f2134 100644 --- a/tests/GraphBLAS-sharp.Tests/Program.fs +++ b/tests/GraphBLAS-sharp.Tests/Program.fs @@ -40,7 +40,8 @@ let allTests = Backend.Vector.ElementWise.mulTests Backend.Vector.FillSubVector.tests Backend.Vector.FillSubVector.complementedTests - Backend.Vector.Reduce.tests ] + Backend.Vector.Reduce.tests + Backend.Vector.ContainNonZero.tests ] |> testSequenced [] diff --git a/tests/GraphBLAS-sharp.Tests/Vector/ContainsNonZero.fs b/tests/GraphBLAS-sharp.Tests/Vector/ContainsNonZero.fs new file mode 100644 index 00000000..122304f7 --- /dev/null +++ b/tests/GraphBLAS-sharp.Tests/Vector/ContainsNonZero.fs @@ -0,0 +1,96 @@ +module Backend.Vector.ContainNonZero + +open Expecto +open Expecto.Logging +open GraphBLAS.FSharp.Backend +open GraphBLAS.FSharp.Tests +open GraphBLAS.FSharp.Tests.Utils +open Context +open Brahma.FSharp + +let logger = + Log.create "Vector.containsNonZero.Tests" + +let context = defaultContext.ClContext + +let q = defaultContext.Queue + +let correctnessGenericTest<'a when 'a: struct and 'a: equality> isZero containsNonZero (array: 'a []) = + + if array.Length > 0 then + let vector = createVectorFromArray Dense array isZero + + let result = + match vector.ToDevice context with + | ClVectorDense clArray -> + let resultCell = containsNonZero q clArray + let result = Array.zeroCreate 1 + + let res = + q.PostAndReply(fun ch -> Msg.CreateToHostMsg<_>(resultCell, result, ch)) + + q.Post(Msg.CreateFreeMsg<_>(resultCell)) + + res.[0] + + $"The results should be the same, vector : {vector}" + |> Expect.equal result (Array.exists (not << isZero) array) + +let testFixtures = + let config = defaultConfig + + let wgSize = 32 + + let getCorrectnessTestName datatype = + sprintf "Correctness on %s, %A" datatype Dense + + [ let containsNonZeroInt = + DenseVector.DenseVector.containsNonZero context wgSize + + correctnessGenericTest ((=) 0) containsNonZeroInt + |> testPropertyWithConfig config (getCorrectnessTestName "int") + + let containsNonZeroByte = + DenseVector.DenseVector.containsNonZero context wgSize + + correctnessGenericTest ((=) 0uy) containsNonZeroByte + |> testPropertyWithConfig config (getCorrectnessTestName "byte") + + let containsNonZeroFloat = + DenseVector.DenseVector.containsNonZero context wgSize + + correctnessGenericTest ((=) 0.0) containsNonZeroFloat + |> testPropertyWithConfig config (getCorrectnessTestName "float") + + let containsNonZeroBool = + DenseVector.DenseVector.containsNonZero context wgSize + + correctnessGenericTest ((=) false) containsNonZeroBool + |> testPropertyWithConfig config (getCorrectnessTestName "bool") + + let containsNonZeroInt = + DenseVector.DenseVector.containsNonZero context wgSize + + correctnessGenericTest ((=) 0) containsNonZeroInt (Array.create 1000 0) + |> testPropertyWithConfig config (getCorrectnessTestName "int zeros") + + let containsNonZeroByte = + DenseVector.DenseVector.containsNonZero context wgSize + + correctnessGenericTest ((=) 0uy) containsNonZeroByte (Array.create 1000 0uy) + |> testPropertyWithConfig config (getCorrectnessTestName "byte zeros") + + let containsNonZeroFloat = + DenseVector.DenseVector.containsNonZero context wgSize + + correctnessGenericTest ((=) 0.0) containsNonZeroFloat (Array.create 1000 0.0) + |> testPropertyWithConfig config (getCorrectnessTestName "float zeros") + + let containsNonZeroBool = + DenseVector.DenseVector.containsNonZero context wgSize + + correctnessGenericTest ((=) false) containsNonZeroBool (Array.create 1000 false) + |> testPropertyWithConfig config (getCorrectnessTestName "bool zeros") ] + +let tests = + testList "Backend.Vector.containsNonZero tests" testFixtures