In [None]:
import segmentation_models_pytorch as smp
import torch

EPS = 1e-10


class SKeMaModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.model = smp.Unet(
            encoder_name="tu-maxvit_tiny_tf_512",
            in_channels=10,
            encoder_weights=None,
        )

        self.register_buffer(
            "per_channel_mean",
            torch.tensor([
                1.93357159e02,
                2.53693333e02,
                1.41648022e02,
                9.99292362e02,
                3.21693919e02,
                6.49704998e-02,
                1.57273007e-01,
                -1.57273007e-01,
                1.82229161e07,
                1.09806622e-01,
            ]).view(1, -1, 1, 1),
        )

        self.register_buffer(
            "per_channel_std",
            torch.tensor([
                1.55697494e02,
                2.12700364e02,
                2.04018106e02,
                1.27588129e03,
                3.77324432e02,
                6.75251176e-01,
                7.32966188e-01,
                7.32966188e-01,
                2.16768826e10,
                4.11232123e-01,
            ]).view(1, -1, 1, 1),
        )

    def forward(self, x):
        x = x - 1000
        x = torch.clamp(x, min=0)

        # Unpack spectral bands
        blue = x.select(1, 0).unsqueeze(1)
        green = x.select(1, 1).unsqueeze(1)
        red = x.select(1, 2).unsqueeze(1)
        nir = x.select(1, 3).unsqueeze(1)
        re = x.select(1, 4).unsqueeze(1)

        # Compute vegetation indices
        ndvi = self.normalized_index(nir, red)
        gndvi = self.normalized_index(nir, green)
        ndvi_re = self.normalized_index(re, red)

        # Compute other indices
        ndwi = self.normalized_index(green, nir)
        chl_green = (nir / (green + EPS)) - 1  # Chlorophyll Index Green

        # Stack all bands and indices
        x_aug = torch.cat([blue, green, red, nir, re, ndvi, ndwi, gndvi, chl_green, ndvi_re], dim=1)

        x_aug_normalized = (x_aug - self.per_channel_mean) / self.per_channel_std

        return self.model(x_aug_normalized)

    @staticmethod
    def normalized_index(a, b):
        return (a - b) / (a + b + EPS)


model = SKeMaModel()

sample_input = torch.rand((2, 5, 512, 512), device=torch.device("cpu"), requires_grad=False)
model(sample_input)

In [None]:
ckpt = torch.load("./Unet_tu-maxvit_tiny_tf_512_20250818_164043.ckpt")
state_dict = ckpt["state_dict"]

# Update keys
del state_dict["mean"]
del state_dict["std"]
model.load_state_dict(state_dict, strict=False)
model.eval()

In [None]:
torch.onnx.export(
    model,
    sample_input,
    "./Unet_tu-maxvit_tiny_tf_512_20250818_164043.onnx",
    input_names=["input"],
    output_names=["output"],
    export_params=True,  # Store model weights in the model file
    opset_version=15,  # ONNX opset version
    do_constant_folding=True,  # Optimize constants
    verbose=False,
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    # dynamic_shapes={"x": (torch.export.Dim("batch"), 5, 512, 512)},
    # dynamo=True,
)