In [8]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms

In [6]:
# Load pre-trained ResNet50
resnet50 = models.resnet50(pretrained=True)

# Remove the final fully connected layer (1000 class classifier) to access the feature vector
# The new final layer has 2048 features
feature_extractor = torch.nn.Sequential(*list(resnet50.children())[:-1])

# look at the model without the final classifying layer
#print(feature_extractor)

In [10]:
class ImageEmbeddingModel(torch.nn.Module):
    def __init__(self, output_dim=512):
        super(ImageEmbeddingModel, self).__init__()
        # Load pre-trained ResNet50 model
        self.feature_extractor = models.resnet50(pretrained=True)
        # Freeze the pre-trained layers
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
        # Replace the last fully connected layer
        num_features = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Linear(num_features, output_dim)

    def forward(self, x):
        x = self.feature_extractor(x)
        return x

# Instantiate the model
image_embedding_model = ImageEmbeddingModel(output_dim=512)

# Check
print(image_embedding_model)

ImageEmbeddingModel(
  (feature_extractor): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): 