In [31]:
### libs
import numpy as np
import pickle 
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from scipy import stats

In [32]:
### custom funcs
# Prepare hooks to capture activations
activations_alexnet = {}
def get_activation(name, activations_dict):
    def hook(model, input, output):
        activations_dict[name] = output.detach()
    return hook

# load and preprocess images
def load_image(image_path):
     # ensure image is in RGB format
    img = Image.open(image_path).convert('RGB') 
    img = preprocess(img)
    return img.unsqueeze(0)  # add batch dimension

# process and collect activations for each image
def extract_activations(image_paths, model, activations_dict):
    activs = {"0" : [],
              "1" : []}
    for image_path in image_paths:
        image = load_image(image_path)
        model(image)
        for key in activs.keys():
            activs[key].append(activations_dict[key].detach().numpy().flatten())
    return activs

def min_max_normalize(array):
    min_vals = array.min(axis=0)
    max_vals = array.max(axis=0)
    normalized_array = (array - min_vals) / (max_vals - min_vals)
    return normalized_array

def z_score_normalize(array):
    mean_vals = array.mean(axis=0)
    std_vals = array.std(axis=0)
    normalized_array = (array - mean_vals) / std_vals
    return normalized_array

def get_rdms(actD):
    rdms = []  
    for key in actD:
        arr = np.array(actD[key])
        norm_arr = z_score_normalize(arr) 
        rdms.append(np.corrcoef(arr))
        
    return rdms

def upper(df):
    try:
        assert(type(df)==np.ndarray)
    except:
        if type(df)==pd.DataFrame:
            df = df.values
        else:
            raise TypeError('Must be np.ndarray or pd.DataFrame')
    mask = np.triu_indices(df.shape[0], k=1)
    return df[mask]

def dump_data(data, filename):
    print('writing file: ' + filename)
    with open(filename, 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

In [33]:
### load paths
category = "elephant"
sim_sim = np.load("{}_sim_sim.npy".format(category))
sim_sim_var = np.load("{}_sim_sim_var.npy".format(category))
dissim_dissim = np.load("{}_dissim_dissim.npy".format(category))

In [34]:
### load model and attach hooks

# load the pretrained AlexNet model
alexnet = models.alexnet(pretrained=True)

# register hooks for desired layers in AlexNet
alexnet.features[3].register_forward_hook(get_activation('0', activations_alexnet))
alexnet.classifier[4].register_forward_hook(get_activation('1', activations_alexnet))


<torch.utils.hooks.RemovableHandle at 0x7ff3d95ba850>

In [35]:
### extract 
# define image preprocessing pipeline 
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

correlations = []
data_dict = {"paths": [],
             "activations": []}
### example extraction loop
for img_path in sim_sim:
    data_dict["paths"].append(img_path)
    
    # Get a flattened version of the activations
    flattened_activations = extract_activations(img_path, alexnet, activations_alexnet)
    data_dict["activations"].append(flattened_activations)
    
    rdms = get_rdms(flattened_activations)
    correlation, _ = stats.spearmanr(upper(rdms[0]), upper(rdms[1]))
    correlations.append(correlation)

  normalized_array = (array - mean_vals) / std_vals


In [36]:
### you can decide what format of the data you want to save
dump_data(data_dict, "{}_raw_activations_alexnet.pkl".format(category))
np.save("{}_spearmanns_alexnet.npy".format(category), correlations)

writing file: elephant_raw_activations_alexnet.pkl


In [16]:
### validate that sizes are correct based on the layers you've chosen 
input_tensor = torch.randn(1, 3, 224, 224) 
x = input_tensor
print("Input:", x.shape)

for i, layer in enumerate(alexnet.features):
    x = layer(x)
    print(f"Output of layer {i} ({layer.__class__.__name__}): {x.detach().numpy().flatten().shape}")

# passing through the classifier part
# flatten the output of the conv layers to pass to the fully connected layers
x = torch.flatten(x, 1)  
print()
print(f"After flattening: {x.shape}")

for i, layer in enumerate(alexnet.classifier):
    x = layer(x)
    print(f"Output of classifier layer {i} ({layer.__class__.__name__}): {x.detach().numpy().flatten().shape}")


Input: torch.Size([1, 3, 224, 224])
Output of layer 0 (Conv2d): (193600,)
Output of layer 1 (ReLU): (193600,)
Output of layer 2 (MaxPool2d): (46656,)
Output of layer 3 (Conv2d): (139968,)
Output of layer 4 (ReLU): (139968,)
Output of layer 5 (MaxPool2d): (32448,)
Output of layer 6 (Conv2d): (64896,)
Output of layer 7 (ReLU): (64896,)
Output of layer 8 (Conv2d): (43264,)
Output of layer 9 (ReLU): (43264,)
Output of layer 10 (Conv2d): (43264,)
Output of layer 11 (ReLU): (43264,)
Output of layer 12 (MaxPool2d): (9216,)

After flattening: torch.Size([1, 9216])
Output of classifier layer 0 (Dropout): (9216,)
Output of classifier layer 1 (Linear): (4096,)
Output of classifier layer 2 (ReLU): (4096,)
Output of classifier layer 3 (Dropout): (4096,)
Output of classifier layer 4 (Linear): (4096,)
Output of classifier layer 5 (ReLU): (4096,)
Output of classifier layer 6 (Linear): (1000,)
