In [1]:
import torch
import torchvision


from model_loading import load_core_model
from model_loading import load_top_model

## Loading Pretrained Checkpoints
We provide code snippets for loading models used in the paper including:
 - Models trained for 100 epochs using either the MMCR, SimCLR or Barlow Twins objectives, and with the value of the key hyperparameter `lmda` swept over [0.0, 0.001, 0.1, 0.2, 0.3, 0.4, 0.5]. These are the core models used to produce Figures 2 and 3.
 - The model we pretrained for 1000 epochs using the Barlow Twins loss and `lmda = 0.2` which is referenced in Figure 4 and has the highest mean predictivity over the considered datasets on the Brain-Score leaderboard.


The supplied functions return standard ResNet-50s (instantiations of the `torchvision.models.resnet.resnet50` class) complete with ImageNet-1k linear classifier heads which were trained online during SSL pretraining with gradients detached.

In [2]:
# load a 'core model'
# note this function expexts objective_name to be one of ['Barlow', 'SimCLR', 'MMCR']
# and lmda to be one of [0.0, 0.001, 0.1, 0.2, 0.3, 0.4, 0.5]
core_model = load_core_model(
    objective_name="SimCLR",
    lmda=0.2,
)

# load the top performing model 
top_model = load_top_model()

In [3]:
print(core_model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:
# Example forward pass
x = torch.randn(1, 3, 224, 224)
logits = core_model(x)
print(logits.shape)  # should be [1, 1000]


torch.Size([1, 1000])
