In [1]:
import torch

In [2]:
import torch.nn as nn

In [3]:
from torchvision.models import resnet18, ResNet18_Weights, VGG16_Weights, vgg16
from torchvision.models.feature_extraction import create_feature_extractor as create_feature_extractor

In [None]:
class EncoderResnet18(nn.Module):
    def __init__(self):
        super(EncoderResnet18, self).__init__()
        # Step 1: Initialize model with the best available weights
        weights = ResNet18_Weights.DEFAULT
        model = resnet18(weights=weights)
        self.preprocess = weights.transforms()
        self.net = create_feature_extractor(model,['avgpool'])
        for p in self.net.parameters():
            p.requires_grad = False

    def forward(self,x):
        x = self.preprocess(x).unsqueeze(0)
        features = self.net(x)['avgpool']
        return features

In [19]:
class EncoderVGG16(nn.Module):
    def __init__(self):
        super(EncoderVGG16, self).__init__()
        # Step 1: Initialize model with the best available weights
        weights = VGG16_Weights.DEFAULT
        model = vgg16(weights=weights)
        self.preprocess = weights.transforms()
        del model.classifier[-1]
        self.net = create_feature_extractor(model,['classifier'])
        for p in self.net.parameters():
            p.requires_grad = False

    def forward(self,x):
        x = self.preprocess(x).unsqueeze(0)
        features = self.net(x)['classifier']
        return features