In [3]:
import torch
import torch.onnx
import numpy as np

from models.segmentation.segmentation_model import mit_unet
from models.classification.model import caformer_b36

In [4]:
seg_model = mit_unet()
seg_model.load_state_dict(torch.load('C:/weights/mit_unet.pth', map_location='cpu'))
seg_model.eval()

Unet(
  (encoder): MixVisionTransformerEncoder(
    (patch_embed1): OverlapPatchEmbed(
      (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed2): OverlapPatchEmbed(
      (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed3): OverlapPatchEmbed(
      (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed4): OverlapPatchEmbed(
      (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (block1): ModuleList(
      (0): Block(
        (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=64

In [5]:
clf_model = caformer_b36(num_classes=7)
clf_model.load_state_dict(torch.load('C:/weights/caformer_b36.pth', map_location='cpu'))
clf_model.eval()

CAFormerB36(
  (stem): Stem(
    (conv): Conv2d(3, 128, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
    (norm): LayerNorm2dNoBias((128,), eps=1e-06, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): MetaFormerStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): MetaFormerBlock(
          (layer_scale1): Identity()
          (layer_scale2): Identity()
          (res_scale1): Identity()
          (res_scale2): Identity()
          (norm1): LayerNorm2dNoBias((128,), eps=1e-06, elementwise_affine=True)
          (token_mixer): SepConv(
            (pwconv1): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (act1): StarReLU(
              (relu): ReLU()
            )
            (dwconv): Conv2d(256, 256, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=256, bias=False)
            (act2): Identity()
            (pwconv2): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (dr

In [None]:
dummy_input_seg = torch.randn(1, 3, 224, 224)
# Export segmentation model
torch.onnx.export(
    seg_model,
    dummy_input_seg,
    "onnx_models/segmentation_model.onnx",
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)

# Export classification model as well
dummy_input_clf = torch.randn(1, 3, 384, 384)
torch.onnx.export(
    clf_model,
    dummy_input_clf,
    "onnx_models/classification_model.onnx",
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)

Models exported to ONNX format successfully
