In [None]:
import numpy as np
from tqdm import tqdm
import scipy.io as sio
from scipy import stats
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.ndimage import label, sum as ndi_sum
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.patches import Rectangle
from scipy.ndimage import gaussian_filter
from scipy.ndimage import binary_erosion
from scipy.stats import pearsonr
from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import StandardScaler

folder_path = "/path/to/your/data/"

# V4 digital twin raw data

In [None]:
# Load the Dispersity results
dispersity = sio.loadmat(folder_path + "V4DT/Dispersity_results/dispersity.mat")['cmap'] # (128, 128, 3) <class 'numpy.ndarray'>
dispersity_raw = sio.loadmat(folder_path + "V4DT/Dispersity_results/FDraw.mat")["FD"] # (128, 128) <class 'numpy.ndarray'>
# Load the roi for visual inspection with Transpose
roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
# Load the RF results
polar_angle = sio.loadmat(folder_path + "V4DT/RF_results/polar_angle.mat")['map'] # (128, 128, 3) <class 'numpy.ndarray'>
# theta_map_raw: [0.0 60.1088362]; mean=31.2365; std=13.07726
theta = sio.loadmat(folder_path + "V4DT/RF_results/theta_map_raw.mat")['theta_map_raw'] # (128, 128); or theta_map.mat['theta_map'] after color coding
eccentricity = sio.loadmat(folder_path + "V4DT/RF_results/eccentricity.mat")['map'] # (128, 128, 3) <class 'numpy.ndarray'>
# r_map_raw: [12.7849 36.81318]; mean=24.77; std=5.7943; # r_map: [233, 767]; mean=501; std=129
r = sio.loadmat(folder_path + "V4DT/RF_results/r_map_raw.mat")['r_map_raw'] # (128, 128); or r_map.mat['r_map'] after color coding
r[roi!=0] -= np.min(r[roi!=0]) # adjusted r_map_raw: [0, 24.0283]; mean=11.985257; std=5.7943
# Load the V4 digital twin responses to 50k images
response = np.load(folder_path + "V4DT/PRsp.npy") # (50000, 128, 128) <class 'numpy.ndarray'>
mean_rsp = np.zeros((128, 128))
for i in range(128):
    for j in range(128):
        mean_rsp[i, j] = np.mean(response[:, i, j])
# visualize the results
fig, axes = plt.subplots(1, 8, figsize=(16, 2))
axes[0].imshow(np.flip(polar_angle, axis=(0, 1))[30:120, 35:110, :])
axes[0].axis('off')  # Turn off the axis
axes[0].set_title('Polar Angle')
axes[1].imshow(np.flip(theta)[30:120, 35:110]) # polar angle range: 0.0 60.10883624961076
axes[1].axis('off')  # Turn off the axis
axes[1].set_title('Theta')
axes[2].imshow(np.flip(eccentricity, axis=(0, 1))[30:120, 35:110, :])
axes[2].axis('off')  # Turn off the axis
axes[2].set_title('Eccentricity')
axes[3].imshow(np.flip(r)[30:120, 35:110]) # eccentricity range: 0.0 24.028273811974355
axes[3].axis('off')  # Turn off the axis
axes[3].set_title('R')
axes[4].imshow(np.flip(dispersity, axis=(0, 1))[30:120, 35:110, :])
axes[4].axis('off')  # Turn off the axis
axes[4].set_title('Dispersity')
axes[5].imshow(np.flip(dispersity_raw)[30:120, 35:110])
axes[5].axis('off')  # Turn off the axis
axes[5].set_title('dispersity_raw')
axes[6].imshow(np.flip(roi)[30:120, 35:110])
axes[6].axis('off')  # Turn off the axis
axes[6].set_title('ROI')
axes[7].imshow(np.flip(mean_rsp)[30:120, 35:110])
axes[7].axis('off')  # Turn off the axis
axes[7].set_title('V4 Mean Rsp')
plt.tight_layout()
plt.show()

# Assign V4 neuronal column onto simulated map grid

In [None]:
# Calculate the correlation between V4 columnar responses and the learned SOM weights
def V4_assignment(folder_path, name):
    # load essential data
    som = np.load(folder_path + name + "/weights.npy") # (60, 60, 50000(+))
    som = som[:, :, :50000] # (60, 60, 50250) to (60, 60, 50000)
    v4_rsp = np.load(folder_path + "V4DT/PRsp.npy") # (50000, 128, 128)
    roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
    rsp = np.zeros((int(np.sum(roi)), 50000)) # (3048, 50000)
    index = 0
    for i in range(roi.shape[0]):
        for j in range(roi.shape[1]):
            if roi[i, j] == 1:
                rsp[index, :] = v4_rsp[:, i, j]
                index += 1
    assert rsp.shape[0] == 3048
    del v4_rsp, roi

    # compute correlation
    assign = np.zeros((60, 60, 3048))
    for i in tqdm(range(60), desc="cor..."):
        for j in range(60):
            for k in range(3048):
                assign[i, j, k] = np.corrcoef(som[i, j, :], rsp[k, :])[0, 1]
    if name == "RSOM":
        np.save(folder_path + name +  "/assign_rsom.npy", assign)
    else:
        np.save(folder_path + name +  "/assign_som.npy", assign)
        assert name == "SOM"

V4_assignment(folder_path, name="RSOM")
V4_assignment(folder_path, name="SOM")

In [None]:
# Based on the most correlated V4 columnar responses, assign the V4 columnar properties to the (R)SOM units
v4idx = np.load(folder_path + "V4DT/idx.npy") # (128, 128), perserved only with 16 domains of 2875 neuronal columns, not the full 3048 neuronal columns
roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128), of which 3048 neuronal columns are rois
selected = np.zeros((np.sum(v4idx != 0))) # (2875,) the 0-based index of the selected 2875 columns in the 3048 columns
selected_ij = np.zeros((int(np.sum(v4idx != 0)), 2)) # (2875, 2) the 0-based i, j positional index of the selected 2875 columns
index_selected = 0
index_roi = 0
for i in range(128):
    for j in range(128):
        if roi[i, j] != 0:
            if v4idx[i, j] != 0:
                selected[index_selected] = index_roi
                selected_ij[index_selected, 0] = i
                selected_ij[index_selected, 1] = j
                index_selected += 1
            index_roi += 1
assert index_selected == 2875
assert index_roi == 3048

name = "RSOM"
correlation = np.load(folder_path + name + "/assign_rsom.npy") # (60, 60, 3048)
correlation = correlation[:, :, selected.astype(int)] # (60, 60, 2875)
matched_idx_rsom = np.zeros((60, 60))
polar_angle_rsom = np.zeros((60, 60, 3))
eccentricity_rsom = np.zeros((60, 60, 3))
dispersity_rsom = np.zeros((60, 60, 3))
dispersity_raw_rsom = np.zeros((60, 60))
for i in range(60):
    for j in range(60):
        k = np.argmax(correlation[i, j, :]) # index of the entry with the maximum correlation
        x = int(selected_ij[k, 0])
        y = int(selected_ij[k, 1])
        matched_idx_rsom[i, j] = v4idx[x, y]
        polar_angle_rsom[i, j, :] = polar_angle[x, y, :]
        eccentricity_rsom[i, j, :] = eccentricity[x, y, :]
        dispersity_rsom[i, j, :] = dispersity[x, y, :]
        dispersity_raw_rsom[i, j] = dispersity_raw[x, y]
np.savez(folder_path + name + "/assigned.npz", 
         matched_idx_rsom=matched_idx_rsom, 
         polar_angle_rsom=polar_angle_rsom, 
         eccentricity_rsom=eccentricity_rsom, 
         dispersity_rsom=dispersity_rsom,
         dispersity_raw_rsom=dispersity_raw_rsom)

name = "SOM"
correlation = np.load(folder_path + name + "/assign_som.npy") # (60, 60, 3048)
correlation = correlation[:, :, selected.astype(int)] # (60, 60, 2875)
matched_idx_som = np.zeros((60, 60))
polar_angle_som = np.zeros((60, 60, 3))
eccentricity_som = np.zeros((60, 60, 3))
dispersity_som = np.zeros((60, 60, 3))
dispersity_raw_som = np.zeros((60, 60))
for i in range(60):
    for j in range(60):
        k = np.argmax(correlation[i, j, :])
        x = int(selected_ij[k, 0])
        y = int(selected_ij[k, 1])
        matched_idx_som[i, j] = v4idx[x, y]
        polar_angle_som[i, j, :] = polar_angle[x, y, :]
        eccentricity_som[i, j, :] = eccentricity[x, y, :]
        dispersity_som[i, j, :] = dispersity[x, y, :]
        dispersity_raw_som[i, j] = dispersity_raw[x, y]
np.savez(folder_path + name + "/assigned.npz", 
         matched_idx_som=matched_idx_som,
         polar_angle_som=polar_angle_som,
         eccentricity_som=eccentricity_som,
         dispersity_som=dispersity_som,
         dispersity_raw_som=dispersity_raw_som)

# Domain map

In [None]:
# given a matched som grid corresponding domain number, remove the unconnected components smaller than a threshold
def connected_components(matched, threshold, return_largest=False):
    # Define the structure (connectivity) for connected components
    structure = np.ones((3, 3), dtype=int)  # 8-connectivity, meaning that both adjacent and diagonal pixels are considered as neighbors
    connected = np.copy(matched) # Initialize an array to store the output
    if return_largest: largest = np.copy(matched) # Initialize an array to store the largest connected component
    for class_num in np.unique(matched): # Process each class separately
        if class_num == 0: continue # Skip the background
        class_mask = (matched == class_num) # Create a binary mask for the current class
        labeled_array, num_features = label(class_mask, structure=structure) # Find connected components in the binary mask
        component_sizes = ndi_sum(class_mask, labeled_array, index=range(1, num_features + 1)) # Compute the size of each component
        for i, size in enumerate(component_sizes): # Zero out components smaller than the threshold
            if size < threshold: connected[labeled_array == (i + 1)] = 0
        if return_largest:
            for i, size in enumerate(component_sizes):
                if size != np.max(component_sizes): largest[labeled_array == (i + 1)] = 0
    if return_largest: return connected, largest
    else: return connected

# domain color map
v4idx = np.load(folder_path + "V4DT/idx.npy")
roi = np.load(folder_path + "V4DT/ROI.npy").T # transpose to match V4idx
v4idx = connected_components(v4idx, threshold=10, return_largest=False) + 1
for i in range(roi.shape[0]):
    for j in range(roi.shape[1]):
        if roi[i, j] != 1: v4idx[i, j] = 0 # non roi area
cc_color_v4 = np.flipud(np.fliplr(v4idx[14:90, 23:90])) # zoom in to the roi area and then do bottom-top rotate, left-right flip
cc_color_som = connected_components(matched_idx_som.T, 10)
cc_color_rsom = connected_components(matched_idx_rsom.T, 10)

# Domain size & adjacency matrix evaluation

In [None]:
# given a matched som grid map detailing grid domain label, visualize all connected components' domain label and size
def connected_components_visual(matched):
    # Define the structure (connectivity) for connected components
    structure = np.ones((3, 3), dtype=int)  # 8-connectivity, meaning that both adjacent and diagonal pixels are considered as neighbors
    sizes = [] # all component sizes
    labels = [] # all component labels
    adjacency = [] # a list of all components, each entry detailing adjacent domains of that component
    for class_num in np.unique(matched): # Process each class separately
        adjacency_domain = []
        if class_num == 0: continue # Skip the background
        class_mask = (matched == class_num) # Create a binary mask for the current class
        labeled_array, num_features = label(class_mask, structure=structure) # Find connected components in the binary mask
        component_sizes = ndi_sum(class_mask, labeled_array, index=range(1, num_features + 1)) # Compute the size of each component
        # consider every connected component's adjacency pattern
        for k, size in enumerate(component_sizes):
            adjacency_single = []
            mask_k = np.zeros((np.shape(matched)[0], np.shape(matched)[1]))
            mask_k[labeled_array == (k + 1)] = 1
            idx = np.pad(matched, ((1, 1), (1, 1)), 'constant', constant_values=0) # zero padding for the boundary: top, bottom, left, right
            for i in range(np.shape(mask_k)[0]):
                for j in range(np.shape(mask_k)[1]):
                    if mask_k[i, j] == 1: # considering one targted within-domain voxel at once
                        m = i + 1 # the same i, j index in the padded idx matrix
                        n = j + 1
                        if int(idx[m + 1, n]) != 0 and int(idx[m + 1, n]) != int(idx[m, n]): # down
                            adjacency_single.append(int(idx[m + 1, n]))
                        if int(idx[m - 1, n]) != 0 and int(idx[m - 1, n]) != int(idx[m, n]): # up
                            adjacency_single.append(int(idx[m - 1, n]))
                        if int(idx[m, n + 1]) != 0 and int(idx[m, n + 1]) != int(idx[m, n]): # right
                            adjacency_single.append(int(idx[m, n + 1]))
                        if int(idx[m, n - 1]) != 0 and int(idx[m, n - 1]) != int(idx[m, n]): # left
                            adjacency_single.append(int(idx[m, n - 1]))
                        if int(idx[m + 1, n + 1]) != 0 and int(idx[m + 1, n + 1]) != int(idx[m, n]): # down right
                            adjacency_single.append(int(idx[m + 1, n + 1]))
                        if int(idx[m + 1, n - 1]) != 0 and int(idx[m + 1, n - 1]) != int(idx[m, n]): # down left
                            adjacency_single.append(int(idx[m + 1, n - 1]))
                        if int(idx[m - 1, n + 1]) != 0 and int(idx[m - 1, n + 1]) != int(idx[m, n]): # up right
                            adjacency_single.append(int(idx[m - 1, n + 1]))
                        if int(idx[m - 1, n - 1]) != 0 and int(idx[m - 1, n - 1]) != int(idx[m, n]): # up left
                            adjacency_single.append(int(idx[m - 1, n - 1]))
            adjacency_domain.append(np.unique(adjacency_single))
        # sort adjacency_domain, component_sizes based on component_sizes order
        sorted_indices = np.argsort(component_sizes)[::-1]
        component_sizes = component_sizes[sorted_indices]
        adjacency_domain = [adjacency_domain[i] for i in sorted_indices]
        # component statistics
        sizes.extend(component_sizes)
        labels.extend([class_num] * num_features)
        adjacency.extend(adjacency_domain)

    sizes = np.array(sizes).astype(int)
    labels = np.array(labels).astype(int)
    # sizes for each domain
    domain_sizes = np.zeros((16))
    total_sizes = np.sum(sizes)
    for i in range(16): 
        domain_sizes[i] = np.mean(sizes[labels == i + 1]) / total_sizes
    # sorted_indices = np.argsort(domain_sizes)[::-1]
    return sizes, labels, adjacency, domain_sizes

# domain adjacency matrix correlation against V4 benchmark
def adjacency_matrix(idx, sizes):
    # find the domain adjacency matrix
    num_domains = 16
    domains = np.unique(idx).astype(int) # number of domains, may include 0, may not include all domains from 1 to 16
    if domains[0] == 0: domains = domains[1:] # exclude 0
    adjacency_matrix = np.zeros((num_domains, num_domains))
    size = idx.shape[0]
    assert size == idx.shape[1]
    assert len(sizes) == num_domains
    idx = np.pad(idx, ((1, 1), (1, 1)), 'constant', constant_values=0) # zero padding for the boundary: top, bottom, left, right
    for i in range(size):
        for j in range(size):
            m = i + 1 # the same i, j index in the padded idx matrix
            n = j + 1
            if idx[m, n] != 0: # considering one targted within-domain neuron at once
                if int(idx[m + 1, n]) != 0 and int(idx[m + 1, n]) != int(idx[m, n]): # down
                    adjacency_matrix[int(idx[m, n]) - 1, int(idx[m + 1, n]) - 1] += 1
                    adjacency_matrix[int(idx[m + 1, n]) - 1, int(idx[m, n]) - 1] += 1
                if int(idx[m - 1, n]) != 0 and int(idx[m - 1, n]) != int(idx[m, n]): # up
                    adjacency_matrix[int(idx[m, n]) - 1, int(idx[m - 1, n]) - 1] += 1
                    adjacency_matrix[int(idx[m - 1, n]) - 1, int(idx[m, n]) - 1] += 1
                if int(idx[m, n + 1]) != 0 and int(idx[m, n + 1]) != int(idx[m, n]): # right
                    adjacency_matrix[int(idx[m, n]) - 1, int(idx[m, n + 1]) - 1] += 1
                    adjacency_matrix[int(idx[m, n + 1]) - 1, int(idx[m, n]) - 1] += 1
                if int(idx[m, n - 1]) != 0 and int(idx[m, n - 1]) != int(idx[m, n]): # left
                    adjacency_matrix[int(idx[m, n]) - 1, int(idx[m, n - 1]) - 1] += 1
                    adjacency_matrix[int(idx[m, n - 1]) - 1, int(idx[m, n]) - 1] += 1
                if int(idx[m + 1, n + 1]) != 0 and int(idx[m + 1, n + 1]) != int(idx[m, n]): # down right
                    adjacency_matrix[int(idx[m, n]) - 1, int(idx[m + 1, n + 1]) - 1] += 1
                    adjacency_matrix[int(idx[m + 1, n + 1]) - 1, int(idx[m, n]) - 1] += 1
                if int(idx[m + 1, n - 1]) != 0 and int(idx[m + 1, n - 1]) != int(idx[m, n]): # down left
                    adjacency_matrix[int(idx[m, n]) - 1, int(idx[m + 1, n - 1]) - 1] += 1
                    adjacency_matrix[int(idx[m + 1, n - 1]) - 1, int(idx[m, n]) - 1] += 1
                if int(idx[m - 1, n + 1]) != 0 and int(idx[m - 1, n + 1]) != int(idx[m, n]): # up right
                    adjacency_matrix[int(idx[m, n]) - 1, int(idx[m - 1, n + 1]) - 1] += 1
                    adjacency_matrix[int(idx[m - 1, n + 1]) - 1, int(idx[m, n]) - 1] += 1
                if int(idx[m - 1, n - 1]) != 0 and int(idx[m - 1, n - 1]) != int(idx[m, n]): # up left
                    adjacency_matrix[int(idx[m, n]) - 1, int(idx[m - 1, n - 1]) - 1] += 1
                    adjacency_matrix[int(idx[m - 1, n - 1]) - 1, int(idx[m, n]) - 1] += 1
    # for d in range(num_domains): adjacency_matrix[d, :] /= np.sum(adjacency_matrix[d, :]) # adjacency_matrix[d, :] *= sizes[d]
    adjacency_matrix /= np.max(adjacency_matrix)
    return adjacency_matrix

# V4 benchmark
v4idx = np.load(folder_path + "V4DT/idx.npy") # (128, 128), perserved only with 16 domains, not the full 3048 neurons
v4idx = connected_components(v4idx, 10, return_largest=False)
sizes_v4, labels_v4, adjacency_v4, domain_sizes_v4 = connected_components_visual(v4idx)
adj_matrix_v4 = adjacency_matrix(v4idx, domain_sizes_v4)
print("V4: {} of components".format(len(sizes_v4)))

# SOM
sizes_som, labels_som, adjacency_som, domain_sizes_som = connected_components_visual(cc_color_som)
# domain adjacency score
adj_matrix_som = adjacency_matrix(cc_color_som, domain_sizes_som)
adj_cor = stats.pearsonr(adj_matrix_som.flatten(), adj_matrix_v4.flatten())[0]
# visualizing the connected components' domain label and size
print("SOM: {} of components, with a domain size cor: {:.3f}; adjacency matrix cor: {:.3f}".format(len(sizes_som), stats.pearsonr(domain_sizes_v4, domain_sizes_som)[0], adj_cor))

# RSOM
sizes_rsom, labels_rsom, adjacency_rsom, domain_sizes_rsom = connected_components_visual(cc_color_rsom)
# domain adjacency score
adj_matrix_rsom = adjacency_matrix(cc_color_rsom, domain_sizes_rsom)
adj_cor = stats.pearsonr(adj_matrix_rsom.flatten(), adj_matrix_v4.flatten())[0]
# visualizing the connected components' domain label and size
print("RSOM: {} of components, with a domain size cor: {:.3f}; adjacency matrix cor: {:.3f}".format(len(sizes_rsom), stats.pearsonr(domain_sizes_v4, domain_sizes_rsom)[0], adj_cor))

In [None]:
# given a matched som grid map detailing grid domain label, construct a relative positioning matrix
# num_domains by num_domains matrix, each entry detailing the averaged distance between two domains
def relative_positioning_matrix(matched):
    num_domains = 16
    size = np.shape(matched)[0]
    assert size == np.shape(matched)[1] # squared map size
    if size == 128: # V4 benchmark
        roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128), of which 3048 neuronal columns are rois
    else: # SOM or RSOM simulation
        roi = np.ones((size, size))
    position_matrix = np.zeros((num_domains, num_domains)) # initialize the relative positioning matrix
    count_matrix = np.zeros((num_domains, num_domains)) # initialize the count matrix to record the number of pairs
    unique_domains = np.unique(matched).astype(int) # number of domains, may include 0, may not include all domains from 1 to 16
    for class_num in unique_domains: # Process each domain separately: 0 (erased unconnected), 1, 2, ..., 16
        if class_num != 0: # Skip the erased components
            class_mask = np.zeros((size, size))
            class_mask[matched == class_num] = 1 # Create a binary mask for the current class
            class_mask = class_mask.astype(int)
            # iterate ROIs within the current domain
            for i in range(size):
                for j in range(size):
                    if class_mask[i, j] == 1:
                        # iterate ROIs of all other domains
                        for m in range(size):
                            for n in range(size):
                                if roi[m, n] == 1 and class_mask[m, n] == 0 and int(matched[m, n]) != 0: # not the same / erased domain
                                    assert 0 <= class_num - 1 < num_domains
                                    assert 0 <= matched[m, n] - 1 < num_domains
                                    dis = np.sqrt((i - m) ** 2 + (j - n) ** 2)
                                    position_matrix[class_num - 1, matched[m, n] - 1] += dis
                                    position_matrix[matched[m, n] - 1, class_num - 1] += dis
                                    count_matrix[class_num - 1, matched[m, n] - 1] += 1
                                    count_matrix[matched[m, n] - 1, class_num - 1] += 1
    # average the position matrix
    for i in range(num_domains):
        for j in range(num_domains):
            if count_matrix[i, j] != 0:
                position_matrix[i, j] /= count_matrix[i, j]
    # normalize the position matrix
    position_matrix /= np.max(position_matrix)
    # for i in range(num_domains): position_matrix[i, :] /= np.max(position_matrix[i, :])
    return position_matrix

position_som = relative_positioning_matrix(cc_color_som.astype(int))
position_rsom = relative_positioning_matrix(cc_color_rsom.astype(int))
position_v4 = relative_positioning_matrix(v4idx.astype(int))
# relative positioning score
print("relative positioning matrix correlation between SOM and V4", stats.pearsonr(position_som.flatten(), position_v4.flatten())[0])
print("relative positioning matrix correlation between RSOM and V4", stats.pearsonr(position_rsom.flatten(), position_v4.flatten())[0])

# Feature dispersity measurements

In [None]:
names = ["V4", "SOM", "RSOM"]
fd_raw = False
k = 20 # patch size
for name in names:
    if not fd_raw: # compute high / low dispersity clustering
        if name == "V4":
            FD_map = sio.loadmat(folder_path + "V4DT/Dispersity_results/FDraw.mat")["FD"] # (128, 128) <class 'numpy.ndarray'>
            roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
            roi_indices = np.where(roi == 1)
            roi_values = FD_map[roi_indices]
        elif name == "SOM":
            FD_map = np.load(folder_path + name + "/assigned.npz")["dispersity_raw_som"] # (60, 60)
            roi = np.ones((60, 60)).astype(int)
            roi_indices = np.where(roi == 1)
            roi_values = FD_map[roi_indices]
        elif name == "RSOM":
            FD_map = np.load(folder_path + name + "/assigned.npz")["dispersity_raw_rsom"] # (60, 60)
            roi = np.ones((60, 60)).astype(int)
            roi_indices = np.where(roi == 1)
            roi_values = FD_map[roi_indices]
        # Perform hierarchical clustering
        Z = linkage(roi_values.reshape(-1, 1), method='ward')  # Ward's method for clustering
        cluster_labels = fcluster(Z, t=2, criterion='maxclust')  # Get 2 clusters

        # Identify the cluster with higher dispersity
        cluster_1_mean = roi_values[cluster_labels == 1].mean()
        cluster_2_mean = roi_values[cluster_labels == 2].mean()

        high_dispersity_cluster = 1 if cluster_1_mean > cluster_2_mean else 2

        # Initialize clusters array
        clusters = np.zeros_like(FD_map, dtype=int)  # Start with all zeros (out-of-ROI and low-dispersity)

        # Assign 1 to grids in the high-dispersity cluster
        clusters[roi_indices] = (cluster_labels == high_dispersity_cluster).astype(int)
    if name == "V4":
        fd = [] # patch mean fd values
        num_patches = 0
        roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
        while num_patches < 1000000:
            # generate random coordinate as the top-left corner of the current patch
            x = np.random.randint(low=14, high=90-k)
            y = np.random.randint(low=23, high=90-k)
            roi_in = np.sum(roi[x:x+k, y:y+k])
            if roi_in >= (0.5 * k * k): 
                fd.append(np.mean(clusters[x:x+k, y:y+k]))
                num_patches += 1
        fd = np.array(fd)
        dispersion = np.var(fd) / np.mean(fd)
        print("dispersion index of V4 benchmark : ", dispersion)
    else:
        fd = [] # patch mean fd values
        num_patches = 0
        while num_patches < 1000000:
            # generate random coordinate as the top-left corner of the current patch
            x = np.random.randint(low=0, high=60-k)
            y = np.random.randint(low=0, high=60-k)
            fd.append(np.mean(clusters[x:x+k, y:y+k]))
            num_patches += 1
        fd = np.array(fd)
        dispersion = np.var(fd) / np.mean(fd)
        print("dispersion index of", name, ": ", dispersion)

# Pairwise columns / grids tuning correlation as a function of map distance

In [None]:
# pairwise grid tuning curve coorelation as a function of exactly the map physical distance between them
def cordis(rsp):
    # check input response
    if rsp.shape[1] == rsp.shape[2] == 128: # (50000, 128, 128) V4 benchmark
        roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
    else: 
        # roi = np.load("/TDANN41_roi.npy")
        roi = np.ones((60, 60)).astype(int) # artificial som map weight response
    size = roi.shape[0]
    assert size == roi.shape[1]
    num_roi = int(np.sum(roi))
    # response, roi data preparation
    response = np.zeros((num_roi, 50000))
    position = np.zeros((num_roi, 2))
    voxel_index = 0
    for i in range(size):
        for j in range(size):
            if roi[i, j] == 1:
                if roi.shape[0] == roi.shape[1] == 128: response[voxel_index, :] = rsp[:, i, j] # V4 benchmark
                elif roi.shape[0] == roi.shape[1] == 60: response[voxel_index, :] = rsp[i, j, :] # artificial som map weight response
                position[voxel_index, 0] = i
                position[voxel_index, 1] = j
                voxel_index += 1
    assert voxel_index == num_roi
    # calculate the correlation matrix
    cordis_matrix = np.zeros((num_roi, num_roi, 2)) # 0th entry for correlation, 1st entry for pairwise map distance
    for i in tqdm(range(num_roi), desc="calculating correlation matrix...", disable=False):
        for j in range(i, num_roi):
            # correlation calculation
            if i != j:
                cordis_matrix[i, j, 0] = pearsonr(response[i, :], response[j, :])[0]
                cordis_matrix[j, i, 0] = cordis_matrix[i, j, 0]
            else: cordis_matrix[i, j, 0] = 1.0
            # distance calculation
            cordis_matrix[i, j, 1] = np.sqrt((position[i, 0] - position[j, 0]) ** 2 + (position[i, 1] - position[j, 1]) ** 2)
            cordis_matrix[j, i, 1] = cordis_matrix[i, j, 1]
    return cordis_matrix

# SOM
name = "SOM"
rsp = np.load(folder_path + name + "/weights.npy") # (60, 60, 50000(+))
assert rsp.shape[2] == 50000
cordis_matrix = cordis(rsp)
np.save(folder_path + name + "/cordis_V4_som.npy", cordis_matrix)

# V4 benchmark
rsp = np.load(folder_path + "V4DT/PRsp.npy") # (50000, 128, 128) <class 'numpy.ndarray'>
cordis_matrix = cordis(rsp)
np.save(folder_path + "V4DT/cordis_v4_benchmark.npy", cordis_matrix)

# RSOM
name = "RSOM"
rsp = np.load(folder_path + name + "/weights.npy") # (60, 60, 50000(+))
rsp = rsp[:, :, :50000] # (60, 60, 50000)
assert rsp.shape[2] == 50000
cordis_matrix = cordis(rsp)
np.save(folder_path + name + "/cordis_V4_rsom.npy", cordis_matrix)

del rsp, cordis_matrix # clear the memory

In [None]:
# Segment the pairwise distance into 100 segments and calculate the average correlation and standard deviation of all correlation estimates within each segment
def cordis_avg(cordis_matrix, segment_num=100):
    num_roi = cordis_matrix.shape[0]
    assert num_roi == cordis_matrix.shape[1]
    pd = cordis_matrix[:, :, 1] # distance (3048, 3048)
    pc = cordis_matrix[:, :, 0] # correlation (3048, 3048)
    dismax = np.max(pd) # maximum distance
    dismin = np.min(pd) # minimum distance
    segment_len = (dismax - dismin) / segment_num # length of one single segment
    segments = np.zeros((num_roi, segment_num, 2)) # store the mean correlation and standard deviation of all correlation estimates within all samples of each segment
    for i in tqdm(range(num_roi), desc="sorting...", disable=True):
        sorted_indices = np.argsort(pd[i, :]) # sort the distance from smallest to largest, get its indices
        pd[i, :] = pd[i, sorted_indices] # distance matrix row vector sorted, from smallest to largest
        pc[i, :] = pc[i, sorted_indices] # correlation matrix row vector sorted, from smallest to largest
        for s in range(segment_num):
            # indices of all entries within the current pairwise distance segment
            segment_indices = np.where((pd[i, :] >= dismin + s * segment_len) & (pd[i, :] < dismin + (s + 1) * segment_len))[0]
            if len(segment_indices) > 0:
                selected_cor = pc[i, segment_indices] # current segment samples' correlation estimates
                segments[i, s, 0] = np.mean(selected_cor) # average
                segments[i, s, 1] = np.std(selected_cor) # standard deviation
            else:
                segments[i, s, 0] = -1
                segments[i, s, 1] = -1
    # compute the average of all ROIs' correlation estimates (mean and std) within each segment
    mean_std_rois = np.zeros((segment_num, 3))
    for i in range(num_roi):
        for s in range(segment_num):
            if segments[i, s, 0] != -1 and segments[i, s, 1] != -1:
                mean_std_rois[s, 0] += segments[i, s, 0] # mean
                mean_std_rois[s, 1] += segments[i, s, 1] # std
                mean_std_rois[s, 2] += 1 # count
    mean_std_rois[:, 0] /= mean_std_rois[:, 2] # average mean
    mean_std_rois[:, 1] /= mean_std_rois[:, 2] # average std
    mean_std_rois = mean_std_rois[:, :2] # discard the count
    return mean_std_rois

cordis_som = cordis_avg(np.load(folder_path + "SOM/cordis_V4_som.npy"), segment_num=110) # SOM
cordis_v4 = cordis_avg(np.load(folder_path + "V4DT/cordis_v4_benchmark.npy"), segment_num=100) # V4 benchmark
cordis_rsom = cordis_avg(np.load(folder_path + "RSOM/cordis_V4_rsom.npy"), segment_num=110) # RSOM
print(pearsonr(cordis_som[:100, 0], cordis_v4[:, 0])[0]) # SOM vs V4
print(pearsonr(cordis_rsom[:100, 0], cordis_v4[:, 0])[0]) # RSOM vs V4

# Visualizations

In [None]:
# SOM: polar_angle_som, eccentricity_som, dispersity_som
som_matched = np.load(folder_path + "SOM/assigned.npz")
polar_angle_som = som_matched['polar_angle_som']
eccentricity_som = som_matched['eccentricity_som']
dispersity_som = som_matched['dispersity_som']
polar_angle_som = np.swapaxes(polar_angle_som, 0, 1) # (60, 60, 3)
polar_angle_grey = 0.2989 * polar_angle_som[:, :, 0] + 0.5870 * polar_angle_som[:, :, 1] + 0.1140 * polar_angle_som[:, :, 2]
polar_angle_grey = gaussian_filter(polar_angle_grey, sigma=4)  # Adjust sigma for more or less smoothing, shape (60, 60)
eccentricity_som = np.swapaxes(eccentricity_som, 0, 1) # (60, 60, 3)
eccentricity_grey = 0.2989 * eccentricity_som[:, :, 0] + 0.5870 * eccentricity_som[:, :, 1] + 0.1140 * eccentricity_som[:, :, 2]
eccentricity_grey = gaussian_filter(eccentricity_grey, sigma=4)  # (60, 60)
dispersity_som = np.swapaxes(dispersity_som, 0, 1) # (60, 60, 3)
cmap_black = ListedColormap(['black', 'red', 'green', 'blue', 'yellow', 'magenta', 'cyan', 'orange', 'purple', 
                             'brown', 'pink', 'lime', 'teal', 'navy', 'gold', 'silver', 'coral']) # black for unassigned 0
domain_colors = ['red', 'green', 'blue', 'yellow', 'magenta', 'cyan', 'orange', 'purple', 'brown', 'pink', 'lime', 'teal', 'navy', 'gold', 'silver', 'coral']

fig, axes = plt.subplots(1, 7, figsize=(21, 3))
plt.subplots_adjust(wspace=1) # Adjust the horizontal distance between subplots
fs_title = 16 # set font size for the title
"""
axes[0].bar(np.arange(len(labels_som)), sizes_som, color=cmap_black(labels_som), width=1)
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("Component Size", fontsize=fs_title)
"""
axes[0].plot(cordis_som[:100, 0], color='black')
axes[0].fill_between(np.arange((len(cordis_som[:100, 0]))), cordis_som[:100, 0]-cordis_som[:100, 1], cordis_som[:100, 0]+cordis_som[:100, 1], color='black', alpha=0.3)
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_xlabel("Distance", fontsize=fs_title)
axes[0].set_ylabel('Tuning Correlation', fontsize=fs_title)
axes[0].spines['top'].set_visible(False)
axes[0].spines['right'].set_visible(False)

axes[1].imshow(cc_color_som, cmap=cmap_black)
axes[1].axis('off')
axes[1].set_title('SOM', fontsize=fs_title, pad=8)

heatmap_distance = axes[2].imshow(position_som, cmap='viridis')  # Use a perceptually uniform colormap for values
# Add a colorbar for the heatmap values
cbar_distance = plt.colorbar(heatmap_distance, ax=axes[2], shrink=0.9, aspect=20)
cbar_distance.set_ticks([0, 1])  # Set ticks to only 0 and 1
cbar_distance.set_ticklabels(['0', '1'])  # Set tick labels to display as '0' and '1'
# Add rectangular boxes for the x-ticks (bottom) and y-ticks (left)
offset = 0.3  # Add white space between rectangles and heatmap
for i, color in enumerate(domain_colors):
    # Add a rectangle at the bottom for the x-axis (adjust the y-position)
    axes[2].add_patch(Rectangle((i - 0.5, -1.5 - offset), 1, 1, color=color, transform=axes[2].transData, clip_on=False))
    # Add a rectangle to the left for the y-axis (adjust the x-position)
    axes[2].add_patch(Rectangle((-1.5 - offset, i - 0.5), 1, 1, color=color, transform=axes[2].transData, clip_on=False))
axes[2].set_xticks([])
axes[2].set_yticks([])
axes[2].set_title("Distance Matrix", fontsize=fs_title, pad=20) # "pad" adjusts the distance between the title and the heatmap

heatmap_adjacency = axes[3].imshow(adj_matrix_som, cmap='viridis')  # Use a perceptually uniform colormap for values
# Add a colorbar for the heatmap values
cbar_adjacency = plt.colorbar(heatmap_distance, ax=axes[3], shrink=0.9, aspect=20)
cbar_adjacency.set_ticks([0, 1])  # Set ticks to only 0 and 1
cbar_adjacency.set_ticklabels(['0', '1'])  # Set tick labels to display as '0' and '1'
# Add rectangular boxes for the x-ticks (bottom) and y-ticks (left)
offset = 0.3  # Add white space between rectangles and heatmap
for i, color in enumerate(domain_colors):
    # Add a rectangle at the bottom for the x-axis (adjust the y-position)
    axes[3].add_patch(Rectangle((i - 0.5, -1.5 - offset), 1, 1, color=color, transform=axes[3].transData, clip_on=False))
    # Add a rectangle to the left for the y-axis (adjust the x-position)
    axes[3].add_patch(Rectangle((-1.5 - offset, i - 0.5), 1, 1, color=color, transform=axes[3].transData, clip_on=False))
axes[3].set_xticks([])
axes[3].set_yticks([])
axes[3].set_title("Adjacency Matrix", fontsize=fs_title, pad=20) # "pad" adjusts the distance between the title and the heatmap

axes[4].imshow(polar_angle_som)
axes[4].set_title("Polar Angle", fontsize=fs_title, pad=8)
polar_contours = axes[4].contour(polar_angle_grey, levels=1, colors='white', linewidths=1.5)
axes[4].axis('off')

axes[5].imshow(eccentricity_som)
axes[5].set_title("Eccentricity", fontsize=fs_title, pad=8)
eccentricity_contours = axes[5].contour(eccentricity_grey, levels=2, colors='white', linewidths=1.5)
axes[5].axis('off')

axes[6].imshow(dispersity_som)
axes[6].set_title("Dispersity", fontsize=fs_title, pad=8)
axes[6].axis('off')
plt.tight_layout()
fig = plt.gcf()
plt.savefig(folder_path + "Fig2/SOM.png", dpi=1000)
del fig, axes

In [None]:
roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
polar_angle = sio.loadmat(folder_path + "V4DT/RF_results/polar_angle.mat")['map'] # (128, 128, 3) <class 'numpy.ndarray'>
eccentricity = sio.loadmat(folder_path + "V4DT/RF_results/eccentricity.mat")['map'] # (128, 128, 3) <class 'numpy.ndarray'>
white = np.ones((3))
for i in range(roi.shape[0]): # white out all non-roi voxels
    for j in range(roi.shape[1]):
        if roi[i, j] != 1:
            polar_angle[i, j, :] = white
            eccentricity[i, j, :] = white
# polar angle
polar_angle = np.flip(polar_angle[14:90, 23:90, :], axis=(0, 1))
polar_angle_grey = 0.2989 * polar_angle[:, :, 0] + 0.5870 * polar_angle[:, :, 1] + 0.1140 * polar_angle[:, :, 2]
polar_angle_grey = gaussian_filter(polar_angle_grey, sigma=2.95) # Apply Gaussian smoothing
roi_eroded = binary_erosion(roi, structure=np.ones((8,8))) # Erode the ROI mask slightly to remove edge effects
roi_eroded = np.flip(roi_eroded[14:90, 23:90]) # flip the mask top-bottom
polar_angle_grey = np.where(roi_eroded, polar_angle_grey, np.nan) # Mask non-ROI regions with NaN
# eccentricity
eccentricity = np.flip(eccentricity[14:90, 23:90, :], axis=(0, 1))
eccentricity_grey = 0.2989 * eccentricity[:, :, 0] + 0.5870 * eccentricity[:, :, 1] + 0.1140 * eccentricity[:, :, 2]
eccentricity_grey = gaussian_filter(eccentricity_grey, sigma=2) # Apply Gaussian smoothing
roi_eroded = binary_erosion(roi, structure=np.ones((6,6))) # Erode the ROI mask slightly to remove edge effects
roi_eroded = np.flip(roi_eroded[14:90, 23:90]) # flip the mask top-bottom
eccentricity_grey = np.where(roi_eroded, eccentricity_grey, np.nan) # Mask non-ROI regions with NaN

fig, axes = plt.subplots(1, 2, figsize=(6, 3))
axes[0].imshow(polar_angle)
polar_contours = axes[0].contour(polar_angle_grey, levels=1, colors='white', linewidths=1.5)
axes[0].axis('off')

axes[1].imshow(eccentricity)
eccentricity_contours = axes[1].contour(eccentricity_grey, levels=2, colors='white', linewidths=1.5)
axes[1].axis('off')

In [None]:
# visualize the retinotopy and dispersity for the V4 benchmark
# dispersity
dispersity = sio.loadmat(folder_path + "V4DT/Dispersity_results/dispersity.mat")['cmap'] # (128, 128, 3) <class 'numpy.ndarray'>
# load roi, mask non-roi voxels to white
roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
polar_angle = sio.loadmat(folder_path + "V4DT/RF_results/polar_angle.mat")['map'] # (128, 128, 3) <class 'numpy.ndarray'>
eccentricity = sio.loadmat(folder_path + "V4DT/RF_results/eccentricity.mat")['map'] # (128, 128, 3) <class 'numpy.ndarray'>
white = np.ones((3))
for i in range(roi.shape[0]): # white out all non-roi voxels
    for j in range(roi.shape[1]):
        if roi[i, j] != 1:
            polar_angle[i, j, :] = white
            eccentricity[i, j, :] = white
            dispersity[i, j, :] = white
# polar angle
polar_angle = np.flip(polar_angle[14:90, 23:90, :], axis=(0, 1))
polar_angle_grey = 0.2989 * polar_angle[:, :, 0] + 0.5870 * polar_angle[:, :, 1] + 0.1140 * polar_angle[:, :, 2]
polar_angle_grey = gaussian_filter(polar_angle_grey, sigma=2.95) # Apply Gaussian smoothing
roi_eroded = binary_erosion(roi, structure=np.ones((8,8))) # Erode the ROI mask slightly to remove edge effects
roi_eroded = np.flip(roi_eroded[14:90, 23:90]) # flip the mask top-bottom
polar_angle_grey = np.where(roi_eroded, polar_angle_grey, np.nan) # Mask non-ROI regions with NaN
# eccentricity
eccentricity = np.flip(eccentricity[14:90, 23:90, :], axis=(0, 1))
eccentricity_grey = 0.2989 * eccentricity[:, :, 0] + 0.5870 * eccentricity[:, :, 1] + 0.1140 * eccentricity[:, :, 2]
eccentricity_grey = gaussian_filter(eccentricity_grey, sigma=2) # Apply Gaussian smoothing
roi_eroded = binary_erosion(roi, structure=np.ones((6,6))) # Erode the ROI mask slightly to remove edge effects
roi_eroded = np.flip(roi_eroded[14:90, 23:90]) # flip the mask top-bottom
eccentricity_grey = np.where(roi_eroded, eccentricity_grey, np.nan) # Mask non-ROI regions with NaN

cmap_wb = ListedColormap(['white', 'black', 'red', 'green', 'blue', 'yellow', 'magenta', 'cyan', 'orange', 'purple', 
                             'brown', 'pink', 'lime', 'teal', 'navy', 'gold', 'silver', 'coral']) # white for non-ROIs, black for erased

fig, axes = plt.subplots(1, 7, figsize=(21, 3))
fs_title = 16 # set font size for the title

"""
axes[0].bar(np.arange(len(labels_v4)), sizes_v4, color=cmap_black(labels_v4), width=1)
axes[0].set_xticks([])
axes[0].set_yticks([])
"""
axes[0].plot(cordis_v4[:100, 0], color='black')
axes[0].fill_between(np.arange((len(cordis_v4[:100, 0]))), cordis_v4[:100, 0]-cordis_v4[:100, 1], cordis_v4[:100, 0]+cordis_v4[:100, 1], color='black', alpha=0.3)
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_xlabel("Distance", fontsize=fs_title)
axes[0].set_ylabel('Tuning Correlation', fontsize=fs_title)
axes[0].spines['top'].set_visible(False)
axes[0].spines['right'].set_visible(False)

axes[1].imshow(cc_color_v4, cmap=cmap_wb)
axes[1].axis('off')
axes[1].set_title('V4', fontsize=fs_title)

heatmap_distance = axes[2].imshow(position_v4, cmap='viridis')  # Use a perceptually uniform colormap for values
# Add a colorbar for the heatmap values
cbar_distance = plt.colorbar(heatmap_distance, ax=axes[2], shrink=0.9, aspect=20)
cbar_distance.set_ticks([0, 1])  # Set ticks to only 0 and 1
cbar_distance.set_ticklabels(['0', '1'])  # Set tick labels to display as '0' and '1'
# Add rectangular boxes for the x-ticks (bottom) and y-ticks (left)
offset = 0.3  # Add white space between rectangles and heatmap
for i, color in enumerate(domain_colors):
    # Add a rectangle at the bottom for the x-axis (adjust the y-position)
    axes[2].add_patch(Rectangle((i - 0.5, -1.5 - offset), 1, 1, color=color, transform=axes[2].transData, clip_on=False))
    # Add a rectangle to the left for the y-axis (adjust the x-position)
    axes[2].add_patch(Rectangle((-1.5 - offset, i - 0.5), 1, 1, color=color, transform=axes[2].transData, clip_on=False))
axes[2].set_xticks([])
axes[2].set_yticks([])

heatmap_adjacency = axes[3].imshow(adj_matrix_v4, cmap='viridis')  # Use a perceptually uniform colormap for values
# Add a colorbar for the heatmap values
cbar_adjacency = plt.colorbar(heatmap_distance, ax=axes[3], shrink=0.9, aspect=20)
cbar_adjacency.set_ticks([0, 1])  # Set ticks to only 0 and 1
cbar_adjacency.set_ticklabels(['0', '1'])  # Set tick labels to display as '0' and '1'
# Add rectangular boxes for the x-ticks (bottom) and y-ticks (left)
offset = 0.3  # Add white space between rectangles and heatmap
for i, color in enumerate(domain_colors):
    # Add a rectangle at the bottom for the x-axis (adjust the y-position)
    axes[3].add_patch(Rectangle((i - 0.5, -1.5 - offset), 1, 1, color=color, transform=axes[3].transData, clip_on=False))
    # Add a rectangle to the left for the y-axis (adjust the x-position)
    axes[3].add_patch(Rectangle((-1.5 - offset, i - 0.5), 1, 1, color=color, transform=axes[3].transData, clip_on=False))
axes[3].set_xticks([])
axes[3].set_yticks([])
"""
axes[4].imshow(np.flip(polar_angle[14:90, 23:90, :], axis=(0, 1)))
axes[4].axis('off')
axes[5].imshow(np.flip(eccentricity[14:90, 23:90, :], axis=(0, 1)))
axes[5].axis('off')
"""
axes[4].imshow(polar_angle)
polar_contours = axes[4].contour(polar_angle_grey, levels=1, colors='white', linewidths=1.5)
axes[4].axis('off')

axes[5].imshow(eccentricity)
eccentricity_contours = axes[5].contour(eccentricity_grey, levels=2, colors='white', linewidths=1.5)
axes[5].axis('off')

axes[6].imshow(np.flip(dispersity[14:90, 23:90, :], axis=(0, 1)))
axes[6].axis('off')
plt.tight_layout()
fig = plt.gcf()
plt.savefig(folder_path + "Fig2/V4.png", dpi=1000)
del fig, axes

In [None]:
# RSOM: polar_angle_rsom, eccentricity_rsom, dispersity_rsom
rsom_matched = np.load(folder_path + "RSOM/assigned.npz")
polar_angle_rsom = rsom_matched['polar_angle_rsom']
eccentricity_rsom = rsom_matched['eccentricity_rsom']
dispersity_rsom = rsom_matched['dispersity_rsom']
polar_angle_rsom = np.swapaxes(polar_angle_rsom, 0, 1) # (60, 60, 3)
polar_angle_grey = 0.2989 * polar_angle_rsom[:, :, 0] + 0.5870 * polar_angle_rsom[:, :, 1] + 0.1140 * polar_angle_rsom[:, :, 2]
polar_angle_grey = gaussian_filter(polar_angle_grey, sigma=4)  # Adjust sigma for more or less smoothing, shape (60, 60)
eccentricity_rsom = np.swapaxes(eccentricity_rsom, 0, 1) # (60, 60, 3)
eccentricity_grey = 0.2989 * eccentricity_rsom[:, :, 0] + 0.5870 * eccentricity_rsom[:, :, 1] + 0.1140 * eccentricity_rsom[:, :, 2]
eccentricity_grey = gaussian_filter(eccentricity_grey, sigma=4)  # (60, 60)
dispersity_rsom = np.swapaxes(dispersity_rsom, 0, 1) # (60, 60, 3)
cmap_black = ListedColormap(['black', 'red', 'green', 'blue', 'yellow', 'magenta', 'cyan', 'orange', 'purple', 
                             'brown', 'pink', 'lime', 'teal', 'navy', 'gold', 'silver', 'coral']) # black for unassigned 0
domain_colors = ['red', 'green', 'blue', 'yellow', 'magenta', 'cyan', 'orange', 'purple', 'brown', 'pink', 'lime', 'teal', 'navy', 'gold', 'silver', 'coral']

fig, axes = plt.subplots(1, 7, figsize=(21, 3))
fs_title = 16 # set font size for the title
"""
axes[0].bar(np.arange(len(labels_rsom)), sizes_rsom, color=cmap_black(labels_rsom), width=1)
axes[0].set_xticks([])
axes[0].set_yticks([])
"""
axes[0].plot(cordis_rsom[:100, 0], color='black')
axes[0].fill_between(np.arange((len(cordis_rsom[:100, 0]))), cordis_rsom[:100, 0]-cordis_rsom[:100, 1], cordis_rsom[:100, 0]+cordis_rsom[:100, 1], color='black', alpha=0.3)
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_xlabel("Distance", fontsize=fs_title)
axes[0].set_ylabel('Tuning Correlation', fontsize=fs_title)
axes[0].spines['top'].set_visible(False)
axes[0].spines['right'].set_visible(False)

axes[1].imshow(cc_color_rsom, cmap=cmap_black)
axes[1].axis('off')
axes[1].set_title('RSOM', fontsize=fs_title)

heatmap_distance = axes[2].imshow(position_rsom, cmap='viridis')  # Use a perceptually uniform colormap for values
# Add a colorbar for the heatmap values
cbar_distance = plt.colorbar(heatmap_distance, ax=axes[2], shrink=0.9, aspect=20)
cbar_distance.set_ticks([0, 1])  # Set ticks to only 0 and 1
cbar_distance.set_ticklabels(['0', '1'])  # Set tick labels to display as '0' and '1'
# Add rectangular boxes for the x-ticks (bottom) and y-ticks (left)
offset = 0.3  # Add white space between rectangles and heatmap
for i, color in enumerate(domain_colors):
    # Add a rectangle at the bottom for the x-axis (adjust the y-position)
    axes[2].add_patch(Rectangle((i - 0.5, -1.5 - offset), 1, 1, color=color, transform=axes[2].transData, clip_on=False))
    # Add a rectangle to the left for the y-axis (adjust the x-position)
    axes[2].add_patch(Rectangle((-1.5 - offset, i - 0.5), 1, 1, color=color, transform=axes[2].transData, clip_on=False))
axes[2].set_xticks([])
axes[2].set_yticks([])

heatmap_adjacency = axes[3].imshow(adj_matrix_rsom, cmap='viridis')  # Use a perceptually uniform colormap for values
# Add a colorbar for the heatmap values
cbar_adjacency = plt.colorbar(heatmap_distance, ax=axes[3], shrink=0.9, aspect=20)
cbar_adjacency.set_ticks([0, 1])  # Set ticks to only 0 and 1
cbar_adjacency.set_ticklabels(['0', '1'])  # Set tick labels to display as '0' and '1'
# Add rectangular boxes for the x-ticks (bottom) and y-ticks (left)
offset = 0.3  # Add white space between rectangles and heatmap
for i, color in enumerate(domain_colors):
    # Add a rectangle at the bottom for the x-axis (adjust the y-position)
    axes[3].add_patch(Rectangle((i - 0.5, -1.5 - offset), 1, 1, color=color, transform=axes[3].transData, clip_on=False))
    # Add a rectangle to the left for the y-axis (adjust the x-position)
    axes[3].add_patch(Rectangle((-1.5 - offset, i - 0.5), 1, 1, color=color, transform=axes[3].transData, clip_on=False))
axes[3].set_xticks([])
axes[3].set_yticks([])

axes[4].imshow(polar_angle_rsom)
polar_contours = axes[4].contour(polar_angle_grey, levels=1, colors='white', linewidths=1.5)
axes[4].axis('off')

axes[5].imshow(eccentricity_rsom)
eccentricity_contours = axes[5].contour(eccentricity_grey, levels=2, colors='white', linewidths=1.5)
axes[5].axis('off')

axes[6].imshow(dispersity_rsom)
axes[6].axis('off')
plt.tight_layout()
fig = plt.gcf()
plt.savefig(folder_path + "Fig2/RSOM.png", dpi=1000)
del fig, axes