In [1]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2

In [6]:
class MobileNetV2Regression(nn.Module):
    def __init__(self, num_output_features=1, checkpoint=None):
        super().__init__()
        # Load a pre-trained ResNet-50 model
        self.mobilenet = mobilenet_v2(weights=None)
        # Replace the classifier layer for regression

        self.mobilenet.classifier = nn.Sequential(
            nn.Flatten(
                start_dim=1
            ),  # Flatten [batch_size, channels, 1, 1] to [batch_size, channels]
            nn.Dropout(p=0.2),
            nn.Linear(in_features=1280, out_features=num_output_features, bias=True),
        )
        # If your targets are in the range [0, 1], you might want to add a sigmoid layer:
        self.sigmoid = nn.Sigmoid()

        # If the checkpoint contains other items, make sure to load the state_dict correctly
        if checkpoint is not None:
            self._checkpoint = self._prepare_lightning_state_dict(
                checkpoint["state_dict"]
            )
            self.mobilenet.load_state_dict(self._checkpoint)

    def _prepare_lightning_state_dict(self, state_dict):
        # If the state_dict is from a Lightning model, it might contain the model and optimizer states
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith("model.mobilenet."):
                new_state_dict[key.replace("model.mobilenet.", "")] = value

        return new_state_dict

    def forward(self, x):
        # Get the main output from the GoogLeNet model
        outputs = self.mobilenet(x)
        if isinstance(outputs, torch.Tensor):
            x = outputs
        else:  # If outputs are GoogLeNetOutputs, extract the main output
            x = outputs.logits

        # Apply the sigmoid function to the main output
        x = self.sigmoid(x)
        return x

In [7]:
MODEL_WEIGHTS = "/home/soroush1/projects/def-kohitij/soroush1/pretrain-imagenet/mobilenetv2_lalem/lightning_logs/version_1/checkpoints/epoch=499-step=176000.ckpt"

# Load the checkpoint
checkpoint = torch.load(MODEL_WEIGHTS, map_location=torch.device("cpu"))
mobilenet_model = MobileNetV2Regression(num_output_features=1, checkpoint=checkpoint)