Skip to content

Commit

Permalink
Replace deprecated KeyArray with Array (#135)
Browse files Browse the repository at this point in the history
JAX 0.4.24 removed `random.KeyArray`. According to the deprecation
warning, it can be replaced with `jax.Array`:
```
DeprecationWarning: jax.random.KeyArray is deprecated. Use jax.Array for annotations, and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) for runtime detection of typed prng keys (i.e. keys created with jax.random.key).
```
  • Loading branch information
martinkim0 committed Feb 8, 2024
1 parent a5b7c1e commit 6cf394e
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 11 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ xfail_strict = true
[tool.ruff]
src = ["src"]
line-length = 120
select = [
lint.select = [
"F", # Errors detected by Pyflakes
"E", # Error detected by Pycodestyle
"W", # Warning detected by Pycodestyle
Expand All @@ -107,7 +107,7 @@ select = [
"UP", # pyupgrade
"RUF100", # Report unused noqa directives
]
ignore = [
lint.ignore = [
# line too long -> we accept long comment lines; formatter gets rid of long code lines
"E501",
# Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
Expand Down
4 changes: 2 additions & 2 deletions src/scib_metrics/_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Union

import jax
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp
from jax import Array

NdArray = Union[np.ndarray, jnp.ndarray]
IntOrKey = Union[int, jax.random.KeyArray]
IntOrKey = Union[int, Array]
ArrayLike = Union[np.ndarray, sp.spmatrix, jnp.ndarray]
7 changes: 4 additions & 3 deletions src/scib_metrics/utils/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax import Array
from sklearn.utils import check_array

from scib_metrics._types import IntOrKey
Expand All @@ -18,7 +19,7 @@ def _tolerance(X: jnp.ndarray, tol: float) -> float:
return np.mean(variances) * tol


def _initialize_random(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray:
def _initialize_random(X: jnp.ndarray, n_clusters: int, key: Array) -> jnp.ndarray:
"""Initialize cluster centroids randomly."""
n_obs = X.shape[0]
key, subkey = jax.random.split(key)
Expand All @@ -28,7 +29,7 @@ def _initialize_random(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray


@partial(jax.jit, static_argnums=1)
def _initialize_plus_plus(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray:
def _initialize_plus_plus(X: jnp.ndarray, n_clusters: int, key: Array) -> jnp.ndarray:
"""Initialize cluster centroids with k-means++ algorithm."""
n_obs = X.shape[0]
key, subkey = jax.random.split(key)
Expand Down Expand Up @@ -111,7 +112,7 @@ def __init__(
self.n_init = n_init
self.max_iter = max_iter
self.tol_scale = tol
self.seed: jax.random.KeyArray = validate_seed(seed)
self.seed: jax.Array = validate_seed(seed)

if init not in ["k-means++", "random"]:
raise ValueError("Invalid init method, must be one of ['k-means++' or 'random'].")
Expand Down
4 changes: 2 additions & 2 deletions src/scib_metrics/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import numpy as np
from chex import ArrayDevice
from jax import nn
from jax import Array, nn
from scipy.sparse import csr_matrix
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_array
Expand Down Expand Up @@ -37,7 +37,7 @@ def one_hot(y: NdArray, n_classes: Optional[int] = None) -> jnp.ndarray:
return nn.one_hot(jnp.ravel(y), n_classes)


def validate_seed(seed: IntOrKey) -> jax.random.KeyArray:
def validate_seed(seed: IntOrKey) -> Array:
"""Validate a seed and return a Jax random key."""
return jax.random.PRNGKey(seed) if isinstance(seed, int) else seed

Expand Down
5 changes: 3 additions & 2 deletions tests/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import jax
import jax.numpy as jnp
from jax import Array

IntOrKey = Union[int, jax.random.KeyArray]
IntOrKey = Union[int, Array]


def _validate_seed(seed: IntOrKey) -> jax.random.KeyArray:
def _validate_seed(seed: IntOrKey) -> Array:
return jax.random.PRNGKey(seed) if isinstance(seed, int) else seed


Expand Down

0 comments on commit 6cf394e

Please sign in to comment.