In [1]:
from timm.models.resnet import resnet101

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Literal


class ConvBlock(nn.Module):
    
    """A simple convolutional block followed by batch normalization and ReLU activation."""
    def __init__(self, 
        in_channels: int, 
        out_channels: int, 
        kernel_size: int=3, 
        stride: int=1, 
        padding: int=1, 
        batch_norm: int=True, 
        activation: Literal['relu', 'leaky_relu', 'sigmoid', 'softmax', 'tanh', 'swish']='relu',
    ):
        super(ConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.batch_norm = batch_norm
        self.activation = activation
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        if batch_norm:
            self.bn = nn.BatchNorm2d(out_channels)
            
        if activation is not None:
            if activation == 'relu':
                self.act = nn.ReLU(inplace=True)
            elif activation == 'gelu':
                self.act = nn.GELU()
            elif activation == 'leaky_relu':
                self.act = nn.LeakyReLU(inplace=True)
            elif activation == 'sigmoid':
                self.act = nn.Sigmoid()
            elif activation == 'softmax':
                self.act = nn.Softmax(dim=1)
            elif activation == 'tanh':
                self.act = nn.Tanh()
            elif activation == 'swish':
                self.act = nn.SiLU()
            else:
                raise ValueError(f'Invalid value for `activation`: {activation}. Supported values are ["relu", "leaky_relu", "sigmoid", "softmax", "tanh"].')
    
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        if self.batch_norm:
            x = self.bn(x)
        
        if self.activation is not None:
            x = self.act(x)
        
        return x


class ASPP(nn.Module):
    """Atrous Spatial Pyramid Pooling as in DeepLab v3+."""
    def __init__(self, in_channels: int, out_channels: int, dilation_rates: tuple[int, ...]) -> None:
        super().__init__()
        # 1Ã—1 conv branch
        self.conv_1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        # parallel atrous conv branches
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3,
                          padding=rate, dilation=rate, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )
            for rate in dilation_rates
        ])
        # image-level pooling branch
        self.image_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        # combine & project
        self.project = nn.Sequential(
            nn.Conv2d(out_channels * (2 + len(dilation_rates)), out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        size = x.shape[-2:]
        feats = [self.conv_1x1(x)] + [branch(x) for branch in self.branches]
        # image-level features
        img_feat = self.image_pool(x)
        img_feat = nn.functional.interpolate(img_feat, size=size, mode="bilinear", align_corners=False)
        feats.append(img_feat)
        x = torch.cat(feats, dim=1)
        return self.project(x)


class DepthwiseSeparableConv(nn.Module):
    """Depthwise separable convolution: depthwise conv + pointwise conv."""
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
        dilation: int = 1,
        bias: bool = False,
    ) -> None:
        super().__init__()
        # Depthwise: groups=in_channels
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, kernel_size,
            stride, padding, dilation, groups=in_channels, bias=bias
        )
        # Pointwise: 1x1 convolution to mix channels
        self.pointwise = nn.Conv2d(
            in_channels, out_channels, kernel_size=1, bias=bias
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.bn(x)
        return self.relu(x)


class Decoder(nn.Module):
    """DeepLab v3+ decoder that fuses low- and high-level features."""
    def __init__(self, low_level_in: int, low_level_out: int, num_classes: int) -> None:
        super().__init__()
        # Reduce low-level feature channels to low_level_out (e.g. 48)
        self.reduce_low = ConvBlock(low_level_in, low_level_out, kernel_size=1, padding=0, batch_norm=True, activation='relu')
        # Two separable conv layers to refine concatenated features
        self.refine = nn.Sequential(
            DepthwiseSeparableConv(low_level_out + 256, 256, kernel_size=3, padding=1),
            DepthwiseSeparableConv(256, 256, kernel_size=3, padding=1),
        )
        # Final classifier
        self.classifier = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, low_level_feat: torch.Tensor, high_level_feat: torch.Tensor) -> torch.Tensor:
        # Upsample ASPP output by factor 4
        high = nn.functional.interpolate(high_level_feat, size=low_level_feat.shape[-2:], mode="bilinear", align_corners=False)
        low = self.reduce_low(low_level_feat)
        x = torch.cat([low, high], dim=1)
        x = self.refine(x)
        return self.classifier(x)


class DeepLabV3Plus(nn.Module):
    """
    DeepLab v3+ for semantic segmentation.
    - backbone: module returning (low_level_feat, high_level_feat)
    - num_classes: # of segmentation classes
    - aspp_rates: dilation rates for ASPP
    """
    def __init__(
        self,
        backbone: nn.Module,
        num_classes: int,
        aspp_out: int = 256,
        aspp_rates: tuple[int, ...] = (12, 24, 36),
    ) -> None:
        super().__init__()
        self.backbone = backbone
        # ASPP on high-level features
        self.aspp = ASPP(in_channels=2048, out_channels=aspp_out, dilation_rates=aspp_rates)
        # Decoder fusing ASPP and low-level (conv2) features
        self.decoder = Decoder(low_level_in=256, low_level_out=48, num_classes=num_classes)
        
        if num_classes == 1:
            self.activation = nn.Sigmoid()
        else:
            self.activation = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        low_level, high_level = self.backbone(x)
        x = self.aspp(high_level)
        x = self.decoder(low_level, x)
        # Final upsample to input resolution
        x = nn.functional.interpolate(x, size=x.shape[-2]*4, mode="bilinear", align_corners=False)
        return self.activation(x)

In [32]:
class ResNetBackbone(nn.Module):
    """
    Wraps a ResNet-101 to output (low_level_feat, high_level_feat).
    output_stride=16: remove stride in layer4; stride=8: also in layer3.
    """
    def __init__(self, output_stride: int = 16, pretrained: bool = True, in_channels=4) -> None:
        super().__init__()
        
        if isinstance(pretrained, bool):
            # resnet = resnet152(weights=ResNet152_Weights.DEFAULT if pretrained else None)
            resnet = resnet101(pretrained=pretrained)

            if in_channels != 3:
                resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        else:
            # resnet = resnet152()
            resnet = resnet101()
            if in_channels != 3:
                resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            if pretrained is not None:
                resnet.load_state_dict(load_pth(pretrained), strict=True)
            
        # Modify strides/dilations for atrous convolution
        if output_stride == 16:
            resnet.layer4[0].conv2.stride = (1, 1)
            resnet.layer4[0].downsample[0].stride = (1, 1)
            for block in resnet.layer4:
                block.conv2.dilation = (2, 2)
                block.conv2.padding = (2, 2)
        elif output_stride == 8:
            for layer in [resnet.layer3, resnet.layer4]:
                layer[0].conv2.stride = (1, 1)
                layer[0].downsample[0].stride = (1, 1)
                for block in layer:
                    block.conv2.dilation = (2 if layer is resnet.layer4 else 4,)*2
                    block.conv2.padding = (2 if layer is resnet.layer4 else 4,)*2
        
        self.low_level_features = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.act1, resnet.maxpool, resnet.layer1
        )
        self.high_level_features = nn.Sequential(
            resnet.layer2, resnet.layer3, resnet.layer4
        )

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        low_level_feat = self.low_level_features(x)
        return low_level_feat, self.high_level_features(low_level_feat)

In [33]:
model = DeepLabV3Plus(
    backbone=ResNetBackbone(output_stride=16, pretrained=True, in_channels=4),
    num_classes=5,
)

In [34]:
x = torch.randn(2, 4, 256, 256)
y = model(x)
y.shape

torch.Size([2, 5, 256, 256])

In [None]:
class UNetDownBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(UNetDownBlock, self).__init__()
        self.conv = nn.Sequential(
            ConvBlock(in_channels, out_channels, batch_norm=True, activation='relu'),
            ConvBlock(out_channels, out_channels, batch_norm=True, activation='relu')
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        skip_connection = self.conv(x)
        pooled_output = self.pool(skip_connection)
        return pooled_output, skip_connection

class UNetUpBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(UNetUpBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.conv = nn.Sequential(
            ConvBlock(in_channels + out_channels, out_channels, batch_norm=True, activation='relu'),
            ConvBlock(out_channels, out_channels, batch_norm=True, activation='relu')
        )

    def forward(self, x: torch.Tensor, skip_connection: torch.Tensor) -> torch.Tensor:
        x_upsampled = F.interpolate(x, size=skip_connection.shape[2:], mode='bilinear', align_corners=False)
        x_concat = torch.cat((skip_connection, x_upsampled), dim=1)
        return self.conv(x_concat)

class UNet(nn.Module):
    def __init__(self, 
        in_channels: int=4, 
        num_classes: int=8, 
        channel_widths: list[int]=[64, 128, 256, 512, 1024],
        activation: bool=True
    ):
        
        super(UNet, self).__init__()
        
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.channel_widths = channel_widths

        self.encoder_blocks = nn.ModuleList([
            UNetDownBlock(in_channels if i == 0 else channel_widths[i-1], channel_widths[i]) for i in range(4)
        ])
        
        self.bottleneck = nn.Sequential(
            ConvBlock(channel_widths[3], channel_widths[4], batch_norm=True, activation='relu'),
            ConvBlock(channel_widths[4], channel_widths[4], batch_norm=True, activation='relu')
        )
        
        
        self.decoder_blocks = nn.ModuleList([
            UNetUpBlock(channel_widths[i], channel_widths[i-1]) for i in range(4, 0, -1)
        ])
        
        self.classifier = nn.Conv2d(channel_widths[0], num_classes, kernel_size=1)
        if num_classes == 1 and activation:
            self.activation = nn.Sigmoid()
        elif activation:
            self.activation = nn.Softmax(dim=1)
        else:
            self.activation = nn.Identity()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        skip_connections = []
        for block in self.encoder_blocks:
            x, s = block(x)
            skip_connections.append(s)
        
        x = self.bottleneck(x)
        
        for block in self.decoder_blocks:
            x = block(x, skip_connections.pop()) 
        
        x = self.classifier(x)
        return self.activation(x)

unet = UNet(in_channels=4, num_classes=5)
x = torch.randn(2, 4, 256, 256)
y_unet = unet(x)
y_unet.shape

torch.Size([2, 5, 256, 256])

In [62]:
total_params = sum(p.numel() for p in unet.parameters())
print(f'Total parameters in UNet: {total_params}')

Total parameters in UNet: 31385669


In [63]:
from segmentation_models_pytorch import Segformer

In [64]:
model = Segformer(
    encoder_name="mit_b5"
)

config.json:   0%|          | 0.00/135 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/328M [00:00<?, ?B/s]

In [67]:
total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters in Segformer mit_b5: {total_params}')

Total parameters in Segformer mit_b5: 81969089


In [2]:
import geopandas as gpd

gdf = gpd.read_parquet('../data/cpb_lc/samples.par')

In [4]:
gdf.dtypes

geometry     geometry
id             object
year            int64
state          object
naip_path      object
s2_path        object
lc_path        object
split          object
fold            int64
dtype: object