In [None]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from skimage.transform import resize
from scipy.ndimage import label, sum as ndi_sum, binary_erosion
import matplotlib.image
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap, LinearSegmentedColormap, to_rgba
from PIL import Image
from scipy.stats import pearsonr

folder_path = "/path/to/your/data/"

# Example grid tuning curve composition

In [None]:
# given a matched som grid corresponding domain number, remove the unconnected components smaller than a threshold
def connected_components(matched, threshold, return_largest=False):
    # Define the structure (connectivity) for connected components
    structure = np.ones((3, 3), dtype=int)  # 8-connectivity, meaning that both adjacent and diagonal pixels are considered as neighbors
    connected = np.copy(matched) # Initialize an array to store the output
    if return_largest: largest = np.copy(matched) # Initialize an array to store the largest connected component
    for class_num in np.unique(matched): # Process each class separately
        if class_num == 0: continue # Skip the background
        class_mask = (matched == class_num) # Create a binary mask for the current class
        labeled_array, num_features = label(class_mask, structure=structure) # Find connected components in the binary mask
        component_sizes = ndi_sum(class_mask, labeled_array, index=range(1, num_features + 1)) # Compute the size of each component
        for i, size in enumerate(component_sizes): # Zero out components smaller than the threshold
            if size < threshold: connected[labeled_array == (i + 1)] = 0
        if return_largest:
            for i, size in enumerate(component_sizes):
                if size != np.max(component_sizes): largest[labeled_array == (i + 1)] = 0
    if return_largest: return connected, largest
    else: return connected

In [None]:
# RSOM domain centrality relative to boundaries
rng = np.random.default_rng(0)
name = "RSOM"
# (60, 60) of 16 V4 voxel types, only connected components perserved
idx_rsom = np.load(folder_path + name + "/idx_newassign.npy").astype(int)
# simulation weight
weight_rsom = np.load(folder_path + name + "/weights.npy")[:, :, :50000] # (60, 60, 50000)
# V4 digital twin data
# load V4 digital twin benchmark
V4rsp = np.load(folder_path + "V4DT/PRsp.npy") # (50000, 128, 128)
roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128); (x, y switch) and transpose back to match the others
V4idx = np.load(folder_path + "V4DT/idx.npy") # (128, 128), v4 connected components
V4idx, _ = connected_components(V4idx, threshold=10, return_largest=True) # remove the unconnected components smaller than a threshold of 10, also return the largest connected component
# prepare the full training data, rsp to 50k images
train_all = np.zeros((np.sum(roi), V4rsp.shape[0]))
entry = 0
for i in range(roi.shape[0]):
    for j in range(roi.shape[1]):
        if roi[i, j] == 1:
            train_all[entry] = V4rsp[:, i, j]
            entry += 1
# create a 1d vector of all 3048 V4 columns' domain assignment
columns_domain = np.zeros((int(np.sum(roi)))) # the domain assignment for all 3048 V4 voxels
count = 0
for i in range(128):
    for j in range(128):
        if roi[i, j] == 1:
            columns_domain[count] = V4idx[i, j]
            count += 1
assert count == columns_domain.shape[0]
# args
device = "cpu"
top_img_num = 1000
epoch_num = 7500
learning_rate = 0.05
alpha = 0.01 # L1 regularization strength
top_column_num = 15
num_units_sample = 5
# define the linear regression model
class LinearRegressionModel(nn.Module):
    def __init__(self, num_rows):
        super(LinearRegressionModel, self).__init__()
        self.weights = nn.Parameter(torch.randn(num_rows, device=device)) # train from scratch
    def forward(self, rsp): return torch.sum(self.weights.view(-1, 1) * rsp, dim=0)
# identify connectivity and boundary
structure_conn = np.ones((3, 3), dtype=int)
structure_boundary = np.ones((3, 3), dtype=bool)
# statistics
central_homogeneous_all = []
boundary_homogeneous_all = []
# iterate through all domains
for domain_id in range(1, 17):
    mask = idx_rsom == domain_id
    if not np.any(mask):
        print(f"Domain {domain_id}: no units")
        continue
    labeled, num = label(mask, structure=structure_conn)
    if num == 0:
        print(f"Domain {domain_id}: no connected components")
        continue
    component_sizes = ndi_sum(mask, labeled, index=range(1, num + 1))
    largest_label = int(np.argmax(component_sizes) + 1)
    largest_mask = labeled == largest_label
    boundary_mask = largest_mask & (~binary_erosion(largest_mask, structure=structure_boundary, border_value=0))
    boundary_coords = np.argwhere(boundary_mask)
    interior_coords = np.argwhere(largest_mask & (~boundary_mask))
    print(
        f"Domain {domain_id}: largest component {largest_mask.sum()} units, {len(boundary_coords)} boundary, {len(interior_coords)} interior"
    )
    if len(boundary_coords) == 0 or len(interior_coords) == 0:
        print("  Skipping distance ranking (missing boundary or interior)")
        continue
    # distance between interior and boundary unit, the smallest variance yields central units
    distances = np.linalg.norm(interior_coords[:, None, :] - boundary_coords[None, :, :], axis=2)
    mean_dist = distances.mean(axis=1)
    var_dist = distances.var(axis=1)
    top_indices = np.argsort(var_dist)[:num_units_sample] # indices of central units with smallest variance
    # iterate through current domain identified central units
    central_homogeneous = []
    for i in top_indices:
        interior_x = int(interior_coords[i][0]) # row index of central units
        interior_y = int(interior_coords[i][1]) # col index of central units
        weight_target = weight_rsom[interior_x, interior_y, :]
        # if restricted to the top-1k responsive images out of 50k, prepare data
        if top_img_num != 50000:
            image_label = np.arange(50000) # 0indexed image names
            _, image_label = zip(*sorted(zip(weight_target, image_label))) # sort the mean responses (from small to large) and the image_label according to the order of mean responses
            image_label = np.flip(image_label[-top_img_num:]) # take selected most responsive images' 0index (from large to small)
            train = train_all[:, image_label] # select the most responsive images
            weight_target = weight_target[image_label] # select the most responsive images' mean responses
        weight_target = torch.tensor(weight_target, device=device, dtype=torch.float32) # target data
        train = torch.tensor(train.astype(np.float32), device=device) # training data of shape (3048, ...)
        # initialize the model, loss function, and optimizer
        model = LinearRegressionModel(num_rows=np.sum(roi)).to(device)
        criterion = nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=epoch_num, gamma=1) # no learning rate decay
        pbar = tqdm(range(epoch_num), desc="train...", disable=True)
        # train the model
        for epoch in pbar: # Forward pass: Compute predicted y by passing xs to the model
            y_pred = model(train)
            l1_regularization = alpha * torch.sum(torch.abs(model.weights)) # L1 LASSO penalty to encourage sparsity
            loss = criterion(y_pred, weight_target) + l1_regularization # the total loss
            optimizer.zero_grad() # Backward pass: Compute gradients
            loss.backward()
            optimizer.step() # Update weights
            scheduler.step() # Update learning rate
            pbar.set_postfix({"Loss": loss.item()})
        regression_w = model.weights.detach().cpu().numpy() # The trained weights (coefficients) of the linear regression model (3048,)
        final_loss = loss.item() # Final error measurement
        # homogeneous connection statistics
        regression_w, columns_domain_sorted = zip(*sorted(zip(regression_w, columns_domain), reverse=True))
        regression_w = np.array(regression_w)[:top_column_num]
        columns_domain_sorted = np.array(columns_domain_sorted).astype(int)[:top_column_num]
        homogeneous_w = 0
        for i in range(top_column_num):
            if columns_domain_sorted[i] == domain_id: homogeneous_w += regression_w[i]
        percentage = homogeneous_w / np.sum(regression_w)
        central_homogeneous.append(percentage)
        central_homogeneous_all.append(percentage)

    # iterate through current domain identified boundary units
    sampled_idx = rng.choice(len(boundary_coords), size=min(num_units_sample, len(boundary_coords)), replace=False)
    boundary_homogeneous = []
    for i in sampled_idx:
        boundary_x = int(boundary_coords[i][0]) # row index of boundary units
        boundary_y = int(boundary_coords[i][1]) # col index of boundary units
        weight_target = weight_rsom[boundary_x, boundary_y, :]
        # if restricted to the top-1k responsive images out of 50k, prepare data
        if top_img_num != 50000:
            image_label = np.arange(50000) # 0indexed image names
            _, image_label = zip(*sorted(zip(weight_target, image_label))) # sort the mean responses (from small to large) and the image_label according to the order of mean responses
            image_label = np.flip(image_label[-top_img_num:]) # take selected most responsive images' 0index (from large to small)
            train = train_all[:, image_label] # select the most responsive images
            weight_target = weight_target[image_label] # select the most responsive images' mean responses
        weight_target = torch.tensor(weight_target, device=device, dtype=torch.float32) # target data
        train = torch.tensor(train.astype(np.float32), device=device) # training data of shape (3048, ...)
        # initialize the model, loss function, and optimizer
        model = LinearRegressionModel(num_rows=np.sum(roi)).to(device)
        criterion = nn.MSELoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=epoch_num, gamma=1) # no learning rate decay
        pbar = tqdm(range(epoch_num), desc="train...", disable=True)
        # train the model
        for epoch in pbar: # Forward pass: Compute predicted y by passing xs to the model
            y_pred = model(train)
            l1_regularization = alpha * torch.sum(torch.abs(model.weights)) # L1 LASSO penalty to encourage sparsity
            loss = criterion(y_pred, weight_target) + l1_regularization # the total loss
            optimizer.zero_grad() # Backward pass: Compute gradients
            loss.backward()
            optimizer.step() # Update weights
            scheduler.step() # Update learning rate
            pbar.set_postfix({"Loss": loss.item()})
        regression_w = model.weights.detach().cpu().numpy() # The trained weights (coefficients) of the linear regression model (3048,)
        final_loss = loss.item() # Final error measurement
        # homogeneous connection statistics
        regression_w, columns_domain_sorted = zip(*sorted(zip(regression_w, columns_domain), reverse=True))
        regression_w = np.array(regression_w)[:top_column_num]
        columns_domain_sorted = np.array(columns_domain_sorted).astype(int)[:top_column_num]
        homogeneous_w = 0
        for i in range(top_column_num):
            if columns_domain_sorted[i] == domain_id: homogeneous_w += regression_w[i]
        percentage = homogeneous_w / np.sum(regression_w)
        boundary_homogeneous.append(percentage)
        boundary_homogeneous_all.append(percentage)

In [None]:
# visualize statistics
central_homogeneous_all = np.array(central_homogeneous_all)
boundary_homogeneous_all = np.array(boundary_homogeneous_all)
print("interior units homogeneous connection rate mean", np.mean(central_homogeneous_all), "and variance", np.var(central_homogeneous_all))
print("boundary units homogeneous connection rate mean", np.mean(boundary_homogeneous_all), "and variance", np.var(boundary_homogeneous_all))

bins = np.linspace(0, 1, 30)
plt.figure(figsize=(6, 4))
plt.hist(central_homogeneous_all, bins=bins, alpha=0.6, label="Interior Unit", color="tab:blue", density=True)
plt.hist(boundary_homogeneous_all, bins=bins, alpha=0.6, label="Boundary Unit", color="tab:orange", density=True)
plt.xlabel("Homogeneous connection rate")
plt.ylabel("Density")
plt.title("Interior vs Boundary homogeneous connections")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# LASSO Linear Regression and Visualization
# load V4 digital twin benchmark
V4rsp = np.load(folder_path + "V4DT/PRsp.npy") # (50000, 128, 128)
roi = np.load(folder_path + "V4DT/ROI.npy").T # (128, 128); (x, y switch) and transpose back to match the others
V4idx = np.load(folder_path + "V4DT/idx.npy") # (128, 128), v4 connected components
V4idx, _ = connected_components(V4idx, threshold=10, return_largest=True) # remove the unconnected components smaller than a threshold of 10, also return the largest connected component
V4domain_map = np.zeros_like(V4idx)
V4domain_map[roi == 1] = V4idx[roi == 1]
top9_0index_V4 = np.load(folder_path + "V4DT/top9_0index.npy") # (128, 128, 9)
voxels_domain = np.zeros((int(np.sum(roi)))) # the domain assignment for all 3048 V4 voxels
count = 0
fs = 12 # set font size for the axes
fs_title = 16 # set font size for the title
for i in range(128): # create a 1d vector of all 3048 V4 voxels' domain assignment
    for j in range(128):
        if roi[i, j] == 1:
            voxels_domain[count] = V4idx[i, j]
            count += 1
assert count == voxels_domain.shape[0]

# low_contrast_palette = ['#f5f8fb', '#d6e1ee', '#b1c7e0', '#7fa6c9', '#4c79a7']
low_contrast_palette = ['#d6e1ee', '#b1c7e0', '#7fa6c9', '#4c79a7']
som_low_contrast_cmap = LinearSegmentedColormap.from_list('som_low_contrast', low_contrast_palette)
som_low_contrast_cmap.set_bad('white')
v4_low_contrast_cmap = LinearSegmentedColormap.from_list('v4_low_contrast', low_contrast_palette)
v4_low_contrast_cmap.set_bad('white')

def upscale_grid(array, scale=3):
    if isinstance(array, np.ma.MaskedArray):
        return np.ma.repeat(np.ma.repeat(array, scale, axis=0), scale, axis=1)
    return np.repeat(np.repeat(array, scale, axis=0), scale, axis=1)

def get_domain_mask(domain_map, domain_id, min_area=10):
    mask = (domain_map == domain_id)
    if not np.any(mask):
        return np.zeros_like(domain_map, dtype=bool)
    mask_with_id = np.where(mask, domain_id, 0)
    cleaned = connected_components(mask_with_id, threshold=min_area, return_largest=False)
    return cleaned == domain_id

def create_boundary_overlay(domain_mask, color, scale=3):
    mask = domain_mask.astype(bool)
    if not np.any(mask):
        return np.zeros((domain_mask.shape[0] * scale, domain_mask.shape[1] * scale, 4))
    up_mask = upscale_grid(mask.astype(int), scale) > 0
    structure = np.ones((3, 3), dtype=bool)
    eroded = binary_erosion(up_mask, structure=structure, border_value=0)
    boundary = up_mask & (~eroded)
    overlay = np.zeros((boundary.shape[0], boundary.shape[1], 4))
    overlay[boundary] = to_rgba(color)
    return overlay

# load SOM simulation
top_img_num = 1000
epoch_num = 7500
learning_rate = 0.05
name = "RSOM"
regression_folder_name = "LASSO_top1k"
num_of_voxels_displayed = 15
weight = np.load(folder_path + name + "/weights.npy")[:, :, :50000] # (60, 60, 50000)
top9_0index = np.load(folder_path + name + "/rsptop_0index.npy") # (60, 60, 9)
idx = np.load(folder_path + name + "/idx_newassign.npy").astype(int) # (60, 60) of 16 V4 voxel types, only connected components perserved
som_domain_map = np.zeros((idx.shape[1] + 2, idx.shape[0] + 2), dtype=int)
som_domain_map[1:-1, 1:-1] = np.swapaxes(idx, 0, 1)

# loop through all domains from 1 to 16
idx_targets = [2] # domain numbers to be optimized, ###################### change this to your target domains
for idx_target in idx_targets:
    # SOM simulation response pattern to one largest connected component's preferred images
    SOMmask = np.where(idx == idx_target) # x, y 0index of the largest connected component
    preferred_imgs_som, imgs_count_som = np.unique(top9_0index[SOMmask[0], SOMmask[1], :].astype(int).flatten(), return_counts=True) # unique top-9 image 0index the largest connected domain prefers
    rsp = np.zeros((60, 60))
    for i in range(60):
        for j in range(60):
            rsp[i, j] = np.mean(weight[i, j, preferred_imgs_som] * imgs_count_som) # weighted sum of the preferred images' responses
    max_rsp_idx = np.where((rsp == np.max(rsp[idx == idx_target]))) # find the maximum within-domain rsp value's index
    # print("SOM current largest component's most preferred image:", len(preferred_imgs), idx[max_rsp_idx])

    # V4 digital twin response pattern to one largest connected component's preferred images
    V4mask = np.where(V4idx == idx_target) # x, y 0index of the largest connected component
    preferred_imgs_v4, imgs_count_v4 = np.unique(top9_0index_V4[V4mask[0], V4mask[1], :].astype(int).flatten(), return_counts=True) # unique top-9 image 0index the largest connected domain prefers
    # print("V4 digital twin current largest component preferred images:", len(preferred_imgs))
    rsp_V4 = np.zeros((128, 128))
    for i in range(128):
        for j in range(128):
            if roi[i, j] == 1: 
                rsp_V4[i, j] = np.mean(V4rsp[preferred_imgs_v4, i, j] * imgs_count_v4) # weighted sum of the preferred images' responses
            else: rsp_V4[i, j] = np.nan # unassigned voxels stay white in plots

    # whether to train the linear regression model further or not
    device = "cpu"
    # prepare the training data
    train = np.zeros((np.sum(roi), V4rsp.shape[0]))
    entry = 0
    for i in range(roi.shape[0]):
        for j in range(roi.shape[1]):
            if roi[i, j] == 1:
                train[entry] = V4rsp[:, i, j]
                entry += 1
    weight_target = weight[max_rsp_idx[0][0], max_rsp_idx[1][0], :50000]
    # if restricted to the top-1k responsive images out of 50k
    if top_img_num != 50000:
        image_label = np.arange(50000) # 0indexed image names
        _, image_label = zip(*sorted(zip(weight_target, image_label))) # sort the mean responses (from small to large) and the image_label according to the order of mean responses
        image_label = np.flip(image_label[-top_img_num:]) # take selected most responsive images' 0index (from large to small)
        train = train[:, image_label] # select the most responsive images
        weight_target = weight_target[image_label] # select the most responsive images' mean responses
    weight_target = torch.tensor(weight_target, device=device, dtype=torch.float32) # target data
    train = torch.tensor(train.astype(np.float32), device=device) # training data of shape (3048, ...)

    # Define the linear regression model
    class LinearRegressionModel(nn.Module):
        def __init__(self, num_rows):
            super(LinearRegressionModel, self).__init__()
            self.weights = nn.Parameter(torch.randn(num_rows, device=device)) # train from scratch
        def forward(self, rsp): return torch.sum(self.weights.view(-1, 1) * rsp, dim=0)

    # Initialize the model, loss function, and optimizer
    model = LinearRegressionModel(num_rows=3048).to(device)
    alpha = 0.01 # L1 regularization strength
    criterion = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=epoch_num, gamma=1) # no learning rate decay
    num_epochs = epoch_num # Training loop
    pbar = tqdm(range(num_epochs), desc="train...")
    for epoch in pbar: # Forward pass: Compute predicted y by passing xs to the model
        y_pred = model(train)
        l1_regularization = alpha * torch.sum(torch.abs(model.weights)) # L1 LASSO penalty to encourage sparsity
        loss = criterion(y_pred, weight_target) + l1_regularization # the total loss
        optimizer.zero_grad() # Backward pass: Compute gradients
        loss.backward()
        optimizer.step() # Update weights
        # with torch.no_grad(): model.weights.clamp_(min=0) # the non-negative weights constraint
        scheduler.step() # Update learning rate
        pbar.set_postfix({"Loss": loss.item()})
    weights_finetuned = model.weights.detach().cpu().numpy() # The trained weights (coefficients) of the linear regression model (3048,)
    final_loss = loss.item() # Final error measurement
        

    scale = 3
    fig, ax = plt.subplots(1, 3, figsize=(9, 3))
    # Subplot2: RSOM example grid weight composition
    regression_w, voxels_domain_sorted = zip(*sorted(zip(weights_finetuned, voxels_domain), reverse=True))
    regression_w = np.array(regression_w)
    voxels_domain_sorted = np.array(voxels_domain_sorted).astype(int)
    colors = ['black', 'red', 'green', 'blue', 'yellow', 'magenta', 'cyan', 'orange', 'purple', 'brown', 'pink', 'lime', 'teal', 'navy', 'gold', 'silver', 'coral'] # black for unassigned 0
    cmap_black = ListedColormap(colors)
    ax[1].bar(range(num_of_voxels_displayed), regression_w[:num_of_voxels_displayed], color=cmap_black(voxels_domain_sorted[:num_of_voxels_displayed]), width=1)
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    ax[1].set_xlabel('V4 columns', fontsize=fs)
    ax[1].set_ylabel('Weights', fontsize=fs)
    ax[1].spines['top'].set_visible(False)
    ax[1].spines['right'].set_visible(False)
    # Subplot1: RSOM simulation
    rsp_pad = np.full((62, 62), np.nan) # white outer rim
    rsp_pad[1:61, 1:61] = rsp
    rsp_pad = np.swapaxes(rsp_pad, 0, 1) # transposed
    rsp_up = upscale_grid(rsp_pad, scale)
    ax[0].imshow(rsp_up, cmap=som_low_contrast_cmap, interpolation='nearest')
    som_mask_target = get_domain_mask(som_domain_map, idx_target)
    overlay_target = create_boundary_overlay(som_mask_target, colors[idx_target], scale=scale)
    ax[0].imshow(overlay_target, interpolation='nearest')
    colors_to_add_switch = False
    if idx_target == 16:
        colors_to_add = [8]
        colors_to_add_switch = True
    if idx_target == 1:
        colors_to_add = [11, 12]
        colors_to_add_switch = True
    if idx_target == 5:
        colors_to_add = [7]
        colors_to_add_switch = True
    if idx_target == 15:
        colors_to_add = [16]
        colors_to_add_switch = True
    if idx_target == 4:
        colors_to_add = [10]
        colors_to_add_switch = True
    if colors_to_add_switch:
        for current_color in colors_to_add:
            som_mask_extra = get_domain_mask(som_domain_map, current_color)
            if np.any(som_mask_extra):
                overlay_extra = create_boundary_overlay(som_mask_extra, colors[current_color], scale=scale)
                ax[0].imshow(overlay_extra, interpolation='nearest')
    highlight_mask = np.zeros((62, 62))
    highlight_mask[max_rsp_idx[1][0] + 1, max_rsp_idx[0][0] + 1] = 1 # transposed with padding
    highlight_up = upscale_grid(highlight_mask, scale) > 0
    highlight_overlay = np.zeros((highlight_up.shape[0], highlight_up.shape[1], 4))
    highlight_overlay[highlight_up] = to_rgba('black')
    highlight_overlay[..., 3] = highlight_overlay[..., 3] * 0.7
    ax[0].imshow(highlight_overlay, interpolation='nearest')
    ax[0].axis('off')
    # Subplot3: V4 digital twin
    rsp_V4 = np.flipud(np.fliplr(rsp_V4))
    roi_display = np.flipud(np.fliplr(roi))
    masked_rsp_V4 = np.ma.array(rsp_V4, mask=(roi_display == 0))
    v4_slice = masked_rsp_V4[38:114, 38:105]
    ax[2].imshow(upscale_grid(v4_slice, scale), cmap=v4_low_contrast_cmap, interpolation='nearest') # low-contrast palette
    domain_mask_target = get_domain_mask(V4domain_map, idx_target)
    domain_map_display = np.flipud(np.fliplr(domain_mask_target))
    domain_slice = domain_map_display[38:114, 38:105]
    overlay_v4 = create_boundary_overlay(domain_slice, colors[idx_target], scale=scale)
    ax[2].imshow(overlay_v4, interpolation='nearest')
    colors_to_add_switch = False
    if idx_target == 16:
        colors_to_add = [8]
        colors_to_add_switch = True
    if idx_target == 1:
        colors_to_add = [11, 12]
        colors_to_add_switch = True
    if idx_target == 5:
        colors_to_add = [7]
        colors_to_add_switch = True
    if idx_target == 15:
        colors_to_add = [16]
        colors_to_add_switch = True
    if idx_target == 4:
        colors_to_add = [10]
        colors_to_add_switch = True
    if colors_to_add_switch:
        for current_color in colors_to_add:
            domain_mask_extra = get_domain_mask(V4domain_map, current_color)
            if np.any(domain_mask_extra):
                domain_mask_extra_display = np.flipud(np.fliplr(domain_mask_extra))[38:114, 38:105]
                overlay_extra_v4 = create_boundary_overlay(domain_mask_extra_display, colors[current_color], scale=scale)
                ax[2].imshow(overlay_extra_v4, interpolation='nearest')
    ax[2].axis('off')
    plt.tight_layout()
    fig = plt.gcf()
    fig.savefig(folder_path + name +  "/" + regression_folder_name + "/domain" + str(int(idx_target)) + "_" + str(int(epoch_num)) + "_T.png", dpi=1000)
    plt.close()

    # SOM domain preferred images visualization
    # sort imgs_count and preferred_imgs according to the descending order of imgs_count
    imgs_count_som, preferred_imgs_som = zip(*sorted(zip(imgs_count_som, preferred_imgs_som), reverse=True))
    imgs_num = 20
    col_num = 2
    img_size = 100
    line_width = 30
    imgs_visualized = np.zeros((int(imgs_num / col_num) * (img_size + line_width) - line_width, col_num * (img_size + line_width) - line_width, 3))
    for i in range(col_num):
        for j in range(int(imgs_num / col_num)):
            path = folder_path + "50K_Imgset/" + str(int(1 + preferred_imgs_som[int(i * (imgs_num / col_num) + j)])) + ".bmp" # the image name is 1-indexed
            img = np.array(Image.open(path))[20:80, 20:80, :] # obtain the non-blurred central part of the image
            img = resize(img, (img_size, img_size, 3), anti_aliasing=True) # resize the image
            imgs_visualized[j*(img_size+line_width) : j*(img_size+line_width)+img_size, 
                            i*(img_size+line_width) : i*(img_size+line_width)+img_size, 
                            :] = img
    for i in range(col_num): imgs_visualized[:, ((i+1)*img_size+i*line_width) : ((i+1)*(img_size+line_width)), :] = 1 # vertical intervals to be white
    for i in range(int(imgs_num / col_num) - 1): imgs_visualized[((i+1)*img_size+i*line_width) : ((i+1)*(img_size+line_width)), :, :] = 1 # horizontal intervals to be white
    plt.imshow(imgs_visualized)
    plt.axis('off')
    plt.close()
    save_path = folder_path + name +  "/" + regression_folder_name + "/domain" + str(int(idx_target)) + "_" + "SOMpreferred" + ".bmp"
    matplotlib.image.imsave(save_path, imgs_visualized)

    # V4 domain preferred images visualization
    # sort imgs_count and preferred_imgs according to the descending order of imgs_count
    imgs_count_v4, preferred_imgs_v4 = zip(*sorted(zip(imgs_count_v4, preferred_imgs_v4), reverse=True))
    imgs_num = 20
    col_num = 2
    img_size = 100
    line_width = 30
    imgs_visualized = np.zeros((int(imgs_num / col_num) * (img_size + line_width) - line_width, col_num * (img_size + line_width) - line_width, 3))
    for i in range(col_num):
        for j in range(int(imgs_num / col_num)):
            path = folder_path + "50K_Imgset/" + str(int(1 + preferred_imgs_v4[int(i * (imgs_num / col_num) + j)])) + ".bmp" # the image name is 1-indexed
            img = np.array(Image.open(path))[20:80, 20:80, :] # obtain the non-blurred central part of the image
            img = resize(img, (img_size, img_size, 3), anti_aliasing=True) # resize the image
            imgs_visualized[j*(img_size+line_width) : j*(img_size+line_width)+img_size, 
                            i*(img_size+line_width) : i*(img_size+line_width)+img_size, 
                            :] = img
    for i in range(col_num): imgs_visualized[:, ((i+1)*img_size+i*line_width) : ((i+1)*(img_size+line_width)), :] = 1 # vertical intervals to be white
    for i in range(int(imgs_num / col_num) - 1): imgs_visualized[((i+1)*img_size+i*line_width) : ((i+1)*(img_size+line_width)), :, :] = 1 # horizontal intervals to be white
    plt.imshow(imgs_visualized)
    plt.axis('off')
    plt.close()
    save_path = folder_path + name +  "/" + regression_folder_name + "/domain" + str(int(idx_target)) + "_" + "V4preferred" + ".bmp"
    matplotlib.image.imsave(save_path, imgs_visualized)

# Domain boundary columns / grids tuning curve correlation against other columns / grid across the boundary

In [None]:
name = "RSOM" # V4, RSOM, SOM
if name == "RSOM":
    # idx = np.load(folder_path + name + "/idx_newassign.npy").astype(int) # (60, 60) of 16 V4 voxel types
    idx = np.load(folder_path + name + "/assigned.npz")["matched_idx_rsom"].astype(int) # (60, 60) of 16 V4 voxel types
    size = idx.shape[0]
    assert size == idx.shape[1]
    roi = np.ones((size, size)).astype(int)
    rsp = np.load(folder_path + name + "/weights.npy")[:, :, :50000] # (60, 60, 50000)
elif name == "SOM":
    idx = np.load(folder_path + name + "/assigned.npz")["matched_idx_som"].astype(int) # (60, 60) of 16 V4 voxel types
    size = idx.shape[0]
    assert size == idx.shape[1]
    roi = np.ones((size, size)).astype(int)
    rsp = np.load(folder_path + name + "/weights.npy")[:, :, :50000] # (60, 60, 50000)
elif name == "V4":
    idx = np.load(folder_path + "V4DT/idx.npy").astype(int) # (128, 128) of 16 V4 voxel types
    size = idx.shape[0]
    assert size == idx.shape[1]
    roi = np.load(folder_path + "V4DT/ROI.npy").T
    rsp = np.load(folder_path + "V4DT/PRsp.npy") # (50000, 128, 128)
idx, _ = connected_components(idx, threshold=10, return_largest=True) # remove the unconnected components smaller than a threshold of 10, also return the largest connected component

for targeted_domain in range(1, 17): # loop through all 16 domains
    domain_mask = np.where(idx == targeted_domain) # targeted domain mask
    
    map = np.zeros((size, size))
    map[domain_mask[0], domain_mask[1]] = 1
    I = [] # target domain boundary grids' x coordinate
    J = [] # target domain boundary grids' y coordinate
    for i in range(size):
        for j in range(size):
            if map[i, j] == 1:
                try:
                    if map[i-1, j] == 0 or map[i+1, j] == 0 or map[i, j-1] == 0 or map[i, j+1] == 0 or map[i-1, j-1] == 0 or map[i-1, j+1] == 0 or map[i+1, j-1] == 0 or map[i+1, j+1] == 0:
                        I.append(i)
                        J.append(j)
                except: pass # out of boundary case is not considered, no meaning at all

    index_map = np.zeros((size, size)) # map the 2D grid to 1D index
    index = 0
    for i in range(size):
        for j in range(size):
            if roi[i, j] == 1:
                index_map[i, j] = index
                index += 1
    if name == "RSOM": assert index == 3600
    elif name == "SOM": assert index == 3600
    elif name == "V4": assert index == 3048

    assert len(I) == len(J)
    half_size = 5
    cor_in = []
    dis_in = []
    cor_out = []
    dis_out = []
    failure = 0
    for g in range(len(I)): # loop through all the boundary grids
        target_idx = int(index_map[I[g], J[g]]) # 1D index of the current boundary grid
        try:
            for x in range(I[g] - half_size, I[g] + half_size + 1): # another pairwise grid
                for y in range(J[g] - half_size, J[g] + half_size + 1):
                    if roi[x, y] == 0: continue
                    d = np.sqrt((I[g] - x) ** 2 + (J[g] - y) ** 2)
                    if name == "V4": c = pearsonr(rsp[:, I[g], J[g]], rsp[:, x, y])[0] # correlation
                    elif name == "RSOM": c = pearsonr(rsp[I[g], J[g], :], rsp[x, y, :])[0] # correlation
                    elif name == "SOM": c = pearsonr(rsp[I[g], J[g], :], rsp[x, y, :])[0] # correlation
                    if map[x, y] == 1: # both pairwise grids in the same domain
                        cor_in.append(c)
                        dis_in.append(d)
                    else: # the other pairwise grid is out of the domain
                        cor_out.append(c)
                        dis_out.append(d) # distance
        except: failure += 1 # out of map boundary case is not considered, no meaning at all

    cor_in_mean = []
    cor_in_std = []
    for d in np.unique(dis_in):
        i = np.where(np.array(dis_in) == d)[0]
        cor_in_mean.append(np.mean(np.array(cor_in)[i]))
        cor_in_std.append(np.std(np.array(cor_in)[i]) / np.sqrt(len(np.array(cor_in)[i]))) # standard error of the mean

    cor_out_mean = [1.0] # the correlation of the boundary grid with itself is 1
    cor_out_std = [0.0]
    for d in np.unique(dis_out):
        i = np.where(np.array(dis_out) == d)[0]
        cor_out_mean.append(np.mean(np.array(cor_out)[i]))
        cor_out_std.append(np.std(np.array(cor_out)[i]) / np.sqrt(len(np.array(cor_out)[i]))) # standard error of the mean

    plt.figure(figsize=(6, 6))
    if name == "RSOM": 
        label_within = "Unit pairs \nwithin domain"
        label_across = "across domain"
    elif name == "SOM":
        label_within = "Unit pairs \nwithin domain"
        label_across = "across domain"
    elif name == "V4":
        label_within = "Columns pairs \nwithin domain"
        label_across = "across domain"
    plt.errorbar(
        x=np.unique(dis_in),                 # X-axis positions
        y=cor_in_mean,                       # Mean values
        yerr=cor_in_std,                     # Error bars
        fmt='-o',                            # 'o' for dot markers
        ecolor='gray',                       # Error bar color
        capsize=3,                           # Add caps to error bars
        label=label_within
    )

    plt.errorbar(
        x=np.insert(np.unique(dis_out), 0, 0.0), # X-axis positions
        y=cor_out_mean,                           # Mean values
        yerr=cor_out_std,                         # Error bars
        fmt='-*',                                # 'o' for dot markers
        ecolor='gray',                           # Error bar color
        capsize=3,                               # Add caps to error bars
        label=label_across
    )

    # Labels and title
    plt.xlabel("Distance", fontsize=22)
    plt.ylabel("Tuning Correlation", fontsize=22)
    plt.legend(loc='lower left', fontsize=22)
    plt.grid(True)
    plt.tight_layout()
    if name == "RSOM": plt.savefig(folder_path + "RSOM/LASSO_top1k/RSOM_boundary_" + str(targeted_domain) + ".png", dpi=1000)
    elif name == "SOM": plt.savefig(folder_path + "RSOM/LASSO_top1k/SOM_boundary_" + str(targeted_domain) + ".png", dpi=1000)
    elif name == "V4": plt.savefig(folder_path + "RSOM/LASSO_top1k/V4_boundary_" + str(targeted_domain) + ".png", dpi=1000)
    plt.close()

# Figure synthesis

In [None]:
for domain in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16"]:    
    # To syhthesize figure from several subplots
    # Step 1: Load all subplots
    # domain = "5" # {1, 2, 5 for example domains}
    png_image = Image.open(folder_path + 'RSOM/LASSO_top1k/domain' + domain + '_7500_T.png')
    bmp_image_left = Image.open(folder_path + 'RSOM/LASSO_top1k/domain' + domain + '_SOMpreferred.bmp')
    bmp_image_right = Image.open(folder_path + 'RSOM/LASSO_top1k/domain' + domain + '_V4preferred.bmp')
    boundary = Image.open(folder_path + 'RSOM/LASSO_top1k/boundary_' + domain + '.png')
    boundary_V4 = Image.open(folder_path + 'RSOM/LASSO_top1k/V4_boundary_' + domain + '.png')
    # Step 2: Resize the images (optional) to ensure they have the same height: finding the maximum height to resize all images to the same height
    max_height = max(png_image.height, bmp_image_left.height, bmp_image_right.height)
    def resize_image(image, max_height):
        width, height = image.size
        if height != max_height: # Calculate the new width while keeping the aspect ratio
            new_width = int((max_height / height) * width)
            return image.resize((new_width, max_height))
        return image
    png_image = resize_image(png_image, max_height)
    bmp_image_left = resize_image(bmp_image_left, max_height)
    bmp_image_right = resize_image(bmp_image_right, max_height)
    boundary = resize_image(boundary, max_height)
    boundary_V4 = resize_image(boundary_V4, max_height)
    # Step 3: Create a new image with the width equal to the sum of the three images' widths
    total_width = bmp_image_left.width + boundary.width + png_image.width + boundary_V4.width + bmp_image_right.width
    concatenated_image = Image.new('RGB', (total_width, max_height))
    # Step 4: Paste each image next to each other in the new wide image
    x_offset = 0
    for img in [bmp_image_left, boundary, png_image, boundary_V4, bmp_image_right]:
        concatenated_image.paste(img, (x_offset, 0))
        x_offset += img.width
    # Step 5: Save the concatenated image
    concatenated_image.save(folder_path + 'RSOM/LASSO_top1k/composition_domain' + domain + '.png')

# resize domain example grid tuning composition map to lower resolution
domain = "2" # {1, 2, 5 for example domains}, change as needed
image = Image.open(folder_path + "RSOM/LASSO_top1k/composition_domain" + domain + ".png")
width, height = image.size
resized_image = image.resize((int(np.round(width/9)), int(np.round(height/9))), Image.Resampling.LANCZOS)
resized_image.save(folder_path + "Fig3/composition_domain" + domain + ".png")