In [None]:
import matplotlib.pyplot as plt
import plot_figures
from jax import vmap, random, numpy as jnp
from met_brewer import met_brew
import sys
import os

# Add the scripts folder to sys.path
sys.path.append(os.path.abspath("../"))

# Now you can import my_script
import spatial_embedings
import extra_initializers

In [None]:
grid = spatial_embedings.twod_grid(10, 10)


In [None]:
# Compute pairwise Euclidean distances
# Vectorized function to compute squared distance between all pairs
def pairwise_sq_distance(x, y):
    return spatial_embedings.sq_distance(x[0], x[1], y[0], y[1])

# Apply the pairwise squared distance function across all points
sq_pairwise_distances = vmap(lambda p: vmap(lambda q: pairwise_sq_distance(p, q))(grid))(grid)

# Extract upper triangle of distance matrix to avoid duplicates
upper_triangle = sq_pairwise_distances[jnp.triu_indices(len(sq_pairwise_distances), k=1)]

# Get unique squared distances
unique_sq_distances = jnp.unique(upper_triangle)
print(jnp.sqrt(unique_sq_distances))


In [None]:

sigma = 0.012



In [None]:
fig, ax = plt.subplots()
for sigma in [0.12, 0.012, 0.0012]:
    probs =1 / (1 + jnp.exp(unique_sq_distances / (4 * sigma))) 
    ax.plot(jnp.sqrt(unique_sq_distances), probs, label="$\\sigma=${}".format(sigma))

ax.set_title("Probability of connection between two neurons as function of distance")
ax.legend()
ax.set_xlabel("$d_{ji}$")
ax.set_ylabel("P($A_{ji}=1$)")
plt.savefig("sigma_prob.svg", format='svg')


In [None]:
n_rec = 100
grid_shape = (10,10)
key = random.PRNGKey(14443324584432312)
sigma = 0.012 
fixed_cell = grid[4*10+4,]
fixed_cell
# good keys
# 20025
# 14443324584432312

In [None]:
from matplotlib.gridspec import GridSpec

def prob(sq_dis, sigma):
    return 1 / (1 + jnp.exp(sq_dis / (4 * sigma)))

# Calculate distances and probabilities for each cell
distances = jnp.array([pairwise_sq_distance(cell, fixed_cell) for cell in grid])



# Create the figure and subplots

sigmas=[0.12,0.012,0.0012]


# Create the figure and GridSpec with an extra column for the colorbar
fig = plt.figure(figsize=(12, 4))
gs = GridSpec(1, 4, figure=fig, width_ratios=[1, 1, 1, 0.05])  # Last column is narrower for the colorbar

# Create subplots in the first 3 columns
axs = [fig.add_subplot(gs[0, i]) for i in range(3)]

# Iterate over each axis and sigma value
for ax, sigma in zip(axs, sigmas):
    probabilities = jnp.array([prob(dist, sigma=sigma) for dist in distances])
    
    # Reshape probabilities into a 10x10 grid for plotting
    prob_grid = probabilities.reshape(10, 10)
    prob_grid = prob_grid.at[4, 4].set(jnp.nan)

    # Plot the heatmap with a fixed color range from 0 to 1
    cax = ax.imshow(prob_grid, cmap="viridis", origin="upper", vmin=0, vmax=0.48)

    # Set plot title and labels
    ax.set_title(f"sigma = {sigma}")
    
    
    ax.invert_yaxis()  # Optional, to match usual grid orientation

    # Rescale the axis ticks from 0-10 to 0-1
    new_labels = jnp.linspace(0, 1, 6)  # Create labels from 0 to 1
    ax.set_xticks(jnp.arange(0, 12, 2))  # Original tick positions
    ax.set_yticks(jnp.arange(0, 12, 2))
    ax.set_xticklabels([f"{label:.1f}" for label in new_labels])  # Set new labels
    ax.set_yticklabels([f"{label:.1f}" for label in new_labels])

    ax.scatter(4, 4, color="red", edgecolor="black", s=100, label="Fixed Cell")


# Add a single shared colorbar in the dedicated last column
cbar = fig.colorbar(cax, ax=axs, orientation="vertical", fraction=0.05, pad=0.02, label="Probability", location="right", shrink=0.8)
# Highlight the fixed cell with a distinct color or marker

axs[0].legend(loc="upper right")

axs[0].set_ylabel("Y Coordinate")
fig.supxlabel("X Coordinate")
fig.suptitle("Probability map of receiving connections for a cell")
plt.savefig("sigma_probmap.svg", format="svg")




In [None]:
sigma = 0.012
subkey_1, subkey_2, key = random.split(key, 3)
cells_loc = extra_initializers.initialize_neurons_position(gridshape=grid_shape, key=subkey_1, n_rec=n_rec)()
M =extra_initializers.initialize_connectivity_mask(connectivity_rec_layer="local", gridshape=grid_shape,
                                                neuron_indices=cells_loc, key=subkey_2,
                                                n_rec=n_rec, sigma=sigma, sparsity=0.1
)()


In [None]:
index = 44
rec_cell_loc_ind = jnp.lexsort((cells_loc[:, 1], cells_loc[:, 0])) # sort by row, then column 
sorted_recurrent_weights =  M[jnp.ix_(rec_cell_loc_ind, rec_cell_loc_ind)]# (n_rec, n_rec) sorts both rows and columns
connection = jnp.where(sorted_recurrent_weights[jnp.array([44]), :] == 1.)[1]
connection // 10

In [None]:
plt.rcParams.update({'font.size': 18})
from matplotlib.colors import ListedColormap
fig,ax = plt.subplots(figsize=(7, 7))
probabilities = jnp.array([prob(dist, sigma=0.012) for dist in distances])
# Reshape probabilities into a 10x10 grid for plotting
prob_grid = probabilities.reshape(10, 10)
prob_grid = prob_grid.at[4, 4].set(jnp.nan)
# OKeeffe2
# Plot the heatmap with a fixed color range from 0 to 1
met_brew_colors = met_brew('OKeeffe2', n=256, brew_type='continuous')
print(met_brew_colors)
#met_brew_colors = met_brew_colors[::-1]
# Create a custom Matplotlib colormap from the met_brewer colors
cmap = ListedColormap(met_brew_colors)
cax = ax.imshow(prob_grid, cmap=cmap, origin="upper", vmin=0, vmax=0.45, interpolation='nearest')

# Set plot title and labels



ax.invert_yaxis()  # Optional, to match usual grid orientation

# Rescale the axis ticks from 0-10 to 0-1
new_labels = jnp.linspace(0, 1, 6)  # Create labels from 0 to 1
ax.set_xticks(jnp.arange(0, 12, 2))  # Original tick positions
ax.set_yticks(jnp.arange(0, 12, 2))
ax.set_xticklabels([f"{label:.1f}" for label in new_labels])  # Set new labels
ax.set_yticklabels([f"{label:.1f}" for label in new_labels])
ax.scatter(connection // 10, connection % 10, color="black", edgecolor="black", label="Post synaptic")
ax.scatter(4, 4, color="#D10F19", s=200, label="Pre synaptic")
ax.set_ylabel("Y Coordinate")
ax.set_xlabel("X Coordinate")
ax.legend()
fig.colorbar(cax, ax=ax, shrink=0.8, label="Probability of connection")  # Add color bar with label
fig.savefig("sigma_probmap_neurons.svg", format="svg")


In [None]:

# Create a figure and axis object
fig, ax = plt.subplots()

# Display the grid on the axis
cax = ax.imshow(grid, cmap='Greys', origin='upper', interpolation='nearest')

# Add a colorbar to the plot
#fig.colorbar(cax, ax=ax, label="Value (0 or 1)")

# Set the title of the plot
ax.set_title("Grid Visualization")

new_labels = jnp.linspace(0, 1, 6)  # Create labels from 0 to 1
ax.set_xticks(jnp.arange(0, 12, 2))  # Original tick positions
ax.set_yticks(jnp.arange(0, 12, 2))
ax.set_xticklabels([f"{label:.1f}" for label in new_labels])  # Set new labels
ax.set_yticklabels([f"{label:.1f}" for label in new_labels])
ax.scatter(4, 4, color="red", edgecolor="black", s=100, label="Fixed Cell")
# Display the plot
plt.show()


In [None]:
fig, ax = plt.subplots()
# Set a grayish background color for the entire grid
ax.set_facecolor('lightgray')

# Get the number of rows and columns in the grid
n_rows, n_cols = grid.shape

# Iterate over each cell in the grid and plot a dot
for row in range(n_rows):
    for col in range(n_cols):
        # Determine the color of the dot: black for 1, white for 0
        dot_color = 'black' if grid[row, col] == 1 else 'white'
        
        # Plot a dot at the grid position (col, row), with a size and color based on the value
        ax.scatter(col, row, color=dot_color, s=50, edgecolors='black', zorder=2)

# Optionally, remove grid lines
ax.grid(False)

# Set the limits for x and y axes (to prevent any extra space around the grid)
ax.set_xlim(-0.5, n_cols - 0.5)
ax.set_ylim(n_rows - 0.5, -0.5)

# Set the aspect ratio to make sure the cells are square
ax.set_aspect('equal')

# Set the title of the plot
ax.set_title("Grid Visualization with Dots")
ax.scatter(4, 4, color='red', s=50, edgecolors='red', zorder=2)
# Display the plot
plt.show()

In [None]:
# Create a figure and axis object
fig, ax = plt.subplots()

# Set the background to white
ax.set_facecolor('white')

# Get the number of rows and columns in the grid
n_rows, n_cols = grid.shape

# Iterate over each cell in the grid and add a red dot for cells with a value of 1
for row in range(n_rows):
    for col in range(n_cols):
        if grid[row, col] == 1:
            # Plot a red dot at the (col, row) position
            ax.scatter(col, row, color='#fe46a5', s=100, zorder=2)

# Add gridlines with gray color and customizable width
ax.set_xticks(jnp.arange(n_cols + 1))  # Add gridlines at each column position (including borders)
ax.set_yticks(jnp.arange(n_rows + 1))  # Add gridlines at each row position (including borders)

ax.grid(True, which='both', axis='both', color='gray', linestyle='-', linewidth=2)

# Set the limits for x and y axes (to prevent any extra space around the grid)
ax.set_xlim(-0.5, n_cols - 0.5)
ax.set_ylim(n_rows - 0.5, -0.5)

# Set the aspect ratio to make sure the cells are square
ax.set_aspect('equal')

# Set the title of the plot
ax.set_title("Grid Visualization with Red Dots and Gridlines at Every Cell")

# Display the plot
plt.show()

In [None]:
# average_degrees = []
# for i in range(0, 10000):
#     subkey_1, subkey_2, key = random.split(key, 3)
#     cells_loc = extra_initializers.initialize_neurons_position(gridshape=grid_shape, key=subkey_1, n_rec=n_rec)()
#     M =extra_initializers.initialize_connectivity_mask(connectivity_rec_layer="local", gridshape=grid_shape,
#                                                     neuron_indices=cells_loc, key=subkey_2,
#                                                     n_rec=n_rec, sigma=sigma, sparsity=0.1
#     )()
#     average_degrees.append(M.sum()/(n_rec * n_rec))

In [None]:
# for i in average_degrees:
#     print(i)

In [None]:
# Convert each jax.numpy array to a float and combine into a single numpy array
values = jnp.array([100 * float(arr) for arr in average_degrees])

fig, ax = plt.subplots()

# Alternatively, you could use a histogram for more detail

ax.hist(values, bins=10, edgecolor='black', alpha=0.7, density=True)
ax.set_xlabel("Values")
ax.set_ylabel("Frequency")
ax.set_title("Histogram of Distribution")
ax.set_xlim(9, 12)


In [None]:
fig.ima