In [1]:
import torch
import torch.nn as nn
from torchvision.models import inception_v3

In [2]:
class InceptionRegression(nn.Module):
    def __init__(self, num_output_features=1, checkpoint=None):
        super().__init__()
        # Load a pre-trained ResNet-50 model
        self.inception = inception_v3(aux_logits=False, weights=None, init_weights=True)
        # Replace the classifier layer for regression
        num_ftrs = self.inception.fc.in_features
        self.inception.fc = nn.Linear(num_ftrs, num_output_features)
        # If your targets are in the range [0, 1], you might want to add a sigmoid layer:
        self.sigmoid = nn.Sigmoid()

    # If the checkpoint contains other items, make sure to load the state_dict correctly
        if checkpoint is not None:
            self._checkpoint = self._prepare_lightning_state_dict(
                checkpoint["state_dict"]
            )
            self.inception.load_state_dict(self._checkpoint)

    def forward(self, x):
        # Get the main output from the GoogLeNet model
        outputs = self.inception(x)
        if isinstance(outputs, torch.Tensor):
            x = outputs
        else:  # If outputs are GoogLeNetOutputs, extract the main output
            x = outputs.logits

        # Apply the sigmoid function to the main output
        x = self.sigmoid(x)
        return x
        
    def _prepare_lightning_state_dict(self, state_dict):
        # If the state_dict is from a Lightning model, it might contain the model and optimizer states
        new_state_dict = {}
        for key, value in state_dict.items():
            # print(f"{key = }")
            if key.startswith("model.inception."):
                new_state_dict[key.replace("model.inception.", "")] = value

        return new_state_dict

In [4]:
MODEL_WEIGHTS = "/home/soroush1/projects/def-kohitij/soroush1/pretrain-imagenet/inception_lalem/lightning_logs/version_1/checkpoints/epoch=358-step=126368.ckpt"

# Load the checkpoint
checkpoint = torch.load(MODEL_WEIGHTS, map_location=torch.device("cpu"))
inception_model = InceptionRegression(num_output_features=1, checkpoint=checkpoint)