In [5]:
import torch
import torch.nn as nn

from torchvision.models import resnet50

In [6]:
class ResNet50Regression(nn.Module):
    def __init__(self, num_output_features=1, checkpoint=None):
        super().__init__()
        # Load a pre-trained ResNet-50 model

        self.resnet = resnet50(weights=None)

        # Replace the classifier layer for regression
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_output_features)

        # If your targets are in the range [0, 1], you might want to add a sigmoid layer:
        self.sigmoid = nn.Sigmoid()

        # If the checkpoint contains other items, make sure to load the state_dict correctly
        if checkpoint is not None:
            self._checkpoint = self._prepare_lightning_state_dict(
                checkpoint["state_dict"]
            )
            self.resnet.load_state_dict(self._checkpoint)

    def _prepare_lightning_state_dict(self, state_dict):
        # If the state_dict is from a Lightning model, it might contain the model and optimizer states
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith("model.resnet."):
                new_state_dict[key.replace("model.resnet.", "")] = value

        return new_state_dict

    def forward(self, x):
        x = self.resnet(x)
        x = self.sigmoid(x)
        return x

In [7]:
RESNET_WEIGHTS = "/Users/soroush/Desktop/epoch=499-step=88000.ckpt"

# Load the checkpoint
checkpoint = torch.load(RESNET_WEIGHTS, map_location=torch.device("cpu"))
resnet_model = ResNet50Regression(num_output_features=1, checkpoint=checkpoint)