In [2]:
import torch
from torchvision import models, transforms
from PIL import Image
import os

In [3]:
# Load pretrained AlexNet
alexnet = models.alexnet(pretrained=True)
alexnet.eval()

Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /Users/aim/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
10.1%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

22.6%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

38.1%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_l

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),           # Resize to AlexNet input
    transforms.ToTensor(),                   # Convert to tensor
    transforms.Normalize(                    # Normalize like ImageNet
        mean=[0.485, 0.456, 0.406],          
        std=[0.229, 0.224, 0.225]
    )
])

In [9]:
# Store activations in a dictionary
features = {}
handles = []

# Define hook
def get_activation(name):
    def hook(model, input, output):
        features[name] = output.detach().cpu()
    return hook

target_layers = {
    'conv1': 0,
    'conv2': 3,
    'fc6': 4,
    'fc7': 5,
}

for layer,ind in target_layers.items():
    handle = alexnet.features[ind].register_forward_hook(get_activation(layer))
    handles.append(handle)

In [15]:
for cat in ['same_label', 'same_image', 'different_label']:
    stim_dir = os.path.join('./mnist_stim', cat) 
    all_features = {cat: []}

    for fname in sorted(os.listdir(stim_dir)):
        if not fname.endswith('.png'):
            continue
        # MNIST is grayscale, so convert to RGB
        img = Image.open(os.path.join(stim_dir, fname)).convert("RGB")
        img_tensor = transform(img).unsqueeze(0)  # add batch dim

        with torch.no_grad():
            _ = alexnet(img_tensor)

        # Save features for this image
        feature_vec = {
            'stimulus': fname,
            'conv1': features['conv1'].flatten().numpy(),
            'conv2': features['conv2'].flatten().numpy(),
            'fc6': features['fc6'].flatten().numpy(),
            'fc7': features['fc7'].flatten().numpy()
        }

        all_features[cat].append(feature_vec)

    for h in handles:
        h.remove()