In [3]:
import torch
import torch.nn as nn
import torchvision.models as models

class ResNeXt50_UNet(nn.Module):
    def __init__(self, num_classes=1):
        super(ResNeXt50_UNet, self).__init__()

        resnext = models.resnext50_32x4d(weights=models.ResNeXt50_32X4D_Weights.IMAGENET1K_V1)
     
        self.encoder = nn.ModuleDict({
            "conv1": nn.Sequential(resnext.conv1, resnext.bn1, resnext.relu),
            "maxpool": resnext.maxpool,
            "layer1": resnext.layer1,  # 256 channels
            "layer2": resnext.layer2,  # 512 channels
            "layer3": resnext.layer3,  # 1024 channels
            "layer4": resnext.layer4   # 2048 channels
        })

        self.upconv1 = self.upconv(2048, 1024)
        self.upconv2 = self.upconv(2048, 512)  # 1024 + 1024 input channels (including skip)
        self.upconv3 = self.upconv(1024, 256)  # 512 + 512 input channels (including skip)
        self.upconv4 = self.upconv(512, 64)    # 256 + 256 input channels (including skip)

        # Final layers
        self.final_upsample = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)

    def upconv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder path
        x1 = self.encoder['conv1'](x)          # First encoding layer
        x2 = self.encoder['maxpool'](x1)       
        x3 = self.encoder['layer1'](x2)        # layer1
        x4 = self.encoder['layer2'](x3)        # layer2
        x5 = self.encoder['layer3'](x4)        # layer3
        x6 = self.encoder['layer4'](x5)        # bottleneck

        # Decoder path with skip connections
        d1 = self.upconv1(x6)                  # Upsampled bottleneck
        d1 = torch.cat([d1, x5], dim=1)        # Skip connection with layer3
        
        d2 = self.upconv2(d1)                  # Upsampled d1
        d2 = torch.cat([d2, x4], dim=1)        # Skip connection with layer2
        
        d3 = self.upconv3(d2)                  # Upsampled d2
        d3 = torch.cat([d3, x3], dim=1)        # Skip connection with layer1
        
        d4 = self.upconv4(d3)                  # Upsampled d3
        
        d5 = self.final_upsample(d4)           # Final upsampling

        out = self.final_conv(d5)              # Final 1x1 convolution
        return out

# Test the new ResNeXt50 UNet model
x = torch.randn(1, 3, 256, 256)
resnext_model = ResNeXt50_UNet(num_classes=1)
resnext_output = resnext_model(x)
print("ResNeXt50_UNet output shape:", resnext_output.shape)



ResNeXt50_UNet output shape: torch.Size([1, 1, 256, 256])
