In [272]:
import jax
from jax import jit
import jax.numpy as np
import jax.experimental.sparse as sparse

row  = np.array([0, 3, 1, 2, 0])
col  = np.array([0, 3, 1, 2, 2])
data = np.array([4, 5, 7, 2, 9])
A_COO = sparse.COO((data,row, col), shape=(4, 4))
print(A_COO.todense())

print(A_COO.data)
A_COO.data


[[4 0 9 0]
 [0 7 0 0]
 [0 0 2 0]
 [0 0 0 5]]
[4 5 7 2 9]


Array([4, 5, 7, 2, 9], dtype=int32)

In [359]:
def to_csr(COO):
    # Crear una lista de tuplas (fila, columna, valor)
    tuplas = list(zip(COO.row, COO.col, COO.data))

    # Ordenar las tuplas por fila
    tuplas_ordenadas = np.array(sorted(tuplas, key=lambda x: x[0]))  # Asegúrate de que tuplas_ordenadas sea un array de jax.numpy

    data_ordenado = tuplas_ordenadas[:, 2]
    col_ordenado  = tuplas_ordenadas[:, 1]

    # Para obtener indices_filas sin usar un bucle for
    indices_filas = np.where(np.concatenate([np.array([True]), tuplas_ordenadas[1:, 0] != tuplas_ordenadas[:-1, 0]]))[0]
    indices_filas = np.append(indices_filas, len(data_ordenado))

    return sparse.CSR((data_ordenado, col_ordenado, indices_filas), shape=COO.shape)


In [None]:

elements =np.array([[0, 1], [1, 2], [2, 3]])
NE = elements.shape[0]

SK = np.array([[1,-1],[-1,1]])
ke_values = np.zeros(((NE, 2, 2)))
for i in range(NE):
    ke_values = ke_values.at[i,:,:].set(SK)




[0 1 0 1 1 2 1 2 2 3 2 3]
[0 0 1 1 1 1 2 2 2 2 3 3]
[ 1. -1. -1.  1.  1. -1. -1.  1.  1. -1. -1.  1.]


In [387]:
import jax.numpy as jnp
from jax.ops import segment_sum
from jax import jit
from functools import partial


@jit
def create_COO(elements, ke_values):
    dof_mat =  np.tile(elements[:, None, :], (1, 2, 1))
    dof_rows = dof_mat.reshape(NE, -1, order='C')
    dof_cols = dof_mat.reshape(NE, -1, order='F')

    rows = dof_rows.reshape(-1)
    cols = dof_cols.reshape(-1)
    ke_values_flatten = ke_values.reshape(-1)

    return sparse.COO((ke_values_flatten, rows, cols), shape=(NE+1, NE+1))

@partial(jax.jit, static_argnames=['n_removes', 'n_init_rows'])
def to_csr(COO, n_removes, n_init_rows):
        # Crear una clave única para cada par de coordenadas (row, col)
    row  = COO.row
    col  = COO.col
    data = COO.data
    max_col = col.max() + 1  # Asegurarse de que no haya colisiones
    keys    = row * max_col + col

    # Determinar los índices únicos y asignar un índice inverso
    sort_indices = jnp.argsort(keys)
    sorted_keys  = keys[sort_indices]
    sorted_data  = data[sort_indices]
    sorted_row   = row[sort_indices]
    sorted_col   = col[sort_indices]

    # Identificar posiciones únicas manualmente sin usar `jnp.where`
    unique_mask = jnp.diff(jnp.concatenate([jnp.array([-1]), sorted_keys])) != 0
    unique_indices = jnp.nonzero(unique_mask, size = sorted_keys.shape[0]-n_removes)[0]

    # Crear un índice inverso que mapea cada clave al índice único correspondiente
    inverse_indices = jnp.cumsum(unique_mask) - 1

    # Sumar los valores de `data` para los índices únicos
    data_summed = segment_sum(sorted_data, inverse_indices, num_segments=len(unique_indices))

    # Extraer coordenadas únicas
    final_row = sorted_row[unique_indices]
    final_col = sorted_col[unique_indices]

    # indices_filas = jnp.where(jnp.concatenate([jnp.array([True]), final_row[1:] != final_row[:-1]]))[0]
    # indices_filas = jnp.append(indices_filas, COO.size)
    change_mask = jnp.concatenate([jnp.array([True]), final_row[1:] != final_row[:-1]])

    # Obtener los índices explícitos donde ocurre un cambio
    indices_filas = jnp.nonzero(change_mask, size=final_row.size, fill_value=0)[0]

    # Agregar el tamaño total como un índice adicional
    indices_filas = jnp.append(indices_filas[0:n_init_rows], COO.size)

    return sparse.CSR((data_summed, final_col, indices_filas), shape=COO.shape)
    # return data_summed, final_row, final_col, indices_filas


A_COO = create_COO(elements, ke_values)
A_CSR = to_csr(A_COO, 2, 4)
print(A_COO.todense())
print(A_CSR.todense())


[[ 1. -1.  0.  0.]
 [-1.  2. -1.  0.]
 [ 0. -1.  2. -1.]
 [ 0.  0. -1.  1.]]
[[ 1. -1.  0.  0.]
 [-1.  2. -1.  0.]
 [ 0. -1.  2. -1.]
 [ 0.  0. -1.  1.]]


In [379]:
A_indptr

Array([ 0,  2,  5,  8, 16], dtype=int32)

In [380]:
sparse.CSR((A_data, A_col, A_indptr), shape=A_COO.shape).todense()

Array([[ 1., -1.,  0.,  0.],
       [-1.,  2., -1.,  0.],
       [ 0., -1.,  2., -1.],
       [ 0.,  0., -1.,  1.]], dtype=float32)

In [360]:
def to_csr(COO):
    # Crear una lista de tuplas (fila, columna, valor)
    tuplas = list(zip(COO.row, COO.col, COO.data))

    # Ordenar las tuplas por fila
    tuplas_ordenadas = np.array(sorted(tuplas, key=lambda x: x[0]))  # Asegúrate de que tuplas_ordenadas sea un array de jax.numpy

    data_ordenado = tuplas_ordenadas[:, 2]
    col_ordenado  = tuplas_ordenadas[:, 1]

    # Para obtener indices_filas sin usar un bucle for
    indices_filas = np.where(np.concatenate([np.array([True]), tuplas_ordenadas[1:, 0] != tuplas_ordenadas[:-1, 0]]))[0]
    indices_filas = np.append(indices_filas, len(data_ordenado))

    return sparse.CSR((data_ordenado, col_ordenado, indices_filas), shape=COO.shape)

to_csr(COO_no_repetidos(A_COO, 2)).todense()

AssertionError: 

In [349]:
to_csr(A_COO, 2).todense()

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function to_csr at C:\Users\itapi\AppData\Local\Temp\ipykernel_27880\437891706.py:6 for jit. This value became a tracer due to JAX operations on these lines:

  operation a[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    from line C:\Users\itapi\AppData\Local\Temp\ipykernel_27880\437891706.py:12:14 (to_csr)

  operation a[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    from line C:\Users\itapi\AppData\Local\Temp\ipykernel_27880\437891706.py:17:19 (to_csr)

  operation a[35m:i32[12][39m = add b c
    from line C:\Users\itapi\AppData\Local\Temp\ipykernel_27880\437891706.py:17:19 (to_csr)

  operation a[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    from line C:\Users\itapi\AppData\Local\Temp\ipykernel_27880\437891706.py:18:19 (to_csr)

  operation a[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    from line C:\Users\itapi\AppData\Local\Temp\ipykernel_27880\437891706.py:19:19 (to_csr)

(Additional originating lines are not shown.)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [296]:
print('A_row =',A_COO.row)
print('A_col =',A_COO.col)
print('A_data =',A_COO.data)
print(A_COO.todense())

A_row = [0 1 0 1 1 2 1 2 2 3 2 3]
A_col = [0 0 1 1 1 1 2 2 2 2 3 3]
A_data = [ 1. -1. -1.  1.  1. -1. -1.  1.  1. -1. -1.  1.]
[[ 1. -1.  0.  0.]
 [-1.  2. -1.  0.]
 [ 0. -1.  2. -1.]
 [ 0.  0. -1.  1.]]


In [297]:
A_rows, A_cols, A_data = remove_same_pos(A_COO.row, A_COO.col, A_COO.data, 2)
print('A_rows =', A_rows)
print('A_col = ',A_cols)
print('A_data =', A_data)



A_rows = [0 0 1 1 1 2 2 2 3 3]
A_col =  [0 1 0 1 2 1 2 3 2 3]
A_data = [ 1. -1. -1.  2. -1. -1.  2. -1. -1.  1.]


In [303]:
# Para obtener indices_filas sin usar un bucle for
indices_filas = np.where(np.concatenate([np.array([True]), A_rows[1:, 0] != tuplas_ordenadas[:-1, 0]]))[0]
indices_filas = np.append(indices_filas, len(data_ordenado))


AssertionError: 

In [318]:
indices_filas

Array([ 0,  2,  5,  8, 16], dtype=int32)

In [320]:
print(sparse.CSR((A_data, A_cols, indices_filas), shape=(4, 4)).todense())
print(create_COO(elements, ke_values).todense())

[[ 1. -1.  0.  0.]
 [-1.  2. -1.  0.]
 [ 0. -1.  2. -1.]
 [ 0.  0. -1.  1.]]
[[ 1. -1.  0.  0.]
 [-1.  2. -1.  0.]
 [ 0. -1.  2. -1.]
 [ 0.  0. -1.  1.]]


In [307]:
tuplas_ordenadas[1:,0]

Array([0., 1., 1., 1., 2., 2., 2., 3., 3.], dtype=float32)

In [305]:
tuplas = list(zip(A_rows, A_cols, A_data))

# Ordenar las tuplas por fila
tuplas_ordenadas = np.array(sorted(tuplas, key=lambda x: x[0]))  # Asegúrate de que tuplas_ordenadas sea un array de jax.numpy

data_ordenado = tuplas_ordenadas[:, 2]
col_ordenado  = tuplas_ordenadas[:, 1]

# Para obtener indices_filas sin usar un bucle for
indices_filas = np.where(np.concatenate([np.array([True]), tuplas_ordenadas[1:, 0] != tuplas_ordenadas[:-1, 0]]))[0]
indices_filas = np.append(indices_filas, len(data_ordenado))

A_CSR = sparse.CSR((data_ordenado, col_ordenado, indices_filas), shape=A_COO.shape)
# A_CSR.todense()

In [261]:
import jax.numpy as jnp
from jax.ops import segment_sum
from jax import jit

@jit
def remove_same_pos(row, col, data):
    # Crear una clave única para cada par de coordenadas (row, col)
    max_col = col.max() + 1  # Evitar colisiones creando un espacio suficiente
    keys = row * max_col + col

    # Ordenar las claves para agrupar duplicados
    sort_indices = jnp.argsort(keys)
    sorted_keys = keys[sort_indices]
    sorted_data = data[sort_indices]
    sorted_row = row[sort_indices]
    sorted_col = col[sort_indices]

    # Identificar las posiciones de cambio en las claves (es decir, las únicas)
    is_unique = jnp.concatenate([jnp.array([True]), sorted_keys[1:] != sorted_keys[:-1]])
    unique_indices = jnp.nonzero(is_unique)[0]

    # Crear un índice inverso que asigna cada clave al índice único correspondiente
    inverse_indices = jnp.cumsum(is_unique) - 1

    # Sumar los valores de `data` correspondientes a las claves únicas
    data_summed = segment_sum(sorted_data, inverse_indices, num_segments=unique_indices.shape[0])

    # Extraer las coordenadas únicas correspondientes a las claves
    final_row = sorted_row[unique_indices]
    final_col = sorted_col[unique_indices]

    return final_row, final_col, data_summed

# Prueba
row = jnp.array([0, 1, 2, 2, 3, 3, 4])
col = jnp.array([1, 2, 1, 1, 0, 0, 5])
data = jnp.array([10, 20, 10, 30, 40, 50, 60], dtype=jnp.float32)

final_row, final_col, data_summed = remove_same_pos(row, col, data)

print("Processed Row:", final_row)
print("Processed Col:", final_col)
print("Processed Data:", data_summed)


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function remove_same_pos at C:\Users\itapi\AppData\Local\Temp\ipykernel_27880\634259057.py:5 for jit. This concrete value was not available in Python because it depends on the values of the arguments row and col.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [248]:
import jax.numpy as jnp

# Identificar posiciones únicas manualmente sin usar `jnp.where`
# unique_mask = jnp.diff(jnp.concatenate([jnp.array([-1]), sorted_keys])) != 0
# unique_indices = jnp.nonzero(unique_mask)[0]

inverse_indices

Array([0, 1, 2, 2, 3, 3, 4], dtype=int32)

In [232]:
jnp.diff(jnp.concatenate([jnp.array([-1]), sorted_keys]))!= 0

Array([ True,  True,  True, False,  True, False,  True], dtype=bool)