### Model Architecture

In [4]:
import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet

In [5]:
class EncoderEfficientNetB5(nn.Module):
    def __init__(self, pretrained=True):
        super(EncoderEfficientNetB5, self).__init__()
        # Load EfficientNet-B5
        self.encoder = EfficientNet.from_pretrained('efficientnet-b5') if pretrained else EfficientNet.from_name('efficientnet-b5')
        # Extract blocks (stages)
        self.blocks = self.encoder._blocks
        
    def forward(self, x):
        # Store features from different stages
        features = []
        x = self.encoder._conv_stem(x)  # Initial convolution
        x = self.encoder._bn0(x)       # Batch normalization
        for block in self.blocks:
            x = block(x)
            features.append(x)         # Save intermediate features
        return features


In [6]:
class DecoderFPN(nn.Module):
    def __init__(self, encoder_channels, decoder_channels):
        super(DecoderFPN, self).__init__()
        self.up_convs = nn.ModuleList()
        self.lat_convs = nn.ModuleList()

        # Build lateral and upsampling convolutions
        for in_channels in encoder_channels:
            self.lat_convs.append(nn.Conv2d(in_channels, decoder_channels, kernel_size=1))
            self.up_convs.append(nn.Conv2d(decoder_channels, decoder_channels, kernel_size=3, padding=1))
        
    def forward(self, encoder_features):
        # Start from the deepest feature and go upwards
        x = self.lat_convs[-1](encoder_features[-1])  # Lateral convolution for the deepest feature
        outputs = [x]  # Collect outputs

        # Iterate through the encoder features in reverse order
        for i in range(len(encoder_features) - 2, -1, -1):
            x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)  # Upsample
            lateral = self.lat_convs[i](encoder_features[i])  # Apply lateral convolution
            x = x + lateral  # Add lateral and upsampled feature
            x = self.up_convs[i](x)  # Apply up convolution
            outputs.append(x)
        
        return outputs[::-1]  # Reverse to match spatial resolution order


In [7]:
class EfficientNetFPN(nn.Module):
    def __init__(self, num_classes, encoder_channels, decoder_channels=256):
        super(EfficientNetFPN, self).__init__()
        self.encoder = EncoderEfficientNetB5(pretrained=True)
        self.decoder = DecoderFPN(encoder_channels, decoder_channels)
        self.final_conv = nn.Conv2d(decoder_channels, num_classes, kernel_size=1)  # Final 1x1 convolution

    def forward(self, x):
        encoder_features = self.encoder(x)
        decoder_features = self.decoder(encoder_features)
        x = decoder_features[0]  # Use the highest-resolution output
        x = self.final_conv(x)
        return x


In [8]:
# Encoder channels for EfficientNet-B5
encoder_channels = [48, 144, 240, 384, 2048]  # Output channels from each stage

# Initialize the model
model = EfficientNetFPN(num_classes=2, encoder_channels=encoder_channels)


Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth" to C:\Users\PRAGNA/.cache\torch\hub\checkpoints\efficientnet-b5-b6417697.pth
100.0%


Loaded pretrained weights for efficientnet-b5


### Loss Functions

In [9]:
import torch
import torch.nn as nn

class BCEWithLogitsLossCustom(nn.Module):
    def __init__(self):
        super(BCEWithLogitsLossCustom, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()  # Automatically applies sigmoid

    def forward(self, inputs, targets):
        return self.bce_loss(inputs, targets)

# Example usage:
# loss_fn = BCEWithLogitsLossCustom()
# loss = loss_fn(output, target)


In [10]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Flatten the inputs and targets
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        # Compute BCE loss
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        
        # Focal Loss modulation factor
        pt = torch.exp(-bce_loss)  # Probability of the true class
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Example usage:
# focal_loss = FocalLoss(alpha=0.25, gamma=2)
# loss = focal_loss(output, target)


In [11]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        # Apply sigmoid to inputs to get probabilities
        inputs = torch.sigmoid(inputs).view(-1)  
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        union = inputs.sum() + targets.sum()

        dice_score = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice_score  # To minimize the Dice Loss

# Example usage:
# dice_loss = DiceLoss()
# loss = dice_loss(output, target)


In [12]:
class DiceBCELoss(nn.Module):
    def __init__(self, alpha=0.5, smooth=1e-6):
        super(DiceBCELoss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss(smooth)

    def forward(self, inputs, targets):
        bce = self.bce_loss(inputs, targets)
        dice = self.dice_loss(inputs, targets)
        return bce + alpha * dice  # You can adjust the alpha weight as needed

# Example usage:
# loss_fn = DiceBCELoss(alpha=0.5)
# loss = loss_fn(output, target)


In [13]:
class DiceFocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, smooth=1e-6):
        super(DiceFocalLoss, self).__init__()
        self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma)
        self.dice_loss = DiceLoss(smooth)

    def forward(self, inputs, targets):
        focal = self.focal_loss(inputs, targets)
        dice = self.dice_loss(inputs, targets)
        return focal + dice  # You can adjust the weighting as needed

# Example usage:
# loss_fn = DiceFocalLoss(alpha=0.25, gamma=2)
# loss = loss_fn(output, target)


In [14]:
class BCEFocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(BCEFocalLoss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Flatten the inputs and targets
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        # Compute BCE loss
        bce_loss = self.bce_loss(inputs, targets)

        # Compute Focal Loss modulation factor
        pt = torch.exp(-bce_loss)  # Probability of the true class
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Example usage:
# combined_loss = BCEFocalLoss(alpha=0.25, gamma=2)
# loss = combined_loss(output, target)
