In [None]:
from scipy.signal import decimate
from scipy.spatial import cKDTree
import time
import copy
import numpy as np
from scipy.signal import decimate
from scipy.spatial import cKDTree
from tqdm import tqdm

In [None]:
def find_N_nearest_nodes_2D(mesh, loc, N):
    dist_sq = np.sum((mesh.p.T - loc)**2, axis=1)
    
    # Use np.argsort to get the indices that would sort the distances in
    # ascending order, and then select the first N indices.
    return np.argsort(dist_sq)[:N]
import copy


def initialize_2D_labyrinth(refined, mask, neur_loc, K_base, K_o, sites,bbox,gap_size,dt):
    K_o_start = copy.copy(K_o)
    K_o_source = K_o_start - K_base
    mesh = MeshTri.init_symmetric().refined(refined)
    mesh.p[0, :] = (mesh.p[0, :] - 0.5) * size[0]
    mesh.p[1, :] = (mesh.p[1, :] - 0.5) * size[1]
    nx, ny = mask.shape
    element = ElementTriP1()
    basis = Basis(mesh, element)
    total_mesh_area = (size[1]-size[0])**2
    total_elements = basis.mesh.t.shape[1]
    avg_element_area = total_mesh_area / total_elements
    side_len = np.sqrt(avg_element_area*4/(np.sqrt(3)))
    gap_size = int(gap_size/side_len)
    centers_of_elems = mesh.p[:, mesh.t].mean(axis=1).T
    coords_0 = np.concatenate((np.where(mask==0)[1][None],np.where(mask==0)[0][None]),axis = 0).T
    x_to_delete = (coords_0[:,0] - nx/2)*size[0]/nx
    y_to_delete = (coords_0[:,1] - ny/2)*size[1]/ny
    coords_to_delete = np.concatenate((x_to_delete[None],y_to_delete[None]),axis =0)
    print('precomputing masked space')
    t1 = time.perf_counter()

    cell_tree = cKDTree(centers_of_elems)
    distances, indices = cell_tree.query(coords_to_delete.T)
    unique_indices = np.unique(indices)
    print('precomputing complete, took ', round(time.perf_counter()-t1,3), ' s')
    basis.mesh = basis.mesh.remove_elements(np.unique(unique_indices))
    mesh = basis.mesh
    basis = Basis(mesh, element)
    sites =( size[1]-size[0])*(sites - bbox[1]/2)/bbox[1]
    center_index = np.zeros((neur_loc.shape[0]))
    for i in range(neur_loc.shape[0]):
        center_index[i] = np.argmin(np.sqrt((sites[:,0] - neur_loc[i][0])**2 + (sites[:,1] - neur_loc[i][1])**2))
    basis_for_gaps = copy.copy(basis)

    num_sources = neur_loc.shape[0]
    #gap_inds = np.zeros((num_sources, gap_size))
    basis_for_gaps = copy.copy(basis)
    first_row =find_N_nearest_nodes_2D(basis_for_gaps.mesh, neur_loc[0], gap_size)
    gap_inds = np.zeros((num_sources, first_row.shape[0]))
    gap_inds[0] = first_row
    if neur_loc.shape[0]>1:
        for i in range(1, num_sources):
            gap_inds[i] = find_N_nearest_nodes_2D(basis_for_gaps.mesh, neur_loc[i], gap_size)
    @BilinearForm
    def stiffness(u, v, w):
        return dot(grad(u), grad(v))

    @BilinearForm
    def mass(u, v, w):
        return u * v

    
    K = stiffness.assemble(basis)
    M = mass.assemble(basis)


    u_prev = np.zeros(basis.N)
    for neur in gap_inds:
        for idx in neur:
            u_prev[int(idx)] = K_o_source[i]
    
    A_mat = M + D_coef * dt * K
    F_vec = M @ u_prev

    u_next = solve(A_mat, F_vec)
    
    K_o_updated = K_o_start + K_base
    
    return u_next, K_o_updated, A_mat, M, mesh, gap_inds, basis, K

def diff_neumann_2D_fast_labyrinth(K_base, K_o_prev, u_prev, A, M, neur_ind):
    """
    Solves the 2D diffusion equation with zero-flux Neumann boundary conditions.
    """
    K_o = K_o_prev.copy()
    K_o_source = K_o - K_base
    for count, neur in enumerate(neur_ind):
        for idx in neur:
            u_prev[int(idx)] = K_o_source[count]

    u_next = solve(A, M@u_prev)

    for count, neur in enumerate(neur_ind):
        sum_of_values = 0
        for idx in neur:
            sum_of_values+=u_next[int(idx)]
        K_o_source[count] = (sum_of_values)/(neur.shape[0])
    
    K_o_updated = K_o_source + K_base
    return u_next, K_o_updated, A, M, neur_ind

In [None]:
if __name__ == __main__:
    bbox = (0, 250, 0, 250) 
    n_steps = 10
    K_base = 5.0
    K_o_init = np.array([[0],[0]])
    neur_loc = np.array([[0.001,0.0]])
    size = np.array([[-0.025],[0.025]])
    D_coef = 9e-9
    refined=7
    gap_size = 0.0005
    concentration = np.load('concentrations.npy')###K concentrations around stimulated neuron. Those will beadded to stated location and diffused
    mask = np.load('mask.npy')##you need an inverted mask of created labyrinth - a one in which 1's denote diffusible states and 0's denote obstacles
    sites = np.load('sites.npy')##sites of cells in labyrinth
    u_prev, K_o, A, M, mesh, neur_ind, basis, K = initialize_2D_labyrinth(refined,mask, neur_loc,K_base,K_o_init,sites,bbox,gap_size=gap_size,dt=dt )
    res = np.zeros((n_steps, mesh.p.shape[1]),dtype='float32')
    for i in tqdm(range(n_steps)):
        K_o[0] = concentration[i]
        u_prev, K_o, A, M, neur_ind,= diff_neumann_2D_fast_labyrinth(K_base,K_o,
                                                                    u_prev=u_prev,
                                                                    A=A,
                                                                    M=M,
                                                                    neur_ind=neur_ind,)
        res[i] = u_prev
###plotting
# global_vmin = 0
# global_vmax = 0.015
# from skfem.visuals.matplotlib import plot
# fig, ax = plt.subplots(figsize=(10,10))
# plot(basis, res[1001],ax = ax,colorbar={'orientation': 'horizontal'}, nrefs=2,animated = True,vmin=global_vmin, vmax=global_vmax,cmap='magma')
# artist = ax.collections[0]