In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models.convnext import convnext_tiny

class ConvNeXt_UNet(nn.Module):
    def __init__(self, num_classes=1, backbone_type='tiny'):
        super(ConvNeXt_UNet, self).__init__()
        
        if backbone_type == 'tiny':
            convnext = convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
            self.feature_channels = [96, 192, 384, 768]
        elif backbone_type == 'small':
            convnext = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1)
            self.feature_channels = [96, 192, 384, 768]
        elif backbone_type == 'base':
            convnext = models.convnext_base(weights=models.ConvNeXt_Base_Weights.IMAGENET1K_V1)

            self.feature_channels = [128, 256, 512, 1024]
        elif backbone_type == 'large':
            convnext = models.convnext_large(weights=models.ConvNeXt_Large_Weights.IMAGENET1K_V1)
            self.feature_channels = [192, 384, 768, 1536]
        else:
            raise ValueError(f"Unsupported backbone type: {backbone_type}")
        
        self.stem = convnext.features[0]                   # Initial conv + norm
        self.downsample1 = convnext.features[1]            # Downsample to 1/4
        self.stage1 = convnext.features[2]                 # Stage 1
        self.downsample2 = convnext.features[3]            # Downsample to 1/8
        self.stage2 = convnext.features[4]                 # Stage 2
        self.downsample3 = convnext.features[5]            # Downsample to 1/16
        self.stage3 = convnext.features[6]                 # Stage 3
        self.downsample4 = convnext.features[7]            # Downsample to 1/32
        self.stage4 = convnext.features[8]                 # Stage 4
        
        # Decoder path - adapted for ConvNeXt dimensions
        self.upconv1 = self.upconv(self.feature_channels[3], self.feature_channels[2])
        self.upconv2 = self.upconv(self.feature_channels[2] * 2, self.feature_channels[1])
        self.upconv3 = self.upconv(self.feature_channels[1] * 2, self.feature_channels[0])
        self.upconv4 = self.upconv(self.feature_channels[0] * 2, self.feature_channels[0] // 2)
        
        # Final layers
        self.final_upsample = nn.ConvTranspose2d(self.feature_channels[0] // 2, 32, kernel_size=2, stride=2)
        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)
    
    def upconv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder path with ConvNeXt
        x0 = self.stem(x)                  # 1/4 resolution
        
        x1 = self.downsample1(x0)          # Already at 1/4 from stem, this is identity
        x1 = self.stage1(x1)               # Stage 1 features
        
        x2 = self.downsample2(x1)          # 1/8 resolution
        x2 = self.stage2(x2)               # Stage 2 features
        
        x3 = self.downsample3(x2)          # 1/16 resolution
        x3 = self.stage3(x3)               # Stage 3 features
        
        x4 = self.downsample4(x3)          # 1/32 resolution
        x4 = self.stage4(x4)               # Stage 4 features (bottleneck)
        
        # Decoder path with skip connections
        d1 = self.upconv1(x4)              # Upsampled from bottleneck
        d1 = torch.cat([d1, x3], dim=1)    # Skip connection with stage 3
        
        d2 = self.upconv2(d1)              # Upsampled
        d2 = torch.cat([d2, x2], dim=1)    # Skip connection with stage 2
        
        d3 = self.upconv3(d2)              # Upsampled
        d3 = torch.cat([d3, x1], dim=1)    # Skip connection with stage 1
        
        d4 = self.upconv4(d3)              # Upsampled
        
        d5 = self.final_upsample(d4)       # Final upsampling to original resolution
        out = self.final_conv(d5)          # Final 1x1 convolution
        
        return out

# Test the ConvNeXt Tiny UNet model
convnext_tiny_unet = ConvNeXt_UNet(num_classes=1, backbone_type='tiny')
x = torch.randn(1, 3, 256, 256)
output = convnext_tiny_unet(x)
print("ConvNeXt-Tiny UNet output shape:", output.shape)

