In [27]:
import segmentation_models_pytorch as smp
import torch

EPS = 1e-10


class SKeMaModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.model = smp.Unet(
            encoder_name="tu-maxvit_tiny_tf_512",
            in_channels=10,
            encoder_weights=None,
        )

        self.register_buffer(
            "per_channel_mean",
            torch.tensor([
                1.93357159e02,
                2.53693333e02,
                1.41648022e02,
                9.99292362e02,
                3.21693919e02,
                6.49704998e-02,
                1.57273007e-01,
                -1.57273007e-01,
                1.82229161e07,
                1.09806622e-01,
            ]).view(1, -1, 1, 1),
        )

        self.register_buffer(
            "per_channel_std",
            torch.tensor([
                1.55697494e02,
                2.12700364e02,
                2.04018106e02,
                1.27588129e03,
                3.77324432e02,
                6.75251176e-01,
                7.32966188e-01,
                7.32966188e-01,
                2.16768826e10,
                4.11232123e-01,
            ]).view(1, -1, 1, 1),
        )

    def forward(self, x):
        # Unpack spectral bands
        blue = x.select(1, 0).unsqueeze(1)
        green = x.select(1, 1).unsqueeze(1)
        red = x.select(1, 2).unsqueeze(1)
        nir = x.select(1, 3).unsqueeze(1)
        re = x.select(1, 4).unsqueeze(1)

        # Compute vegetation indices
        ndvi = self.normalized_index(nir, red)
        gndvi = self.normalized_index(nir, green)
        ndvi_re = self.normalized_index(re, red)

        # Compute other indices
        ndwi = self.normalized_index(green, nir)
        chl_green = (nir / (green + EPS)) - 1  # Chlorophyll Index Green

        # Stack all bands and indices
        x_aug = torch.cat([blue, green, red, nir, re, ndvi, ndwi, gndvi, chl_green, ndvi_re], dim=1)

        x_aug_normalized = (x_aug - self.per_channel_mean) / self.per_channel_std

        return self.model(x_aug_normalized)

    @staticmethod
    def normalized_index(a, b):
        return (a - b) / (a + b + EPS)


model = SKeMaModel()

sample_input = torch.rand((2, 5, 512, 512), device=torch.device("cpu"), requires_grad=False)
model(sample_input)

tensor([[[[ 0.5728,  0.9021, -0.5201,  ...,  0.7888,  1.0070,  0.4125],
          [ 0.1453,  0.0023, -0.9450,  ..., -0.1469, -0.0198, -0.0842],
          [-0.1167, -0.2070, -1.2092,  ..., -0.7366, -0.5809, -0.3558],
          ...,
          [-0.2290, -1.1171, -1.8872,  ..., -0.5359,  0.1297,  0.2123],
          [-0.5878, -0.3411, -0.9095,  ..., -0.1793, -0.2207,  0.3550],
          [-0.0072,  0.3908,  0.1005,  ..., -0.0839,  0.1583, -0.1113]]],


        [[[ 0.4041,  0.8689,  0.6696,  ...,  0.1934,  0.0745,  0.3263],
          [ 0.5460,  0.2423,  0.4991,  ...,  0.1760, -0.1802,  0.0019],
          [-0.0688, -0.1781, -0.0824,  ..., -0.0106, -0.1722, -0.1696],
          ...,
          [-0.6502, -0.3678, -0.9837,  ..., -0.8543,  0.7018, -0.1046],
          [-0.3223,  0.3278, -0.7628,  ..., -0.6914,  0.4229,  0.1085],
          [-0.1094, -0.0042,  0.0209,  ..., -1.1350, -0.1796, -0.1817]]]],
       grad_fn=<ConvolutionBackward0>)

In [28]:
ckpt = torch.load("./Unet_tu-maxvit_tiny_tf_512_20250818_164043.ckpt", map_location="cpu")
state_dict = ckpt["state_dict"]

# Update keys
del state_dict["mean"]
del state_dict["std"]
model.load_state_dict(state_dict, strict=False)
model.eval()

SKeMaModel(
  (model): Unet(
    (encoder): TimmUniversalEncoder(
      (model): FeatureListNet(
        (stem): Stem(
          (conv1): Conv2dSame(10, 64, kernel_size=(3, 3), stride=(2, 2))
          (norm1): BatchNormAct2d(
            64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELUTanh()
          )
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (stages_0): MaxxVitStage(
          (blocks): Sequential(
            (0): MaxxVitBlock(
              (conv): MbConvBlock(
                (shortcut): Downsample2d(
                  (pool): AvgPool2dSame(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0))
                  (expand): Identity()
                )
                (pre_norm): BatchNormAct2d(
                  64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
                  (drop): Identity()
                  (act): Identity()
  

In [40]:
torch.onnx.export(
    model,
    sample_input,
    "./Unet_tu-maxvit_tiny_tf_512_20250818_164043.onnx",
    input_names=["input"],
    output_names=["output"],
    export_params=True,
    external_data=False,  # Store model weights in the model file
    opset_version=15,  # ONNX opset version
    do_constant_folding=True,  # Optimize constants
    verbose=False,
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    # dynamic_shapes={"x": (torch.export.Dim("batch"), 5, 512, 512)},
    dynamo=False,
)

  torch.onnx.export(
  assert condition, message
