In [19]:
import torch

# Load the pretrained Barlow Twins ResNet-50 model
model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
model.eval()  # Set the model to evaluation mode


Using cache found in /data/shared/cache/torch/hub/facebookresearch_barlowtwins_main


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [28]:
device = 'cpu'
layer = 2

x = torch.Tensor(1,3,224,224)
activations = {}  # Global dict to store the activation
x = x.to(device)
"""Forward pass through the network, extracting activations from specific layer"""
activation = {}

def get_activation(name):
    def hook(model, input, output):
        activation[name] = output
    return hook

# Register forward hook on Block #5 (indexing starts at 0)
match layer:
    case 1:
        model.layer1.register_forward_hook(get_activation(f"layer{layer}"))
    case 2:
        model.layer2.register_forward_hook(get_activation(f"layer{layer}"))
    case 3:
        model.layer3.register_forward_hook(get_activation(f"layer{layer}"))
    case 4:
        model.layer4.register_forward_hook(get_activation(f"layer{layer}"))  

output = model(x).to(device)
x = activation[f"layer{layer}"]       
print(x.shape)    


torch.Size([1, 512, 28, 28])


In [26]:
import torch
import torch.nn as nn
from torchvision.models import resnet50

# Load the pretrained Barlow Twins model (ResNet-50 backbone)
model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
model.eval()

# Dictionary to store extracted features
features = {}

# Hook function to store feature maps
def get_features(name):
    def hook(model, input, output):
        features[name] = output  # Save feature map
    return hook

# Register hooks at the end of each residual block (before FC layer)
model.layer1.register_forward_hook(get_features('layer1'))
model.layer2.register_forward_hook(get_features('layer2'))
model.layer3.register_forward_hook(get_features('layer3'))
model.layer4.register_forward_hook(get_features('layer4'))  # Last residual block

# Forward pass with a dummy image
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    _ = model(dummy_input)

# Print feature shapes to verify extraction
for layer, fmap in features.items():
    print(f"{layer} feature map shape: {fmap.shape}")


Using cache found in /data/shared/cache/torch/hub/facebookresearch_barlowtwins_main


layer1 feature map shape: torch.Size([1, 256, 56, 56])
layer2 feature map shape: torch.Size([1, 512, 28, 28])
layer3 feature map shape: torch.Size([1, 1024, 14, 14])
layer4 feature map shape: torch.Size([1, 2048, 7, 7])
