## Import the neccessary libraries

In [30]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

## Using the CIFAR10 dataset from PyTorch and using only a subset for a faster pass through the model.

* **Disclaimer:** Please note that here my batch size and length of my dataset are the same, and I did this intentionally for this example. However, in most situations this will not be the case. In those instances, make sure to get save the results from each batch to a list and concatenate at the end to get the entire dataset.

In [31]:
BATCH = 1000

transform = transforms.Compose([
    transforms.ToTensor()
])

fix_set = list(range(0, 1000))

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_dataset = torch.utils.data.Subset(train_dataset, fix_set)
print(len(train_dataset))
test_dataset = torch.utils.data.Subset(test_dataset, fix_set)
print(len(test_dataset))

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH, shuffle=False)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified
1000
1000


## Loading the ResNet50

In [32]:
resnet50 = models.resnet50(pretrained=True)
resnet50.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

## Annotating the ResNet50 architecture to get a better sense of the layers

In [33]:
for i, x in resnet50.named_children():
    print(f'NAME: \n {i}')
    print(f'CONTENT: \n {x}')

NAME: 
 conv1
CONTENT: 
 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
NAME: 
 bn1
CONTENT: 
 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
NAME: 
 relu
CONTENT: 
 ReLU(inplace=True)
NAME: 
 maxpool
CONTENT: 
 MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
NAME: 
 layer1
CONTENT: 
 Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2

## Creating a function to extract the activations

In [34]:
embeds = {}
def get_activation(name):
    def hook(model, input, output):
        embeds[name] = output
    return hook

## Going through the children of ResNet50 and picking the sequential blocks (layer 1, 2, 3 and 4). Then going through the bottleneck blocks and attaching a hook in layers called conv1, conv2, and conv3 to extract the activations

In [35]:
for name, layer in resnet50.named_children():
    if name in ['layer1', 'layer2', 'layer3', 'layer4']:
        for i, bottleneck in enumerate(layer.children()):
            hook1 = bottleneck.conv1.register_forward_hook(get_activation(f'conv1_{name}_{i}'))
            hook2 = bottleneck.conv2.register_forward_hook(get_activation(f'conv2_{name}_{i}'))
            hook3 = bottleneck.conv3.register_forward_hook(get_activation(f'conv3_{name}_{i}'))

for hook in [hook1, hook2, hook3]:
    hook.remove()

## Loading the images to the model and extracting the features

In [36]:
with torch.no_grad():
    for images, labels in train_loader:
        out = resnet50(images)

## Going through the dictionary with activations from different layers and printing the output shapes

In [42]:
for i, key_layers in enumerate(embeds.keys()):
    print(f'# {i} | Embedding Shape {embeds[key_layers].shape}\n ------------------------------------------------------------')

# 0 | Embedding Shape torch.Size([1000, 64, 8, 8])
 ------------------------------------------------------------
# 1 | Embedding Shape torch.Size([1000, 64, 8, 8])
 ------------------------------------------------------------
# 2 | Embedding Shape torch.Size([1000, 256, 8, 8])
 ------------------------------------------------------------
# 3 | Embedding Shape torch.Size([1000, 64, 8, 8])
 ------------------------------------------------------------
# 4 | Embedding Shape torch.Size([1000, 64, 8, 8])
 ------------------------------------------------------------
# 5 | Embedding Shape torch.Size([1000, 256, 8, 8])
 ------------------------------------------------------------
# 6 | Embedding Shape torch.Size([1000, 64, 8, 8])
 ------------------------------------------------------------
# 7 | Embedding Shape torch.Size([1000, 64, 8, 8])
 ------------------------------------------------------------
# 8 | Embedding Shape torch.Size([1000, 256, 8, 8])
 -----------------------------------------