In [1]:
import torch
import torch.nn as nn

In [8]:
class VGG19(nn.Module):
    def __init__(self, num_classes=1000, init_weights=True):
        super(VGG19, self).__init__()
        self.features = self._make_layers(
            [
                64,
                64,
                "M",
                128,
                128,
                "M",
                256,
                256,
                256,
                256,
                "M",
                512,
                512,
                512,
                512,
                "M",
                512,
                512,
                512,
                512,
                "M",
            ]
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


class VGG19Regression(VGG19):
    def __init__(self, num_output_features=1, checkpoint=None):
        super().__init__(num_classes=num_output_features, init_weights=True)
        # Replace the last layer for regression
        self.classifier[-1] = nn.Linear(4096, num_output_features)
        # If your targets are in the range [0, 1], you might want to add a sigmoid layer:
        self.sigmoid = nn.Sigmoid()

        # If the checkpoint contains other items, make sure to load the state_dict correctly
        if checkpoint is not None:
            self._checkpoint = self._prepare_lightning_state_dict(
                checkpoint["state_dict"]
            )
            super().load_state_dict(self._checkpoint)

    def _prepare_lightning_state_dict(self, state_dict):
        # If the state_dict is from a Lightning model, it might contain the model and optimizer states
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith("vgg."):
                new_state_dict[key.replace("vgg.", "")] = value

        return new_state_dict

    def forward(self, x):
        x = super().forward(x)
        x = self.sigmoid(x)
        return x

In [9]:
MODEL_WEIGHTS = "/home/soroush1/projects/def-kohitij/soroush1/pretrain-imagenet/vgg19_lalem/lightning_logs/version_1/checkpoints/epoch=92-step=32736.ckpt"

# Load the checkpoint
checkpoint = torch.load(MODEL_WEIGHTS, map_location=torch.device("cpu"))
vgg19_model = VGG19Regression(num_output_features=1, checkpoint=checkpoint)