In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
class DoubleConv(nn.Module):
    """(Conv2d => BN => ReLU) * 2"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


In [3]:


class GlacierNet(nn.Module):
    """
    Custom U-Net for 5-Channel Satellite Imagery.
    Lightweight and Fast.
    """
    def __init__(self, n_channels=5, n_classes=4):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        # --- Encoder ---
        self.inc = DoubleConv(n_channels, 32)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(32, 64))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256))
        
        # --- Bottleneck ---
        self.bot = nn.Sequential(nn.MaxPool2d(2), DoubleConv(256, 512))

        # --- Decoder ---
        self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(512, 256) # 256 from up + 256 from skip
        
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(256, 128)
        
        self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(128, 64)
        
        self.up4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.conv4 = DoubleConv(64, 32)
        
        # --- Output ---
        self.outc = nn.Conv2d(32, n_classes, kernel_size=1)

    def forward(self, x):
        # Down
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.bot(x4)
        
        # Up with Skip Connections
        x = self.up1(x5)
        # Resize to handle potential rounding errors in pooling
        if x.shape != x4.shape: 
            x = F.interpolate(x, size=x4.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x4, x], dim=1)
        x = self.conv1(x)
        
        x = self.up2(x)
        if x.shape != x3.shape: 
            x = F.interpolate(x, size=x3.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x3, x], dim=1)
        x = self.conv2(x)
        
        x = self.up3(x)
        if x.shape != x2.shape: 
            x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x2, x], dim=1)
        x = self.conv3(x)
        
        x = self.up4(x)
        if x.shape != x1.shape: 
            x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x1, x], dim=1)
        x = self.conv4(x)
        
        logits = self.outc(x)
        return logits

if __name__ == "__main__":
    # Sanity Check
    model = GlacierNet(n_channels=5, n_classes=4)
    # Count params
    print(f" GlacierNet Initialized. Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Test with dummy input
    dummy_input = torch.randn(1, 5, 512, 512)
    output = model(dummy_input)
    print(f" Input shape: {dummy_input.shape} -> Output shape: {output.shape}")

 GlacierNet Initialized. Parameters: 7,763,716
 Input shape: torch.Size([1, 5, 512, 512]) -> Output shape: torch.Size([1, 4, 512, 512])
