diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c8d74ff..25050827 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,9 +13,9 @@ jobs: fail-fast: false matrix: ml-deps: - - "torch==1.9.1+cpu tensorflow-cpu==2.5.3" - "torch==1.10.2+cpu tensorflow-cpu==2.6.3" - "torch==1.11.0+cpu tensorflow-cpu==2.7.1" + - "torch==1.12.0+cpu tensorflow-cpu==2.8.1" env: run_coverage: ${{ github.ref == 'refs/heads/master' }} diff --git a/setup.py b/setup.py index ac570cc4..4318fe0e 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ import setuptools -tensorflow = ["tensorflow>=2.5"] -pytorch = ["torch>=1.9"] +tensorflow = ["tensorflow>=2.6"] +pytorch = ["torch>=1.10"] sklearn = ["scikit-learn>=1.0"] cloud = ["tiledb-cloud"] full = sorted({"torchvision", *tensorflow, *pytorch, *sklearn, *cloud}) diff --git a/tiledb/ml/readers/pytorch.py b/tiledb/ml/readers/pytorch.py index b50eda73..f39f183b 100644 --- a/tiledb/ml/readers/pytorch.py +++ b/tiledb/ml/readers/pytorch.py @@ -11,13 +11,6 @@ import torch from torch.utils.data import DataLoader, IterableDataset, get_worker_info -try: - # torch>=1.10 - sparse_csr_tensor = torch.sparse_csr_tensor -except AttributeError: - # torch=1.9 - sparse_csr_tensor = torch._sparse_csr_tensor - import tiledb from ._tensor_schema import DenseTensorSchema, SparseTensorSchema, TensorSchema @@ -157,7 +150,7 @@ def _csr_to_coo_collate(arrays: Sequence[scipy.sparse.csr_matrix]) -> torch.Tens def _csr_collate(arrays: Sequence[scipy.sparse.csr_matrix]) -> torch.Tensor: """Collate multiple Scipy CSR matrices to a torch.Tensor with sparse_csr layout.""" stacked = scipy.sparse.vstack(arrays) - return sparse_csr_tensor( + return torch.sparse_csr_tensor( torch.from_numpy(stacked.indptr), torch.from_numpy(stacked.indices), stacked.data,