In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx

In [3]:
# First, recreate the EXACT same model architecture
class MeshClassifier(nn.Module):
    def __init__(self, input_dim=3, num_classes=3):
        super(MeshClassifier, self).__init__()

        # filter
        self.filter = nn.Conv1d(input_dim, 32, kernel_size=3, padding=1)

        self.conv1 = nn.Conv1d(32, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 256, 1)
        self.conv4 = nn.Conv1d(256, 512, 1)

        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0)

    def forward(self, vertices_list):
        batch_features = []
        
        for vertices in vertices_list:
            x = vertices.T.unsqueeze(0) # (N, 3) -> (1, 3, N)
            
            x = F.relu(self.filter(x))
            x = F.relu(self.conv1(x))
            x = F.relu(self.conv2(x))
            x = F.relu(self.conv3(x))
            x = F.relu(self.conv4(x))
            x = torch.max(x, dim=2)[0]  # Global max pooling (1, 512)
            batch_features.append(x.squeeze(0))  # (512,)

        batch_features = torch.stack(batch_features)  # (batch_size, 512)
        x = F.relu(self.fc1(batch_features))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [4]:
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MeshClassifier().to(device)
model.load_state_dict(torch.load("model/mesh_classifier.pth"))
model.eval()

MeshClassifier(
  (filter): Conv1d(3, 32, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv1): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
  (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
  (conv3): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
  (conv4): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
  (fc1): Linear(in_features=512, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=3, bias=True)
  (dropout): Dropout(p=0, inplace=False)
)

In [6]:
# Dummy input is used by ONNX to trace the execution of the model to capture its computational graph
dummy_input = [torch.randn(1500, 3).to(device)]  # Match the model's expected input format -> List of tensor

# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    "model/mesh_classifier.onnx",
    export_params=True,
    opset_version=13,
    do_constant_folding=True,
    input_names=["vertices"],
    output_names=["class_scores"],
    dynamic_axes={
        "vertices": {0: "num_vertices"},  # Dynamic vertex count
    }
)