<a href="https://colab.research.google.com/github/Sibusisongwenya/WIP-Project/blob/main/magic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision.models import densenet121, DenseNet121_Weights
import logging
from typing import Optional
from collections import OrderedDict
from torchbnn.modules import BayesLinear


class BayesianDenseNet121_LLSVI(nn.Module):
    """
    Bayesian DenseNet121 model using LLSVI in classifier layers for regression.
    Outputs a single continuous value.
    """
    def __init__(self, pretrained: bool = True):
        super(BayesianDenseNet121_LLSVI, self).__init__()
        # Initialize the DenseNet121 model with or without pretrained weights.
        self.densenet121 = models.densenet121(pretrained=pretrained)

        # Replace the original classifier with a Bayesian version.

        # Here we assume BayesLinear is a Bayesian equivalent of nn.Linear.
        in_features = self.densenet121.classifier.in_features
        out_features = 1  # Assuming you want a single output neuron
        self.densenet121.classifier = nn.Sequential(
            BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=in_features, out_features=out_features), # Corrected instantiation
        )

    def forward(self, x: torch.Tensor, sample: Optional[bool] = False) -> torch.Tensor:
        # Forward pass through the modified DenseNet121.
        return self.densenet121(x)


# In the BayesianDenseNet121_LLSVI class
        self.classifier = nn.Sequential(
            BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=num_features, out_features=128), # specify prior and size
            nn.SiLU(),
            BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=128, out_features=1) # specify prior and size
        )


def init_weights(self) -> None:
    """
    Initializes the weights of the model using appropriate distributions.
    """
    for m in self.modules():
        if isinstance(m, BayesLinear):
            # The following lines assume the layer has `W_mu` and `W_rho`
            nn.init.kaiming_normal_(m.W_mu, mode='fan_out', nonlinearity='relu')  # Initialize W_mu
            nn.init.normal_(m.W_rho, 0, 0.1)  # Initialize W_rho with a normal distribution

            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor, sample: bool = False) -> torch.Tensor:
        """
        Forward pass through DenseNet121 and Bayesian classifier.

        Args:
            x (torch.Tensor): Input tensor.
            sample (bool): Flag to enable sampling in BayesLinear layers.

        Returns:
            torch.Tensor: Model output.
        """
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)

        for module in self.classifier:
            if isinstance(module, BayesLinear):
                out = module(out, sample=sample)
            else:
                out = module(out)
        return out

    def kl_loss(self) -> torch.Tensor:
        """
        Compute KL divergence from all BayesLinear layers for regularization.

        Returns:
            torch.Tensor: KL divergence value.
        """
        kl = 0.0
        for m in self.modules():
            if hasattr(m, "kl_loss"):
                kl += m.kl_loss()
        return kl


class DenseNet121_LLDropout(nn.Module):
    """
    DenseNet121 model with Last Layer Dropout (MC-Dropout) for Bayesian regression.
    Outputs a single continuous value and allows uncertainty estimation via MC sampling.
    """
    def __init__(self, pretrained: bool = True, dropout_prob: float = 0.5):
        super().__init__()
        self.dropout_prob = dropout_prob
        weights = DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None
        self.densenet = densenet121(weights=weights)

        # Add dropout to the features module and classifier
        self.densenet.features.add_module("dropout1", nn.Dropout(p=dropout_prob))
        num_features = self.densenet.classifier.in_features
        self.densenet.classifier = nn.Sequential(
            nn.Dropout(dropout_prob),
            nn.Linear(num_features, 1)
        )

    def forward(self, x: torch.Tensor, sample: bool = False) -> torch.Tensor:
        """
        Forward pass with optional MC-Dropout sampling.

        Args:
            x (torch.Tensor): Input tensor.
            sample (bool): If True, forces dropout layers to train mode for MC sampling.

        Returns:
            torch.Tensor: Model output.
        """
        if sample:
            # Save original dropout states and force dropout layers to train mode for sampling
            dropout_modules = [m for m in self.modules() if isinstance(m, nn.Dropout)]
            orig_states = {m: m.training for m in dropout_modules}
            for m in dropout_modules:
                m.train()

            out = self.densenet(x)

            # Revert dropout layers to their original states
            for m in dropout_modules:
                if not orig_states[m]:
                    m.eval()
            return out
        else:
            return self.densenet(x)

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def load_bayesian_model(checkpoint_path: str, device: torch.device, pretrained: bool = True) -> torch.nn.Module:
    """
    Load the Bayesian DenseNet121_LLSVI regression model from a checkpoint.

    Args:
        checkpoint_path (str): Path to the model checkpoint file.
        device (torch.device): The device to load the model onto.
        pretrained (bool): Whether to initialize with pretrained weights.

    Returns:
        torch.nn.Module: The loaded Bayesian model.
    """
    try:
        # Initialize the model without calling init_weights()
        model = BayesianDenseNet121_LLSVI(pretrained=pretrained)
        logging.info("Initialized BayesianDenseNet121_LLSVI model.")


        # Load the checkpoint file from disk
        checkpoint = torch.load(checkpoint_path, map_location=device)
        # Extract state_dict (if the checkpoint is wrapped in a dict with "state_dict")
        state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint

        new_state_dict = OrderedDict()
        for old_key, value in state_dict.items():
            # If key starts with "features.", prepend "densenet121."
            if old_key.startswith("features."):
                new_key = "densenet121." + old_key
                new_state_dict[new_key] = value

            # If key is from the classifier, remap to the Bayesian layer parameters
            elif old_key.startswith("classifier."):
                # Map deterministic weight to Bayesian mean and add default log sigma
                if old_key == "classifier.weight":
                    new_state_dict["densenet121.classifier.0.weight_mu"] = value
                    # Initialize log sigma to a small constant, e.g., -5.0
                    new_state_dict["densenet121.classifier.0.weight_log_sigma"] = torch.full_like(value, -5.0)
                # Map deterministic bias to Bayesian mean and add default log sigma
                elif old_key == "classifier.bias":
                    new_state_dict["densenet121.classifier.0.bias_mu"] = value
                    new_state_dict["densenet121.classifier.0.bias_log_sigma"] = torch.full_like(value, -5.0)
                else:
                    # For any other classifier key, prepend "densenet121."
                    new_key = "densenet121." + old_key
                    new_state_dict[new_key] = value
            else:
                # For any other key, if not already prefixed, add "densenet121."
                if not old_key.startswith("densenet121."):
                    new_key = "densenet121." + old_key
                else:
                    new_key = old_key
                new_state_dict[new_key] = value

        # Load the renamed state_dict into the model
        missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
        print("Missing keys:", missing_keys)
        print("Unexpected keys:", unexpected_keys)
        logging.info("State dict loaded successfully.")

        # Move model to device
        model.to(device)
        return model


    except FileNotFoundError:
        logging.error(f"Checkpoint file not found: {checkpoint_path}")
        raise
    except Exception as e:
        logging.error(f"Failed to load Bayesian model: {e}")
        raise
def load_mc_dropout_model(
    checkpoint_path: str,
    device: torch.device,
    pretrained: bool = True,
    dropout_prob: float = 0.5
) -> torch.nn.Module:
    """
    Load the DenseNet121_LLDropout regression model from a checkpoint.

    Args:
        checkpoint_path (str): Path to the model checkpoint file.
        device (torch.device): The device to load the model onto.
        pretrained (bool): Whether to initialize with pretrained weights.
        dropout_prob (float): Dropout probability for MC-Dropout layers.

    Returns:
        torch.nn.Module: The loaded MC-Dropout model.
    """

    try:
        # Initialize the model
        model = DenseNet121_LLDropout(pretrained=pretrained, dropout_prob=dropout_prob)
        logging.info("Initialized DenseNet121_LLDropout model with dropout probability: %.2f", dropout_prob)


        # 1. Load the checkpoint once
        checkpoint = torch.load(checkpoint_path, map_location=device)

        # 2. Extract the actual state_dict
        if "state_dict" in checkpoint:
            state_dict = checkpoint["state_dict"]
        else:
            state_dict = checkpoint

        new_state_dict = OrderedDict()
        for old_key, value in state_dict.items():
            # Case 1: If the key starts with "features.", prepend "densenet."
            if old_key.startswith("features."):
                new_key = "densenet." + old_key

            # Case 2: If the key starts with "classifier.", rename it to "densenet.classifier.1"
            elif old_key.startswith("classifier."):
                # e.g., "classifier.weight" -> "classifier.1.weight"
                suffix = old_key[len("classifier"):]  # e.g. ".weight"
                new_key = "densenet.classifier.1" + suffix

            # Otherwise, keep the key as-is
            else:
                new_key = old_key

            new_state_dict[new_key] = value

        # Load with strict=False first to see which keys match or don't match
        missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
        print("Missing keys:", missing_keys)
        print("Unexpected keys:", unexpected_keys)

        # If you want to enforce strict loading after verifying the renamed keys:
        # model.load_state_dict(new_state_dict, strict=True)
        logging.info("State dict loaded successfully.")

        model.to(device)
        return model

    except FileNotFoundError:
        logging.error(f"Checkpoint file not found: {checkpoint_path}")
        raise
    except Exception as e:
        logging.error(f"Failed to load MC-Dropout model: {e}")
        raise