In [1]:
import numpy as np
from tqdm import tqdm
from skimage.transform import resize
import matplotlib.image
from matplotlib import pyplot as plt
from matplotlib import colors
from PIL import Image
from scipy.stats import pearsonr

# V4 digital twin image preference map

In [2]:
# Visualize the top-9 images of each grid in the V4 digital twin
features = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/PRsp.npy")
features = np.transpose(features, (2, 1, 0)) # (128, 128, 50000)
features = np.swapaxes(features, 0, 1)
grid_num = int(features.shape[0])
roi = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/ROI.npy").T # 3048 roi voxels
# define the path to the folder containing all images
folder_path = "/Users/dunhan/Desktop/topoV4/50K_Imgset/"
# define the size of a single image
img_size = 30 # 4
line_width = 5 # 1
top_img_num = 9
# create a blank map of black color (R=0, G=0, B=0)
map = np.zeros((grid_num * (img_size*3 + line_width) + line_width,
                grid_num * (img_size*3 + line_width) + line_width, 
                3))
grid_top_images = np.zeros((grid_num, grid_num, top_img_num)) # store the top 9 images of each grid
# fill the map with the images
for i in tqdm(range(grid_num), desc="map initialization...", disable=False):
    for j in range(grid_num):
            if roi[i, j] == 1:
                # 1-indexed image names
                image_label = np.arange(50000) + 1 # or "labels + 1" if features were selected from the top-3 responsive images
                # sort the mean responses (from small to large) and the image_label according to the order of mean responses
                _, image_label = zip(*sorted(zip(features[i, j, :],image_label)))
                # take the top nine images with largest mean response
                image_label = np.flip(image_label[-top_img_num:])
                # store the 0-indexed 9 imgs of each grid
                grid_top_images[i, j, :] = (image_label - 1).astype(int)

                # locate the top left corner of the current grid in the map
                x = i * (img_size*3 + line_width) + line_width
                y = j * (img_size*3 + line_width) + line_width
                
                # fill the map's current grid with the selected nine images
                for row in range(3):
                    for col in range(3):    
                        # load the image
                        path = folder_path + str(int(image_label[row*3+col])) + ".bmp" # the image name is 1-indexed
                        img = np.array(Image.open(path))[20:80, 20:80, :] # obtain the non-blurred central part of the image
                        img = resize(img, (img_size, img_size, 3), anti_aliasing=True) # resize the image
                        # put the image onto the map
                        map[x + row * img_size : x + (row + 1) * img_size, 
                            y + col * img_size : y + (col + 1) * img_size, 
                            :] = img
            else:
                # fill the map's current grid with pure white color
                x = i * (img_size*3 + line_width)
                y = j * (img_size*3 + line_width)
                map[x : x + img_size*3 + line_width*2, y : y + img_size*3 + line_width*2, :] = 1.0
map = np.fliplr(np.flipud(map)) # top-bottom flip and then left-right flip
size = (img_size*3 + line_width)
map = map[38*size:114*size, 38*size:105*size, :] # only keep the roi part
map_save_path = "/Users/dunhan/Desktop/topoV4/som/Figures/V4_DT_full.bmp"
np.save("/Users/dunhan/Desktop/topoV4/som/V4_benchmark/rsptop_0index", grid_top_images)
matplotlib.image.imsave(map_save_path, map)
del features, map

map initialization...: 100%|██████████| 128/128 [01:14<00:00,  1.72it/s]


In [None]:
# visualization of RSOM training from V4 digital twin neuronal columns' tuning curves + estimated retinotopic positions
# visualize the V4 data ROI shape
roi = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
roi = np.flip(roi)
roi = roi[37:115, 37:106]

fig, ax = plt.subplots(figsize=(6, 6))
cmap = colors.ListedColormap(['white', 'lightyellow'])  # 0 for white, 1 for lightyellow
ax.imshow(roi, cmap=cmap, interpolation='none')
roi_grid = np.ma.masked_where(roi == 0, roi) # Create a masked array to only apply grid where roi is 1
# Add grid lines only for the region of interest
for i in range(roi.shape[0]):
    for j in range(roi.shape[1]):
        if roi[i, j] == 1:  # Only show grid for ROI entries with value 1
            ax.plot([j-0.5, j+0.5], [i-0.5, i-0.5], color='black', linestyle='--', linewidth=0.5)  # Top line
            ax.plot([j-0.5, j+0.5], [i+0.5, i+0.5], color='black', linestyle='--', linewidth=0.5)  # Bottom line
            ax.plot([j-0.5, j-0.5], [i-0.5, i+0.5], color='black', linestyle='--', linewidth=0.5)  # Left line
            ax.plot([j+0.5, j+0.5], [i-0.5, i+0.5], color='black', linestyle='--', linewidth=0.5)  # Right line
# Add contour lines for the region of interest with a 3D effect (using multiple levels and shadowing)
contour = ax.contour(roi, levels=[0.5], colors='black', linewidths=2, alpha=0.9)
# Adding additional contour for 3D effect (shadow)
ax.contour(roi, levels=[0.5], colors='gray', linewidths=6, alpha=0.75, linestyles='solid')
# Remove axis labels and ticks
ax.set_xticks([])
ax.set_yticks([])
# Remove the axis box
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
# Show the plot
plt.tight_layout()
# plt.show()
plt.savefig("/Users/dunhan/Desktop/topoV4/som/Figures/ROI.png", dpi=1000)
plt.close()

# Create a 60 by 60 grid to visualize the SOM map
rows, cols = 60, 60
# Create figure and axes
fig, ax = plt.subplots(figsize=(4, 4))  # Adjust figsize for higher resolution
# Plot the grid
for x in range(rows + 1):
    ax.plot([x, x], [0, cols], color='black', linewidth=0.5)
for y in range(cols + 1):
    ax.plot([0, rows], [y, y], color='black', linewidth=0.5)
# Set the aspect of the plot to be equal
ax.set_aspect('equal')
# Set limits to match grid size
ax.set_xlim(0, rows)
ax.set_ylim(0, cols)
# Turn off axes ticks
ax.set_xticks([])
ax.set_yticks([])
# Save as high resolution
plt.savefig('/Users/dunhan/Desktop/grids.png', dpi=3000)  # Save the image in high resolution

# RSOM image preference map

In [None]:
name = "som16e_250tr_05t" # and "som16e" for the SOM
features = np.load("/Users/dunhan/Desktop/topoV4/som/weight_as_units/" + name + "/weights.npy") # (60, 60, 50000(+))
if features.shape[2] > 50000:
    features = features[:, :, :-int(features.shape[2] - 50000)] # remove positional information from the data
    assert features.shape[2] == 50000

# define the path to the folder containing all images
folder_path = "/Users/dunhan/Desktop/topoV4/50K_Imgset/"
# define the size of a single image
img_size = 30
line_width = 5
# create a blank map of black color (R=0, G=0, B=0)
map = np.zeros((grid_num * (img_size*3 + line_width) + line_width,
                grid_num * (img_size*3 + line_width) + line_width, 
                3))
grid_top_images = np.zeros((grid_num, grid_num, top_img_num)) # store the top 9 images of each grid
# fill the map with the images
for i in tqdm(range(grid_num), desc="map initialization...", disable=False):
    for j in range(grid_num):
            if roi[i, j] == 1:
                # 1-indexed image names
                image_label = np.arange(50000) + 1 # or "labels + 1" if features were selected from the top-3 responsive images
                # sort the mean responses (from small to large) and the image_label according to the order of mean responses
                _, image_label = zip(*sorted(zip(features[i, j, :],image_label)))
                # take the top nine images with largest mean response
                image_label = np.flip(image_label[-top_img_num:])
                # store the 0-indexed 9 imgs of each grid
                grid_top_images[i, j, :] = (image_label - 1).astype(int)

                # locate the top left corner of the current grid in the map
                x = i * (img_size*3 + line_width) + line_width
                y = j * (img_size*3 + line_width) + line_width
                
                # fill the map's current grid with the selected nine images
                for row in range(3):
                    for col in range(3):
                                
                        # load the image
                        path = folder_path + str(int(image_label[row*3+col])) + ".bmp" # the image name is 1-indexed
                        img = np.array(Image.open(path))[20:80, 20:80, :] # obtain the non-blurred central part of the image
                        img = resize(img, (img_size, img_size, 3), anti_aliasing=True) # resize the image

                        # put the image onto the map
                        map[x + row * img_size : x + (row + 1) * img_size, 
                            y + col * img_size : y + (col + 1) * img_size, 
                            :] = img

map_save_path = "/Users/dunhan/Desktop/topoV4/som/weight_as_units/" + name + "/weights.bmp"
np.save("/Users/dunhan/Desktop/topoV4/som/weight_as_units/" + name + "/rsptop_0index", grid_top_images)
matplotlib.image.imsave(map_save_path, map)

# Tuning curve shape

In [3]:
# tuning curve comparison figure
# Compare the tuning curves of V4 digital twin benchmark and the learned SOM
# Load the V4 digital twin responses to 50k images
response = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/Prsp.npy")
roi = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
v4_benchmark = response[:, roi == 1] # (50000, 128, 128) into (50000, 3048)
print(v4_benchmark.shape) # (50000, 3048)
v4_benchmark = np.sort(v4_benchmark, axis=0)[::-1]  # Sort each column along rows (flip for descending order)
print(v4_benchmark.shape) # (50000, 3048)
v4_mean = np.mean(v4_benchmark, axis=1) # mean tuning curve
v4_std = np.std(v4_benchmark, axis=1) # tuning curve std
index = []
for i in range(v4_benchmark.shape[1]):
    assert len(v4_benchmark[:, i]) == 50000
    half = (np.max(v4_benchmark[:, i]) - np.min(v4_benchmark[:, i])) / 2
    for j in range(50000):
        if v4_benchmark[j, i] <= half:
            index.append(j)
            break
print("V4 half:", np.mean(index), np.std(index))

name_all = ["som16e_250tr_05t"]
# name_all = ["som16e_250tr_1020"]
for name in name_all:
    rsp = np.load("/Users/dunhan/Desktop/topoV4/som/weight_as_units/" + name + "/weights.npy")
    if rsp.shape[2] > 50000: rsp = rsp[:, :, :-int(rsp.shape[2] - 50000)] # now rsp / som weights has shape (60, 60, 50000)
    rsp = rsp.reshape(-1, rsp.shape[2]) # (3600, 50000)
    rsp = np.sort(rsp.T, axis=0)[::-1]  # Sort each column along rows (flip for descending order), (50000, 3600)
    rsp_mean = np.mean(rsp, axis=1) # mean tuning curve
    rsp_std = np.std(rsp, axis=1) # tuning curve std

    # normalize every tuning curve
    index = []
    for i in range(rsp.shape[1]):
        assert len(rsp[:, i]) == 50000
        half = (np.max(rsp[:, i]) - np.min(rsp[:, i])) / 2
        for j in range(50000):
            if rsp[j, i] <= half:
                index.append(j)
                break

    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    print("V4 and RSOM tuning curve correlation: {:.3f}".format(pearsonr(np.concatenate((v4_mean, v4_std)), np.concatenate((rsp_mean, rsp_std)))[0]))
    axes[0].plot(v4_mean, color='black')
    axes[0].fill_between(range(len(v4_mean)), v4_mean - v4_std, v4_mean + v4_std, color='black', alpha=0.3)
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    axes[0].set_xlabel('50K imgs', fontsize=16)
    axes[0].set_ylabel('response', fontsize=16)
    axes[0].set_title('V4', fontsize=18)
    axes[1].plot(rsp_mean, color='gray')
    axes[1].fill_between(range(len(rsp_mean)), rsp_mean - rsp_std, rsp_mean + rsp_std, color='gray', alpha=0.3)
    axes[1].set_xticks([])
    axes[1].set_yticks([])
    axes[1].set_xlabel('50K imgs', fontsize=16)
    # axes[1].set_ylabel('response', fontsize=18)
    axes[1].set_title("RSOM", fontsize=18)
    fig = plt.gcf()
    plt.tight_layout()
    plt.close()
    # fig.savefig("/Users/dunhan/Desktop/topoV4/som/Figures/tunings.png", dpi=100)
    fig.savefig("/Users/dunhan/Desktop/topoV4/som/weight_as_units/" + name + "/tunings.png", dpi=1000)
    del rsp, rsp_mean, rsp_std
print("V4_RSOM half", np.mean(index), np.std(index))
del v4_benchmark, v4_mean, v4_std, name_all, name

# V4 half: 842.0524934383202 1297.4226355648088
# V4 and RSOM tuning curve correlation: 0.996
# V4_RSOM half 1157.6261111111112 1806.055900970211

(50000, 3048)
(50000, 3048)
V4 half: 842.0524934383202 1297.4226355648088
V4 and RSOM tuning curve correlation: 0.996
V4_RSOM half 1157.6261111111112 1806.055900970211


# Pairwise columns / grids tuning correlation as a function of map distance

In [2]:
# pairwise grid tuning curve coorelation as a function of exactly the map physical distance between them
def cordis(rsp):
    # check input response
    if rsp.shape[1] == rsp.shape[2] == 128: # (50000, 128, 128) V4 benchmark
        roi = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/ROI.npy").T # (128, 128)
    else: roi = np.ones((60, 60)).astype(int) # artificial som map weight response
    size = roi.shape[0]
    assert size == roi.shape[1]
    num_roi = int(np.sum(roi))
    # response, roi data preparation
    response = np.zeros((num_roi, 50000))
    position = np.zeros((num_roi, 2))
    voxel_index = 0
    for i in range(size):
        for j in range(size):
            if roi[i, j] == 1:
                response[voxel_index, :] = rsp[:, i, j]
                position[voxel_index, 0] = i
                position[voxel_index, 1] = j
                voxel_index += 1
    assert voxel_index == num_roi
    # calculate the correlation matrix
    cordis_matrix = np.zeros((num_roi, num_roi, 2)) # 0th entry for correlation, 1st entry for pairwise map distance
    for i in tqdm(range(num_roi), desc="calculating correlation matrix...", disable=False):
        for j in range(i, num_roi):
            # correlation calculation
            if i != j:
                cordis_matrix[i, j, 0] = pearsonr(response[i, :], response[j, :])[0]
                cordis_matrix[j, i, 0] = cordis_matrix[i, j, 0]
            else: cordis_matrix[i, j, 0] = 1.0
            # distance calculation
            cordis_matrix[i, j, 1] = np.sqrt((position[i, 0] - position[j, 0]) ** 2 + (position[i, 1] - position[j, 1]) ** 2)
            cordis_matrix[j, i, 1] = cordis_matrix[i, j, 1]
    return cordis_matrix
# cordis_v4 = cordis(response)
# np.save("/Users/dunhan/Desktop/topoV4/som/V4_benchmark/cordis_v4.npy", cordis_v4)
# print(cordis_v4.shape)

def cordis_avg(cordis_matrix, segment_num=100):
    num_roi = cordis_matrix.shape[0]
    assert num_roi == cordis_matrix.shape[1]
    pd = cordis_matrix[:, :, 1] # distance (3048, 3048)
    pc = cordis_matrix[:, :, 0] # correlation (3048, 3048)
    dismax = np.max(pd) # maximum distance
    dismin = np.min(pd) # minimum distance
    segment_len = (dismax - dismin) / segment_num # length of one single segment
    segments = np.zeros((num_roi, segment_num, 2)) # store the mean correlation and standard deviation of all correlation estimates within all samples of each segment
    for i in tqdm(range(num_roi), desc="sorting...", disable=True):
        sorted_indices = np.argsort(pd[i, :]) # sort the distance from smallest to largest, get its indices
        pd[i, :] = pd[i, sorted_indices] # distance matrix row vector sorted, from smallest to largest
        pc[i, :] = pc[i, sorted_indices] # correlation matrix row vector sorted, from smallest to largest
        for s in range(segment_num):
            # indices of all entries within the current pairwise distance segment
            segment_indices = np.where((pd[i, :] >= dismin + s * segment_len) & (pd[i, :] < dismin + (s + 1) * segment_len))[0]
            if len(segment_indices) > 0:
                selected_cor = pc[i, segment_indices] # current segment samples' correlation estimates
                segments[i, s, 0] = np.mean(selected_cor) # average
                segments[i, s, 1] = np.std(selected_cor) # standard deviation
            else:
                segments[i, s, 0] = -1
                segments[i, s, 1] = -1
    # compute the average of all ROIs' correlation estimates (mean and std) within each segment
    mean_std_rois = np.zeros((segment_num, 3))
    for i in range(num_roi):
        for s in range(segment_num):
            if segments[i, s, 0] != -1 and segments[i, s, 1] != -1:
                mean_std_rois[s, 0] += segments[i, s, 0] # mean
                mean_std_rois[s, 1] += segments[i, s, 1] # std
                mean_std_rois[s, 2] += 1 # count
    mean_std_rois[:, 0] /= mean_std_rois[:, 2] # average mean
    mean_std_rois[:, 1] /= mean_std_rois[:, 2] # average std
    mean_std_rois = mean_std_rois[:, :2] # discard the count
    return mean_std_rois


# V4 digital twin as benchmark
segment_num = 100
cordis_v4 = np.load("/Users/dunhan/Desktop/topoV4/som/V4_benchmark/cordis_v4_benchmark.npy")
mean_std_rois = cordis_avg(cordis_v4, segment_num) # V4 benchmark
modes = ["TDANN", "V4_RSOM", "V4_SOM"]
modes = ["V4_RSOM"]
for mode in modes:
    # RSOM
    if mode == "TDANN":
        # cordis_rsom = np.load("/Users/dunhan/Desktop/topoV4/som/weight_as_units/som16e_250tr_TDANN41/cordis_resnet_rsom.npy")
        # cordis_rsom = np.load("/Users/dunhan/Desktop/topoV4/som/weight_as_units/som16e_250_TDANNrfc/cordis_TDANN_rsom.npy")
        cordis_rsom = np.load("/Users/dunhan/Desktop/topoV4/ResNet/cordis_TDANN31.npy")
        cordis_rsom = np.nan_to_num(cordis_rsom, nan=0.0, posinf=0.0, neginf=0.0)
    elif mode == "V4_RSOM":
        cordis_rsom = np.load("/Users/dunhan/Desktop/topoV4/som/weight_as_units/som16e_250tr_05t/cordis_V4_rsom.npy")
    elif mode == "V4_SOM":
        cordis_rsom = np.load("/Users/dunhan/Desktop/topoV4/som/weight_as_units/som16e/cordis_V4_som.npy")
    mean_std_rois_rsom = cordis_avg(cordis_rsom, segment_num) # V4 RSOM or ResNet TDANN babynet RSOM
    # visualization, as a comparison
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    axes[0].plot(mean_std_rois[:, 0], color='black')
    axes[0].fill_between(np.arange((len(mean_std_rois[:, 0]))), 
                        mean_std_rois[:, 0]-mean_std_rois[:, 1], mean_std_rois[:, 0]+mean_std_rois[:, 1], color='black', alpha=0.3)
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    axes[0].set_xlabel('Pairwise distance', fontsize=20)
    axes[0].set_ylabel('Tuning correlation', fontsize=20)
    axes[0].set_title("V4", fontsize=28)
    axes[1].plot(mean_std_rois_rsom[:, 0], color='gray')
    axes[1].fill_between(np.arange((len(mean_std_rois_rsom[:, 0]))), 
                        mean_std_rois_rsom[:, 0]-mean_std_rois_rsom[:, 1], mean_std_rois_rsom[:, 0]+mean_std_rois_rsom[:, 1], color='gray', alpha=0.3)
    axes[1].set_xticks([])
    axes[1].set_yticks([])
    axes[1].set_xlabel('Pairwise distance', fontsize=20)
    # axes[1].set_ylabel('Tuning correlation', fontsize=20)
    if mode == "TDANN":
        axes[1].set_title("TDANN31", fontsize=28) # ResNet RSOM
        plt.tight_layout()
        fig = plt.gcf()
        fig.savefig("/Users/dunhan/Desktop/topoV4/ResNet/TDANN31_cor_dis.png", dpi=1000)
        print("pearson correlation of cordis function between V4 and TDANN31: {:.3f}".format( # ResNet RSOM
            pearsonr(np.concatenate((mean_std_rois[:, 0], mean_std_rois[:, 1])), np.concatenate((mean_std_rois_rsom[:, 0], mean_std_rois_rsom[:, 1])))[0]))
        print(pearsonr(mean_std_rois[:, 0], mean_std_rois_rsom[:, 0]))
    elif mode == "V4_RSOM":
        axes[1].set_title("RSOM", fontsize=28)
        plt.tight_layout()
        fig = plt.gcf()
        fig.savefig("/Users/dunhan/Desktop/topoV4/som/Figures/cor_dis.png", dpi=1000)
        print("\npearson correlation of cordis function between V4 and V4_RSOM: {:.3f}".format(
            pearsonr(np.concatenate((mean_std_rois[:, 0], mean_std_rois[:, 1])), np.concatenate((mean_std_rois_rsom[:, 0], mean_std_rois_rsom[:, 1])))[0]))
        print(pearsonr(mean_std_rois[:, 0], mean_std_rois_rsom[:, 0]))
    elif mode == "V4_SOM":
        axes[1].set_title("SOM", fontsize=28)
        plt.tight_layout()
        fig = plt.gcf()
        fig.savefig("/Users/dunhan/Desktop/topoV4/som/Figures/cor_dis_som.png", dpi=1000)
        print("\npearson correlation of cordis function between V4 and V4_SOM: {:.3f}".format(
            pearsonr(np.concatenate((mean_std_rois[:, 0], mean_std_rois[:, 1])), np.concatenate((mean_std_rois_rsom[:, 0], mean_std_rois_rsom[:, 1])))[0]))
        print(pearsonr(mean_std_rois[:, 0], mean_std_rois_rsom[:, 0]))
    plt.close()

# pearson correlation of cordis function between V4 and TDANN31: 0.838
# PearsonRResult(statistic=0.8288380838226055, pvalue=1.8549385524249295e-26)


pearson correlation of cordis function between V4 and V4_RSOM: 0.821
PearsonRResult(statistic=0.9546374024653328, pvalue=2.3195355350045804e-53)


In [4]:
image = Image.open("/Users/dunhan/Desktop/Picture4.png")
width, height = image.size
resized_image = image.resize((int(np.round(width/4)), int(np.round(height/4))), Image.Resampling.LANCZOS)
resized_image.save("/Users/dunhan/Desktop/Fig4.png")