In [1]:
import torch
import torchvision
import fastai
from fastai.vision import *
import torch.nn.functional as F
from fastai.layers import *
import onnx

In [2]:
model_name = "supervised"

In [3]:
# From fastai
def conv_bn_lrelu(ni:int, nf:int, ks:int=3, stride:int=1)->nn.Sequential:
    "Create a seuence Conv2d->BatchNorm2d->LeakyReLu layer."
    return nn.Sequential(
        nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=ks//2),
        nn.BatchNorm2d(nf),
        nn.LeakyReLU(negative_slope=0.1, inplace=True))

class ResLayer(nn.Module):
    "Resnet style layer with `ni` inputs."
    def __init__(self, ni:int):
        super().__init__()
        self.conv1=conv_bn_lrelu(ni, ni//2, ks=1)
        self.conv2=conv_bn_lrelu(ni//2, ni, ks=3)

    def forward(self, x): return x + self.conv2(self.conv1(x))

# From fastai, modified head
class CustomDarknet(nn.Module):
    "https://github.com/pjreddie/darknet"

    def make_group_layer(self, ch_in: int, num_blocks: int, stride: int = 1):
        "starts with conv layer - `ch_in` channels in - then has `num_blocks` `ResLayer`"
        return [conv_bn_lrelu(ch_in, ch_in * 2, stride=stride)
                ] + [(ResLayer(ch_in * 2)) for i in range(num_blocks)]

    def __init__(self, num_blocks: Collection[int], num_classes: int, nf=32):
        "create darknet with `nf` and `num_blocks` layers"
        super().__init__()
        layers = [conv_bn_lrelu(3, nf, ks=3, stride=1)]
        for i, nb in enumerate(num_blocks):
            layers += self.make_group_layer(nf, nb, stride=2 - (i == 1))
            nf *= 2
        layers += [nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(nf, num_classes)]
        layers += [nn.Linear(num_classes, 2), SigmoidRange(-1, 1)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [4]:
learn = load_learner("../", fname="%s.pkl" % model_name)
learn.summary

<bound method model_summary of Learner(data=ImageDataBunch;

Train: LabelList (0 items)
x: ImageItemList

y: FloatList

Path: ..;

Valid: LabelList (0 items)
x: ImageItemList

y: FloatList

Path: ..;

Test: None, model=CustomDarknet(
  (layers): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.1, inplace)
    )
    (1): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.1, inplace)
    )
    (2): ResLayer(
      (conv1): Sequential(
        (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Le

In [5]:
model = learn.model.cpu()

In [6]:
dummy_input = torch.randn(1, 3, 240, 320).cpu()
torch.onnx.export(model, dummy_input, f="../%s.onnx" % model_name, verbose=True)

graph(%0 : Float(1, 3, 240, 320)
      %1 : Float(16, 3, 3, 3)
      %2 : Float(16)
      %3 : Float(16)
      %4 : Float(16)
      %5 : Float(16)
      %6 : Long()
      %7 : Float(32, 16, 3, 3)
      %8 : Float(32)
      %9 : Float(32)
      %10 : Float(32)
      %11 : Float(32)
      %12 : Long()
      %13 : Float(16, 32, 1, 1)
      %14 : Float(16)
      %15 : Float(16)
      %16 : Float(16)
      %17 : Float(16)
      %18 : Long()
      %19 : Float(32, 16, 3, 3)
      %20 : Float(32)
      %21 : Float(32)
      %22 : Float(32)
      %23 : Float(32)
      %24 : Long()
      %25 : Float(64, 32, 3, 3)
      %26 : Float(64)
      %27 : Float(64)
      %28 : Float(64)
      %29 : Float(64)
      %30 : Long()
      %31 : Float(32, 64, 1, 1)
      %32 : Float(32)
      %33 : Float(32)
      %34 : Float(32)
      %35 : Float(32)
      %36 : Long()
      %37 : Float(64, 32, 3, 3)
      %38 : Float(64)
      %39 : Float(64)
      %40 : Float(64)
      %41 : Float(64)
      %42 : Long()
    




In [7]:
exported_model = onnx.load("../%s.onnx" % model_name)
onnx.checker.check_model(exported_model)

In [10]:
print(onnx.helper.printable_graph(exported_model.graph))

graph torch-jit-export (
  %0[FLOAT, 1x3x240x320]
) initializers (
  %1[FLOAT, 16x3x3x3]
  %2[FLOAT, 16]
  %3[FLOAT, 16]
  %4[FLOAT, 16]
  %5[FLOAT, 16]
  %6[INT64, scalar]
  %7[FLOAT, 32x16x3x3]
  %8[FLOAT, 32]
  %9[FLOAT, 32]
  %10[FLOAT, 32]
  %11[FLOAT, 32]
  %12[INT64, scalar]
  %13[FLOAT, 16x32x1x1]
  %14[FLOAT, 16]
  %15[FLOAT, 16]
  %16[FLOAT, 16]
  %17[FLOAT, 16]
  %18[INT64, scalar]
  %19[FLOAT, 32x16x3x3]
  %20[FLOAT, 32]
  %21[FLOAT, 32]
  %22[FLOAT, 32]
  %23[FLOAT, 32]
  %24[INT64, scalar]
  %25[FLOAT, 64x32x3x3]
  %26[FLOAT, 64]
  %27[FLOAT, 64]
  %28[FLOAT, 64]
  %29[FLOAT, 64]
  %30[INT64, scalar]
  %31[FLOAT, 32x64x1x1]
  %32[FLOAT, 32]
  %33[FLOAT, 32]
  %34[FLOAT, 32]
  %35[FLOAT, 32]
  %36[INT64, scalar]
  %37[FLOAT, 64x32x3x3]
  %38[FLOAT, 64]
  %39[FLOAT, 64]
  %40[FLOAT, 64]
  %41[FLOAT, 64]
  %42[INT64, scalar]
  %43[FLOAT, 32x64x1x1]
  %44[FLOAT, 32]
  %45[FLOAT, 32]
  %46[FLOAT, 32]
  %47[FLOAT, 32]
  %48[INT64, scalar]
  %49[FLOAT, 64x32x3x3]
  %50[FLOAT, 64