In [None]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from tqdm import tqdm
import scipy.io
import h5py
import cv2
from skimage.transform import resize
from scipy.ndimage import label, sum as ndi_sum
import matplotlib.image
from matplotlib import pyplot as plt
import matplotlib.colors as mcolors
from PIL import Image
from scipy.cluster.hierarchy import linkage, fcluster
from skimage.measure import find_contours
import math

folder_path = "/path/to/your/data/"

# Retinotopy map

In [None]:
layer = "layer40"
positions = np.load(folder_path + "Fig4/TDANNfinal_positions/layer" + str(layer[-2:-1]) + "." + str(layer[-1]) + ".npz")["coordinates"] # (num_units, 2)
polar_angle = np.zeros((positions.shape[0])) # theta
eccentricity = np.zeros((positions.shape[0])) # r
for i in range(positions.shape[0]):
    if layer == "layer40" or layer == "layer41": # 25088 = 512 * 7 * 7
        size = 7
        center = 3
        s = 3
    elif layer == "layer30" or layer == "layer31": # 50176 = 256 * 14 * 14
        size = 14
        center = 7
        s = 1.5
    depth_index = i // (size * size)        # Get index along the first dimension (512)
    row_col_index = i % (size * size)       # Remaining index within the 7x7 matrix
    row_index = row_col_index // size    # Get index along the second dimension (7)
    col_index = row_col_index % size     # Get index along the third dimension (7)
    polar_angle[i] = math.degrees(np.arctan2((center - (size-col_index)), (center - (row_index))))
    eccentricity[i] = math.degrees(math.atan(np.sqrt((row_index-center) ** 2 + (col_index-center) ** 2) / 45)) # assuming 45 cm away from the fovea
polar_angle = (polar_angle - np.min(polar_angle)) / (np.max(polar_angle) - np.min(polar_angle))
eccentricity = (eccentricity - np.min(eccentricity)) / (np.max(eccentricity) - np.min(eccentricity))

# visualization
"""
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
fs = 16
axes[0].scatter(positions[:, 1], positions[:, 0], c=polar_angle, s=s, marker='o', edgecolors='none')
axes[0].invert_yaxis()  # Move the origin to the top-left corner
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("Polar Angle", fontsize=fs)
for spine in axes[0].spines.values(): spine.set_visible(False)  # Remove outer black box
axes[1].scatter(positions[:, 1], positions[:, 0], c=eccentricity, s=s, marker='o', edgecolors='none')
axes[1].invert_yaxis()  # Move the origin to the top-left corner
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_title("Eccentricity", fontsize=fs)
for spine in axes[1].spines.values(): spine.set_visible(False)  # Remove outer black box
plt.tight_layout()
fig = plt.gcf()
fig.savefig(folder_path + "Fig4/" + layer + "/theta_r.png", dpi=1000)
plt.close()
del fig
"""
fig = plt.subplots(figsize=(4, 4))
plt.scatter(positions[:, 1], positions[:, 0], c=polar_angle, s=s, marker='o', edgecolors='none', cmap='hsv')
plt.gca().invert_yaxis()  # Move the origin to the top-left corner
plt.axis("off")  # Hide axes
# plt.title("Polar Angle", fontsize=22)
plt.tight_layout()
plt.savefig(folder_path + "Fig4/" + layer + "/theta.png", dpi=300)
plt.close()
del fig

fig = plt.subplots(figsize=(4, 4))
plt.scatter(positions[:, 1], positions[:, 0], c=eccentricity, s=s, marker='o', edgecolors='none', cmap='hsv')
plt.gca().invert_yaxis()  # Move the origin to the top-left corner
plt.axis("off")  # Hide axes
# plt.title("Eccentricity", fontsize=22)
plt.tight_layout()
plt.savefig(folder_path + "Fig4/" + layer + "/r.png", dpi=300)
plt.close()
del fig

# Estimate TDANN unit heatmap

In [None]:
def load_model_from_checkpoint(checkpoint_path: str, device: str):
    model = torchvision.models.resnet18(pretrained=False)
    # drop the FC layer
    model.fc = nn.Identity()
    # load weights
    ckpt = torch.load(checkpoint_path, map_location=torch.device(device))
    state_dict = ckpt["classy_state_dict"]["base_model"]["model"]["trunk"]
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("base_model") and "fc." not in k:
            remainder = k.split("base_model.")[-1]
            new_state_dict[remainder] = v
    model.load_state_dict(new_state_dict)
    # freeze all weights
    for param in model.parameters(): param.requires_grad = False
    return model

def heatmap_estimation():
    layer = "layer41" # targeted layer
    imgs_0index = np.load(folder_path + "Fig4/TDANNrsp_1kimg_0idx.npz")[layer] # (num_units, 1000), 0-indexed top 100 images for each unit
    # load TDANN model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    checkpoint_path = folder_path + "Fig4/TDANNfinal.torch"
    model = load_model_from_checkpoint(checkpoint_path, device)
    model = model.to(device)
    model.eval()
    # number of feature channels, essentially num_units to estimate a heatmap
    if layer == "layer10" or layer == "layer11":
        fc_num = 64
    elif layer == "layer20" or layer == "layer21":
        fc_num = 128
    elif layer == "layer30" or layer == "layer31":
        fc_num = 256
    elif layer == "layer40" or layer == "layer41":
        fc_num = 512
    heatmaps1k_all = np.zeros((fc_num, 26, 224, 224)) # to store one averaged heatmaps1k + top25 heatmaps for all estimated units
    for fc in range(fc_num): # Find the target unit's index in the targeted layer response tensor: 3rd layer [256, 14, 14]; 4th layer [512, 7, 7]
        # central unit localization
        if layer == "layer10" or layer == "layer11":
            feature_map_size = 56
            xy = 28
        elif layer == "layer20" or layer == "layer21":
            feature_map_size = 28
            xy = 14
        elif layer == "layer30" or layer == "layer31":
            feature_map_size = 14
            xy = 7
        elif layer == "layer40" or layer == "layer41":
            feature_map_size = 7
            xy = 3
        flat_index = fc*feature_map_size*feature_map_size + xy*feature_map_size + xy # 0-indexed unit index
        depth_index = flat_index // (feature_map_size * feature_map_size)        # Get index along the first dimension
        row_col_index = flat_index % (feature_map_size * feature_map_size)       # Remaining index within the 7x7 matrix
        row_index = row_col_index // feature_map_size             # Get index along the second dimension (7)
        col_index = row_col_index % feature_map_size             # Get index along the third dimension (7)
        print("Currently targeting at unit: ", flat_index, (0, depth_index, row_index, col_index)) # unit index indicates its original position in the 3rd layer response tensor
        # iterate through this unit most preferred 1k images
        gradients1k = np.zeros((224, 224)) # store the gradient of the current unit's activation w.r.t. 1000 input image
        for k in tqdm(range(1000), desc="top 1k imgs...", disable=True):
            # Define hook to capture activation
            activation = [None]  # Use a mutable container to avoid using global variable
            def hook_fn(module, input, output):
                activation[0] = output  # Store output in the list
            # targeting at different layer
            if layer == "layer10":
                target_layer = model.layer1[0]
            elif layer == "layer11":
                target_layer = model.layer1[1]
            elif layer == "layer20":
                target_layer = model.layer2[0]
            elif layer == "layer21":
                target_layer = model.layer2[1]
            elif layer == "layer30":
                target_layer = model.layer3[0]
            elif layer == "layer31":
                target_layer = model.layer3[1]
            elif layer == "layer40":
                target_layer = model.layer4[0]
            elif layer == "layer41":
                target_layer = model.layer4[1]
            hook = target_layer.register_forward_hook(hook_fn)
            preprocess = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
            # generate 20 image samples with noise from the same image
            num_samples = 20
            sigma = 0.20  # Noise standard deviation
            index1 = (1 + imgs_0index[fc, :]).astype(int) # 0-indexed to 1-indexed to match image file name
            batch_vector = []
            image_path = folder_path + "50K_Imgset/" + str(index1[k]) + ".bmp"
            input_image = preprocess(Image.open(image_path)).unsqueeze(0)  # Add batch dimension
            for _ in range(num_samples): batch_vector.append(input_image + sigma * torch.randn_like(input_image)) # Add Gaussian noise to input image
            batch_vector = torch.cat(batch_vector, dim=0)  # Shape will be [20, 3, 224, 224]
            batch_vector = batch_vector.to(device)
            batch_vector.requires_grad_()
            # Forward pass
            model.fc = nn.Identity()
            model(batch_vector)
            # Ensure that activation is correctly captured
            if activation[0] is None: raise RuntimeError("Activation was not captured. Check the hook function.")
            # Target the specific unit's activation
            loss = activation[0][:, depth_index, row_index, col_index]
            loss = loss.mean() # if for multiple samples, sum over all of them
            loss.backward()
            # Square and store gradient, torch.Size([1000, 3, 224, 224])
            gradients = batch_vector.grad ** 2
            # Average over samples, from torch.Size([batch_size, 3, 224, 224]) to torch.Size([3, 224, 224]) and to [224, 224] of np array
            gradients = gradients.mean(dim=0).mean(dim=0).detach().cpu().numpy()
            if k < 25: heatmaps1k_all[fc, int(k+1), :, :] = gradients # store top 25 images' heatmap as well
            # normalize the gradient to have total sum of the current averaged loss value, np array of shape [224, 224]
            g_sum = gradients.sum()
            if g_sum != 0.0: gradients1k += gradients / g_sum * loss.item()
            hook.remove()
        # store the current feature channel / unit's estimated heatmap averaged from 1k imgs
        heatmaps1k_all[fc, 0, :, :] = gradients1k
    # save the final results
    ###### change the next line for different layer
    np.save(folder_path + "heatmaps41.npy", heatmaps1k_all)

In [None]:
# convert npy heatmaps to mat
heatmaps = np.load(folder_path + "Fig4/heatmaps.npy")[:, 0, :, :]
print(heatmaps.shape)
scipy.io.savemat(folder_path + "Fig4/layer30/heatmaps.mat", {'heatmaps': heatmaps})

# NEXT, run Fit_2D_Gaussian.m (contributed by Tianye Wang) to fit 2D Gaussian to TDANN unit's heatmaps, then continue with Fig4_dispersity.ipynb

In [None]:
# Visualize the estimated 2D Gaussian fit center and std for layer3.0, the V2 layer
xs = scipy.io.loadmat(folder_path + "Fig4/layer30/b1.mat")["b1"] # x-center, (256, 1)
ys = scipy.io.loadmat(folder_path + "Fig4/layer30/b2.mat")["b2"] # y-center, (256, 1)
xs_std = scipy.io.loadmat(folder_path + "Fig4/layer30/c1.mat")["c1"] # x-std, (256, 1)
ys_std = scipy.io.loadmat(folder_path + "Fig4/layer30/c2.mat")["c2"] # y-std, (256, 1)
heatmaps = np.load(folder_path + "Fig4/layer30/heatmaps.npy")[:, 0, :, :] # Estimated averaged Heatmap (256, 224, 224)

# visualize the heatmap, with the estimated 2D Gaussian fit center shown in a black dot
fig, axes = plt.subplots(2, 16, figsize=(32, 4))
l = 3
for i in range(32):
    ax = axes[i // 16, i % 16]
    heatmap = heatmaps[i, :, :]
    x, y = int(xs[i][0]), int(ys[i][0])
    heatmap[x-l:x+l, y-l:y+l] = np.min(heatmap) # set the center to the minimum value
    ax.imshow(heatmap, cmap="hot")
    ax.axis("off")
    ellipse = matplotlib.patches.Ellipse((y, x), width=2*ys_std[i][0], height=2*xs_std[i][0], edgecolor="cyan", facecolor="none", linewidth=2)
    ax.add_patch(ellipse)
plt.tight_layout()

In [None]:
# Visualize the estimated 2D Gaussian fit center and std for layer3.1, the V4 layer
xs = scipy.io.loadmat(folder_path + "Fig4/layer31/b1.mat")["b1"] # x-center, (256, 1)
ys = scipy.io.loadmat(folder_path + "Fig4/layer31/b2.mat")["b2"] # y-center, (256, 1)
xs_std = scipy.io.loadmat(folder_path + "Fig4/layer31/c1.mat")["c1"] # x-std, (256, 1)
ys_std = scipy.io.loadmat(folder_path + "Fig4/layer31/c2.mat")["c2"] # y-std, (256, 1)
heatmaps = np.load(folder_path + "Fig4/layer31/heatmaps.npy")[:, 0, :, :] # Estimated averaged Heatmap (256, 224, 224)

# visualize the heatmap, with the estimated 2D Gaussian fit center shown in a black dot
fig, axes = plt.subplots(2, 16, figsize=(32, 4))
l = 3
for i in range(32):
    ax = axes[i // 16, i % 16]
    heatmap = heatmaps[i, :, :]
    x, y = int(xs[i][0]), int(ys[i][0])
    heatmap[x-l:x+l, y-l:y+l] = np.min(heatmap) # set the center to the minimum value
    ax.imshow(heatmap, cmap="hot")
    ax.axis("off")
    ellipse = matplotlib.patches.Ellipse((y, x), width=2*ys_std[i][0], height=2*xs_std[i][0], edgecolor="cyan", facecolor="none", linewidth=2)
    ax.add_patch(ellipse)
plt.tight_layout()

In [None]:
# Visualize the top 25 images' heatmaps for an example unit in layer 3.1, the V4 layer
heatmaps = np.load(folder_path + "Fig4/layer31/heatmaps.npy")[17, 1:, :, :] # (25, 224, 224)
areas = []
fig, axes = plt.subplots(4, 12, figsize=(24, 8))
for i in range(24):
    mask = np.where(heatmaps[i, :, :] >= np.mean(heatmaps[i, :, :]))
    areas.append(len(mask[0]) / (224*224))
    ax = axes[i // 12 + 2, i % 12]
    ax.imshow(heatmaps[i, :, :], cmap="hot")
    ax.axis("off")
    ax = axes[i // 12, i % 12]
    # adjust unmasked pixel value to +1, visualize the > average area
    masked_array = heatmaps[i, :, :]
    masked_array[mask[0], mask[1]] += 1
    ax.imshow(masked_array, cmap="hot")
    ax.axis("off")

In [None]:
# look at the distribution of the estimated heatmaps' area
layer = "layer31"
heatmaps = np.load(folder_path + "Fig4/" + layer + "/heatmaps.npy")[:, 1:, :, :] # Estimated averaged Heatmap (256, 25, 224, 224)
areas = []
for i in range(heatmaps.shape[0]):
    for j in range(heatmaps.shape[1]):
        mask = np.where(heatmaps[i, j, :, :] >= np.mean(heatmaps[i, j, :, :]))
        areas.append(len(mask[0]) / (224*224))
        assert len(mask[0]) == len(mask[1])
areas = np.array(areas)
areas = areas[areas <= 0.5]
fig, axes = plt.subplots(1, 2, figsize=(12, 2))
axes[0].hist(areas, bins=50)
axes[0].set_xlabel("heat salient pixels percentage")
axes[0].set_ylabel("counts")
heatmaps = np.load(folder_path + "Fig4/" + layer + "/heatmaps.npy")[:, 0, :, :] # Estimated averaged Heatmap (256, 224, 224)
areas = []
for i in range(heatmaps.shape[0]):
    mask = np.where(heatmaps[i, :, :] >= np.mean(heatmaps[i, :, :]))
    areas.append(len(mask[0]) / (224*224))
    assert len(mask[0]) == len(mask[1])
areas = np.array(areas)
areas = areas[areas <= 0.5]
axes[1].hist(areas, bins=50)
axes[1].set_xlabel("averaged heat salient pixels percentage")
axes[1].set_ylabel("counts")

In [None]:
# Find out the image aperture radius
img = np.array(Image.open(folder_path + "50K_Imgset/1.bmp"))
print(img.shape, img[0, 0, :], img[0, 99, :], img[99, 0, :], img[99, 99, :])
dist = []
for i in range(100):
    for j in range(100):
        if img[i, j, 0] == 48 and img[i, j, 1] == 48 and img[i, j, 2] == 48:
            dist.append(np.sqrt((i-50)**2 + (j-50)**2))
print(len(dist), np.mean(dist), np.min(dist)) # aperture radius: 46.0 / 100.0 of the image size

In [None]:
# Preprocess the estimated heatmap for layer4.0, the IT layer
heatmaps = np.load(folder_path + "Fig4/layer40/heatmaps.npy")[:, 0, :, :] # Estimated Heatmap (512, 224, 224)
for k in tqdm(range(heatmaps.shape[0]), desc="heatmap preprocessing...", disable=False):
    for i in range(224):
        for j in range(224):
            if heatmaps[k, i, j] != 0:
                if np.sqrt((i-112)**2 + (j-112)**2) > (0.46 * 224): # out of the image aperture
                    heatmaps[k, i, j] = 0
scipy.io.savemat(folder_path + "Fig4/layer40/heatmaps.mat", {'heatmaps': heatmaps}) # save the heatmaps to mat file

In [None]:
# Visualize the estimated 2D Gaussian fit center and std for layer4.0, the IT layer
xs = scipy.io.loadmat(folder_path + "Fig4/layer40/b1.mat")["b1"] # x-center, (512, 1)
ys = scipy.io.loadmat(folder_path + "Fig4/layer40/b2.mat")["b2"] # y-center, (512, 1)
xs_std = scipy.io.loadmat(folder_path + "Fig4/layer40/c1.mat")["c1"] # x-std, (512, 1)
ys_std = scipy.io.loadmat(folder_path + "Fig4/layer40/c2.mat")["c2"] # y-std, (512, 1)
# visualize the heatmap, with the estimated 2D Gaussian fit center shown in a black dot
fig, axes = plt.subplots(2, 16, figsize=(32, 4))
l = 3
for i in range(32):
    ax = axes[i // 16, i % 16]
    heatmap = heatmaps[i, :, :]
    x, y = int(xs[i][0]), int(ys[i][0])
    heatmap[x-l:x+l, y-l:y+l] = np.min(heatmap) # set the center to the minimum value
    ax.imshow(heatmap, cmap="hot")
    ax.axis("off")
    ellipse = matplotlib.patches.Ellipse((y, x), width=2*ys_std[i][0], height=2*xs_std[i][0], edgecolor="cyan", facecolor="none", linewidth=2)
    ax.add_patch(ellipse)
plt.tight_layout()

In [None]:
# Preprocess the estimated heatmap for layer4.1, the IT layer
heatmaps = np.load(folder_path + "Fig4/layer41/heatmaps.npy")[:, 0, :, :] # Estimated Heatmap (512, 224, 224)
for k in tqdm(range(heatmaps.shape[0]), desc="heatmap preprocessing...", disable=False):
    for i in range(224):
        for j in range(224):
            if heatmaps[k, i, j] != 0:
                if np.sqrt((i-112)**2 + (j-112)**2) > (0.46 * 224): # out of the image aperture
                    heatmaps[k, i, j] = 0
scipy.io.savemat(folder_path + "Fig4/layer41/heatmaps.mat", {'heatmaps': heatmaps}) # save the heatmaps to mat file

In [None]:
# Visualize the estimated 2D Gaussian fit center and std for layer4.1, the IT layer
xs = scipy.io.loadmat(folder_path + "Fig4/layer41/b1.mat")["b1"] # x-center, (512, 1)
ys = scipy.io.loadmat(folder_path + "Fig4/layer41/b2.mat")["b2"] # y-center, (512, 1)
xs_std = scipy.io.loadmat(folder_path + "Fig4/layer41/c1.mat")["c1"] # x-std, (512, 1)
ys_std = scipy.io.loadmat(folder_path + "Fig4/layer41/c2.mat")["c2"] # y-std, (512, 1)
# visualize the heatmap, with the estimated 2D Gaussian fit center shown in a black dot
fig, axes = plt.subplots(2, 16, figsize=(32, 4))
l = 3
for i in range(32):
    ax = axes[i // 16, i % 16]
    heatmap = heatmaps[i, :, :]
    x, y = int(xs[i][0]), int(ys[i][0])
    heatmap[x-l:x+l, y-l:y+l] = np.min(heatmap) # set the center to the minimum value
    ax.imshow(heatmap, cmap="hot")
    ax.axis("off")
    ellipse = matplotlib.patches.Ellipse((y, x), width=2*ys_std[i][0], height=2*xs_std[i][0], edgecolor="cyan", facecolor="none", linewidth=2)
    ax.add_patch(ellipse)
plt.tight_layout()

# TDANN unit response to partially Occlued/Preserved images

In [None]:
def load_model_from_checkpoint(checkpoint_path: str, device: str):
    model = torchvision.models.resnet18(pretrained=False)
    # drop the FC layer
    model.fc = nn.Identity()
    # load weights
    ckpt = torch.load(checkpoint_path, map_location=torch.device(device))
    state_dict = ckpt["classy_state_dict"]["base_model"]["model"]["trunk"]
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("base_model") and "fc." not in k:
            remainder = k.split("base_model.")[-1]
            new_state_dict[remainder] = v
    model.load_state_dict(new_state_dict)
    # freeze all weights
    for param in model.parameters(): param.requires_grad = False
    return model

# Image Occlusion test
layer = "layer41" # targeted layer, change for different layers
imgs_0index = np.load(folder_path + "Fig4/TDANNrsp_1kimg_0idx.npz")[layer] # (num_units, 1000), 0-indexed top 100 images for each unit
Sigma = h5py.File(folder_path + "Fig4/" + layer + "/Sigma.mat")["Sigma"] # 2D gaussian fitted receptive field parameters
heatmaps25 = np.load(folder_path + "Fig4/" + layer + "/heatmaps.npy") # (unit_num, 26, 224, 224)
# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_path = folder_path + "Fig4/TDANNfinal.torch"
model = load_model_from_checkpoint(checkpoint_path, device)
model = model.to(device)
model.eval()
# parameters
num_topimgs = 25 # use top 25 images to do ON and OFF response test
num_partial_imgs = 200 # for each top image, number of synthesis images to be used for the partial image test
step = 0.04 # step size for the Gaussian mask, maybe a bit small
if layer == "layer10" or layer == "layer11":
    unit_num = 64
    feature_map_size = 56
    xy = 28
elif layer == "layer20" or layer == "layer21":
    unit_num = 128
    feature_map_size = 28
    xy = 14
elif layer == "layer30" or layer == "layer31":
    unit_num = 256
    feature_map_size = 14
    xy = 7
elif layer == "layer40" or layer == "layer41":
    unit_num = 512
    feature_map_size = 7
    xy = 3
OnOff_activation = np.zeros((unit_num, num_topimgs, num_partial_imgs * 2))
# iterate through all channels' feature map center units
for k in tqdm(range(unit_num), desc="iterating through units...", disable=True):
    try:
        # central unit localization
        flat_index = k*feature_map_size*feature_map_size + xy*feature_map_size + xy # 0-indexed unit index
        depth_index = flat_index // (feature_map_size * feature_map_size)        # Get index along the first dimension
        row_col_index = flat_index % (feature_map_size * feature_map_size)       # Remaining index within the 7x7 matrix
        row_index = row_col_index // feature_map_size             # Get index along the second dimension (7)
        col_index = row_col_index % feature_map_size             # Get index along the third dimension (7)
        print("Currently targeting at unit: ", flat_index, (0, depth_index, row_index, col_index)) # unit index indicates its original position in the 3rd layer response tensor
        # most preferred images for the current unit
        index1 = (1 + imgs_0index[k, :]).astype(int) # 0-indexed to 1-indexed to match image file name, obtain the top 1000 most preferred images for the current unit
        # 2D Gaussian fit Parameters
        area = np.pi * Sigma[0, k] * Sigma[1, k]
        sig = np.sqrt(Sigma[0, k] * Sigma[1, k]) / 2
        for i in range(num_topimgs):
            # sort heatmap values from large to small for the current image
            heatmap = heatmaps25[k, i+1, :, :] # (224, 224), averaged heatmap before 25 top images in the 2nd dim
            heatmap = np.reshape(heatmap, (224 * 224,))
            heatmap = np.argsort(-heatmap)
            # space to store the synthesized images
            imgs_on = np.zeros((num_partial_imgs, 224, 224, 3))
            imgs_off = np.zeros((num_partial_imgs, 224, 224, 3))
            # generate preserved and occluded images with an aperture
            image = np.array(Image.open(folder_path + "50K_Imgset/" + str(index1[i]) + ".bmp"))
            for j in range(num_partial_imgs):
                mask = np.zeros((224*224,))
                mask[heatmap[:int(step*j*area)]] = 1
                mask1 = np.repeat(np.reshape(mask, (224,224,1)),3,axis=2)
                mask2 = cv2.GaussianBlur(mask1,(0,0),2.0)
                # ON
                img = image.copy()
                img = resize(img, (224, 224, 3), anti_aliasing=True)
                im3 = img - cv2.GaussianBlur(img, (0, 0), sig)
                img[mask1 < 1] = 0.5
                imgs_on[j, :, :, :] = cv2.GaussianBlur(img, (0, 0), sig) + im3 * mask2 # top-k preserved images
                # OFF
                img = image.copy()
                img = resize(img, (224, 224, 3), anti_aliasing=True)
                im3 = img - cv2.GaussianBlur(img, (0, 0), sig)
                img[mask1 > 0] = 0.5
                imgs_off[j, :, :, :] = cv2.GaussianBlur(img, (0, 0), sig) + im3 * (1 - mask2) # top-k removed images
            # create image batch from shape (num_partial_imgs * 2, 224, 224, 3) to (num_partial_imgs * 2, 3, 224, 224) as a torch tensor input
            imgs_batch = torch.tensor(np.concatenate((imgs_on, imgs_off), axis=0).astype(np.float32), device=device).permute(0, 3, 1, 2)
            # model forward pass, obtain activation values of the target unit
            activation = None # Placeholder for the activation
            def hook_fn(module, input, output): # Define the hook function to capture the output
                global activation
                activation = output
            # targeting at different layer
            if layer == "layer10":
                target_layer = model.layer1[0]
            elif layer == "layer11":
                target_layer = model.layer1[1]
            elif layer == "layer20":
                target_layer = model.layer2[0]
            elif layer == "layer21":
                target_layer = model.layer2[1]
            elif layer == "layer30":
                target_layer = model.layer3[0]
            elif layer == "layer31":
                target_layer = model.layer3[1]
            elif layer == "layer40":
                target_layer = model.layer4[0]
            elif layer == "layer41":
                target_layer = model.layer4[1]
            hook = target_layer.register_forward_hook(hook_fn)
            with torch.no_grad(): model(imgs_batch)
            OnOff_activation[k, i, :] = activation[:, depth_index, row_index, col_index].detach().cpu().numpy()  # obtain the activations of the target unit
            hook.remove()
    except Exception as e: print("error message at k =", k, e)
    break # remove this break to run through all units
# save the final results
np.save(folder_path + "Fig4/OnOff41.npy", OnOff_activation)

# Feature dispersity calculation

In [None]:
# Calculate the feature dispersity value for V4 digital twin
# Step 1: Define constants
pnum = 3048
snum = 25 # Number of image samples
num_partial_imgs = 200 # Number of partial images
step = 0.02 # Step size for the Gaussian mask
S0 = 2 * np.log(2) / step # Constant for normalization
# On, Off responses
# Load MATLAB .mat files
MRspS = scipy.io.loadmat("/Users/dunhan/Desktop/topoV4/V4DT/Dispersity_results/MRspS.mat")['MRspS'] # OFF responses, Shape: (3048, 25, 200)
TRspS = scipy.io.loadmat("/Users/dunhan/Desktop/topoV4/V4DT/Dispersity_results/TRspS.mat")['TRspS'] # ON responses, Shape: (3048, 25, 200)
ROI = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
# Step 2: Initialize ERsp
ERsp = np.zeros((2, pnum, snum, num_partial_imgs)) # Shape: (2, pnum, 25, 200)
# Step 3: Normalize MRspS and TRspS
for i in range(pnum):
    for j in range(snum):
        tmp1 = MRspS[i, j, :].squeeze()
        tmp2 = TRspS[i, j, :].squeeze()
        ERsp[0, i, j, :] = tmp1 / tmp1[0] # Normalize MRspS
        ERsp[1, i, j, :] = tmp2 / tmp1[0] # Normalize TRspS
# Step 4: Compute mean across samples (axis=2)
MRsp = np.mean(ERsp, axis=2)  # Shape: (2, pnum, 200)
# Step 5: Compute TRsp
TRsp = np.zeros(pnum) # Raw feature dispersity values
for ci in range(pnum):
    tmp1 = MRsp[0, ci, :] - MRsp[1, ci, :]
    tmp2 = tmp1 > 0
    tmp3 = tmp1 < 0
    tmp4 = tmp2[:-1] * tmp3[1:]  # Detect zero-crossings
    tmp5 = np.where(tmp4 > 0)[0]  # Find crossing indices
    if tmp5.size == 0: TRsp[ci] = (num_partial_imgs - 1) / S0
    else:
        ID = tmp5[0]
        tmp6 = abs(tmp1[ID]) + abs(tmp1[ID + 1])
        tmp7 = ID + abs(tmp1[ID]) / tmp6 - 1
        TRsp[ci] = tmp7 / S0

FD = np.zeros((128, 128)) # Feature dispersity map
count = 0
for i in range(128):
    for j in range(128):
        if ROI[i, j] == 1:
            FD[i, j] = TRsp[count]
            count += 1
assert count == 3048 # Ensure all units are processed
# Normalize FD and map to 100 levels
Min, Max = 0.0, 1.0
Imap = np.floor((FD - Min) / (Max - Min) * 100).astype(int) # (128, 128)
Imap = np.clip(Imap, 1, 100) # (128, 128), Clamping values to [1, 100]
# Get Parula colormap with 100 levels
colA = scipy.io.loadmat(folder_path + "V4DT/Dispersity_results/color_scheme.mat")["colA"] # (100, 3)
cmap = np.zeros((128, 128, 3)) # Create an empty colormap image (initialized as black)
mask = ROI == 1 # masked ROI region
cmap[mask] = colA[Imap[mask] - 1]  # -1 because Python indexing starts at 0

# Display the final colormap
plt.imshow(cmap)
plt.axis("off")  # Hide axes
plt.title("Colormap with ROI Mask")
plt.show()

In [None]:
# Calculate the feature dispersity value for TDANN layers
LAYERS = ["layer31", "layer40"]
for layer in LAYERS:
    # Step 0: Load the OnOff activation values
    OnOff_activation = np.load(folder_path + "Fig4/" + layer + "/OnOff" + str(layer[-2:]) + ".npy") # (pnum, 25, 400)
    # Step 1: Define constants
    pnum = OnOff_activation.shape[0] # 256 for layer3.1 (V4), 512 for layer4.0 and layer4.1 (IT)
    snum = 25 # Number of image samples
    num_partial_imgs = 200 # Number of partial images
    step = 0.04 # Step size for the Gaussian mask
    S0 = 2 * np.log(2) / step # Constant for normalization
    # On, Off responses
    TRspS = OnOff_activation[:, :, :num_partial_imgs] # ON responses, Shape: (pnum, 25, 200)
    MRspS = OnOff_activation[:, :, num_partial_imgs:] # OFF responses, Shape: (pnum, 25, 200)
    # Step 2: Initialize ERsp
    ERsp = np.zeros((2, pnum, snum, num_partial_imgs)) # Shape: (2, pnum, 25, 200)
    # Step 3: Normalize MRspS and TRspS
    for i in range(pnum):
        for j in range(snum):
            tmp1 = MRspS[i, j, :].squeeze()
            tmp2 = TRspS[i, j, :].squeeze()
            ERsp[0, i, j, :] = tmp1 / tmp1[0] # Normalize MRspS
            ERsp[1, i, j, :] = tmp2 / tmp1[0] # Normalize TRspS
    # Step 4: Compute mean across samples (axis=2)
    MRsp = np.mean(ERsp, axis=2)  # Shape: (2, pnum, 200)
    # Step 5: Compute TRsp
    TRsp = np.zeros(pnum) # Raw feature dispersity values
    for ci in range(pnum):
        tmp1 = MRsp[0, ci, :] - MRsp[1, ci, :]
        tmp2 = tmp1 > 0
        tmp3 = tmp1 < 0
        tmp4 = tmp2[:-1] * tmp3[1:]  # Detect zero-crossings
        tmp5 = np.where(tmp4 > 0)[0]  # Find crossing indices
        if tmp5.size == 0: TRsp[ci] = (num_partial_imgs - 1) / S0
        else:
            ID = tmp5[0]
            tmp6 = abs(tmp1[ID]) + abs(tmp1[ID + 1])
            tmp7 = ID + abs(tmp1[ID]) / tmp6 - 1
            TRsp[ci] = tmp7 / S0
    # save the feature dispersity values
    np.save(folder_path + "Fig4/" + layer + "/FD.npy", TRsp)

In [None]:
# Histogram visualization
for i, layer in enumerate(["V4DT", "layer31", "layer40"]):
    if layer == "V4DT":
        FD = scipy.io.loadmat(folder_path + "V4DT/Dispersity_results/FDraw.mat")["FD"]
        roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
        FD = FD[roi == 1] # (3048,)
        x_vals_v4 = np.linspace(FD.min(), FD.max(), 100)  # Continuous range of values
        shape_hat, loc_hat, scale_hat = scipy.stats.gamma.fit(FD, floc=0)  # Force location to 0
        gamma_pdf_v4 = scipy.stats.gamma.pdf(x_vals_v4, a=shape_hat, loc=loc_hat, scale=scale_hat)
    else:
        FD = scipy.io.loadmat(folder_path + "V4DT/Dispersity_results/FDraw.mat")["FD"]
        roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128) <class 'numpy.ndarray'>
        FD = FD[roi == 1] # (3048,)
        x_vals_v4 = np.linspace(FD.min(), FD.max(), 100)  # Continuous range of values
        shape_hat, loc_hat, scale_hat = scipy.stats.gamma.fit(FD, floc=0)  # Force location to 0
        gamma_pdf_v4 = scipy.stats.gamma.pdf(x_vals_v4, a=shape_hat, loc=loc_hat, scale=scale_hat)

        FD = np.load(folder_path + "Fig4/" + layer + "/FD.npy")
        rsquares = scipy.io.loadmat(folder_path + "Fig4/" + layer + "/rsquares.mat")["rsquares"]
        rsquares = rsquares.squeeze()
        # FD = FD[rsquares >= 0.9]
        FD = FD[FD > 0]
    fs = 16
    x_vals = np.linspace(FD.min(), FD.max(), 100)  # Continuous range of values
    # Fit a Gamma distribution to the data
    shape_hat, loc_hat, scale_hat = scipy.stats.gamma.fit(FD, floc=0)  # Force location to 0
    # Compute the Gamma PDF values
    gamma_pdf = scipy.stats.gamma.pdf(x_vals, a=shape_hat, loc=loc_hat, scale=scale_hat)
    fig = plt.figure(figsize=(4, 3))
    plt.hist(FD, bins=20, density=True, alpha=0.6, color='b', edgecolor='black')
    if layer == "layer31":
        plt.plot(x_vals, gamma_pdf, 'r-', linewidth=2, label=f'TDANN V4 layer')
    elif layer == "layer40":
        plt.plot(x_vals, gamma_pdf, 'r-', linewidth=2, label=f'TDANN ITC layer')
    else: plt.plot(x_vals, gamma_pdf, 'r-', linewidth=2, label=f'{layer} gamma fit (k={shape_hat:.2f}, Î¸={scale_hat:.2f})')
    gamma_pdf_v4 = gamma_pdf_v4 / np.sum(gamma_pdf_v4) * np.sum(gamma_pdf) # re-scale the V4DT gamma pdf to match the other pdf sum
    plt.plot(x_vals_v4, gamma_pdf_v4, 'k-', linewidth=1, label=f'V4 digital twin')
    plt.xlabel("Dispersity", fontsize=fs)
    plt.ylabel("Density", fontsize=fs)
    plt.yticks([])
    plt.legend()
    # save the figure
    plt.tight_layout()
    if layer == "V4DT": fig.savefig(folder_path + "V4DT/FDhist.png", dpi=300)
    else: fig.savefig(folder_path + "Fig4/" + layer + "/FDhist.png", dpi=300)
    plt.close()
    del fig

# Feature dispersity map

In [None]:
# feature dispersity map visualization / scatter plot
LAYERS = ["layer31", "layer40"]
for layer in LAYERS:
    positions = np.load(folder_path + "Fig4/TDANNfinal_positions/layer" + str(layer[-2:-1]) + "." + str(layer[-1]) + ".npz")["coordinates"] # (num_units, 2)
    FD = np.load(folder_path + "Fig4/" + layer + "/FD.npy")
    if layer == "layer41" or layer == "layer40":
        num_units = 25088
        size = 7
        s = 3
    elif layer == "layer31" or layer == "layer30":
        num_units = 50176
        size = 14
        s = 1.5
    FD_all = np.zeros((num_units))
    for i in range(num_units):
        depth_index = i // (size * size)
        FD_all[i] = FD[depth_index]

    # Normalize FD and map to 100 levels
    Min, Max = 0, 3
    Imap = np.floor((FD_all - Min) / (Max - Min) * 100).astype(int)
    Imap = np.clip(Imap, 1, 100) # Clamping values to [1, 100]
    # Get Parula colormap with 100 levels
    colA = scipy.io.loadmat(folder_path + "V4DT/Dispersity_results/color_scheme.mat")["colA"] # (100, 3)
    colA = mcolors.ListedColormap(colA)

    # Display the final colormap
    fig = plt.figure(figsize=(4, 4))
    plt.scatter(positions[:, 1], positions[:, 0], c=Imap, s=s, cmap=colA, marker='o', edgecolors='none')
    plt.gca().invert_yaxis() # Move origin to the top-left
    plt.axis("off")  # Hide axes
    # plt.title("Dispersity", fontsize=22)
    # save the figure
    plt.tight_layout()
    plt.savefig(folder_path + "Fig4/" + layer + "/FDscatter03.png", dpi=1000)
    plt.close()
    del fig

In [None]:
# given a map corresponding domain number, remove the unconnected components smaller than a threshold
def connected_components(matched, threshold, return_largest=False):
    # Define the structure (connectivity) for connected components
    structure = np.ones((3, 3), dtype=int)  # 4-connectivity, meaning that both adjacent and diagonal pixels are considered as neighbors
    structure[0, 0] = 0
    structure[0, 2] = 0
    structure[2, 0] = 0
    structure[2, 2] = 0
    connected = np.copy(matched) # Initialize an array to store the output
    if return_largest: largest = np.copy(matched) # Initialize an array to store the largest connected component
    for class_num in np.unique(matched): # Process each class separately
        if class_num == 0: continue # Skip the background
        class_mask = (matched == class_num) # Create a binary mask for the current class
        labeled_array, num_features = label(class_mask, structure=structure) # Find connected components in the binary mask
        component_sizes = ndi_sum(class_mask, labeled_array, index=range(1, num_features + 1)) # Compute the size of each component
        for i, size in enumerate(component_sizes): # Zero out components smaller than the threshold
            if size < threshold: connected[labeled_array == (i + 1)] = 0
        if return_largest:
            for i, size in enumerate(component_sizes):
                if size != np.max(component_sizes): largest[labeled_array == (i + 1)] = 0
    for i in range(1, matched.shape[0]-1):
        for j in range(1, matched.shape[1]-1):
            if connected[i, j] == 0:
                if connected[i-1, j] == 1 and connected[i+1, j] == 1 and connected[i, j-1] == 1 and connected[i, j+1] == 1:
                    connected[i, j] = 1
    if return_largest: return connected, largest
    else: return connected

In [None]:
# feature dispersity map visualization / aggregated into 60 by 60 gridded colormap
LAYERS = ["layer31", "layer40"]
for layer in LAYERS:
    positions = np.load(folder_path + "Fig4/TDANNfinal_positions/layer" + str(layer[-2:-1]) + "." + str(layer[-1]) + ".npz")["coordinates"] # (num_units, 2)
    FD = np.load(folder_path + "Fig4/" + layer + "/FD.npy")
    if layer == "layer41" or layer == "layer40":
        num_units = 25088
        size = 7
        s = 3
    elif layer == "layer31" or layer == "layer30":
        num_units = 50176
        size = 14
        s = 1.5
    FD_all = np.zeros((num_units))
    for i in range(num_units):
        depth_index = i // (size * size)
        FD_all[i] = FD[depth_index]
    grid_num = 60
    FD_map = np.zeros((grid_num, grid_num))
    ROI = np.zeros((grid_num, grid_num))
    cortical_size = max(positions[:, 0]) - min(positions[:, 0]) # define the length of the 2D plane
    for i in tqdm(range(grid_num), desc="map initialization...", disable=True):
        for j in range(grid_num):
            # first find all units in this current grid
            xmin_cortex = cortical_size / grid_num * i
            xmax_cortex = cortical_size / grid_num * (i + 1)
            ymin_cortex = cortical_size / grid_num * j
            ymax_cortex = cortical_size / grid_num * (j + 1)
            # find all units in this current grid
            units_within_grid_indeices = np.where((positions[:, 0] >= xmin_cortex) & (positions[:, 0] < xmax_cortex) & (positions[:, 1] >= ymin_cortex) & (positions[:, 1] < ymax_cortex))[0]
            if len(units_within_grid_indeices) > 0:
                FD_map[i, j] = np.mean(FD_all[units_within_grid_indeices])
                ROI[i, j] = 1
    
    # Hierarchical clustering, perform hierarchical clustering
    # Create a mask to ignore zero values
    nonzero_mask = FD_map > 0
    nonzero_values = FD_map[nonzero_mask]  # Extract nonzero values
    # Perform hierarchical clustering only on nonzero values
    Z = linkage(nonzero_values.reshape(-1, 1), method='ward')
    cluster_labels = fcluster(Z, t=2, criterion='maxclust')
    # Map cluster labels back to original shape
    clusters = np.zeros_like(FD_map, dtype=int)  # Default 0 (ignored values)
    clusters[nonzero_mask] = cluster_labels  # Assign labels only to nonzero positions
    # Identify the high-value cluster
    cluster_1_mean = nonzero_values[cluster_labels == 1].mean()
    cluster_2_mean = nonzero_values[cluster_labels == 2].mean()
    high_dispersity_cluster = 1 if cluster_1_mean > cluster_2_mean else 2
    # Assign 1 only to grids in the high-dispersity cluster
    clusters[clusters != high_dispersity_cluster] = 0
    clusters[clusters == high_dispersity_cluster] = 1  # Convert high-value cluster to 1
    # Remove unconnected components smaller than a threshold
    clusters = connected_components(clusters, 10, return_largest=False)
    contours = find_contours(clusters, level=0.5)  # Extracts boundaries of 1 clusters

    # Normalize FD and map to 100 levels
    Min, Max = 0, 1
    Imap = np.floor((FD_map - Min) / (Max - Min) * 100).astype(int)
    Imap = np.clip(Imap, 1, 100) # Clamping values to [1, 100], values outside the interval are clipped to the interval edges
    # Get Parula colormap with 100 levels
    colA = scipy.io.loadmat(folder_path + "V4DT/Dispersity_results/color_scheme.mat")["colA"] # (100, 3)
    cmap = np.zeros((grid_num, grid_num, 3)) # Create an empty colormap image (initialized as black); from top left to bottom right: [vertical, horizontal, rgb]
    mask = ROI == 1 # masked ROI region
    mask_off = ROI == 0 # masked OFF region
    cmap[mask] = colA[Imap[mask] - 1]  # -1 because Python indexing starts at 0
    cmap[mask_off] = np.ones((3)) # OFF region is white

    # Display the final colormap
    fig = plt.figure(figsize=(4, 4))
    plt.imshow(cmap)
    plt.axis("off")  # Hide axes
    plt.title("Dispersity", fontsize=22)
    # save the figure
    plt.tight_layout()
    plt.savefig(folder_path + "Fig4/" + layer + "/FDmap01.png", dpi=300)
    plt.close()
    del fig

    # scatter plot with contours overlay
    image = Image.open(folder_path + "Fig4/" + layer + "/FDscatter01.png")
    image = np.array(image)
    image = image[300:3700, 300:3700, :] # Crop the image to focus on the scatter plot
    fig = plt.figure(figsize=(4, 4))
    plt.imshow(image)
    for contour in contours:
        contour[:, 1] *= image.shape[1] / grid_num
        contour[:, 0] *= image.shape[0] / grid_num
        plt.plot(contour[:, 1], contour[:,0], color="darkgoldenrod", lw=4)  # Draw the contours
    plt.axis("off")
    plt.title("Dispersity", fontsize=22)
    # save the figure
    plt.tight_layout()
    plt.savefig(folder_path + "Fig4/" + layer + "/FDscatter01_contour.png", dpi=300)
    plt.close()
    del fig