In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet101

In [4]:
class ResNet101Regression(nn.Module):
    def __init__(self, num_output_features=1, checkpoint=checkpoint):
        super().__init__()
        # Load a pre-trained ResNet-50 model
        self.resnet = resnet101(weights=None)
        # Replace the classifier layer for regression
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_output_features)
        # If your targets are in the range [0, 1], you might want to add a sigmoid layer:
        self.sigmoid = nn.Sigmoid()

        # If the checkpoint contains other items, make sure to load the state_dict correctly
        if checkpoint is not None:
            self._checkpoint = self._prepare_lightning_state_dict(
                checkpoint["state_dict"]
            )
            self.resnet.load_state_dict(self._checkpoint)

    def _prepare_lightning_state_dict(self, state_dict):
        # If the state_dict is from a Lightning model, it might contain the model and optimizer states
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith("model.resnet."):
                new_state_dict[key.replace("model.resnet.", "")] = value

        return new_state_dict

    def forward(self, x):
        x = self.resnet(x)
        x = self.sigmoid(x)
        return x

In [6]:
MODEL_WEIGHTS = "/home/soroush1/projects/def-kohitij/soroush1/pretrain-imagenet/resnet101_lalem/lightning_logs/version_0/checkpoints/epoch=220-step=77792.ckpt"

# Load the checkpoint
checkpoint = torch.load(MODEL_WEIGHTS, map_location=torch.device("cpu"))
resnet101_model = ResNet101Regression(num_output_features=1, checkpoint=checkpoint)