In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl

import numpy as np
import scipy.stats as stats
from scipy.stats import pearsonr
from scipy.spatial.distance import euclidean
import math
from tqdm import tqdm


import torch

import os
from spacetorch.paths import *
from spacetorch.datasets import floc
from spacetorch.datasets import DatasetRegistry
from spacetorch.models.trunks.resnet_chanx import SOBasicBlock,SOResNet,SOCONV
from spacetorch.analyses.core import select_model

import h5py



device_id = 4 
torch.cuda.set_device(device_id)
device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu")

from torchvision.models.resnet import resnet18,ResNet18_Weights

In [None]:
def load_and_eval_model(model_name):
    model = SOResNet(SOBasicBlock, 18).to(device)
    path = select_model(model_name)
    if path=='ResNet-18':
        model = resnet18(ResNet18_Weights).to(device)
        model.eval()
    elif path=='Untrained':
        model = resnet18(weights=None).to(device)
        model.eval()
    else:
        checkpoint = torch.load(path, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'])
        model.eval().to(device)
    return model

CB_SOM_RN_18 = load_and_eval_model('CB_SOM_RN_18').to(device)
AB_SOM_RN_18 = load_and_eval_model('AB_SOM_RN_18').to(device)


In [None]:
from torch.utils.data import Dataset, DataLoader
from spacetorch.datasets.animacy_size import Animacy_Size,ANIMACY_SIZE_TRANSFORM
from spacetorch.paths import ANIMACY_SIZE_DIR
floc_loader = DataLoader(
        DatasetRegistry.get("fLoc"), batch_size=256, shuffle=True, num_workers=32, pin_memory=True
    )

animacy_size_loader = DataLoader(
        DatasetRegistry.get("animacy-size"), batch_size=30, shuffle=True, num_workers=8, pin_memory=True
    )

In [None]:
from collections import defaultdict
from spacetorch.datasets.floc import CATEGORIES
def floc_object(floc_loader):
    images_by_category = defaultdict(list)

    for idx in range(len(floc_loader.dataset)):
        image, target = floc_loader.dataset[idx]
        category = CATEGORIES[target]
        images_by_category[category].append(image)
    images_by_category = {category: torch.stack(images) for category, images in images_by_category.items()}
    return images_by_category


def floc_catagories(floc_loader):
    category_images = {}
    floc_object_category = floc_object(floc_loader)
    category_images["adult_child"] = torch.cat((floc_object_category["adult"], floc_object_category["child"]), dim=0)
    category_images["limb_body"] = torch.cat((floc_object_category["limb"], floc_object_category["body"]), dim=0)
    category_images["car_instrument"] = torch.cat((floc_object_category["car"], floc_object_category["instrument"]), dim=0)
    category_images["corridor_house"] = torch.cat((floc_object_category["corridor"], floc_object_category["house"]), dim=0)
    return category_images


In [None]:
from spacetorch.datasets.animacy_size import TYPES

def animacy_size_object(animacy_size_loader):
    object_animacy_size = defaultdict(list)

    for idx in range(len(animacy_size_loader.dataset)):
        image, target = animacy_size_loader.dataset[idx]
        category = TYPES[target]
        object_animacy_size[category].append(image)
    object_animacy_size = {category: torch.stack(images) for category, images in object_animacy_size.items()}
    return object_animacy_size

In [None]:
def category_animacy_size(animacy_size_loader):
    category_animacy_size = {}
    object_animacy_size = animacy_size_object(animacy_size_loader)
    category_animacy_size["Big"] = torch.cat((object_animacy_size["Big-Animate"], object_animacy_size["Big-Inanimate"]), dim=0)
    category_animacy_size["Small"] = torch.cat((object_animacy_size["Small-Animate"], object_animacy_size["Small-Inanimate"]), dim=0)
    category_animacy_size["Animate"] = torch.cat((object_animacy_size["Big-Animate"], object_animacy_size["Small-Animate"]), dim=0)
    category_animacy_size["Inanimate"] = torch.cat((object_animacy_size["Big-Inanimate"], object_animacy_size["Small-Inanimate"]), dim=0)
    return category_animacy_size

In [None]:
def categorize_conture_images():
    conture_segments = {}
    conture,_ = load_v4_images()
    bar_images = conture[:16]
    corner_images = conture[16:64]
    curve_images = conture[64:100]
    conture_segments["bar"] = bar_images
    conture_segments["corner"] = corner_images
    conture_segments["curve"] = curve_images
    return conture_segments
    

In [None]:
def get_activation(name, activation):        
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

activation = {}

def save_hook(net):
    for name, module in net.named_modules():
        module.register_forward_hook(get_activation(name, activation))
    

def get_closest_factors(num): 
    num_root = int(math.sqrt(num))
    while num % num_root != 0: 
        num_root -= 1
    return num_root, int(num / num_root)

def get_layers(net):
    layers = []
    for name, layer in net.named_modules():
        if isinstance(layer,(SOCONV,SOBasicBlock)):
            
            layers.append(name)
    return layers

def get_loc(net):
    loc = {}
    for name, layer in net.named_modules():
        if isinstance(layer,SOCONV):
            loc[name] = np.array(layer.neuron_locations())

def get_level(target):
    if target in ['adult_child', 'limb_body', 'car_instrument', 'corridor_house']:
        level = "category"
    elif target in ['adult', 'body', 'car', 'child', 'corridor', 'house', 'instrument', 'limb', 'number', 'scrambled', 'word']:
        level = "object"
    elif target == "Big-Small":
        level = "Size"
    elif target == "Animate_Inanimate":
        level = "Animate"
    elif target in ["corner","curve","bar"]:
        level = "SI"
    elif target == "curve_corner":
        level = "CVCNI"
    else:
        raise ValueError("Invalid category")
    return level


In [None]:
def make_dataloader(target):
    level = get_level(target)
    num_workers = 16
    if level == "category":
        floc_category = floc_catagories(floc_loader)
        target_dataloader = DataLoader(floc_category[target], batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)
        other_images_tensor = torch.cat([images for category, images in floc_category.items() if category != target], dim=0)
        other_dataloader = DataLoader(other_images_tensor, batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)    
    
    elif level == "object":
        floc_object_category = floc_object(floc_loader)
        target_dataloader = DataLoader(floc_object_category[target], batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)
        other_images_tensor = torch.cat([images for category, images in floc_object_category.items() if category != target], dim=0)
        other_dataloader = DataLoader(other_images_tensor, batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)   
    
    elif level == "Size":
        animacy_size = category_animacy_size(animacy_size_loader)
        target_dataloader = DataLoader(animacy_size["Big"], batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)
        other_dataloader = DataLoader(animacy_size["Small"], batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)   
   
    elif level == "Animate":
        animacy_size = category_animacy_size(animacy_size_loader)

        target_dataloader = DataLoader(animacy_size["Animate"], batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)
        other_dataloader = DataLoader(animacy_size["Inanimate"], batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)
    
    elif level == "SI":
        conture_segments = categorize_conture_images()
        target_dataloader = DataLoader(conture_segments[target], batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)
        other_images_tensor = torch.cat([images for category, images in conture_segments.items() if category != target], dim=0)
        other_dataloader = DataLoader(other_images_tensor, batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)   
   
    elif level == "CVCNI":
        conture_segments = categorize_conture_images()
        target_dataloader = DataLoader(conture_segments['curve'], batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)
        other_dataloader = DataLoader(conture_segments['corner'], batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)  
        
    return target_dataloader,other_dataloader

In [None]:
def get_activation_list(net, data_loader,target_layers=None):
    if target_layers == None:
        layers = get_layers(net)
    else:
        layers = target_layers
    all_activation = {}
    
    with tqdm(enumerate(data_loader), total=len(data_loader), desc="Processing Activations") as pbar:
        for i, data in pbar:
            _ = net(data.to(device))
            if len(all_activation) == 0: 
                for k in activation.keys():
                    if k in layers:
                        all_activation[k] = activation[k].cpu().numpy()
            else: 
                for k in activation.keys():
                    if k in layers:
                        all_activation[k] = np.concatenate([all_activation[k], activation[k].cpu().numpy()])

            pbar.set_postfix({"Processed": f"{i+1}/{len(data_loader)}"})

    return all_activation

In [None]:
def CVCNI_selectivity(net, target_dataloader, other_dataloader, layer):

    target_act = get_activation_list(net, target_dataloader)

    other_act = get_activation_list(net, other_dataloader)
    
    kernel_x, kernel_y = get_closest_factors(target_act[layer].shape[1])
    
    max_resp_target,_ = torch.from_numpy(target_act[layer]).max(0)

    max_resp_target =  max_resp_target.reshape(kernel_x, kernel_y, -1).mean(2)

    max_resp_other,_ = torch.from_numpy(other_act[layer]).max(0)

    max_resp_other = max_resp_other.reshape(kernel_x, kernel_y, -1).mean(2)
        

    cvcni = ((max_resp_target - max_resp_other) /(max_resp_target + max_resp_other))
    
        
    return cvcni

In [None]:
from scipy.stats import ttest_ind

def index_selectivity(net, target_dataloader, other_dataloader, layer, thre=0.4, mode='dprime'):
    target_act = get_activation_list(net, target_dataloader)
    other_act = get_activation_list(net, other_dataloader)
    kernel_x, kernel_y = get_closest_factors(target_act[layer].shape[1])

    if mode == "dprime" or mode == "discrete":
        target_mean = torch.from_numpy(target_act[layer]).mean(0).reshape(kernel_x, kernel_y, -1)
        no_target_mean = torch.from_numpy(other_act[layer]).mean(0).reshape(kernel_x, kernel_y, -1)
        target_variance = torch.from_numpy(target_act[layer]).var(0).reshape(kernel_x, kernel_y, -1)
        no_target_variance = torch.from_numpy(other_act[layer]).var(0).reshape(kernel_x, kernel_y, -1)

        dprime = ((target_mean - no_target_mean) / torch.sqrt((target_variance + no_target_variance) / 2.0)).mean(2)
        if mode == "dprime":
            selectivity = dprime
        elif mode == "discrete":
            selectivity = np.where(dprime >= thre, 1, 0)
    elif mode == "t-test":
        statistic, p_values = ttest_ind(target_act[layer], other_act[layer], equal_var=False,nan_policy='raise')
        selectivity = statistic.reshape(kernel_x, kernel_y, -1).mean(-1)
    else:
        raise ValueError("Invalid mode, select: 't-test' , 'dprime' , discrete ")
        
    return selectivity


In [None]:
def map_to_category(target_level):
    if target_level == 'adult_child':
        category = 'Face'
    elif  target_level =='limb_body':
        category = 'Body'
    
    elif target_level == 'car_instrument':
        category = 'Object'

    elif target_level == 'corridor_house':
        category = 'Place'
    elif target_level == 'Animate_Inanimate':
        category = 'Animacy'
    elif target_level == 'Big-Small':
        category = 'Size'

    elif target_level == 'angles':
        category = 'Preferred Orientation'
    elif target_level == 'sfs':
        category = 'Spatial Frequency'
    else:
        category = target_level

    return category

    

In [None]:
def category_selectivity_idx(net,layer,target_category,**kwargs):
    mode = kwargs.get('mode') 
    thre = kwargs.get('thre')
    print(mode)
    save_hook(net)
    target_dataloader, other_dataloader = make_dataloader(target_category)
    dprime = index_selectivity(net,target_dataloader,other_dataloader,layer,mode=mode,thre=thre)
    return dprime

In [None]:
import scipy.ndimage as nd

def plot_dprime(model_type, layer, target_category, cmap='plasma', save=False, format='png', smoothed=True, **kwargs):
    net = load_and_eval_model(model_type).to(device)
    selectivity = category_selectivity_idx(net, layer, target_category, **kwargs)
    if smoothed:
        selectivity = nd.gaussian_filter(selectivity, sigma=0.5)
    category = map_to_category(target_level=target_category)

    fig, ax = plt.subplots()  # Create a figure and a set of subplots

    vmin_rounded = math.floor((selectivity.min()) * 10) / 10
    vmax_rounded = math.ceil((selectivity.max()) * 10) / 10
    cax = ax.matshow(selectivity, cmap=cmap,vmin= vmin_rounded,vmax = vmax_rounded )  # Use the ax object for matshow


    if save:
        save_path = f"{FIGURE_DIR}/Manuscripts/Category_selectivity/{model_type}/{category}"
        os.makedirs(save_path, exist_ok=True)
        print(save_path)

        # Adjust colorbar size to match ax
        cbar = fig.colorbar(cax, ax=ax, fraction=0.046, pad=0.04)

        cbar.ax.tick_params(labelsize=12)
        ax.axis('off')
        plt.savefig(f'{save_path}_{layer}.{format}', format=format)
        plt.show()
        plt.clf()

    else:
        # Adjust colorbar size to match ax
        cbar = fig.colorbar(cax, ax=ax, fraction=0.046, pad=0.04)

        ax.set_title(f'{category}_{layer}')
        ax.axis('off')
        plt.show()


       

In [None]:
plot_dprime('ResNet-18','layer4.1.conv1','adult_child')

In [None]:
from spacetorch.utils.array_utils import lower_tri, midpoints_from_bin_edges,flatten

def delta_selectivity(dprime,metric=None):
    dprime = dprime.reshape(-1)
    
    pairwise_differences =  []
    for x in range (dprime.shape[0]):
        for y in range (x+1,dprime.shape[0]):
            difference = abs(dprime[x] - dprime[y].item())
            pairwise_differences.append(difference)

    pairwise_differences_torch = torch.tensor(pairwise_differences)

    if metric == 'angles':
        pairwise_differences_torch[np.where(pairwise_differences_torch >= 90)] = 180 - pairwise_differences_torch[np.where(pairwise_differences_torch >= 90)]
    
    return pairwise_differences_torch

In [None]:
from scipy.spatial.distance import cdist
from utilities import neuron_locations

def normalized_random_shuffeling(dim1,dim2,pairwise_differences_torch,normalize=True):
    distances = torch.Tensor(list(neuron_locations(F1=dim1,F2=dim2)))
    distances_kernels = cdist(distances, distances, metric='euclidean')
    distances_kernels = distances_kernels[np.triu_indices(len(distances), k=1)]
    unique_distances = np.unique(distances_kernels)
    selectivity_per_distance = [pairwise_differences_torch[distances_kernels == dist].mean() for dist in unique_distances]
    if normalize:
        copy_difference_per_distance = pairwise_differences_torch
        shuffled_tensors = torch.empty((copy_difference_per_distance.size()[0], 0))
        

        for _ in range(1000):

            shuffled_tensor = copy_difference_per_distance[torch.randperm(copy_difference_per_distance.size()[0])]
            shuffled_tensor = shuffled_tensor.unsqueeze(1)  
            shuffled_tensors=torch.cat((shuffled_tensors,shuffled_tensor),dim=1)


        shuffled_tensor = torch.mean(shuffled_tensor,1)

        shuffled_tensor = [shuffled_tensor[distances_kernels == dist].mean() for dist in unique_distances]

        normalized_tensor = [corr / shuffled for corr, shuffled in zip(selectivity_per_distance, shuffled_tensor)]
    else:
        normalized_tensor = selectivity_per_distance

    return normalized_tensor,unique_distances

        

    

In [None]:
from matplotlib.font_manager import FontProperties

def pairwise_selectivity_distance(model_name,layer_name='layer4.1.soconv1',target_category='adult_child',step=4):

    model = load_and_eval_model(model_name)

    dprime = category_selectivity_idx(model,layer=layer_name,target_category=target_category,mode='dprime')
    
    pairwise_differences_torch = delta_selectivity(dprime)
    

    normalized_tensor,unique_distances = normalized_random_shuffeling(dim1=dprime.shape[0],dim2=dprime.shape[1],pairwise_differences_torch=pairwise_differences_torch)

    selected_distances = unique_distances[::step]
    selected_normalized_tensor = normalized_tensor[::step]
    font = FontProperties(family='New Times Roman',size="large",weight='bold')

    plt.plot(selected_distances, selected_normalized_tensor, marker='o', linestyle='-')
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.title(f'model: {model_name}_{target_category}_selectivity')
    plt.xlabel('Pair-wise Euclidean Distance', fontsize=12,fontproperties=font)
    plt.ylabel(r"$\Delta$ Selectivity" "\n" "(vs. Chance)", fontsize=12,fontproperties=font)
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.tight_layout()
    plt.ylim(0,2)
    plt.show()

In [None]:
pairwise_selectivity_distance('CB_SOM_RN_18',layer_name='layer4.0.soconv1')