In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.cuda.is_available()

True

In [3]:
NUM_CLASSES = 4

class WasteCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(WasteCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 4, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(4, 8, 3, padding=1)
        self.fc1 = nn.Linear(8*64*64, 64)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Create dummy model
model = WasteCNN()
model.eval()


WasteCNN(
  (conv1): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=32768, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=4, bias=True)
)

In [4]:
sum(p.numel() for p in model.parameters())

2097884

In [5]:
dummy_input = torch.randn(1, 3, 256, 256)

In [None]:
torch.onnx.export(
    model,
    (dummy_input,),
    "../models/waste_model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    opset_version=13
)