########################################################################################
# Workshop: Coding for PDEs with Neural Networks
# Date: 2025-24-01
# Author: Danilo Aballay, Vicente Iligaray, Ignacio Tapia y Manuel Sánchez
########################################################################################

In [7]:
import jax.numpy as jnp
from jax.ops import segment_sum
import jax.experimental.sparse as sparse
from jax import jit
import jax
from functools import partial

elements =jnp.array([[0, 1], [1, 2], [2, 3]])
NE = elements.shape[0]

SK = jnp.array([[1,-1],[-1,1]])
ke_values = jnp.zeros(((NE, 2, 2)))
for i in range(NE):
    ke_values = ke_values.at[i,:,:].set(SK)

In [9]:
@jit
def create_COO(elements, ke_values):
    NE       = elements.shape[0]
    dof_mat  = jnp.tile(elements[:, None, :], (1, 2, 1))
    dof_rows = dof_mat.reshape(NE, -1, order='C')
    dof_cols = dof_mat.reshape(NE, -1, order='F')

    rows = dof_rows.reshape(-1)
    cols = dof_cols.reshape(-1)
    ke_values_flatten = ke_values.reshape(-1)
    ke_values_flatten = ke_values_flatten.at[1:3].set(0)
    ke_values_flatten = ke_values_flatten.at[0].set(1)

    return sparse.COO((ke_values_flatten, rows, cols), shape=(NE+1, NE+1))

@partial(jax.jit, static_argnames=['n_removes', 'n_init_rows'])
def to_csr(COO, n_removes, n_init_rows):
        # Crear una clave única para cada par de coordenadas (row, col)
    row  = COO.row
    col  = COO.col
    data = COO.data
    max_col = col.max() + 1  # Asegurarse de que no haya colisiones
    keys    = row * max_col + col

    # Determinar los índices únicos y asignar un índice inverso
    sort_indices = jnp.argsort(keys)
    sorted_keys  = keys[sort_indices]
    sorted_data  = data[sort_indices]
    sorted_row   = row[sort_indices]
    sorted_col   = col[sort_indices]

    # Identificar posiciones únicas manualmente sin usar `jnp.where`
    unique_mask = jnp.diff(jnp.concatenate([jnp.array([-1]), sorted_keys])) != 0
    unique_indices = jnp.nonzero(unique_mask, size = sorted_keys.shape[0]-n_removes)[0]

    # Crear un índice inverso que mapea cada clave al índice único correspondiente
    inverse_indices = jnp.cumsum(unique_mask) - 1

    # Sumar los valores de `data` para los índices únicos
    data_summed = segment_sum(sorted_data, inverse_indices, num_segments=len(unique_indices))

    # Extraer coordenadas únicas
    final_row = sorted_row[unique_indices]
    final_col = sorted_col[unique_indices]

    # indices_filas = jnp.where(jnp.concatenate([jnp.array([True]), final_row[1:] != final_row[:-1]]))[0]
    # indices_filas = jnp.append(indices_filas, COO.size)
    change_mask = jnp.concatenate([jnp.array([True]), final_row[1:] != final_row[:-1]])

    # Obtener los índices explícitos donde ocurre un cambio
    indices_filas = jnp.nonzero(change_mask, size=final_row.size, fill_value=0)[0]

    # Agregar el tamaño total como un índice adicional
    indices_filas = jnp.append(indices_filas[0:n_init_rows], len(final_col))

    return sparse.CSR((data_summed, final_col, indices_filas), shape=COO.shape)
    # return data_summed, final_row, final_col, indices_filas


A_COO = create_COO(elements, ke_values)
A_CSR = to_csr(A_COO, A_COO.shape[0]-2, A_COO.shape[0])

print(A_COO.todense())
print(A_CSR.todense())


[[ 1.  0.  0.  0.]
 [ 0.  2. -1.  0.]
 [ 0. -1.  2. -1.]
 [ 0.  0. -1.  1.]]
[[ 1.  0.  0.  0.]
 [ 0.  2. -1.  0.]
 [ 0. -1.  2. -1.]
 [ 0.  0. -1.  1.]]


In [1]:
import jax.numpy as jnp
# from jax.ops import segment_sum
import jax.experimental.sparse as sparse
from jax.ops import segment_sum
from jax import jit
import jax
from functools import partial

from r_adaptivity_sparse import make_loss_model, make_model


In [2]:
#  Number of neurons per hidden layer in the neural network
nn = int(4) # Two times the number of neurons 

# Initialize the neural network model for the approximate solution
model = make_model(nn)

theta = model(jnp.array([1]))

In [3]:
from Laplace_JAXSparse2D import softmax_nodes, generate_mesh, element_stiffness, create_COO, to_csr, solve
values_phi0_ = jnp.array([[0.9083804012656871,  0.7331497981296533,  0.47654496148466596, 0.21994012483967862, 0.04470952170364481],
                         [0.7331497981296533,  0.591721954534264, 0.38461732752642075, 0.17751270051857745,  0.036084856923188136],
                         [0.47654496148466596, 0.38461732752642075, 0.25,       0.11538267247357925, 0.02345503851533401],
                         [0.21994012483967862, 0.17751270051857745,  0.11538267247357925, 0.053252644428581054, 0.010825220107479883],
                         [0.04470952170364481, 0.036084856923188136, 0.02345503851533401, 0.010825220107479883, 0.002200555327023207]])
nx = int(theta.shape[1]/2) +1
ny = nx

node_coords_x, node_coords_y  = softmax_nodes(theta)
coords, elements = generate_mesh(nx, ny, node_coords_x, node_coords_y)
n_elements = elements.shape[0]
n_nodes = coords.shape[0]

dirichlet_nodes = jnp.append(jnp.arange(nx),nx*jnp.arange(1,ny))
neumann_nodes = jnp.append(nx*jnp.arange(2,ny)-1, jnp.arange((ny-1)*nx-1, ny*nx))

dirichlet_nodes = jnp.append(dirichlet_nodes, neumann_nodes)

# Extract the coordinates for the start and end points of each element
start_coords = coords[elements[:, 0], :]
end_coords = coords[elements[:, 2], :]

# Compute element lengths in a vectorized manner
element_length = end_coords - start_coords

n_nodes = coords.shape[0]
### ASSEMBLE CSR

ke_values = jax.vmap(element_stiffness)(element_length)

A_COO = create_COO(elements, ke_values, n_nodes)
print(A_COO.todense())


[[ 0.66666667 -0.16666667  0.         -0.16666667 -0.33333333  0.
   0.          0.          0.        ]
 [-0.16666667  1.33333333 -0.16666667 -0.33333333 -0.33333333 -0.33333333
   0.          0.          0.        ]
 [ 0.         -0.16666667  0.66666667  0.         -0.33333333 -0.16666667
   0.          0.          0.        ]
 [-0.16666667 -0.33333333  0.          1.33333333 -0.33333333  0.
  -0.16666667 -0.33333333  0.        ]
 [-0.33333333 -0.33333333 -0.33333333 -0.33333333  2.66666667 -0.33333333
  -0.33333333 -0.33333333 -0.33333333]
 [ 0.         -0.33333333 -0.16666667  0.         -0.33333333  1.33333333
   0.         -0.33333333 -0.16666667]
 [ 0.          0.          0.         -0.16666667 -0.33333333  0.
   0.66666667 -0.16666667  0.        ]
 [ 0.          0.          0.         -0.33333333 -0.33333333 -0.33333333
  -0.16666667  1.33333333 -0.16666667]
 [ 0.          0.          0.          0.         -0.33333333 -0.16666667
   0.         -0.16666667  0.66666667]]


In [4]:
COO = A_COO
n_removes = 0
n_init_rows = n_nodes
row  = COO.row
col  = COO.col
data = COO.data
max_col = col.max() + 1  
keys    = row * max_col + col

sort_indices = jnp.argsort(keys)
sorted_keys  = keys[sort_indices]
sorted_data  = data[sort_indices]
sorted_row   = row[sort_indices]
sorted_col   = col[sort_indices]

unique_mask = jnp.diff(jnp.concatenate([jnp.array([-1]), sorted_keys])) != 0
unique_indices = jnp.nonzero(unique_mask, size = sorted_keys.shape[0]-n_removes)[0]

inverse_indices = jnp.cumsum(unique_mask) - 1

data_summed = segment_sum(sorted_data, inverse_indices, num_segments=len(unique_indices))

final_row = sorted_row[unique_indices]
final_col = sorted_col[unique_indices]
change_mask = jnp.concatenate([jnp.array([True]), final_row[1:] != final_row[:-1]])

indices_filas = jnp.nonzero(change_mask, size=final_row.size, fill_value=0)[0]

indices_filas = jnp.append(indices_filas[0:n_init_rows], len(final_col))

# return sparse.CSR((data_summed, final_col, indices_filas), shape=COO.shape)


In [36]:
final_col.shape

(1024,)

In [13]:
to_remove

49

In [10]:
to_remove+24

49

In [14]:
to_remove = 16 + 2*(nx-2)*6 + 2*(ny-2)*6 + 9*(nx-2)*(ny-2) 
A_val_length = 16*elements.shape[0]
final_col[-(A_val_length-to_remove):]

Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64)

In [15]:
A_CSR = to_csr(A_COO, A_val_length -  to_remove, n_nodes)
print(A_CSR.todense())

[[ 0.66666667 -0.16666667  0.         -0.16666667 -0.33333333  0.
   0.          0.          0.        ]
 [-0.16666667  1.33333333 -0.16666667 -0.33333333 -0.33333333 -0.33333333
   0.          0.          0.        ]
 [ 0.         -0.16666667  0.66666667  0.         -0.33333333 -0.16666667
   0.          0.          0.        ]
 [-0.16666667 -0.33333333  0.          1.33333333 -0.33333333  0.
  -0.16666667 -0.33333333  0.        ]
 [-0.33333333 -0.33333333 -0.33333333 -0.33333333  2.66666667 -0.33333333
  -0.33333333 -0.33333333 -0.33333333]
 [ 0.         -0.33333333 -0.16666667  0.         -0.33333333  1.33333333
   0.         -0.33333333 -0.16666667]
 [ 0.          0.          0.         -0.16666667 -0.33333333  0.
   0.66666667 -0.16666667  0.        ]
 [ 0.          0.          0.         -0.33333333 -0.33333333 -0.33333333
  -0.16666667  1.33333333 -0.16666667]
 [ 0.          0.          0.          0.         -0.33333333 -0.16666667
   0.         -0.16666667  0.66666667]]


In [23]:
print(jnp.all(A_CSR.todense() == A_COO.todense()))

True


In [28]:
row

Array([0, 1, 4, 3, 0, 1, 4, 3, 0, 1, 4, 3, 0, 1, 4, 3, 1, 2, 5, 4, 1, 2,
       5, 4, 1, 2, 5, 4, 1, 2, 5, 4, 3, 4, 7, 6, 3, 4, 7, 6, 3, 4, 7, 6,
       3, 4, 7, 6, 4, 5, 8, 7, 4, 5, 8, 7, 4, 5, 8, 7, 4, 5, 8, 7],      dtype=int64)

In [29]:
data = A_COO.data
row = A_COO.row
col = A_COO.col

A_BCOO = sparse.BCOO((data, jnp.stack(row,col)), shape=(n_nodes, n_nodes))

TypeError: Only integer scalar arrays can be converted to a scalar index.

In [80]:
data = jnp.array([1., 3., 5., 2, -7])
row = jnp.array([0, 1, 2, 1, 2])
col = jnp.array([0, 1, 2, 2, 0])
A = sparse.COO((data, row, col), shape=(3, 3))
B = (A.T)
print(A.todense(),'\n', B.todense())

[[ 1.  0.  0.]
 [ 0.  3.  2.]
 [-7.  0.  5.]] 
 [[ 1.  0. -7.]
 [ 0.  3.  0.]
 [ 0.  2.  5.]]


In [98]:
A.data -0.1*B.data

Array([ 0.9,  2.7,  4.5,  1.8, -6.3], dtype=float64)

In [None]:
print(sum_COO(A, B).todense())

[[ 2.  0. -7.]
 [ 0.  6.  2.]
 [-7.  2. 10.]]


In [82]:
A_data = jnp.concatenate([A.data, -0.1*B.data])
A_row = jnp.concatenate([A.row, B.row])
A_col = jnp.concatenate([A.col, B.col])

A_COO = sparse.COO((A_data, A_row, A_col), shape=(3, 3))
print(A_COO.todense())

[[ 0.9  0.   0.7]
 [ 0.   2.7  2. ]
 [-7.  -0.2  4.5]]


In [94]:
A_COO.data

Array([ 1. ,  3. ,  5. ,  2. , -7. , -0.1, -0.3, -0.5, -0.2,  0.7],      dtype=float64)

In [97]:
A_CSR = to_csr(A_COO, 5, 3)
print(A_CSR.todense())

[[ 0.9  0.   0.7]
 [ 0.   2.7  2. ]
 [-7.   0.   0. ]]
