In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np

# Set seed for reproducibility
torch.manual_seed(42)

class ContinuousAtrousConvModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ContinuousAtrousConvModule, self).__init__()
        # Three parallel atrous convolutions with different rates [1, 2, 4]
        middle_channels = out_channels // 3
        self.atrous_conv1 = nn.Conv2d(in_channels, middle_channels, kernel_size=3, dilation=1, padding=1)
        self.atrous_conv2 = nn.Conv2d(in_channels, middle_channels, kernel_size=3, dilation=2, padding=2)
        self.atrous_conv3 = nn.Conv2d(in_channels, out_channels - 2*middle_channels, kernel_size=3, dilation=4, padding=4)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_channels)
        
        # Global average pooling to get 1×1×2048
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
    def forward(self, x):
        x1 = self.atrous_conv1(x)
        x2 = self.atrous_conv2(x)
        x3 = self.atrous_conv3(x)
        x = torch.cat([x1, x2, x3], dim=1)
        x = self.bn(x)
        x = self.relu(x)
        x = self.global_pool(x)  # Global average pooling to get 1×1×2048
        return x

class ImprovedSkipLayer(nn.Module):
    def __init__(self, low_level_channels, mid_level_channels, high_level_channels, num_classes):
        super(ImprovedSkipLayer, self).__init__()
        
        # Convolutional layers for skip connections
        self.conv_low = nn.Conv2d(low_level_channels, num_classes, kernel_size=1)
        self.conv_mid = nn.Conv2d(mid_level_channels, num_classes, kernel_size=1)
        self.conv_high = nn.Conv2d(high_level_channels, num_classes, kernel_size=1)
        
    def forward(self, x_low, x_mid, x_high, x_upsampled):
        # Process each feature level
        x_low_processed = self.conv_low(x_low)
        x_mid_processed = self.conv_mid(x_mid)
        x_high_processed = self.conv_high(x_high)
        
        # 8x upsampling from 1×1×4 to match high level features
        x_upsampled_8x = F.interpolate(x_upsampled, size=x_high_processed.shape[2:], 
                                      mode='bilinear', align_corners=False)
        
        # Fusion with high-level features
        high_fusion = x_high_processed + x_upsampled_8x
        
        # 4x upsampling to match mid level features
        high_fusion_upsampled = F.interpolate(high_fusion, size=x_mid_processed.shape[2:], 
                                             mode='bilinear', align_corners=False)
        
        # Fusion with mid-level features
        mid_fusion = x_mid_processed + high_fusion_upsampled
        
        # 2x upsampling to match low level features
        mid_fusion_upsampled = F.interpolate(mid_fusion, size=x_low_processed.shape[2:], 
                                           mode='bilinear', align_corners=False)
        
        # Fusion with low-level features
        final_fusion = x_low_processed + mid_fusion_upsampled
        
        # Final upsampling to original image size
        output = F.interpolate(final_fusion, size=(640, 640), mode='bilinear', align_corners=False)
        
        return output

class InvertedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expansion_factor=6):
        super(InvertedResidualBlock, self).__init__()
        self.stride = stride
        self.use_residual = stride == 1 and in_channels == out_channels
        
        hidden_dim = in_channels * expansion_factor
        
        # 1x1 pointwise conv for expansion
        self.expand = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6(inplace=True)
        ) if expansion_factor != 1 else nn.Identity()
        
        # 3x3 depthwise conv
        self.depthwise = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, 
                     groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6(inplace=True)
        )
        
        # 1x1 pointwise conv for projection
        self.project = nn.Sequential(
            nn.Conv2d(hidden_dim, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )
    
    def forward(self, x):
        identity = x
        
        x = self.expand(x)
        x = self.depthwise(x)
        x = self.project(x)
        
        if self.use_residual:
            x = x + identity
        
        return x

class ImprovedFCN(nn.Module):
    def __init__(self, num_classes=6):  # Set default to 4 classes as per paper
        super(ImprovedFCN, self).__init__()
        
        # Use MobileNetV2 as backbone (as described in the paper)
        self.mobilenet = models.mobilenet_v2(pretrained=True)
        
        # Extract feature maps at different levels
        # We'll need to track the layers to extract low, mid and high level features
        
        # MobileNet feature extractor until different stages
        self.low_level_features = nn.Sequential(*list(self.mobilenet.features)[:4])    # Earlier layers
        self.mid_level_features = nn.Sequential(*list(self.mobilenet.features)[4:7])   # Middle layers
        self.high_level_features = nn.Sequential(*list(self.mobilenet.features)[7:14]) # Later layers
        self.final_features = nn.Sequential(*list(self.mobilenet.features)[14:])       # Final layers
        
        # Channel dimensions at each level (based on MobileNetV2 architecture)
        self.low_level_channels = 24    # Adjust based on actual channel numbers
        self.mid_level_channels = 32
        self.high_level_channels = 96
        self.final_channels = 1280
        
        # Continuous Atrous Convolution Module
        self.continuous_atrous = ContinuousAtrousConvModule(self.final_channels, 2048)
        
        # Final prediction layer
        self.final_conv = nn.Conv2d(2048, num_classes, kernel_size=1)
        
        # Improved Skip Layer
        self.improved_skip = ImprovedSkipLayer(
            self.low_level_channels, 
            self.mid_level_channels, 
            self.high_level_channels, 
            num_classes
        )
        
        # Softmax classifier
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        input_size = x.size()[2:]  # Remember original input size
        
        # Extract features at different levels
        x_low = self.low_level_features(x)
        print(f"Low level features shape: {x_low.shape}")
        
        x = x_low
        x_mid = self.mid_level_features(x)
        print(f"Mid level features shape: {x_mid.shape}")
        
        x = x_mid
        x_high = self.high_level_features(x)
        print(f"High level features shape: {x_high.shape}")
        
        x = x_high
        x = self.final_features(x)
        print(f"Final features shape: {x.shape}")
        
        # Apply Continuous Atrous Convolution Module
        x = self.continuous_atrous(x)
        print(f"After continuous atrous: {x.shape}")
        
        # Final prediction
        x = self.final_conv(x)
        print(f"After final conv: {x.shape}")
        
        # Apply improved skip connections and upsampling
        x = self.improved_skip(x_low, x_mid, x_high, x)
        print(f"After improved skip layer: {x.shape}")
        
        # Softmax classifier
        x = self.softmax(x)
        
        return x

# Test the model with dummy input
if __name__ == "__main__":
    # Create dummy input
    dummy_input = torch.randn(1, 3, 640, 640)
    
    # Initialize model
    model = ImprovedFCN(num_classes=6)  # 4 defect types as per paper
    model.eval()
    
    # Run forward pass
    with torch.no_grad():
        output = model(dummy_input)
        
    print(f"Output shape: {output.shape}")  # Should be [1, 4, 640, 640]

Low level features shape: torch.Size([1, 24, 160, 160])
Mid level features shape: torch.Size([1, 32, 80, 80])
High level features shape: torch.Size([1, 96, 40, 40])
Final features shape: torch.Size([1, 1280, 20, 20])
After continuous atrous: torch.Size([1, 2048, 1, 1])
After final conv: torch.Size([1, 6, 1, 1])
After improved skip layer: torch.Size([1, 6, 640, 640])
Output shape: torch.Size([1, 6, 640, 640])


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np

# Set seed for reproducibility
torch.manual_seed(42)

class ContinuousAtrousConvModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ContinuousAtrousConvModule, self).__init__()
        # Three parallel atrous convolutions with different rates [1, 2, 4]
        middle_channels = out_channels // 3
        self.atrous_conv1 = nn.Conv2d(in_channels, middle_channels, kernel_size=3, dilation=1, padding=1)
        self.atrous_conv2 = nn.Conv2d(in_channels, middle_channels, kernel_size=3, dilation=2, padding=2)
        self.atrous_conv3 = nn.Conv2d(in_channels, out_channels - 2*middle_channels, kernel_size=3, dilation=4, padding=4)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_channels)
        
        # Global average pooling to get 1×1×2048
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
    def forward(self, x):
        x1 = self.atrous_conv1(x)
        x2 = self.atrous_conv2(x)
        x3 = self.atrous_conv3(x)
        x = torch.cat([x1, x2, x3], dim=1)
        x = self.bn(x)
        x = self.relu(x)
        x = self.global_pool(x)  # Global average pooling to get 1×1×2048
        return x

class ImprovedSkipLayer(nn.Module):
    def __init__(self, low_level_channels, mid_level_channels, high_level_channels, num_classes):
        super(ImprovedSkipLayer, self).__init__()
        
        # Convolutional layers for skip connections
        self.conv_low = nn.Conv2d(low_level_channels, num_classes, kernel_size=1)  # 160×160×128 -> 160×160×num_classes
        self.conv_mid = nn.Conv2d(mid_level_channels, num_classes, kernel_size=1)  # 80×80×256 -> 80×80×num_classes  
        self.conv_high = nn.Conv2d(high_level_channels, num_classes, kernel_size=1) # 40×40×512 -> 40×40×num_classes
        
    def forward(self, x_low, x_mid, x_high, x_upsampled):
        # Process each feature level
        x_low_processed = self.conv_low(x_low)
        x_mid_processed = self.conv_mid(x_mid)
        x_high_processed = self.conv_high(x_high)
        
        # 8x upsampling from 1×1×num_classes to match high level features
        x_upsampled_8x = F.interpolate(x_upsampled, size=x_high_processed.shape[2:], 
                                      mode='bilinear', align_corners=False)
        
        # Fusion with high-level features
        high_fusion = x_high_processed + x_upsampled_8x
        
        # 4x upsampling to match mid level features
        high_fusion_upsampled = F.interpolate(high_fusion, size=x_mid_processed.shape[2:], 
                                             mode='bilinear', align_corners=False)
        
        # Fusion with mid-level features
        mid_fusion = x_mid_processed + high_fusion_upsampled
        
        # 2x upsampling to match low level features
        mid_fusion_upsampled = F.interpolate(mid_fusion, size=x_low_processed.shape[2:], 
                                           mode='bilinear', align_corners=False)
        
        # Fusion with low-level features
        final_fusion = x_low_processed + mid_fusion_upsampled
        
        # Final upsampling to original image size (from low level to original size)
        output = F.interpolate(final_fusion, size=(640, 640), mode='bilinear', align_corners=False)
        
        return output

class InvertedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expansion_factor=6):
        super(InvertedResidualBlock, self).__init__()
        self.stride = stride
        self.use_residual = stride == 1 and in_channels == out_channels
        
        hidden_dim = in_channels * expansion_factor
        
        # 1x1 pointwise conv for expansion
        self.expand = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6(inplace=True)
        ) if expansion_factor != 1 else nn.Identity()
        
        # 3x3 depthwise conv
        self.depthwise = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, 
                     groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6(inplace=True)
        )
        
        # 1x1 pointwise conv for projection
        self.project = nn.Sequential(
            nn.Conv2d(hidden_dim, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )
    
    def forward(self, x):
        identity = x
        
        x = self.expand(x)
        x = self.depthwise(x)
        x = self.project(x)
        
        if self.use_residual:
            x = x + identity
        
        return x

class ImprovedFCN(nn.Module):
    def __init__(self, num_classes=6):  # 6 classes as per your implementation
        super(ImprovedFCN, self).__init__()
        
        # Use MobileNetV2 as backbone (as described in the paper)
        self.mobilenet = models.mobilenet_v2(pretrained=True)
        
        # Extract feature maps at different levels
        # MobileNet feature extractor until different stages
        self.low_level_features = nn.Sequential(*list(self.mobilenet.features)[:4])    # Earlier layers
        self.mid_level_features = nn.Sequential(*list(self.mobilenet.features)[4:7])   # Middle layers
        self.high_level_features = nn.Sequential(*list(self.mobilenet.features)[7:14]) # Later layers
        self.final_features = nn.Sequential(*list(self.mobilenet.features)[14:])       # Final layers
        
        # Channel dimensions at each level (based on MobileNetV2 architecture)
        self.low_level_channels = 24    # Actual channel number in MobileNetV2
        self.mid_level_channels = 32    # Actual channel number in MobileNetV2
        self.high_level_channels = 96    # Actual channel number in MobileNetV2
        self.final_channels = 1280       # Actual channel number in MobileNetV2
        
        # Continuous Atrous Convolution Module
        self.continuous_atrous = ContinuousAtrousConvModule(self.final_channels, 2048)
        
        # Final prediction layer
        self.final_conv = nn.Conv2d(2048, num_classes, kernel_size=1)
        
        # Improved Skip Layer
        self.improved_skip = ImprovedSkipLayer(
            self.low_level_channels, 
            self.mid_level_channels, 
            self.high_level_channels, 
            num_classes
        )
        
        # Softmax classifier
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        # Original input size: 640×640×3
        print(f"Input: {x.shape}")
        
        # Extract features at different levels with MobileNetV2
        x_low = self.low_level_features(x)
        print(f"Low level features shape: {x_low.shape}")  # Should be around 160×160×24
        
        x_mid = self.mid_level_features(x_low)
        print(f"Mid level features shape: {x_mid.shape}")  # Should be around 80×80×32
        
        x_high = self.high_level_features(x_mid)
        print(f"High level features shape: {x_high.shape}")  # Should be around 40×40×96
        
        x_final = self.final_features(x_high)
        print(f"Final features shape: {x_final.shape}")  # Should be around 20×20×1280
        
        # Apply Continuous Atrous Convolution Module
        x_atrous = self.continuous_atrous(x_final)
        print(f"After continuous atrous: {x_atrous.shape}")  # Should be 1×1×2048
        
        # Final prediction
        x_pred = self.final_conv(x_atrous)
        print(f"After final conv: {x_pred.shape}")  # Should be 1×1×num_classes
        
        # Apply improved skip connections and upsampling
        output = self.improved_skip(x_low, x_mid, x_high, x_pred)
        print(f"After improved skip layer: {output.shape}")  # Should be 640×640×num_classes
        
        # Softmax classifier
        output = self.softmax(output)
        print(f"Final output shape: {output.shape}")  # Should be 640×640×num_classes
        
        return output

# Test the model with dummy input
if __name__ == "__main__":
    # Create dummy input
    dummy_input = torch.randn(1, 3, 640, 640)
    
    # Initialize model
    model = ImprovedFCN(num_classes=6)  # As per your implementation
    model.eval()
    
    # Run forward pass
    with torch.no_grad():
        output = model(dummy_input)
        
    print(f"Output shape: {output.shape}")  # Should be [1, 6, 640, 640]

Input: torch.Size([1, 3, 640, 640])
Low level features shape: torch.Size([1, 24, 160, 160])
Mid level features shape: torch.Size([1, 32, 80, 80])
High level features shape: torch.Size([1, 96, 40, 40])
Final features shape: torch.Size([1, 1280, 20, 20])
After continuous atrous: torch.Size([1, 2048, 1, 1])
After final conv: torch.Size([1, 6, 1, 1])
After improved skip layer: torch.Size([1, 6, 640, 640])
Final output shape: torch.Size([1, 6, 640, 640])
Output shape: torch.Size([1, 6, 640, 640])


**Mobilenet v2**

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.models import mobilenet_v2

# Set seed for reproducibility
torch.manual_seed(42)

class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        # Expansion phase (1x1 conv to increase channels)
        if expand_ratio != 1:
            layers.append(nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False))
            layers.append(nn.BatchNorm2d(hidden_dim))
            layers.append(nn.ReLU6(inplace=True))
        
        # Depthwise convolution
        layers.extend([
            # Depthwise convolution (3x3 with groups equal to input channels)
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6(inplace=True),
        ])
        
        # Projection phase (1x1 conv to decrease channels)
        layers.extend([
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        ])
        
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class ContinuousAtrousConvModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ContinuousAtrousConvModule, self).__init__()
        # Three parallel atrous convolutions with different rates [1, 2, 4]
        middle_channels = out_channels // 3
        self.atrous_conv1 = nn.Conv2d(in_channels, middle_channels, kernel_size=3, dilation=1, padding=1)
        self.atrous_conv2 = nn.Conv2d(in_channels, middle_channels, kernel_size=3, dilation=2, padding=2)
        self.atrous_conv3 = nn.Conv2d(in_channels, out_channels - 2*middle_channels, kernel_size=3, dilation=4, padding=4)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_channels)
        
        # Global average pooling to get 1×1×out_channels
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
    def forward(self, x):
        print(f"ContinuousAtrousConvModule input: {x.shape}")
        x1 = self.atrous_conv1(x)
        x2 = self.atrous_conv2(x)
        x3 = self.atrous_conv3(x)
        x = torch.cat([x1, x2, x3], dim=1)
        print(f"After atrous convolutions and concatenation: {x.shape}")
        x = self.bn(x)
        x = self.relu(x)
        x = self.global_pool(x)
        print(f"After global pooling: {x.shape}")
        return x

class ImprovedSkipLayer(nn.Module):
    def __init__(self, low_level_channels, mid_level_channels, high_level_channels, num_classes):
        super(ImprovedSkipLayer, self).__init__()
        
        # Convolutional layers for skip connections
        self.conv_low = nn.Conv2d(low_level_channels, num_classes, kernel_size=1)
        self.conv_mid = nn.Conv2d(mid_level_channels, num_classes, kernel_size=1)
        self.conv_high = nn.Conv2d(high_level_channels, num_classes, kernel_size=1)
        
    def forward(self, x_low, x_mid, x_high, x_upsampled):
        print(f"ImprovedSkipLayer inputs - low: {x_low.shape}, mid: {x_mid.shape}, high: {x_high.shape}, upsampled: {x_upsampled.shape}")
        
        # Process each feature level
        x_low_processed = self.conv_low(x_low)
        x_mid_processed = self.conv_mid(x_mid)
        x_high_processed = self.conv_high(x_high)
        
        print(f"After 1x1 convolutions - low: {x_low_processed.shape}, mid: {x_mid_processed.shape}, high: {x_high_processed.shape}")
        
        # 8x upsampling from 1×1×num_classes to match high-level features
        x_upsampled_8x = F.interpolate(x_upsampled, size=x_high_processed.shape[2:], 
                                      mode='bilinear', align_corners=False)
        print(f"After 8x upsampling: {x_upsampled_8x.shape}")
        
        # Fusion with high-level features
        high_fusion = x_high_processed + x_upsampled_8x
        print(f"After high fusion: {high_fusion.shape}")
        
        # 2x upsampling (from high to mid)
        high_fusion_upsampled = F.interpolate(high_fusion, size=x_mid_processed.shape[2:], 
                                            mode='bilinear', align_corners=False)
        print(f"After high to mid upsampling: {high_fusion_upsampled.shape}")
        
        # Fusion with mid-level features
        mid_fusion = x_mid_processed + high_fusion_upsampled
        print(f"After mid fusion: {mid_fusion.shape}")
        
        # 2x upsampling (from mid to low)
        mid_fusion_upsampled = F.interpolate(mid_fusion, size=x_low_processed.shape[2:], 
                                           mode='bilinear', align_corners=False)
        print(f"After mid to low upsampling: {mid_fusion_upsampled.shape}")
        
        # Fusion with low-level features
        final_fusion = x_low_processed + mid_fusion_upsampled
        print(f"After final fusion: {final_fusion.shape}")
        
        # Final 4x upsampling to original image size
        output = F.interpolate(final_fusion, size=(640, 640), mode='bilinear', align_corners=False)
        print(f"Final output: {output.shape}")
        
        return output

class ImprovedFCN_MobileNetV2(nn.Module):
    def __init__(self, num_classes=6):
        super(ImprovedFCN_MobileNetV2, self).__init__()
        
        # Load pre-trained MobileNetV2 as the backbone
        mobilenet = mobilenet_v2(pretrained=True)
        
        # Extract low, mid, and high-level features from MobileNetV2
        # Low-level features (after the first inverted residual block)
        self.low_level_features = nn.Sequential(
            mobilenet.features[0],  # Conv2d + BN + ReLU6
            mobilenet.features[1],  # InvertedResidual (t=1, c=16, n=1, s=1)
        )
        
        # Mid-level features (after the 3rd inverted residual block)
        self.mid_level_features = nn.Sequential(
            mobilenet.features[2],  # InvertedResidual (t=6, c=24, n=2, s=2)
            mobilenet.features[3],  # InvertedResidual
        )
        
        # High-level features (after the 6th inverted residual block)
        self.high_level_features = nn.Sequential(
            mobilenet.features[4],  # InvertedResidual (t=6, c=32, n=3, s=2)
            mobilenet.features[5],  # InvertedResidual
            mobilenet.features[6],  # InvertedResidual
            mobilenet.features[7],  # InvertedResidual (t=6, c=64, n=4, s=2)
            mobilenet.features[8],  # InvertedResidual
            mobilenet.features[9],  # InvertedResidual
            mobilenet.features[10], # InvertedResidual
        )
        
        # Deeper features for continuous atrous convolution
        self.deeper_features = nn.Sequential(
            mobilenet.features[11],  # InvertedResidual (t=6, c=96, n=3, s=1)
            mobilenet.features[12],  # InvertedResidual
            mobilenet.features[13],  # InvertedResidual
            mobilenet.features[14],  # InvertedResidual (t=6, c=160, n=3, s=2)
            mobilenet.features[15],  # InvertedResidual
            mobilenet.features[16],  # InvertedResidual
            mobilenet.features[17],  # InvertedResidual (t=6, c=320, n=1, s=1)
        )
        
        # Feature channels based on MobileNetV2 architecture
        low_level_channels = 16    # After first inverted residual
        mid_level_channels = 24    # After third inverted residual
        high_level_channels = 64   # After seventh inverted residual
        deeper_features_channels = 320  # After the final inverted residual

        
        
        # Continuous Atrous Convolution Module
        self.continuous_atrous = ContinuousAtrousConvModule(deeper_features_channels, 2048)
        
        # Final prediction layer
        self.final_conv = nn.Conv2d(2048, num_classes, kernel_size=1)
        
        # Improved Skip Layer
        self.improved_skip = ImprovedSkipLayer(low_level_channels, mid_level_channels, high_level_channels, num_classes)
        
        # Softmax classifier
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        # Initial shape should be B x 3 x 640 x 640
        print(f"Input: {x.shape}")
        
        # Extract features at different levels from MobileNetV2
        x_low = self.low_level_features(x)  # Low-level features
        print(f"Low-level features: {x_low.shape}")
        
        x_mid = self.mid_level_features(x_low)  # Mid-level features
        print(f"Mid-level features: {x_mid.shape}")
        
        x_high = self.high_level_features(x_mid)  # High-level features
        print(f"High-level features: {x_high.shape}")
        
        x_deep = self.deeper_features(x_high)  # Deeper features for atrous
        print(f"Deeper features: {x_deep.shape}")
        
        # Continuous Atrous Convolution Module
        x = self.continuous_atrous(x_deep)
        print(f"After continuous atrous: {x.shape}")
        
        # Final convolution for prediction
        x = self.final_conv(x)
        print(f"After final_conv: {x.shape}")
        
        # Apply improved skip connections and upsampling
        x = self.improved_skip(x_low, x_mid, x_high, x)
        print(f"After improved skip layer: {x.shape}")
        
        # Softmax classifier
        x = self.softmax(x)
        print(f"After softmax: {x.shape}")
        
        return x

# Create dummy input (numpy to tensor)
dummy_input_np = np.random.rand(1, 3, 640, 640).astype(np.float32)
dummy_input_tensor = torch.from_numpy(dummy_input_np)

# Initialize model
model = ImprovedFCN_MobileNetV2(num_classes=6)
model.eval()  # Set to eval mode

# Run forward pass
with torch.no_grad():
    output = model(dummy_input_tensor)

Input: torch.Size([1, 3, 640, 640])
Low-level features: torch.Size([1, 16, 320, 320])
Mid-level features: torch.Size([1, 24, 160, 160])
High-level features: torch.Size([1, 64, 40, 40])
Deeper features: torch.Size([1, 320, 20, 20])
ContinuousAtrousConvModule input: torch.Size([1, 320, 20, 20])
After atrous convolutions and concatenation: torch.Size([1, 2048, 20, 20])
After global pooling: torch.Size([1, 2048, 1, 1])
After continuous atrous: torch.Size([1, 2048, 1, 1])
After final_conv: torch.Size([1, 6, 1, 1])
ImprovedSkipLayer inputs - low: torch.Size([1, 16, 320, 320]), mid: torch.Size([1, 24, 160, 160]), high: torch.Size([1, 64, 40, 40]), upsampled: torch.Size([1, 6, 1, 1])
After 1x1 convolutions - low: torch.Size([1, 6, 320, 320]), mid: torch.Size([1, 6, 160, 160]), high: torch.Size([1, 6, 40, 40])
After 8x upsampling: torch.Size([1, 6, 40, 40])
After high fusion: torch.Size([1, 6, 40, 40])
After high to mid upsampling: torch.Size([1, 6, 160, 160])
After mid fusion: torch.Size([1, 6

**Mobilenet v2 final**

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.models import mobilenet_v2

# Set seed for reproducibility
torch.manual_seed(42)

class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        # Expansion phase (1x1 conv to increase channels)
        if expand_ratio != 1:
            layers.append(nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False))
            layers.append(nn.BatchNorm2d(hidden_dim))
            layers.append(nn.ReLU6(inplace=True))
        
        # Depthwise convolution
        layers.extend([
            # Depthwise convolution (3x3 with groups equal to input channels)
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6(inplace=True),
        ])
        
        # Projection phase (1x1 conv to decrease channels)
        layers.extend([
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        ])
        
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class ContinuousAtrousConvModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ContinuousAtrousConvModule, self).__init__()
        # Three parallel atrous convolutions with different rates [1, 2, 4]
        middle_channels = out_channels // 3
        self.atrous_conv1 = nn.Conv2d(in_channels, middle_channels, kernel_size=3, dilation=1, padding=1)
        self.atrous_conv2 = nn.Conv2d(in_channels, middle_channels, kernel_size=3, dilation=2, padding=2)
        self.atrous_conv3 = nn.Conv2d(in_channels, out_channels - 2*middle_channels, kernel_size=3, dilation=4, padding=4)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_channels)
        
        # Global average pooling to get 1×1×out_channels
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
    def forward(self, x):
        print(f"ContinuousAtrousConvModule input: {x.shape}")
        x1 = self.atrous_conv1(x)
        x2 = self.atrous_conv2(x)
        x3 = self.atrous_conv3(x)
        x = torch.cat([x1, x2, x3], dim=1)
        print(f"After atrous convolutions and concatenation: {x.shape}")
        x = self.bn(x)
        x = self.relu(x)
        x = self.global_pool(x)
        print(f"After global pooling: {x.shape}")
        return x

class ImprovedSkipLayer(nn.Module):
    def __init__(self, low_level_channels, mid_level_channels, high_level_channels, num_classes):
        super(ImprovedSkipLayer, self).__init__()
        
        # Convolutional layers for skip connections
        self.conv_low = nn.Conv2d(low_level_channels, num_classes, kernel_size=1)
        self.conv_mid = nn.Conv2d(mid_level_channels, num_classes, kernel_size=1)
        self.conv_high = nn.Conv2d(high_level_channels, num_classes, kernel_size=1)
        
    def forward(self, x_low, x_mid, x_high, x_upsampled):
        print(f"ImprovedSkipLayer inputs - low: {x_low.shape}, mid: {x_mid.shape}, high: {x_high.shape}, upsampled: {x_upsampled.shape}")
        
        # Process each feature level
        x_low_processed = self.conv_low(x_low)
        x_mid_processed = self.conv_mid(x_mid)
        x_high_processed = self.conv_high(x_high)
        
        print(f"After 1x1 convolutions - low: {x_low_processed.shape}, mid: {x_mid_processed.shape}, high: {x_high_processed.shape}")
        
        # 8x upsampling from 1×1×num_classes to match high-level features
        x_upsampled_8x = F.interpolate(x_upsampled, size=x_high_processed.shape[2:], 
                                      mode='bilinear', align_corners=False)
        print(f"After 8x upsampling: {x_upsampled_8x.shape}")
        
        # Fusion with high-level features
        high_fusion = x_high_processed + x_upsampled_8x
        print(f"After high fusion: {high_fusion.shape}")
        
        # 2x upsampling (from high to mid)
        high_fusion_upsampled = F.interpolate(high_fusion, size=x_mid_processed.shape[2:], 
                                            mode='bilinear', align_corners=False)
        print(f"After high to mid upsampling: {high_fusion_upsampled.shape}")
        
        # Fusion with mid-level features
        mid_fusion = x_mid_processed + high_fusion_upsampled
        print(f"After mid fusion: {mid_fusion.shape}")
        
        # 2x upsampling (from mid to low)
        mid_fusion_upsampled = F.interpolate(mid_fusion, size=x_low_processed.shape[2:], 
                                           mode='bilinear', align_corners=False)
        print(f"After mid to low upsampling: {mid_fusion_upsampled.shape}")
        
        # Fusion with low-level features
        final_fusion = x_low_processed + mid_fusion_upsampled
        print(f"After final fusion: {final_fusion.shape}")
        
        # Final 4x upsampling to original image size
        output = F.interpolate(final_fusion, size=(640, 640), mode='bilinear', align_corners=False)
        print(f"Final output: {output.shape}")
        
        return output

class ImprovedFCN_MobileNetV2(nn.Module):
    def __init__(self, num_classes=6):
        super(ImprovedFCN_MobileNetV2, self).__init__()
        
        # Load pre-trained MobileNetV2 as the backbone
        mobilenet = mobilenet_v2(pretrained=True)
        
        # Extract low, mid, and high-level features from MobileNetV2
        # Low-level features (after the first inverted residual block)
        self.low_level_features = nn.Sequential(
            mobilenet.features[0],  # Conv2d + BN + ReLU6
            mobilenet.features[1],  # InvertedResidual (t=1, c=16, n=1, s=1)
        )
        
        # Mid-level features (after the 3rd inverted residual block)
        self.mid_level_features = nn.Sequential(
            mobilenet.features[2],  # InvertedResidual (t=6, c=24, n=2, s=2)
            mobilenet.features[3],  # InvertedResidual
        )
        
        # High-level features (after the 6th inverted residual block)
        self.high_level_features = nn.Sequential(
            mobilenet.features[4],  # InvertedResidual (t=6, c=32, n=3, s=2)
            mobilenet.features[5],  # InvertedResidual
            mobilenet.features[6],  # InvertedResidual
            mobilenet.features[7],  # InvertedResidual (t=6, c=64, n=4, s=2)
            mobilenet.features[8],  # InvertedResidual
            mobilenet.features[9],  # InvertedResidual
            mobilenet.features[10], # InvertedResidual
        )
        
        # Deeper features for continuous atrous convolution
        self.deeper_features = nn.Sequential(
            mobilenet.features[11],  # InvertedResidual (t=6, c=96, n=3, s=1)
            mobilenet.features[12],  # InvertedResidual
            mobilenet.features[13],  # InvertedResidual
            mobilenet.features[14],  # InvertedResidual (t=6, c=160, n=3, s=2)
            mobilenet.features[15],  # InvertedResidual
            mobilenet.features[16],  # InvertedResidual
            mobilenet.features[17],  # InvertedResidual (t=6, c=320, n=1, s=1)
        )
        
        # Feature channels based on MobileNetV2 architecture
        low_level_channels = 16    # After first inverted residual
        mid_level_channels = 24    # After third inverted residual
        high_level_channels = 64   # After seventh inverted residual
        deeper_features_channels = 320  # After the final inverted residual
        
        # Add channel expansion layers before continuous atrous module
        self.channel_expansion = nn.Conv2d(320, 1024, kernel_size=1, bias=False)
        self.bn_expansion = nn.BatchNorm2d(1024)
        self.relu_expansion = nn.ReLU6(inplace=True)
        
        # Continuous Atrous Convolution Module
        self.continuous_atrous = ContinuousAtrousConvModule(1024, 2048)
        
        # Final prediction layer
        self.final_conv = nn.Conv2d(2048, num_classes, kernel_size=1)
        
        # Improved Skip Layer
        self.improved_skip = ImprovedSkipLayer(low_level_channels, mid_level_channels, high_level_channels, num_classes)
        
        # Softmax classifier
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        # Initial shape should be B x 3 x 640 x 640
        print(f"Input: {x.shape}")
        
        # Extract features at different levels from MobileNetV2
        x_low = self.low_level_features(x)  # Low-level features
        print(f"Low-level features: {x_low.shape}")
        
        x_mid = self.mid_level_features(x_low)  # Mid-level features
        print(f"Mid-level features: {x_mid.shape}")
        
        x_high = self.high_level_features(x_mid)  # High-level features
        print(f"High-level features: {x_high.shape}")
        
        x_deep = self.deeper_features(x_high)  # Deeper features for atrous
        print(f"Deeper features: {x_deep.shape}")
        
        # Apply channel expansion
        x_expanded = self.relu_expansion(self.bn_expansion(self.channel_expansion(x_deep)))
        print(f"After channel expansion: {x_expanded.shape}")
        
        # Continuous Atrous Convolution Module
        x = self.continuous_atrous(x_expanded)
        print(f"After continuous atrous: {x.shape}")
        
        # Final convolution for prediction
        x = self.final_conv(x)
        print(f"After final_conv: {x.shape}")
        
        # Apply improved skip connections and upsampling
        x = self.improved_skip(x_low, x_mid, x_high, x)
        print(f"After improved skip layer: {x.shape}")
        
        # Softmax classifier
        x = self.softmax(x)
        print(f"After softmax: {x.shape}")
        
        return x

# Create dummy input (numpy to tensor)
dummy_input_np = np.random.rand(1, 3, 640, 640).astype(np.float32)
dummy_input_tensor = torch.from_numpy(dummy_input_np)

# Initialize model
model = ImprovedFCN_MobileNetV2(num_classes=6)
model.eval()  # Set to eval mode

# Run forward pass
with torch.no_grad():
    output = model(dummy_input_tensor)

Input: torch.Size([1, 3, 640, 640])
Low-level features: torch.Size([1, 16, 320, 320])
Mid-level features: torch.Size([1, 24, 160, 160])
High-level features: torch.Size([1, 64, 40, 40])
Deeper features: torch.Size([1, 320, 20, 20])
After channel expansion: torch.Size([1, 1024, 20, 20])
ContinuousAtrousConvModule input: torch.Size([1, 1024, 20, 20])
After atrous convolutions and concatenation: torch.Size([1, 2048, 20, 20])
After global pooling: torch.Size([1, 2048, 1, 1])
After continuous atrous: torch.Size([1, 2048, 1, 1])
After final_conv: torch.Size([1, 6, 1, 1])
ImprovedSkipLayer inputs - low: torch.Size([1, 16, 320, 320]), mid: torch.Size([1, 24, 160, 160]), high: torch.Size([1, 64, 40, 40]), upsampled: torch.Size([1, 6, 1, 1])
After 1x1 convolutions - low: torch.Size([1, 6, 320, 320]), mid: torch.Size([1, 6, 160, 160]), high: torch.Size([1, 6, 40, 40])
After 8x upsampling: torch.Size([1, 6, 40, 40])
After high fusion: torch.Size([1, 6, 40, 40])
After high to mid upsampling: torch.S