In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision.models.resnet import ResNet50_Weights

In [2]:
class ImgEncoder_CNN(nn.Module):
    def __init__(self, projection_dim=512, hidden_dim=256, dropout_rate=0.1):
        super(ImgEncoder_CNN, self).__init__()
        
        # Load the pre-trained ResNet50 model
        base_model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        
        # Remove the final fully connected layer
        self.base_model = nn.Sequential(*list(base_model.children())[:-1])
        
        # Freeze the parameters of the base model
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Define the projection head
        self.projection_head = nn.Sequential(
            nn.Linear(2048, hidden_dim),  # Input dim is 2048 (output of ResNet50)
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Linear(hidden_dim, projection_dim)  # Output dim is 512
        )

    def forward(self, x):
        # Ensure input is a tensor
        if not isinstance(x, torch.Tensor):
            raise TypeError(f"Input must be a tensor, got {type(x)} instead.")
        
        # Ensure input is in the correct shape (batch_size, channels, height, width)
        if x.dim() == 3:
            x = x.unsqueeze(0)  # Add batch dimension if missing
        
        # Extract features from the base model
        h = self.base_model(x)
        
        # Flatten the features
        h = h.view(h.size(0), -1)  # Shape: [batch_size, 2048]
        
        # Pass through the projection head
        z = self.projection_head(h)
        
        # Normalize the output embeddings
        z = F.normalize(z, dim=1)  # Normalize embeddings to unit length
        
        return z

In [3]:
image_input = torch.randn(32, 3, 224, 224)  # Example image batch (32 images, 3 channels, 224x224)
embedding_dim = 512

image_encoder = ImgEncoder_CNN(projection_dim=embedding_dim)

image_features = image_encoder(image_input)


print(image_features.shape)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 149MB/s]


torch.Size([32, 512])


In [4]:
from torchinfo import summary

summary(image_encoder, input_data = image_input)

Layer (type:depth-idx)                        Output Shape              Param #
ImgEncoder_CNN                                [32, 512]                 --
├─Sequential: 1-1                             [32, 2048, 1, 1]          --
│    └─Conv2d: 2-1                            [32, 64, 112, 112]        (9,408)
│    └─BatchNorm2d: 2-2                       [32, 64, 112, 112]        (128)
│    └─ReLU: 2-3                              [32, 64, 112, 112]        --
│    └─MaxPool2d: 2-4                         [32, 64, 56, 56]          --
│    └─Sequential: 2-5                        [32, 256, 56, 56]         --
│    │    └─Bottleneck: 3-1                   [32, 256, 56, 56]         (75,008)
│    │    └─Bottleneck: 3-2                   [32, 256, 56, 56]         (70,400)
│    │    └─Bottleneck: 3-3                   [32, 256, 56, 56]         (70,400)
│    └─Sequential: 2-6                        [32, 512, 28, 28]         --
│    │    └─Bottleneck: 3-4                   [32, 512, 28, 28]      