In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
import torch.optim as optim
from tqdm import tqdm
import scipy.io
from skimage.transform import resize
from scipy.ndimage import label, sum as ndi_sum
import matplotlib.image
from matplotlib import pyplot as plt
from matplotlib.colors import hsv_to_rgb
from PIL import Image
from scipy.stats import pearsonr
import math

In [47]:
# resize TDANN41 image preference map
image = Image.open("/Users/dunhan/Desktop/topoV4/ResNet/TDANN41_IT.bmp")
width, height = image.size
resized_image = image.resize((int(np.round(width/16)), int(np.round(height/16))), Image.Resampling.LANCZOS)
resized_image.save("/Users/dunhan/Desktop/topoV4/som/Figures/TDANN41.png")

# TDANN map visualization

In [None]:
# prepare TDANN later layer responses to 50K images
file_names_all = ['1_5kfeatures.npz', '2_5kfeatures.npz', '3_5kfeatures.npz', '4_5kfeatures.npz', 
                '5_5kfeatures.npz', '6_5kfeatures.npz', '7_5kfeatures.npz', '8_5kfeatures.npz', 
                '9_5kfeatures.npz', '10_5kfeatures.npz']
feature_folder_path = '/Volumes/dunhanSSD/topoV4/' # the path to the folder of pre-computed responses
data = np.zeros((1, 50176))
for file_name in tqdm(range(len(file_names_all)), desc="10 files to go", disable=False): # iterate through all the files
    features = np.load(feature_folder_path + file_names_all[file_name])["layer30"] # targeted at V4, of shape (5000, 256, 14, 14)
    features = features.reshape(features.shape[0], -1) # of shape (5000, n_units)
    data = np.concatenate((data, features), axis=0) # of shape(50K_images, n_units)
data = data[1:, :] # remove the first row of zeros, of shape (50000, 50176)
print(data.shape)
np.save("/Volumes/dunhanSSD/topoV4/TDANN31_V2.npy", data) # save the data
del data, features # release memory

In [None]:
# Visualization of TDANN final network & positions, into a 2D 60 by 60 gridded map
# combined_features = np.load("/Volumes/dunhanSSD/topoV4/TDANN_final/responses/TDANN31_V4.npy") # (50000, 50176)
combined_features = np.load("/Volumes/dunhanSSD/topoV4/TDANN_final/responses/layer41.npy") # (50000, 25088)
positions = np.load("/Users/dunhan/Desktop/topoV4/ResNet/TDANNfinal_positions/layer4.1.npz")["coordinates"] # (25088, 2)
folder_path = "/Users/dunhan/Desktop/topoV4/50K_Imgset/"
cortical_size = max(positions[:, 0]) - min(positions[:, 0]) # define the length of the 2D plane
grid_num = 60 # each grid contains the most-preferred 9 images by the mean within-grid-units' response
grids_count = int(grid_num ** 2)
num_imgs_each_side = 3 # number of images on each side of the grid, total number of images in a grid should be squared
# define the size of a single image
img_size = 30
line_width = 5
# create a blank map of black color (R=0, G=0, B=0)
map = np.zeros((grid_num * (img_size*3 + line_width) + line_width, 
                grid_num * (img_size*3 + line_width) + line_width, 
                3))
# To store each grid's agregated response to all 50K images, with an additional roi
TDANN41_weight = np.zeros((grid_num, grid_num, 50000))
roi = np.zeros((grid_num, grid_num))
for i in tqdm(range(grid_num), desc="map initialization..."):
    for j in range(grid_num):
        # first find all units in this current grid
        xmin_cortex = cortical_size / grid_num * i
        xmax_cortex = cortical_size / grid_num * (i + 1)
        ymin_cortex = cortical_size / grid_num * j
        ymax_cortex = cortical_size / grid_num * (j + 1)
        # find all units in this current grid
        units_within_grid_indeices = np.where((positions[:, 0] >= xmin_cortex) & (positions[:, 0] < xmax_cortex) & (positions[:, 1] >= ymin_cortex) & (positions[:, 1] < ymax_cortex))[0]
        if len(units_within_grid_indeices) > 0:
            roi[i, j] = 1
            # compute the mean response of all units (neurons) within this grid to all 50K images (into a row vector)
            mean_responses = np.mean(combined_features[:, units_within_grid_indeices], axis=1).T
            TDANN41_weight[i, j, :] = mean_responses
            image_label = np.arange(50000) + 1 # 1-indexed image names
            # sort the mean responses (from small to large) and the image_label according to the order of mean responses
            mean_responses, image_label = zip(*sorted(zip(mean_responses,image_label)))
            image_label = np.flip(image_label[-int(num_imgs_each_side ** 2):]) # take the top nine images with largest mean response

            # locate the top left corner of the current grid in the map
            x = i * (img_size*3 + line_width) + line_width
            y = j * (img_size*3 + line_width) + line_width
            
            # fill the map's current grid with the selected nine images
            for row in range(3):
                for col in range(3):
                    # load the image
                    path = folder_path + str(int(image_label[row*3+col])) + ".bmp" # the image name is 1-indexed
                    img = np.array(Image.open(path))[20:80, 20:80, :] # obtain the non-blurred central part of the image
                    img = resize(img, (img_size, img_size, 3), anti_aliasing=True) # resize the image
                    # put the image onto the map
                    # if i > 10 and i < 50 and j > 15 and j < 50:
                    # img_label.append(image_label[row*3+col])
                    map[x + row * img_size : x + (row + 1) * img_size, 
                        y + col * img_size : y + (col + 1) * img_size, 
                        :] = img
        else: # white out the grid if no units are in this grid
            x = i * (img_size*3 + line_width)
            y = j * (img_size*3 + line_width)
            map[x : x + img_size*3 + line_width*2, 
                y : y + img_size*3 + line_width*2, 
                :] = 1.0

matplotlib.image.imsave('/Users/dunhan/Desktop/topoV4/ResNet/TDANN41_IT.bmp', map)
np.save("/Users/dunhan/Desktop/topoV4/ResNet/TDANN41_weight.npy", TDANN41_weight)
np.save("/Users/dunhan/Desktop/topoV4/ResNet/TDANN41_roi.npy", roi)
del map

# save the top-k most preferred images for each TDANN41 unitn in 0-indexed format
print(combined_features.shape)
imgs_0index = np.zeros((combined_features.shape[1], 1000))
for i in tqdm(range(combined_features.shape[1])):
    image_label = np.arange(50000) # 0-indexed image
    response_vector, image_label = zip(*sorted(zip(combined_features[:, i], image_label), reverse=True)) # sort from large to small
    imgs_0index[i, :] = image_label[:1000] # store the 0-indexed top 1000 images for each unit
print(imgs_0index.shape)
np.save("/Users/dunhan/Desktop/topoV4/ResNet/responses/TDANN41rsp_1kimgs_0index.npy", imgs_0index)
del combined_features, positions

# Tuning curve analysis

In [3]:
# tuning curve comparison figure
# Compare the tuning curves of V4 digital twin benchmark and the learned SOM
layer = "layer41"
roi = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
response = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/Prsp.npy")
v4_benchmark = response[:, roi == 1] # (50000, 128, 128) into (50000, 3048)
v4_benchmark = np.sort(v4_benchmark, axis=0)[::-1]  # Sort each column along rows (flip for descending order)
v4_mean = np.mean(v4_benchmark, axis=1) # mean tuning curve
v4_std = np.std(v4_benchmark, axis=1) # tuning curve std
if layer == "layer31":
    rsp = np.load("/Users/dunhan/Desktop/topoV4/ResNet/TDANN31_weight.npy")
    roi = np.load("/Users/dunhan/Desktop/topoV4/ResNet/TDANN31_roi.npy")
elif layer == "layer41":
    rsp = np.load("/Users/dunhan/Desktop/topoV4/ResNet/TDANN41_weight.npy")
    roi = np.load("/Users/dunhan/Desktop/topoV4/ResNet/TDANN41_roi.npy")
rsp = rsp[np.where(roi == 1)[0], np.where(roi == 1)[1], :] # (3277, 50000)
rsp = np.sort(rsp.T, axis=0)[::-1]  # Sort each column along rows (flip for descending order), (50000, 3277)
rsp_mean = np.mean(rsp, axis=1) # mean tuning curve
rsp_std = np.std(rsp, axis=1) # tuning curve std

# normalize every tuning curve
index = []
for i in range(rsp.shape[1]):
    assert len(rsp[:, i]) == 50000
    half = (np.max(rsp[:, i]) - np.min(rsp[:, i])) / 2
    for j in range(50000):
        if rsp[j, i] <= half:
            index.append(j)
            break
print("Tuning curve half at the top images:", np.mean(index), np.std(index))
if layer == "layer31":
    print("pearson correlation of tuning curve between V4 and TDANN31: {:.3f}".format(pearsonr(np.concatenate((v4_mean, v4_std)), np.concatenate((rsp_mean, rsp_std)))[0]))
elif layer == "layer41":
    print("pearson correlation of tuning curve between V4 and TDANN41: {:.3f}".format(pearsonr(np.concatenate((v4_mean, v4_std)), np.concatenate((rsp_mean, rsp_std)))[0]))

# pairwise grid tuning curve coorelation as a function of exactly the map physical distance between them
def cordis(rsp):
    # check input response
    if rsp.shape[1] == rsp.shape[2] == 128: # (50000, 128, 128) V4 benchmark
        roi = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/ROI.npy").T # (128, 128)
    else: roi = np.ones((60, 60)).astype(int) # artificial som map weight response
    size = roi.shape[0]
    assert size == roi.shape[1]
    num_roi = int(np.sum(roi))
    # response, roi data preparation
    response = np.zeros((num_roi, 50000))
    position = np.zeros((num_roi, 2))
    voxel_index = 0
    for i in range(size):
        for j in range(size):
            if roi[i, j] == 1:
                response[voxel_index, :] = rsp[:, i, j]
                position[voxel_index, 0] = i
                position[voxel_index, 1] = j
                voxel_index += 1
    assert voxel_index == num_roi
    # calculate the correlation matrix
    cordis_matrix = np.zeros((num_roi, num_roi, 2)) # 0th entry for correlation, 1st entry for pairwise map distance
    for i in tqdm(range(num_roi), desc="calculating correlation matrix...", disable=False):
        for j in range(i, num_roi):
            # correlation calculation
            if i != j:
                cordis_matrix[i, j, 0] = pearsonr(response[i, :], response[j, :])[0]
                cordis_matrix[j, i, 0] = cordis_matrix[i, j, 0]
            else: cordis_matrix[i, j, 0] = 1.0
            # distance calculation
            cordis_matrix[i, j, 1] = np.sqrt((position[i, 0] - position[j, 0]) ** 2 + (position[i, 1] - position[j, 1]) ** 2)
            cordis_matrix[j, i, 1] = cordis_matrix[i, j, 1]
    return cordis_matrix

def cordis_avg(cordis_matrix, segment_num=100):
    num_roi = cordis_matrix.shape[0]
    assert num_roi == cordis_matrix.shape[1]
    pd = cordis_matrix[:, :, 1] # distance (3048, 3048)
    pc = cordis_matrix[:, :, 0] # correlation (3048, 3048)
    dismax = np.max(pd) # maximum distance
    dismin = np.min(pd) # minimum distance
    segment_len = (dismax - dismin) / segment_num # length of one single segment
    segments = np.zeros((num_roi, segment_num, 2)) # store the mean correlation and standard deviation of all correlation estimates within all samples of each segment
    for i in tqdm(range(num_roi), desc="sorting...", disable=True):
        sorted_indices = np.argsort(pd[i, :]) # sort the distance from smallest to largest, get its indices
        pd[i, :] = pd[i, sorted_indices] # distance matrix row vector sorted, from smallest to largest
        pc[i, :] = pc[i, sorted_indices] # correlation matrix row vector sorted, from smallest to largest
        for s in range(segment_num):
            # indices of all entries within the current pairwise distance segment
            segment_indices = np.where((pd[i, :] >= dismin + s * segment_len) & (pd[i, :] < dismin + (s + 1) * segment_len))[0]
            if len(segment_indices) > 0:
                selected_cor = pc[i, segment_indices] # current segment samples' correlation estimates
                segments[i, s, 0] = np.mean(selected_cor) # average
                segments[i, s, 1] = np.std(selected_cor) # standard deviation
            else:
                segments[i, s, 0] = -1
                segments[i, s, 1] = -1
    # compute the average of all ROIs' correlation estimates (mean and std) within each segment
    mean_std_rois = np.zeros((segment_num, 3))
    for i in range(num_roi):
        for s in range(segment_num):
            if segments[i, s, 0] != -1 and segments[i, s, 1] != -1:
                mean_std_rois[s, 0] += segments[i, s, 0] # mean
                mean_std_rois[s, 1] += segments[i, s, 1] # std
                mean_std_rois[s, 2] += 1 # count
    mean_std_rois[:, 0] /= mean_std_rois[:, 2] # average mean
    mean_std_rois[:, 1] /= mean_std_rois[:, 2] # average std
    mean_std_rois = mean_std_rois[:, :2] # discard the count
    return mean_std_rois

# V4 digital twin as benchmark
segment_num = 100
cordis_v4 = np.load("/Users/dunhan/Desktop/topoV4/som/V4_benchmark/cordis_v4_benchmark.npy")
mean_std_rois = cordis_avg(cordis_v4, segment_num) # V4 benchmark
if layer == "layer31":
    cordis_rsom = np.load("/Users/dunhan/Desktop/topoV4/ResNet/cordis_TDANN31.npy")
elif layer == "layer41":
    cordis_rsom = np.load("/Users/dunhan/Desktop/topoV4/ResNet/cordis_TDANN41.npy")
cordis_rsom[np.isinf(cordis_rsom) | np.isnan(cordis_rsom)] = 0
cordis_rsom = np.nan_to_num(cordis_rsom, nan=0.0, posinf=0.0, neginf=0.0)
mean_std_rois_rsom = cordis_avg(cordis_rsom, segment_num)
mean_std_rois_rsom = np.nan_to_num(mean_std_rois_rsom, nan=0.0, posinf=0.0, neginf=0.0)
# visualization, as a comparison
fig, axes = plt.subplots(2, 1, figsize=(3, 6))
axes[0].plot(rsp_mean, color='gray')
axes[0].fill_between(range(len(rsp_mean)), rsp_mean - rsp_std, rsp_mean + rsp_std, color='gray', alpha=0.3)
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_xlabel('50K imgs', fontsize=18)
axes[0].set_ylabel('response', fontsize=18)
axes[0].set_title("Tuning curve", fontsize=22)
axes[1].plot(mean_std_rois_rsom[:, 0], color='gray')
axes[1].fill_between(np.arange((len(mean_std_rois_rsom[:, 0]))), mean_std_rois_rsom[:, 0]-mean_std_rois_rsom[:, 1], mean_std_rois_rsom[:, 0]+mean_std_rois_rsom[:, 1], color='gray', alpha=0.3)
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_xlabel('Pairwise distance', fontsize=18)
axes[1].set_ylabel('Tuning correlation', fontsize=18)
# axes[1].set_title("Correlation and distance", fontsize=22)
plt.tight_layout()
fig = plt.gcf()
if layer == "layer31":
    fig.savefig("/Users/dunhan/Desktop/topoV4/ResNet/TDANN31_tuning.png", dpi=1000)
    print("pearson correlation of cordis function between V4 and TDANN31: {:.3f}".format(pearsonr(np.concatenate((mean_std_rois[:, 0], mean_std_rois[:, 1])), 
                                                                                                np.concatenate((mean_std_rois_rsom[:, 0], mean_std_rois_rsom[:, 1])))[0]))
elif layer == "layer41":
    fig.savefig("/Users/dunhan/Desktop/topoV4/ResNet/TDANN41_tuning.png", dpi=1000)
    print("pearson correlation of cordis function between V4 and TDANN41: {:.3f}".format(pearsonr(np.concatenate((mean_std_rois[:, 0], mean_std_rois[:, 1])), 
                                                                                                np.concatenate((mean_std_rois_rsom[:, 0], mean_std_rois_rsom[:, 1])))[0]))
plt.close()

# TDANN31, purpoted V4 layer
# Tuning curve half at the top images: 20475.135222150675 16965.060078655137
# pearson correlation of tuning curve between V4 and TDANN31: 0.708
# pearson correlation of cordis function between V4 and TDANN31: 0.838

# TDANN41, purpoted IT layer
# Tuning curve half at the top images: 1205.8199259508583 1884.0477172733094
# pearson correlation of tuning curve between V4 and TDANN41: 0.947
# pearson correlation of cordis function between V4 and TDANN41: 0.778

Tuning curve half at the top images: 1205.8199259508583 1884.0477172733094
pearson correlation of tuning curve between V4 and TDANN41: 0.947


  mean_std_rois[:, 0] /= mean_std_rois[:, 2] # average mean
  mean_std_rois[:, 1] /= mean_std_rois[:, 2] # average std


pearson correlation of cordis function between V4 and TDANN41: 0.778


# Representation similarity matrix

In [None]:
img0 = np.load("/Users/dunhan/Desktop/topoV4/som/V4_benchmark/rsptop_0index.npy")
roi = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/ROI.npy").T
img0 = img0[np.where(roi == 1)[0], np.where(roi == 1)[1], :].flatten() # (3048*9,)
img0 = np.unique(img0).astype(int) # 0index of unique top-9 images preferred by 3048 V4 neuronal columns
img_num = len(img0)

# Load the V4 digital twin responses to 50k images
response = np.load("/Users/dunhan/Desktop/topoV4/Tianye/Analysis/S4_preference_map/Prsp.npy") # of shape (50000, 128, 128)
rsp = np.zeros((3048, 50000))
index = 0
for i in range(128):
    for j in range(128):
        if roi[i, j] == 1:
            rsp[index, :] = response[:, i, j]
            index += 1
del response

rsm_V4 = np.zeros((img_num, img_num))
for i in tqdm(range(img_num), desc="V4 RSM calculation...", disable=False):
    for j in range(i, img_num):  # Start j from i to only compute the upper triangular
        rsm_V4[i, j] = np.corrcoef(rsp[:, img0[i]], rsp[:, img0[j]])[0, 1]
        rsm_V4[j, i] = rsm_V4[i, j]  # Mirror the value to the lower triangular part
np.save("/Users/dunhan/Desktop/topoV4/som/V4_benchmark/rsm_V4.npy", rsm_V4)
del rsp

rsp = np.load("/Volumes/dunhanSSD/topoV4/TDANN_final/responses/TDANN31_V4.npy").T # to shape (50176, 50000)
assert rsp.shape[1] == 50000
rsm_TDANN31 = np.zeros((img_num, img_num))
for i in tqdm(range(img_num), desc="TDANN31 RSM calculation...", disable=False):
    for j in range(i, img_num):  # Start j from i to only compute the upper triangular
        rsm_TDANN31[i, j] = np.corrcoef(rsp[:, img0[i]], rsp[:, img0[j]])[0, 1]
        rsm_TDANN31[j, i] = rsm_TDANN31[i, j]  # Mirror the value to the lower triangular part
np.save("/Users/dunhan/Desktop/topoV4/som/V4_benchmark/rsm_TDANN31.npy", rsm_TDANN31)
del rsp

rsp = np.load("/Volumes/dunhanSSD/topoV4/TDANN_final/responses/layer41.npy").T # to shape (25088, 50000)
assert rsp.shape[1] == 50000
rsm_TDANN41 = np.zeros((img_num, img_num))
for i in tqdm(range(img_num), desc="TDANN31 RSM calculation...", disable=False):
    for j in range(i, img_num):  # Start j from i to only compute the upper triangular
        rsm_TDANN41[i, j] = np.corrcoef(rsp[:, img0[i]], rsp[:, img0[j]])[0, 1]
        rsm_TDANN41[j, i] = rsm_TDANN41[i, j]  # Mirror the value to the lower triangular part
np.save("/Users/dunhan/Desktop/topoV4/som/V4_benchmark/rsm_TDANN41.npy", rsm_TDANN41)
del rsp

V4_TDANN31 = np.corrcoef(rsm_V4.flatten(), rsm_TDANN31.flatten())[0, 1]
print("V4_TDANN31:", V4_TDANN31) # 0.30309409089513717
V4_TDANN41 = np.corrcoef(rsm_V4.flatten(), rsm_TDANN41.flatten())[0, 1]
print("V4_TDANN41:", V4_TDANN41) # 0.24174864985821207

# Retinotopy map

In [5]:
layer = "layer31"
if layer == "layer41":
    positions = np.load("/Users/dunhan/Desktop/topoV4/ResNet/TDANNfinal_positions/layer4.1.npz")["coordinates"] # (25088, 2)
elif layer == "layer31":
    positions = np.load("/Users/dunhan/Desktop/topoV4/ResNet/TDANNfinal_positions/layer3.1.npz")["coordinates"] # (50176, 2)
polar_angle = np.zeros((positions.shape[0])) # theta
eccentricity = np.zeros((positions.shape[0])) # r
for i in range(positions.shape[0]):
    if layer == "layer41": # 25088 = 512 * 7 * 7
        depth_index = i // (7 * 7)        # Get index along the first dimension (512)
        row_col_index = i % (7 * 7)       # Remaining index within the 7x7 matrix
        row_index = row_col_index // 7    # Get index along the second dimension (7)
        col_index = row_col_index % 7     # Get index along the third dimension (7)
        polar_angle[i] = math.degrees(np.arctan2((col_index-3), (row_index-3)))
        eccentricity[i] = math.degrees(math.atan(np.sqrt((row_index-3) ** 2 + (col_index-3) ** 2) / 45)) # assuming 45 cm away from the fovea
    elif layer == "layer31": # 50176 = 256 * 14 * 14
        depth_index = i // (14 * 14)
        row_col_index = i % (14 * 14)
        row_index = row_col_index // 14
        col_index = row_col_index % 14
        polar_angle[i] = math.degrees(np.arctan2((col_index-7), (row_index-7)))
        eccentricity[i] = math.degrees(math.atan(np.sqrt((row_index-7) ** 2 + (col_index-7) ** 2) / 45)) # assuming 45 cm away from the fovea
polar_angle = (polar_angle - np.min(polar_angle)) / (np.max(polar_angle) - np.min(polar_angle))
hues = polar_angle  # Hue corresponds to normalized values
saturation = 1.0  # Full saturation
value = 1.0  # Full brightness
polar_angle = np.stack([hues, np.full_like(hues, saturation), np.full_like(hues, value)], axis=1) # Combine to create HSV colors
polar_angle = hsv_to_rgb(polar_angle) # Convert HSV to RGB

eccentricity = (eccentricity - np.min(eccentricity)) / (np.max(eccentricity) - np.min(eccentricity))
hues = eccentricity  # Hue corresponds to normalized values
saturation = 1.0  # Full saturation
value = 1.0  # Full brightness
eccentricity = np.stack([hues, np.full_like(hues, saturation), np.full_like(hues, value)], axis=1) # Combine to create HSV colors
eccentricity = hsv_to_rgb(eccentricity) # Convert HSV to RGB

# visualization
fig, axes = plt.subplots(2, 1, figsize=(3, 6))
if layer == "layer41": s = 0.1
elif layer == "layer31": s = 0.05
axes[0].scatter(positions[:, 0], positions[:, 1], c=polar_angle, s=s)
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("Polar angle", fontsize=22)
for spine in axes[0].spines.values(): spine.set_visible(False)  # Remove outer black box
axes[1].scatter(positions[:, 0], positions[:, 1], c=eccentricity, s=s)
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_title("Eccentricity", fontsize=22)
for spine in axes[1].spines.values(): spine.set_visible(False)  # Remove outer black box
plt.tight_layout()
fig = plt.gcf()
if layer == "layer41":
    fig.savefig("/Users/dunhan/Desktop/topoV4/ResNet/TDANN41_theta_r.png", dpi=1000)
elif layer == "layer31":
    fig.savefig("/Users/dunhan/Desktop/topoV4/ResNet/TDANN31_theta_r.png", dpi=1000)
plt.close()

# Feature dispersity

In [11]:
# feature dispersity map visualization
layer = "layer41"
if layer == "layer41":
    positions = np.load("/Users/dunhan/Desktop/topoV4/ResNet/TDANNfinal_positions/layer4.1.npz")["coordinates"] # (25088, 2)
    FD = scipy.io.loadmat("/Users/dunhan/Desktop/topoV4/som/TDANN_dispersity/layer41_corrected6/FD.mat")["TRsp"]
    num_units = 25088
elif layer == "layer31":
    positions = np.load("/Users/dunhan/Desktop/topoV4/ResNet/TDANNfinal_positions/layer3.1.npz")["coordinates"] # (50176, 2)
    FD = scipy.io.loadmat("/Users/dunhan/Desktop/topoV4/som/TDANN_dispersity/layer31_corrected6/FD.mat")["TRsp"]
    num_units = 50176
FD_all = np.zeros((num_units))
for i in range(num_units):
    if layer == "layer41": # 25088 = 512 * 7 * 7
        depth_index = i // (7 * 7)        # Get index along the first dimension (512)
        row_col_index = i % (7 * 7)       # Remaining index within the 7x7 matrix
        FD_all[i] = FD[depth_index]
    elif layer == "layer31": # 50176 = 256 * 14 * 14
        depth_index = i // (14 * 14)
        row_col_index = i % (14 * 14)
        FD_all[i] = FD[depth_index]
# FD_all_hist = FD_all.copy()
grid_num = 60
FD_map = np.zeros((grid_num, grid_num))
cortical_size = max(positions[:, 0]) - min(positions[:, 0]) # define the length of the 2D plane
for i in tqdm(range(grid_num), desc="map initialization...", disable=True):
    for j in range(grid_num):
        # first find all units in this current grid
        xmin_cortex = cortical_size / grid_num * i
        xmax_cortex = cortical_size / grid_num * (i + 1)
        ymin_cortex = cortical_size / grid_num * j
        ymax_cortex = cortical_size / grid_num * (j + 1)
        # find all units in this current grid
        units_within_grid_indeices = np.where((positions[:, 0] >= xmin_cortex) & (positions[:, 0] < xmax_cortex) & (positions[:, 1] >= ymin_cortex) & (positions[:, 1] < ymax_cortex))[0]
        if len(units_within_grid_indeices) > 0:
            FD_map[i, j] = np.mean(FD_all[units_within_grid_indeices])
FD_all = (FD_all - np.min(FD_all)) / (np.max(FD_all) - np.min(FD_all))
hues = FD_all  # Hue corresponds to normalized values
saturation = 1.0  # Full saturation
value = 1.0  # Full brightness
FD_all = np.stack([hues, np.full_like(hues, saturation), np.full_like(hues, value)], axis=1) # Combine to create HSV colors
FD_all = hsv_to_rgb(FD_all) # Convert HSV to RGB

# visualization
"""
axes[1].hist(FD_all_hist, bins=10, color='gray')
axes[1].set_title("FD distribution", fontsize=20)
axes[1].set_ylabel("Counts", fontsize=16)
axes[1].set_yticks([]) # erase y-axis ticks
axes[1].set_xticks([])
"""
fig, axes = plt.subplots(2, 1, figsize=(3, 6))
if layer == "layer41": s = 0.1
elif layer == "layer31": s = 0.05
axes[0].scatter(positions[:, 0], positions[:, 1], c=FD_all, s=s)
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("FD raw", fontsize=20)
for spine in axes[0].spines.values(): spine.set_visible(False)  # Remove outer black box
im = axes[1].imshow(FD_map)
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_title("FD map", fontsize=20)
cbar = fig.colorbar(im, ax=axes[1], orientation='vertical', shrink=0.8)
cbar.set_label("FD Map Scale")  # Optional label for colorbar
plt.tight_layout()
fig = plt.gcf()
if layer == "layer41":
    fig.savefig("/Users/dunhan/Desktop/topoV4/ResNet/TDANN41_FD.png", dpi=1000)
elif layer == "layer31":
    fig.savefig("/Users/dunhan/Desktop/topoV4/ResNet/TDANN31_FD.png", dpi=1000)
plt.close()

  FD_all[i] = FD[depth_index]
