Skip to content

Commit

Permalink
Add support for PyData/Sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
alugowski committed Aug 31, 2023
1 parent 62a933a commit 0678df1
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ jobs:
pip install suitesparse-graphblas==7.4.4.1a1
pip install python-graphblas
- name: Install pydata sparse
if: ${{ !contains(matrix.python-version, 'pypy') && matrix.python-version != '3.7' }} # no wheels for pypy and old python
run: |
pip install sparse
- name: Test without Jupyter
if: ${{ !contains(matrix.python-version, 'pypy') }} # no scipy wheels for pypy
run: pytest
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Sparse matrix spy plot and sparkline renderer. Supports:
* **SciPy** - sparse matrices and arrays like `csr_matrix` and `coo_array` [(demo)](demo.ipynb)
* **NumPy** - `ndarray` [(demo)](demo-numpy.ipynb)
* **[Python-graphblas](https://github.com/python-graphblas/python-graphblas)** - `gb.Matrix` [(demo)](demo-python-graphblas.ipynb)
* **[PyData/Sparse](https://sparse.pydata.org/)** - `COO`, `DOK`, `GCXS` [(demo)](demo-pydata-sparse.ipynb)

Features:
* Simple `spy()` method, similar to MatLAB's spy.
Expand Down
111 changes: 111 additions & 0 deletions demo-pydata-sparse.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions matspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def _register_bundled():
from .adapters.graphblas_driver import GraphBLASDriver
register_driver(GraphBLASDriver)

from .adapters.sparse_driver import PyDataSparseDriver
register_driver(PyDataSparseDriver)


_register_bundled()

Expand Down
18 changes: 18 additions & 0 deletions matspy/adapters/sparse_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (C) 2023 Adam Lugowski.
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file.
# SPDX-License-Identifier: BSD-2-Clause

from typing import Any, Iterable

from . import Driver, MatrixSpyAdapter


class PyDataSparseDriver(Driver):
@staticmethod
def get_supported_type_prefixes() -> Iterable[str]:
return ["sparse."]

@staticmethod
def adapt_spy(mat: Any) -> MatrixSpyAdapter:
from .sparse_impl import PyDataSparseSpy
return PyDataSparseSpy(mat)
66 changes: 66 additions & 0 deletions matspy/adapters/sparse_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (C) 2023 Adam Lugowski.
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file.
# SPDX-License-Identifier: BSD-2-Clause

from typing import Tuple

import numpy as np
import sparse

from . import describe, generate_spy_triple_product, MatrixSpyAdapter


def generate_spy_triple_product_sparse(matrix_shape, spy_shape) -> Tuple[sparse.SparseArray, sparse.SparseArray]:
# construct a triple product that will scale the matrix
left, right = generate_spy_triple_product(matrix_shape, spy_shape)

left_shape, (left_rows, left_cols) = left
right_shape, (right_rows, right_cols) = right
left_mat = sparse.COO(coords=(left_rows, left_cols), data=np.ones(len(left_rows)), shape=left_shape)
right_mat = sparse.COO(coords=(right_rows, right_cols), data=np.ones(len(right_rows)), shape=right_shape)

return left_mat, right_mat


class PyDataSparseSpy(MatrixSpyAdapter):
def __init__(self, mat):
super().__init__()
self.mat = mat

def get_shape(self) -> tuple:
return self.mat.shape

def describe(self) -> str:
parts = [
self.mat.format,
]

return describe(shape=self.mat.shape,
nnz=self.mat.nnz, nz_type=self.mat.dtype,
notes=", ".join(parts))

def get_spy(self, spy_shape: tuple) -> np.array:
if isinstance(self.mat, sparse.DOK):
self.mat = self.mat.asformat("coo")

# construct a triple product that will scale the matrix
left, right = generate_spy_triple_product_sparse(self.mat.shape, spy_shape)

# save existing matrix data
mat_data_save = self.mat.data

# replace with all ones
self.mat.data = np.ones(self.mat.data.shape)

# triple product
try:
spy = left @ self.mat @ right
except ValueError:
# broken matmul on some types
temp = self.mat.asformat("coo")
spy = left @ temp @ right

# restore original matrix data
self.mat.data = mat_data_save

return np.array(spy.todense())
58 changes: 58 additions & 0 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (C) 2023 Adam Lugowski.
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file.
# SPDX-License-Identifier: BSD-2-Clause

import unittest

try:
import sparse
except ImportError:
sparse = None

import numpy as np
import scipy.sparse

from matspy import spy_to_mpl, to_sparkline, to_spy_heatmap

np.random.seed(123)


@unittest.skipIf(sparse is None, "pydata/sparse not installed")
class PyDataSparseTests(unittest.TestCase):
def setUp(self):
self.mats = [
sparse.COO.from_scipy_sparse(scipy.sparse.random(10, 10, density=0.4)),
sparse.COO.from_scipy_sparse(scipy.sparse.random(5, 10, density=0.4)),
sparse.COO.from_scipy_sparse(scipy.sparse.random(5, 1, density=0.4)),
sparse.COO.from_scipy_sparse(scipy.sparse.coo_matrix(([], ([], [])), shape=(10, 10))),
]

def test_no_crash(self):
import matplotlib.pyplot as plt
for fmt in "coo", "gcxs", "dok", "csr", "csc":
for source_mat in self.mats:
mat = source_mat.asformat(fmt)

fig, ax = spy_to_mpl(mat)
plt.close(fig)

res = to_sparkline(mat)
self.assertGreater(len(res), 10)

def test_count(self):
arrs = [
(0, sparse.COO(np.array([[0]]))),
(1, sparse.COO(np.array([[1]]))),
(0, sparse.COO(np.array([[0, 0], [0, 0]]))),
(1, sparse.COO(np.array([[1, 0], [0, 0]]))),
]

for count, arr in arrs:
area = np.prod(arr.shape)
heatmap = to_spy_heatmap(arr, buckets=1, shading="absolute")
self.assertEqual(len(heatmap), 1)
self.assertAlmostEqual( count / area, heatmap[0][0], places=2)


if __name__ == '__main__':
unittest.main()

0 comments on commit 0678df1

Please sign in to comment.