Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/GraphBLAS-sharp.Backend/Matrices.fs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ type Matrix<'a when 'a: struct> =
| MatrixCSR matrix -> matrix.ColumnCount
| MatrixCOO matrix -> matrix.ColumnCount

member this.Dispose() =
match this with
| MatrixCSR matrix -> (matrix :> IDeviceMemObject).Dispose()
| MatrixCOO matrix -> (matrix :> IDeviceMemObject).Dispose()

and CSRMatrix<'elem when 'elem: struct> =
{ Context: ClContext
RowCount: int
Expand Down
13 changes: 10 additions & 3 deletions src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/CSRMatrix.fs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ open Microsoft.FSharp.Quotations
module CSRMatrix =
let private expandRows (clContext: ClContext) =
let expandRows =
<@ fun (range: Range1D) workGroupSize (rowPointers: ClArray<int>) (rowIndices: ClArray<int>) ->
<@ fun (range: Range1D) workGroupSize (rowPointers: ClArray<int>) (rowIndices: ClArray<int>) rowCount nnz ->

let lid = range.LocalID0
let groupId = range.GlobalID0 / workGroupSize

let rowStart = rowPointers.[groupId]
let rowEnd = rowPointers.[groupId + 1]

let rowEnd =
if groupId <> rowCount - 1 then
rowPointers.[groupId + 1]
else
nnz

let rowLength = rowEnd - rowStart

let mutable i = lid
Expand All @@ -36,7 +42,8 @@ module CSRMatrix =
)

processor.Post(
Msg.MsgSetArguments(fun () -> kernel.SetArguments ndRange workGroupSize rowPointers rowIndices)
Msg.MsgSetArguments
(fun () -> kernel.SetArguments ndRange workGroupSize rowPointers rowIndices rowCount nnz)
)

processor.Post(Msg.CreateRunMsg<_, _> kernel)
Expand Down
40 changes: 36 additions & 4 deletions src/GraphBLAS-sharp.Backend/Matrix/Matrix.fs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,53 @@ open Brahma.FSharp.OpenCL
open Microsoft.FSharp.Quotations

module Matrix =
let copy (clContext: ClContext) =
let copy =
GraphBLAS.FSharp.Backend.ClArray.copy clContext

let copyData =
GraphBLAS.FSharp.Backend.ClArray.copy clContext

fun (processor: MailboxProcessor<_>) workGroupSize (matrix: Matrix<'a>) ->
match matrix with
| MatrixCOO m ->
let res =
{ Context = clContext
RowCount = m.RowCount
ColumnCount = m.ColumnCount
Rows = copy processor workGroupSize m.Rows
Columns = copy processor workGroupSize m.Columns
Values = copyData processor workGroupSize m.Values }

MatrixCOO res
| MatrixCSR m ->
let res =
{ Context = clContext
RowCount = m.RowCount
ColumnCount = m.ColumnCount
RowPointers = copy processor workGroupSize m.RowPointers
Columns = copy processor workGroupSize m.Columns
Values = copyData processor workGroupSize m.Values }

MatrixCSR res

let toCSR (clContext: ClContext) workGroupSize =
let toCSR = COOMatrix.toCSR clContext workGroupSize
let copy = copy clContext

fun (processor: MailboxProcessor<_>) (matrix: Matrix<'a>) ->
match matrix with
| MatrixCOO m -> toCSR processor m |> MatrixCSR
| MatrixCSR _ -> matrix
| MatrixCSR _ -> copy processor workGroupSize matrix

let toCOO (clContext: ClContext) workGroupSize =
let toCOO = CSRMatrix.toCOO clContext
let toCOO = CSRMatrix.toCOO clContext workGroupSize
let copy = copy clContext

fun (processor: MailboxProcessor<_>) (matrix: Matrix<'a>) ->
match matrix with
| MatrixCOO _ -> matrix
| MatrixCSR m -> toCOO workGroupSize processor m |> MatrixCOO
| MatrixCOO _ -> copy processor workGroupSize matrix
| MatrixCSR m -> toCOO processor m |> MatrixCOO

let eWiseAdd (clContext: ClContext) (opAdd: Expr<'a -> 'a -> 'a>) workGroupSize =
let COOeWiseAdd =
Expand Down
86 changes: 86 additions & 0 deletions src/GraphBLAS-sharp/Objects/Matrix.fs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
namespace GraphBLAS.FSharp

open Brahma.FSharp.OpenCL
open GraphBLAS.FSharp.Backend

type MatrixFromat =
| CSR
| COO
Expand All @@ -18,6 +21,89 @@ type Matrix<'a when 'a: struct> =
| MatrixCSR matrix -> matrix.ColumnCount
| MatrixCOO matrix -> matrix.ColumnCount

member this.NNZCount =
match this with
| MatrixCOO m -> m.Values.Length
| MatrixCSR m -> m.Values.Length

member this.ToBackend(context: ClContext) =
match this with
| MatrixCOO m ->
let rows = context.CreateClArray m.Rows
let columns = context.CreateClArray m.Columns
let values = context.CreateClArray m.Values

let result =
{ Backend.COOMatrix.Context = context
RowCount = m.RowCount
ColumnCount = m.ColumnCount
Rows = rows
Columns = columns
Values = values }

Backend.MatrixCOO result
| MatrixCSR m ->
let rows = context.CreateClArray m.RowPointers
let columns = context.CreateClArray m.ColumnIndices
let values = context.CreateClArray m.Values

let result =
{ Backend.CSRMatrix.Context = context
RowCount = m.RowCount
ColumnCount = m.ColumnCount
RowPointers = rows
Columns = columns
Values = values }

Backend.MatrixCSR result

static member FromBackend (q: MailboxProcessor<_>) matrix =
match matrix with
| Backend.MatrixCOO m ->
let rows = Array.zeroCreate m.Rows.Length
let columns = Array.zeroCreate m.Columns.Length
let values = Array.zeroCreate m.Values.Length

let _ =
q.Post(Msg.CreateToHostMsg(m.Rows, rows))

let _ =
q.Post(Msg.CreateToHostMsg(m.Columns, columns))

let _ =
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(m.Values, values, ch))

let result =
{ RowCount = m.RowCount
ColumnCount = m.ColumnCount
Rows = rows
Columns = columns
Values = values }

MatrixCOO result
| Backend.MatrixCSR m ->
let rows = Array.zeroCreate m.RowPointers.Length
let columns = Array.zeroCreate m.Columns.Length
let values = Array.zeroCreate m.Values.Length

let _ =
q.Post(Msg.CreateToHostMsg(m.RowPointers, rows))

let _ =
q.Post(Msg.CreateToHostMsg(m.Columns, columns))

let _ =
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(m.Values, values, ch))

let result =
{ RowCount = m.RowCount
ColumnCount = m.ColumnCount
RowPointers = rows
ColumnIndices = columns
Values = values }

MatrixCSR result

and CSRMatrix<'a> =
{ RowCount: int
ColumnCount: int
Expand Down
Loading