In [None]:
import torch
import numpy as np
import os
import glob
from brainscore_vision import load_model

In [2]:
models = {
  'alexnet': ['features.2', 'features.7', 'features.7', 'features.12'],
  'deit_base_imagenet_full_seed-0': ['blocks.3.mlp.fc1', 'blocks.8.norm2', 'blocks.3.mlp.act', 'blocks.9.norm2'],
  'resnet-50-robust': ['layer3.0.downsample.0', 'layer4.0.downsample.0', 'layer3.0.downsample.0', 'layer4.0.downsample.0'],
  'deit_large_imagenet_full_seed-0': ['blocks.4.norm1', 'blocks.18.norm2', 'blocks.9.norm1', 'blocks.20.norm2'],
  'resnet152_imagenet_full': ['layer1.0.bn1', 'layer3.3.bn3', 'layer3.0.bn3', 'layer3.34.bn3'],
  'resnext101_32x32d_wsl': ['layer1.0.relu', 'layer3.0.relu', 'layer2.0.relu', 'layer3.21.relu'],
  'convnext_small_imagenet_100_seed-0': ['features.5.2.block.0', 'features.5.17.block.0', 'features.4.0', 'features.5.9.block.0'],
  'convnext_small_imagenet_10_seed-0': ['features.5.2.block.0', 'features.5.17.block.0', 'features.4.0', 'features.5.9.block.0'],
  'resnext101_32x48d_wsl': ['layer2.2.relu', 'layer3.0.relu', 'layer2.0.relu', 'layer3.20.relu'],
  'resnet50_ecoset_full': ['layer1.0.bn1', 'layer4.0.conv2', 'layer3.0.conv1', 'layer4.0.relu'],
  'resnet50_imagenet_100_seed-0': ['layer1.0.conv1', 'layer3.5.bn3', 'layer3.0.conv1', 'layer4.0.relu'],
  'resnet101_ecoset_full': ['layer1.0.bn1', 'layer3.4.relu', 'layer3.0.bn3', 'layer4.0.relu'],
  'resnext101_32x8d_wsl': ['layer2.3.relu', 'layer3.4.relu', 'layer2.1.relu', 'layer3.3.relu'],
  'convnext_small_imagenet_full_seed-0': ['features.5.2.block.0', 'features.5.17.block.0', 'features.4.0', 'features.5.9.block.0'],
  'convnext_tiny_imagenet_full_seed-0': ['features.6.0', 'features.5.4.block.0', 'features.4.0', 'features.5.4.block.0'],
  'deit_small_imagenet_100_seed-0': ['blocks.2.norm1', 'blocks.6.norm2', 'blocks.5.norm1', 'blocks.9.norm2'],
  'convnext_base_imagenet_full_seed-0': ['features.5.7.block.0', 'features.5.12.block.0', 'features.4.0', 'features.5.11.block.0'],
  'resnet50_tutorial': ['layer2', 'layer2', 'layer2', 'layer3'],
  'resnet101_imagenet_full': ['layer1.0.bn1', 'layer4.0.bn1', 'layer3.0.bn3', 'layer4.0.relu'],
  'convnext_large_imagenet_full_seed-0': ['features.5.7.block.5', 'features.5.7.block.0', 'features.4.1', 'features.5.11.block.0'],
  'resnet50_imagenet_full': ['layer1.0.conv1', 'layer3.5.bn3', 'layer3.0.conv1', 'layer4.0.relu'],
  'resnet18_imagenet_full': ['layer1.0.bn1', 'layer3.0.conv2', 'layer2.0.bn2', 'layer4.0.bn1'],
  'resnet152_ecoset_full': ['layer1.0.bn1', 'layer3.3.bn3', 'layer3.0.bn3', 'layer4.0.relu'],
  'resnet18_ecoset_full': ['layer1.0.conv1', 'layer3.0.conv1', 'layer2.0.bn2', 'layer4.0.bn1'],
  'resnet-152_v2_pytorch': ['avgpool', 'layer4.1.relu', 'layer4.1.relu', 'layer4.1.bn2'],
  'resnet34_ecoset_full': ['layer1.0.bn1', 'layer3.1.conv1', 'layer3.0.conv1', 'layer4.0.conv1'],
  'resnet18_imagenet21kP': ['layer2.0.relu', 'layer2.0.relu', 'layer2.0.relu', 'layer4.0.relu'],
  'deit_small_imagenet_full_seed-0': ['blocks.2.norm1', 'blocks.6.norm2', 'blocks.5.norm1', 'blocks.9.norm2']
}

In [3]:
imagenet = f'{ os.getcwd() }/data/imagenet/val'
stimuli = glob.glob(os.path.join(imagenet, '**'), recursive=True)
stimuli = [ stimulus for stimulus in stimuli if os.path.isfile(stimulus) ]

In [4]:
output = f'{ os.getcwd() }/data/output/deit-base'

In [5]:
def extract(model, layers, batch_size = 2000):
  model = load_model(model)
  act_model = model.activations_model
  output = f'{ os.getcwd() }/data/output/{model.identifier}'
  counter = 1
  for batch_start in range(0, len(stimuli), batch_size):
    # Get the current batch of stimuli
    batch = stimuli[batch_start:batch_start + batch_size]
    print(f'{counter} - running the model ({batch_start})...')
    # Extract model activations for the batch
    model_extraction = act_model(batch, layers)
    print('building the npz...')
    npz_data = {
        layer: model_extraction.sel(neuroid=model_extraction['layer'] == layer).data
        for layer in layers
      }
    print('persisting the data...')
    np.savez(os.path.join(output, '{}-{}.npz'.format(model.identifier, counter)), **npz_data)

    # Increment the counter
    counter += 1

In [None]:
for model, layers in models.items():
    extract(model, layers)