## Export S2-only model

In [1]:
import segmentation_models_pytorch as smp
import torch

EPS = 1e-10


class SKeMaModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.model = smp.Unet(
            encoder_name="tu-maxvit_tiny_tf_512",
            in_channels=10,
            encoder_weights=None,
        )

        self.register_buffer(
            "per_channel_mean",
            torch.tensor([
                1.93357159e02,
                2.53693333e02,
                1.41648022e02,
                9.99292362e02,
                3.21693919e02,
                6.49704998e-02,
                1.57273007e-01,
                -1.57273007e-01,
                1.82229161e07,
                1.09806622e-01,
            ]).view(1, -1, 1, 1),
        )

        self.register_buffer(
            "per_channel_std",
            torch.tensor([
                1.55697494e02,
                2.12700364e02,
                2.04018106e02,
                1.27588129e03,
                3.77324432e02,
                6.75251176e-01,
                7.32966188e-01,
                7.32966188e-01,
                2.16768826e10,
                4.11232123e-01,
            ]).view(1, -1, 1, 1),
        )

    def forward(self, x):
        # Unpack spectral bands
        blue = x.select(1, 0).unsqueeze(1)
        green = x.select(1, 1).unsqueeze(1)
        red = x.select(1, 2).unsqueeze(1)
        nir = x.select(1, 3).unsqueeze(1)
        re = x.select(1, 4).unsqueeze(1)

        # Compute vegetation indices
        ndvi = self.normalized_index(nir, red)
        gndvi = self.normalized_index(nir, green)
        ndvi_re = self.normalized_index(re, red)

        # Compute other indices
        ndwi = self.normalized_index(green, nir)
        chl_green = (nir / (green + EPS)) - 1  # Chlorophyll Index Green

        # Stack all bands and indices
        x_aug = torch.cat([blue, green, red, nir, re, ndvi, ndwi, gndvi, chl_green, ndvi_re], dim=1)

        x_aug_normalized = (x_aug - self.per_channel_mean) / self.per_channel_std

        return self.model(x_aug_normalized)

    @staticmethod
    def normalized_index(a, b):
        return (a - b) / (a + b + EPS)


model = SKeMaModel()

sample_input = torch.rand((2, 5, 512, 512), device=torch.device("cpu"), requires_grad=False)
model(sample_input)

  from .autonotebook import tqdm as notebook_tqdm


tensor([[[[ 0.7303,  1.1697,  0.8220,  ...,  0.7788,  1.0340,  0.3348],
          [ 0.0057, -0.1139, -0.4681,  ..., -0.2208,  0.1807,  0.0909],
          [ 0.1513,  0.2111,  0.5061,  ...,  0.2507, -0.2675, -0.5306],
          ...,
          [-0.0229,  0.2071,  1.2011,  ...,  0.4269,  0.0604, -0.2384],
          [-0.0645,  0.4666,  0.1669,  ...,  1.2136,  0.4438,  0.3940],
          [-0.5043, -0.8172, -1.1685,  ..., -0.0964, -0.0730, -0.3842]]],


        [[[ 0.5609,  1.3146,  1.7041,  ...,  0.9803,  1.2356,  0.5044],
          [-0.4189,  0.2561,  0.1056,  ..., -0.3059, -0.5648, -0.4129],
          [ 0.2414,  0.7687, -0.5919,  ...,  0.1580, -0.2497, -0.3917],
          ...,
          [ 0.5073,  0.2510,  0.2949,  ..., -0.8702, -0.5213, -0.0911],
          [ 0.2164,  0.0224, -0.8488,  ...,  0.2523, -0.2204, -0.2463],
          [-0.0543, -0.1873, -0.5646,  ..., -0.7739, -0.4300, -0.2990]]]],
       grad_fn=<ConvolutionBackward0>)

In [2]:
ckpt = torch.load("./Unet_tu-maxvit_tiny_tf_512_20250818_164043.ckpt", map_location="cpu")
state_dict = ckpt["state_dict"]

# Update keys
del state_dict["mean"]
del state_dict["std"]
model.load_state_dict(state_dict, strict=False)
model.eval()

SKeMaModel(
  (model): Unet(
    (encoder): TimmUniversalEncoder(
      (model): FeatureListNet(
        (stem): Stem(
          (conv1): Conv2dSame(10, 64, kernel_size=(3, 3), stride=(2, 2))
          (norm1): BatchNormAct2d(
            64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELUTanh()
          )
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (stages_0): MaxxVitStage(
          (blocks): Sequential(
            (0): MaxxVitBlock(
              (conv): MbConvBlock(
                (shortcut): Downsample2d(
                  (pool): AvgPool2dSame(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0))
                  (expand): Identity()
                )
                (pre_norm): BatchNormAct2d(
                  64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
                  (drop): Identity()
                  (act): Identity()
  

In [3]:
torch.onnx.export(
    model,
    sample_input,
    "./Unet_tu-maxvit_tiny_tf_512_20250818_164043.onnx",
    input_names=["input"],
    output_names=["output"],
    export_params=True,
    external_data=False,  # Store model weights in the model file
    opset_version=15,  # ONNX opset version
    do_constant_folding=True,  # Optimize constants
    verbose=False,
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    # dynamic_shapes={"x": (torch.export.Dim("batch"), 5, 512, 512)},
    dynamo=False,
)

  torch.onnx.export(
  assert condition, message


## Export S2 + Bathymetry + Substrate model

In [4]:
EPS = 1e-10


class SKeMaBathyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.model = smp.Unet(
            encoder_name="tu-maxvit_tiny_tf_512",
            in_channels=12,
            encoder_weights=None,
        )

        self.register_buffer(
            "per_channel_mean",
            torch.tensor([
                1.93357159e02,
                2.53693333e02,
                1.41648022e02,
                9.99292362e02,
                3.21693919e02,
                1.30867292e00,
                2.63550136e00,
                6.49704998e-02,
                1.57273007e-01,
                -1.57273007e-01,
                1.82229161e07,
                1.09806622e-01,
            ]).view(1, -1, 1, 1),
        )

        self.register_buffer(
            "per_channel_std",
            torch.tensor([
                1.55697494e02,
                2.12700364e02,
                2.04018106e02,
                1.27588129e03,
                3.77324432e02,
                1.33938435e00,
                2.14640498e02,
                6.75251176e-01,
                7.32966188e-01,
                7.32966188e-01,
                2.16768826e10,
                4.11232123e-01,
            ]).view(1, -1, 1, 1),
        )

    def forward(self, x):
        # Unpack spectral bands
        blue = x.select(1, 0).unsqueeze(1)
        green = x.select(1, 1).unsqueeze(1)
        red = x.select(1, 2).unsqueeze(1)
        nir = x.select(1, 3).unsqueeze(1)
        re = x.select(1, 4).unsqueeze(1)
        substrate = x.select(1, 5).unsqueeze(1)
        bathymetry = x.select(1, 6).unsqueeze(1)

        # Compute vegetation indices
        ndvi = self.normalized_index(nir, red)
        gndvi = self.normalized_index(nir, green)
        ndvi_re = self.normalized_index(re, red)

        # Compute other indices
        ndwi = self.normalized_index(green, nir)
        chl_green = (nir / (green + EPS)) - 1  # Chlorophyll Index Green

        # Stack all bands and indices
        x_aug = torch.cat(
            [blue, green, red, nir, re, substrate, bathymetry, ndvi, ndwi, gndvi, chl_green, ndvi_re], dim=1
        )

        x_aug_normalized = (x_aug - self.per_channel_mean) / self.per_channel_std

        return self.model(x_aug_normalized)

    @staticmethod
    def normalized_index(a, b):
        return (a - b) / (a + b + EPS)


model = SKeMaBathyModel()

sample_input = torch.rand((2, 7, 512, 512), device=torch.device("cpu"), requires_grad=False)
model(sample_input)

tensor([[[[ 0.4813,  0.0931,  0.0299,  ...,  1.1187,  1.0073,  0.2813],
          [ 0.5153, -0.0539,  0.6094,  ...,  1.4073,  1.8217,  0.7350],
          [ 0.7110, -0.1710,  0.7305,  ...,  3.1022,  2.5016,  0.0418],
          ...,
          [ 0.2322, -0.3572, -0.6058,  ...,  0.0209,  0.8186,  0.3770],
          [ 0.7268,  0.2270, -0.0429,  ..., -0.0917,  1.0067,  0.5032],
          [-0.2261,  0.0787,  0.0792,  ...,  0.1037,  0.0944,  0.1754]]],


        [[[ 0.3767,  0.7756,  0.5304,  ...,  0.7425,  1.2530,  0.4901],
          [-0.0702, -0.5960,  0.2289,  ...,  0.7328,  1.0064,  1.0822],
          [ 0.5095, -0.2555, -0.3810,  ..., -0.4529,  0.7583,  1.3517],
          ...,
          [ 1.3123,  0.5650,  0.9073,  ...,  0.8774,  0.4388,  0.0916],
          [ 0.7773,  1.8775,  1.1773,  ...,  0.6576,  0.6688, -0.2499],
          [ 0.2814, -0.4011, -0.5158,  ...,  0.3432,  0.0538, -0.2632]]]],
       grad_fn=<ConvolutionBackward0>)

In [5]:
ckpt = torch.load("./Unet_tu-maxvit_tiny_tf_512_20250714_222203.ckpt", map_location="cpu")
state_dict = ckpt["state_dict"]

# Update keys
del state_dict["mean"]
del state_dict["std"]
model.load_state_dict(state_dict, strict=False)
model.eval()

SKeMaBathyModel(
  (model): Unet(
    (encoder): TimmUniversalEncoder(
      (model): FeatureListNet(
        (stem): Stem(
          (conv1): Conv2dSame(12, 64, kernel_size=(3, 3), stride=(2, 2))
          (norm1): BatchNormAct2d(
            64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELUTanh()
          )
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (stages_0): MaxxVitStage(
          (blocks): Sequential(
            (0): MaxxVitBlock(
              (conv): MbConvBlock(
                (shortcut): Downsample2d(
                  (pool): AvgPool2dSame(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0))
                  (expand): Identity()
                )
                (pre_norm): BatchNormAct2d(
                  64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
                  (drop): Identity()
                  (act): Identity

In [6]:
torch.onnx.export(
    model,
    sample_input,
    "./Unet_tu-maxvit_tiny_tf_512_20250714_222203.onnx",
    input_names=["input"],
    output_names=["output"],
    export_params=True,
    external_data=False,  # Store model weights in the model file
    opset_version=15,  # ONNX opset version
    do_constant_folding=True,  # Optimize constants
    verbose=False,
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    # dynamic_shapes={"x": (torch.export.Dim("batch"), 5, 512, 512)},
    dynamo=False,
)

  torch.onnx.export(
