In [3]:
import os
import numpy as np
import pandas as pd

from PIL import Image
from collections import OrderedDict

import torch
import torchvision.models as models
import torchvision.transforms as T


if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
else:
    print(f'you are not using a gpu (or your cuda install is messed up). fix it.')

NVIDIA A100-SXM4-40GB MIG 3g.20gb


In [2]:
# load pretrained alexnet (imagenet)
model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
model.eval()

# standard imagenet preprocessing
transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /n/home12/amarvi/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


100%|██████████| 233M/233M [00:04<00:00, 58.8MB/s] 


In [4]:
def load_images(paths):
    # returns tensor [n_images, 3, 224, 224]
    imgs = [transform(Image.open(p).convert('RGB')) for p in paths]
    return torch.stack(imgs)

def get_activations(model, x):
    activations = OrderedDict()
    hooks = []

    def hook_fn(name):
        def hook(module, inp, out):
            # detach so autograd doesn't eat your ram
            activations[name] = out.detach()
        return hook

    # register hooks on all submodules
    for name, module in model.named_modules():
        if name:  # skip the top-level container
            hooks.append(module.register_forward_hook(hook_fn(name)))

    with torch.no_grad():
        _ = model(x)

    for h in hooks:
        h.remove()

    return activations

In [10]:
# load in NSD and localizer images (total: 1072 images)
IMG_DIR = '../../datasets/NNN/NSD1000_LOC'
IMG_PATHS = [os.path.join(IMG_DIR, img) for img in os.listdir(IMG_DIR) if '.tsv' not in img]
print(f'Total number of images found: {len(IMG_PATHS)}')

img_tensor = load_images(IMG_PATHS)
acts = get_activations(model, img_tensor)

Total number of images found: 1072


In [52]:
# FOR REFERENCE, ALEXNET LAYERS + ACTUAL ACTs
for name, module in model.named_modules():
    if '.' in name:
        print(f'{name:<15} | {module.__class__.__name__:>15}')
        
print('\n\n', '='*50, '\n\n')
        
for layer, out in acts.items():
    print(f'{layer:<15} | {str(out.shape):>35}')

features.0      |          Conv2d
features.1      |            ReLU
features.2      |       MaxPool2d
features.3      |          Conv2d
features.4      |            ReLU
features.5      |       MaxPool2d
features.6      |          Conv2d
features.7      |            ReLU
features.8      |          Conv2d
features.9      |            ReLU
features.10     |          Conv2d
features.11     |            ReLU
features.12     |       MaxPool2d
classifier.0    |         Dropout
classifier.1    |          Linear
classifier.2    |            ReLU
classifier.3    |         Dropout
classifier.4    |          Linear
classifier.5    |            ReLU
classifier.6    |          Linear




features.0      |      torch.Size([1072, 64, 55, 55])
features.1      |      torch.Size([1072, 64, 55, 55])
features.2      |      torch.Size([1072, 64, 27, 27])
features.3      |     torch.Size([1072, 192, 27, 27])
features.4      |     torch.Size([1072, 192, 27, 27])
features.5      |     torch.Size([1072, 192, 1