In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import torchsummary
import numpy as np

class LightweightMedicalCNN(nn.Module):
    def __init__(self, num_classes, in_channels=1):
        super(LightweightMedicalCNN, self).__init__()
        
        # Initial parameters
        self.in_channels = in_channels
        self.num_classes = num_classes
        
        # First Convolutional Block
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        
        # Second Convolutional Block with Primary Caps inspiration
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        
        # Third Convolutional Block
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        
        # Spatial Attention Module
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(128, 1, kernel_size=7, padding=3),
            nn.Sigmoid()
        )
        
        # Dynamic Routing inspired module
        self.caps_layer = nn.Sequential(
            nn.Conv2d(128, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((4, 4))
        )
        
        # Calculate the size for the flatten layer
        self._to_linear = 16 * 4 * 4
        
        # Classification layers
        self.classifier = nn.Sequential(
            nn.Linear(self._to_linear, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # First block
        x = self.conv1(x)
        
        # Second block
        x = self.conv2(x)
        
        # Third block
        x = self.conv3(x)
        
        # Apply spatial attention
        attention = self.spatial_attention(x)
        x = x * attention
        
        # Capsule inspired feature extraction
        x = self.caps_layer(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Classification
        x = self.classifier(x)
        
        return x

def print_model_summary(model, input_size=(1, 256, 256)):
    """
    Print model summary and calculate model size
    """
    # Convert input size to include batch dimension
    batch_size = 1
    input_shape = (batch_size, *input_size)
    
    # Create dummy input
    dummy_input = torch.randn(input_shape)
    
    # Print model architecture
    print("\nModel Architecture:")
    print(model)
    
    # Calculate total parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nTotal Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    
    # Estimate model size in MB
    model_size_mb = total_params * 4 / (1024 * 1024)  # Assuming 4 bytes per parameter
    print(f"Estimated Model Size: {model_size_mb:.2f} MB")
    
    # Test forward pass
    try:
        output = model(dummy_input)
        print(f"\nInput Shape: {input_shape}")
        print(f"Output Shape: {output.shape}")
        print("\nModel summary test passed successfully!")
    except Exception as e:
        print(f"\nError during forward pass: {str(e)}")

# Example usage
def test_model(num_classes=5):
    """
    Test the model with sample data
    """
    # Initialize model
    model = LightweightMedicalCNN(num_classes=num_classes, in_channels=1)
    
    # Print model summary
    print_model_summary(model)
    
    return model

if __name__ == "__main__":
    # Test with 5 classes
    model = test_model(num_classes=5)


Model Architecture:
LightweightMedicalCNN(
  (conv1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (spatial_attention): Sequential(
    (0): Conv2d(128, 1, kernel_size=(7, 7), str

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class SquashFunction(nn.Module):
    def forward(self, x, dim=-1):
        squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * x / torch.sqrt(squared_norm + 1e-8)

class PrimaryCapsules(nn.Module):
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride):
        super(PrimaryCapsules, self).__init__()
        self.dim_caps = dim_caps
        self.out_channels = out_channels
        self.conv = nn.Conv2d(in_channels, out_channels * dim_caps, kernel_size, stride, padding=1)  # Added padding=1
        self.squash = SquashFunction()

    def forward(self, x):
        outputs = self.conv(x)
        batch, _, height, width = outputs.shape
        
        outputs = outputs.view(batch, self.out_channels, self.dim_caps, height, width)
        outputs = outputs.permute(0, 1, 3, 4, 2).contiguous()
        outputs = outputs.view(batch, -1, self.dim_caps)
        
        return self.squash(outputs)

class CapsuleLayer(nn.Module):
    def __init__(self, num_caps_in, num_caps_out, dim_caps_in, dim_caps_out, num_iterations=3):
        super(CapsuleLayer, self).__init__()
        self.num_iterations = num_iterations
        self.num_caps_in = num_caps_in
        self.num_caps_out = num_caps_out
        
        self.W = nn.Parameter(torch.randn(1, num_caps_in, num_caps_out, dim_caps_out, dim_caps_in))
        self.squash = SquashFunction()

    def forward(self, u):
        batch_size = u.size(0)
        u = u.unsqueeze(2).unsqueeze(4)
        u_hat = torch.matmul(self.W, u)
        u_hat = u_hat.squeeze(-1)
        
        b = torch.zeros(batch_size, self.num_caps_in, self.num_caps_out).to(u.device)
        
        for i in range(self.num_iterations):
            c = F.softmax(b, dim=2)
            c = c.unsqueeze(3)
            s = (c * u_hat).sum(dim=1)
            v = self.squash(s)
            if i < self.num_iterations - 1:
                b = b + (u_hat * v.unsqueeze(1)).sum(dim=-1)
        
        return v

class DenseBlock(nn.Module):
    def __init__(self, in_channels, num_layers, growth_rate=12):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            layer = nn.Sequential(
                nn.BatchNorm2d(in_channels + i * growth_rate),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1),
                nn.Dropout(0.2)
            )
            self.layers.append(layer)
            
    def forward(self, x):
        features = [x]
        for layer in self.layers:
            new_features = layer(torch.cat(features, 1))
            features.append(new_features)
        return torch.cat(features, 1)

class EnhancedMedicalCapsCNN(nn.Module):
    def __init__(self, num_classes, in_channels=1):
        super(EnhancedMedicalCapsCNN, self).__init__()
        
        # Initial feature extraction (256x256 -> 64x64)
        self.init_conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # Multi-scale feature extraction (64x64 -> 64x64)
        self.multi_scale = nn.ModuleList([
            nn.Conv2d(32, 16, kernel_size=k, padding=k//2) 
            for k in [3, 5, 7]
        ])
        
        # Dense block (64x64 -> 64x64)
        self.dense_block = DenseBlock(48, num_layers=4, growth_rate=12)
        dense_out_channels = 48 + 4 * 12  # 96 channels
        
        # Transition layer (64x64 -> 32x32)
        self.transition = nn.Sequential(
            nn.BatchNorm2d(dense_out_channels),
            nn.Conv2d(dense_out_channels, dense_out_channels // 2, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )
        
        transition_out_channels = dense_out_channels // 2  # 48 channels
        
        # Primary capsules (32x32 -> 16x16 with padding=1)
        self.primary_caps = PrimaryCapsules(
            in_channels=transition_out_channels,
            out_channels=32,
            dim_caps=8,
            kernel_size=3,
            stride=2
        )
        
        # Calculate primary capsules output size
        # After PrimaryCapsules: 16x16 feature maps with 32 channels
        # Total capsules = 32 * 16 * 16 = 8192
        primary_caps_size = 32 * 16 * 16  # 8192
        
        # Medical feature capsules
        self.medical_caps = CapsuleLayer(
            num_caps_in=primary_caps_size,
            num_caps_out=16,
            dim_caps_in=8,
            dim_caps_out=16
        )
        
        # Diagnostic capsules
        self.diagnostic_caps = CapsuleLayer(
            num_caps_in=16,
            num_caps_out=num_classes,
            dim_caps_in=16,
            dim_caps_out=16
        )
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Conv2d(48, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(num_classes * 16, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Initial feature extraction (256x256 -> 64x64)
        x = self.init_conv(x)  # Output: [batch, 32, 64, 64]
        
        # Multi-scale feature extraction
        multi_scale_features = [conv(x) for conv in self.multi_scale]
        x = torch.cat(multi_scale_features, dim=1)  # Output: [batch, 48, 64, 64]
        
        # Apply attention
        attention = self.attention(x)
        x = x * attention
        
        # Dense feature extraction with transition (64x64 -> 32x32)
        x = self.dense_block(x)  # Output: [batch, 96, 64, 64]
        x = self.transition(x)   # Output: [batch, 48, 32, 32]
        
        # Primary capsules (32x32 -> 16x16)
        primary_caps = self.primary_caps(x)  # Output: [batch, 8192, 8]
        
        # Medical feature capsules
        medical_caps = self.medical_caps(primary_caps)  # Output: [batch, 16, 16]
        
        # Diagnostic capsules
        diagnostic_caps = self.diagnostic_caps(medical_caps)  # Output: [batch, num_classes, 16]
        
        # Final classification
        x = diagnostic_caps.view(diagnostic_caps.size(0), -1)
        output = self.classifier(x)
        
        if self.training:
            return output, diagnostic_caps
        return output

# The rest of the code remains the same
class MarginLoss(nn.Module):
    def __init__(self, m_pos=0.9, m_neg=0.1, lambda_=0.5):
        super(MarginLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lambda_ = lambda_

    def forward(self, caps_output, target):
        batch_size = caps_output.size(0)
        v_c = torch.sqrt((caps_output ** 2).sum(dim=-1))
        
        left = F.relu(self.m_pos - v_c) ** 2
        right = F.relu(v_c - self.m_neg) ** 2
        
        target = F.one_hot(target, num_classes=v_c.size(1))
        
        loss = target * left + self.lambda_ * (1.0 - target) * right
        return loss.sum(dim=1).mean()

def print_model_summary(model, input_size=(1, 256, 256)):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    batch_size = 1
    input_shape = (batch_size, *input_size)
    dummy_input = torch.randn(input_shape).to(device)
    
    print("\nModel Architecture:")
    print(model)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nTotal Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    
    model_size_mb = total_params * 4 / (1024 * 1024)
    print(f"Estimated Model Size: {model_size_mb:.2f} MB")
    
    try:
        output = model(dummy_input)
        print(f"\nInput Shape: {input_shape}")
        if isinstance(output, tuple):
            print(f"Output Shapes: {[o.shape for o in output]}")
        else:
            print(f"Output Shape: {output.shape}")
        print("\nModel summary test passed successfully!")
    except Exception as e:
        print(f"\nError during forward pass: {str(e)}")

def test_model(num_classes=5):
    model = EnhancedMedicalCapsCNN(num_classes=num_classes)
    print_model_summary(model)
    return model

if __name__ == "__main__":
    model = test_model(num_classes=5)