diff --git a/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs b/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs index b25cd85e..3e030589 100644 --- a/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs +++ b/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs @@ -3,6 +3,8 @@ namespace GraphBLAS.FSharp.Backend.Common open Brahma.FSharp open FSharp.Quotations open GraphBLAS.FSharp.Backend.Quotes +open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions +open GraphBLAS.FSharp.Backend.Objects.ClCell module PrefixSum = let private update (opAdd: Expr<'a -> 'a -> 'a>) (clContext: ClContext) workGroupSize = @@ -38,7 +40,7 @@ module PrefixSum = ) processor.Post(Msg.CreateRunMsg<_, _> kernel) - processor.Post(Msg.CreateFreeMsg(mirror)) + mirror.Free processor let private scanGeneral beforeLocalSumClear @@ -48,10 +50,8 @@ module PrefixSum = workGroupSize = - let subSum = SubSum.treeSum opAdd - let scan = - <@ fun (ndRange: Range1D) inputArrayLength verticesLength (resultBuffer: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell) -> + <@ fun (ndRange: Range1D) inputArrayLength verticesLength (inputArray: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell) -> let mirror = mirror.Value @@ -62,46 +62,34 @@ module PrefixSum = if mirror then i <- inputArrayLength - 1 - i - let localID = ndRange.LocalID0 + let lid = ndRange.LocalID0 let zero = zero.Value if gid < inputArrayLength then - resultLocalBuffer.[localID] <- resultBuffer.[i] + resultLocalBuffer.[lid] <- inputArray.[i] else - resultLocalBuffer.[localID] <- zero + resultLocalBuffer.[lid] <- zero barrierLocal () - (%subSum) workGroupSize localID resultLocalBuffer - - if localID = workGroupSize - 1 then - if verticesLength <= 1 && localID = gid then - totalSumBuffer.Value <- resultLocalBuffer.[localID] - - verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[localID] - (%beforeLocalSumClear) resultBuffer resultLocalBuffer.[localID] inputArrayLength gid i - resultLocalBuffer.[localID] <- zero + // Local tree reduce + (%SubSum.upSweep opAdd) workGroupSize lid resultLocalBuffer - let mutable step = workGroupSize + if lid = workGroupSize - 1 then + // if last iteration + if verticesLength <= 1 && lid = gid then + totalSumBuffer.Value <- resultLocalBuffer.[lid] - while step > 1 do - barrierLocal () + verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[lid] + (%beforeLocalSumClear) inputArray resultLocalBuffer.[lid] inputArrayLength gid i + resultLocalBuffer.[lid] <- zero - if localID < workGroupSize / step then - let i = step * (localID + 1) - 1 - let j = i - (step >>> 1) - - let tmp = resultLocalBuffer.[i] - let buff = (%opAdd) tmp resultLocalBuffer.[j] - resultLocalBuffer.[i] <- buff - resultLocalBuffer.[j] <- tmp - - step <- step >>> 1 + (%SubSum.downSweep opAdd) workGroupSize lid resultLocalBuffer barrierLocal () - (%writeData) resultBuffer resultLocalBuffer inputArrayLength workGroupSize gid i localID @> + (%writeData) inputArray resultLocalBuffer inputArrayLength workGroupSize gid i lid @> let program = clContext.Compile(scan) @@ -132,13 +120,14 @@ module PrefixSum = ) processor.Post(Msg.CreateRunMsg<_, _> kernel) - processor.Post(Msg.CreateFreeMsg(zero)) - processor.Post(Msg.CreateFreeMsg(mirror)) + + zero.Free processor + mirror.Free processor let private scanExclusive<'a when 'a: struct> = scanGeneral <@ fun (_: ClArray<'a>) (_: 'a) (_: int) (_: int) (_: int) -> () @> - <@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (smth: int) (gid: int) (i: int) (localID: int) -> + <@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (_: int) (gid: int) (i: int) (localID: int) -> if gid < inputArrayLength then resultBuffer.[i] <- resultLocalBuffer.[localID] @> @@ -206,8 +195,8 @@ module PrefixSum = verticesArrays <- swap verticesArrays verticesLength <- (verticesLength - 1) / workGroupSize + 1 - processor.Post(Msg.CreateFreeMsg(firstVertices)) - processor.Post(Msg.CreateFreeMsg(secondVertices)) + firstVertices.Free processor + secondVertices.Free processor totalSum @@ -226,7 +215,7 @@ module PrefixSum = /// /// let arr = [| 1; 1; 1; 1 |] /// let sum = [| 0 |] - /// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0 + /// runExcludeInplace clContext workGroupSize processor arr sum (+) 0 /// |> ignore /// ... /// > val arr = [| 0; 1; 2; 3 |] @@ -252,7 +241,7 @@ module PrefixSum = /// /// let arr = [| 1; 1; 1; 1 |] /// let sum = [| 0 |] - /// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0 + /// runExcludeInplace clContext workGroupSize processor arr sum (+) 0 /// |> ignore /// ... /// > val arr = [| 1; 2; 3; 4 |] @@ -270,3 +259,73 @@ module PrefixSum = fun (processor: MailboxProcessor<_>) (inputArray: ClArray) -> scan processor inputArray 0 + + module ByKey = + let private sequentialSegments opWrite (clContext: ClContext) workGroupSize opAdd zero = + + let kernel = + <@ fun (ndRange: Range1D) lenght uniqueKeysCount (values: ClArray<'a>) (keys: ClArray) (offsets: ClArray) -> + let gid = ndRange.GlobalID0 + + if gid < uniqueKeysCount then + let sourcePosition = offsets.[gid] + let sourceKey = keys.[sourcePosition] + + let mutable currentSum = zero + let mutable previousSum = zero + + let mutable currentPosition = sourcePosition + + while currentPosition < lenght + && keys.[currentPosition] = sourceKey do + + previousSum <- currentSum + currentSum <- (%opAdd) currentSum values.[currentPosition] + + values.[currentPosition] <- (%opWrite) previousSum currentSum + + currentPosition <- currentPosition + 1 @> + + let kernel = clContext.Compile kernel + + fun (processor: MailboxProcessor<_>) uniqueKeysCount (values: ClArray<'a>) (keys: ClArray) (offsets: ClArray) -> + + let kernel = kernel.GetKernel() + + let ndRange = + Range1D.CreateValid(values.Length, workGroupSize) + + processor.Post( + Msg.MsgSetArguments + (fun () -> kernel.KernelFunc ndRange values.Length uniqueKeysCount values keys offsets) + ) + + processor.Post(Msg.CreateRunMsg<_, _> kernel) + + /// + /// Exclude scan by key. + /// + /// + /// + /// let arr = [| 1; 1; 1; 1; 1; 1|] + /// let keys = [| 1; 2; 2; 2; 3; 3 |] + /// ... + /// > val result = [| 0; 0; 1; 2; 0; 1 |] + /// + /// + let sequentialExclude clContext = + sequentialSegments (Map.fst ()) clContext + + /// + /// Include scan by key. + /// + /// + /// + /// let arr = [| 1; 1; 1; 1; 1; 1|] + /// let keys = [| 1; 2; 2; 2; 3; 3 |] + /// ... + /// > val result = [| 1; 1; 2; 3; 1; 2 |] + /// + /// + let sequentialInclude clContext = + sequentialSegments (Map.snd ()) clContext diff --git a/src/GraphBLAS-sharp.Backend/Quotes/Map.fs b/src/GraphBLAS-sharp.Backend/Quotes/Map.fs index 2ec988d5..58ad1026 100644 --- a/src/GraphBLAS-sharp.Backend/Quotes/Map.fs +++ b/src/GraphBLAS-sharp.Backend/Quotes/Map.fs @@ -21,3 +21,7 @@ module Map = match (%map) item with | Some _ -> 1 | None -> 0 @> + + let fst () = <@ fun fst _ -> fst @> + + let snd () = <@ fun _ snd -> snd @> diff --git a/src/GraphBLAS-sharp.Backend/Quotes/SubSum.fs b/src/GraphBLAS-sharp.Backend/Quotes/SubSum.fs index 3aa5c894..b16d4ebc 100644 --- a/src/GraphBLAS-sharp.Backend/Quotes/SubSum.fs +++ b/src/GraphBLAS-sharp.Backend/Quotes/SubSum.fs @@ -31,10 +31,30 @@ module SubSum = barrierLocal () @> - let sequentialSum<'a> opAdd = - sumGeneral<'a> <| sequentialAccess<'a> opAdd + let sequentialSum<'a> = sumGeneral<'a> << sequentialAccess<'a> - let treeSum<'a> opAdd = sumGeneral<'a> <| treeAccess<'a> opAdd + let upSweep<'a> = sumGeneral<'a> << treeAccess<'a> + + let downSweep opAdd = + <@ fun wgSize lid (localBuffer: 'a []) -> + let mutable step = wgSize + + while step > 1 do + barrierLocal () + + if lid < wgSize / step then + let i = step * (lid + 1) - 1 + let j = i - (step >>> 1) + + let tmp = localBuffer.[i] + + let operand = localBuffer.[j] // brahma error + let buff = (%opAdd) tmp operand + + localBuffer.[i] <- buff + localBuffer.[j] <- tmp + + step <- step >>> 1 @> let localPrefixSum opAdd = <@ fun (lid: int) (workGroupSize: int) (array: 'a []) -> @@ -52,4 +72,6 @@ module SubSum = barrierLocal () array.[lid] <- value @> + + let localIntPrefixSum = localPrefixSum <@ (+) @> diff --git a/tests/GraphBLAS-sharp.Tests/Common/Scan/ByKey.fs b/tests/GraphBLAS-sharp.Tests/Common/Scan/ByKey.fs new file mode 100644 index 00000000..1cb81709 --- /dev/null +++ b/tests/GraphBLAS-sharp.Tests/Common/Scan/ByKey.fs @@ -0,0 +1,111 @@ +module GraphBLAS.FSharp.Tests.Backend.Common.Scan.ByKey + +open GraphBLAS.FSharp.Backend.Common +open GraphBLAS.FSharp.Backend.Objects.ClContext +open Expecto +open GraphBLAS.FSharp.Tests +open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions + +let context = Context.defaultContext.ClContext + +let processor = Context.defaultContext.Queue + +let checkResult isEqual keysAndValues actual hostScan = + + let expected = + HostPrimitives.scanByKey hostScan keysAndValues + + "Results must be the same" + |> Utils.compareArrays isEqual actual expected + +let makeTestSequentialSegments isEqual scanHost scanDevice (keysAndValues: (int * 'a) []) = + if keysAndValues.Length > 0 then + let keys, values = + Array.sortBy fst keysAndValues |> Array.unzip + + let offsets = + HostPrimitives.getUniqueBitmapFirstOccurrence keys + |> HostPrimitives.getBitPositions + + let uniqueKeysCount = Array.distinct keys |> Array.length + + let clKeys = + context.CreateClArrayWithSpecificAllocationMode(HostInterop, keys) + + let clValues = + context.CreateClArrayWithSpecificAllocationMode(HostInterop, values) + + let clOffsets = + context.CreateClArrayWithSpecificAllocationMode(HostInterop, offsets) + + scanDevice processor uniqueKeysCount clValues clKeys clOffsets + + let actual = clValues.ToHostAndFree processor + clKeys.Free processor + clOffsets.Free processor + + let keysAndValues = Array.zip keys values + + checkResult isEqual keysAndValues actual scanHost + +let createTest (zero: 'a) opAddQ opAdd isEqual deviceScan hostScan = + + let hostScan = hostScan zero opAdd + + let deviceScan = + deviceScan context Utils.defaultWorkGroupSize opAddQ zero + + makeTestSequentialSegments isEqual hostScan deviceScan + |> testPropertyWithConfig Utils.defaultConfig $"test on {typeof<'a>}" + +let sequentialSegmentsTests = + let excludeTests = + [ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude + + if Utils.isFloat64Available context.ClDevice then + createTest + 0.0 + <@ (+) @> + (+) + Utils.floatIsEqual + PrefixSum.ByKey.sequentialExclude + HostPrimitives.prefixSumExclude + + createTest + 0.0f + <@ (+) @> + (+) + Utils.float32IsEqual + PrefixSum.ByKey.sequentialExclude + HostPrimitives.prefixSumExclude + + createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude + createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude ] + |> testList "exclude" + + let includeTests = + [ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude + + if Utils.isFloat64Available context.ClDevice then + createTest + 0.0 + <@ (+) @> + (+) + Utils.floatIsEqual + PrefixSum.ByKey.sequentialInclude + HostPrimitives.prefixSumInclude + + createTest + 0.0f + <@ (+) @> + (+) + Utils.float32IsEqual + PrefixSum.ByKey.sequentialInclude + HostPrimitives.prefixSumInclude + + createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude + createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude ] + + |> testList "include" + + testList "Sequential segments" [ excludeTests; includeTests ] diff --git a/tests/GraphBLAS-sharp.Tests/Common/ClArray/PrefixSum.fs b/tests/GraphBLAS-sharp.Tests/Common/Scan/PrefixSum.fs similarity index 95% rename from tests/GraphBLAS-sharp.Tests/Common/ClArray/PrefixSum.fs rename to tests/GraphBLAS-sharp.Tests/Common/Scan/PrefixSum.fs index 3c8bedee..c8ce588a 100644 --- a/tests/GraphBLAS-sharp.Tests/Common/ClArray/PrefixSum.fs +++ b/tests/GraphBLAS-sharp.Tests/Common/Scan/PrefixSum.fs @@ -1,4 +1,4 @@ -module GraphBLAS.FSharp.Tests.Backend.Common.ClArray.PrefixSum +module GraphBLAS.FSharp.Tests.Backend.Common.Scan.PrefixSum open Expecto open Expecto.Logging @@ -62,7 +62,7 @@ let makeTest plus zero isEqual scan (array: 'a []) = let testFixtures plus plusQ zero isEqual name = PrefixSum.runIncludeInplace plusQ context wgSize |> makeTest plus zero isEqual - |> testPropertyWithConfig config (sprintf "Correctness on %s" name) + |> testPropertyWithConfig config $"Correctness on %s{name}" let tests = q.Error.Add(fun e -> failwithf "%A" e) diff --git a/tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj b/tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj index b67154c2..234c76a1 100644 --- a/tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj +++ b/tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj @@ -24,12 +24,13 @@ - + + diff --git a/tests/GraphBLAS-sharp.Tests/Helpers.fs b/tests/GraphBLAS-sharp.Tests/Helpers.fs index d29dfe3e..c45a2674 100644 --- a/tests/GraphBLAS-sharp.Tests/Helpers.fs +++ b/tests/GraphBLAS-sharp.Tests/Helpers.fs @@ -141,13 +141,13 @@ module Utils = result module HostPrimitives = - let prefixSumInclude array = - Array.scan (+) 0 array - |> fun scanned -> scanned.[1..] + let prefixSumInclude zero add array = + Array.scan add zero array + |> fun scanned -> scanned.[1..], Array.last scanned - let prefixSumExclude sourceArray = - prefixSumInclude sourceArray - |> Array.insertAt 0 0 + let prefixSumExclude zero add sourceArray = + prefixSumInclude zero add sourceArray + |> (fst >> Array.insertAt 0 zero) |> fun array -> Array.take sourceArray.Length array, Array.last array let getUniqueBitmapLastOccurrence array = @@ -177,19 +177,20 @@ module HostPrimitives = |> Array.choose id let reduceByKey keys value reduceOp = - let zipped = Array.zip keys value - - Array.distinct keys + Array.zip keys value + |> Array.groupBy fst |> Array.map - (fun key -> - // extract elements corresponding to key - (key, - Array.map snd - <| Array.filter ((=) key << fst) zipped)) - // reduce elements - |> Array.map (fun (key, values) -> key, Array.reduce reduceOp values) + (fun (key, array) -> + Array.map snd array + |> Array.reduce reduceOp + |> fun value -> key, value) |> Array.unzip + let scanByKey scan keysAndValues = + Array.groupBy fst keysAndValues + |> Array.map (fun (_, array) -> Array.map snd array |> scan |> fst) + |> Array.concat + module Context = type TestContext = { ClContext: ClContext diff --git a/tests/GraphBLAS-sharp.Tests/Program.fs b/tests/GraphBLAS-sharp.Tests/Program.fs index b46c375c..8532df05 100644 --- a/tests/GraphBLAS-sharp.Tests/Program.fs +++ b/tests/GraphBLAS-sharp.Tests/Program.fs @@ -17,6 +17,12 @@ let matrixTests = |> testSequenced let commonTests = + let scanTests = + testList + "Scan" + [ Common.Scan.ByKey.sequentialSegmentsTests + Common.Scan.PrefixSum.tests ] + let reduceTests = testList "Reduce" @@ -29,8 +35,7 @@ let commonTests = let clArrayTests = testList "ClArray" - [ Common.ClArray.PrefixSum.tests - Common.ClArray.RemoveDuplicates.tests + [ Common.ClArray.RemoveDuplicates.tests Common.ClArray.Copy.tests Common.ClArray.Replicate.tests Common.ClArray.Exists.tests @@ -51,6 +56,7 @@ let commonTests = [ clArrayTests sortTests reduceTests + scanTests Common.Scatter.tests ] |> testSequenced