In [None]:
#default_exp export.onnx

In [None]:
#hide
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#hide
import fastai
fastai.__version__

'2.0.18'

In [None]:
#export
import torch.nn as nn
import torch
from typing import Union
from pathlib import Path
from fastai.vision.all import *

import os
import onnx
import onnx.utils
from onnx import optimizer

### Standalone Torch -> ONNX Exporter

In [None]:
#export
def torch_to_onnx(model:nn.Module,
                  activation:nn.Module=None,
                  save_path:str     = '../exported-models/',
                  model_fname:str   = 'onnx-model',
                  input_shape:tuple = (1,3,224,224),
                  input_name:str    = 'input_image',
                  output_names:Union[str,list] = 'output',
                  verbose:bool = True,
                  **export_args) -> os.PathLike:
    """
    Export a `nn.Module` -> ONNX
    This function exports the model with support for batching,
    checks that the export was done properly, and polishes the
    model up (removes unnecessary fluff added during conversion)

    The path to the saved model is returned
    Key Arguments
    =============
    * activation:  If not None, append this to the end of your model.
                   Typically a `nn.Softmax(-1)` or `nn.Sigmoid()`
    * input_shape: Shape of the inputs to the model
    """
    save_path = Path(save_path)
    if isinstance(output_names, str): output_names = [output_names]
    if activation: model = nn.Sequential(*[model, activation])
    model.eval()
    x = torch.randn(input_shape, requires_grad=True)
    x = x.cuda() if torch.cuda.is_available() else x
    model(x)
    dynamic_batch = {0: 'batch'}
    dynamic_axes  = {input_name : dynamic_batch}
    for out in output_names: dynamic_axes[out] = dynamic_batch
    torch.onnx._export(model, x, f"{save_path/model_fname}.onnx",
                       export_params=True, verbose=False,
                       input_names=[input_name], output_names=output_names,
                       dynamic_axes=dynamic_axes, keep_initializers_as_inputs=True,
                       **export_args)
    if verbose:
        print(f"Loading, polishing, and optimising exported model from {save_path/model_fname}.onnx")
    onnx_model = onnx.load(f'{save_path/model_fname}.onnx')
    model = onnx.utils.polish_model(onnx_model)
    #onnx.checker.check_model(model)

    # removing unused parts of the model
    passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]
    optimized_model = optimizer.optimize(onnx_model, passes)

    onnx.save(optimized_model, f'{save_path/model_fname}.onnx')
    print('<Exported ONNX model successfully>') # print regardless
    return f'{save_path/model_fname}.onnx'

### Fastai Learner -> ONNX Exporter

In [None]:
@patch
@delegates(to=torch_to_onnx, but=["model", "save_path"])
def export_to_onnx(self:Learner, save_path=None, activation:Union[str,nn.Module,None]='auto', **kwargs):
    """Export to ONNX along with an accompanying `vocab.txt` file
    * If `save_path` is None, model is exported to `Learner.path`
    * If `activation`=='auto', the act function (nn.Sigmoid or nn.Softmax)
      is determined based on `Learner.loss_func`
    """
    if save_path is None: save_path = self.path
    else: save_path = Path(save_path)
    if activation=='auto':
        if self.loss_func.__class__==fastai.losses.CrossEntropyLossFlat:
            activation = nn.Softmax(-1)
        elif self.loss_func.__class__==fastai.losses.BCEWithLogitsLossFlat:
            activation = nn.Sigmoid()
    
    with open(save_path/"vocab.txt", "w") as f:
        f.write(', '.join(self.dls.vocab))
        print(f"Wrote 'vocab.txt' file to {save_path}")
    torch_to_onnx(self.model, save_path=save_path, activation=activation, **kwargs)

In [None]:
learn.export_to_onnx("/tmp")

Wrote 'vocab.txt' file to /tmp
Loading, polishing, and optimising exported model from /tmp/onnx-model.onnx
<Exported ONNX model successfully>


In [None]:
#hide
from nbdev.export import *
notebook2script('export_onnx.ipynb')

Converted export_onnx.ipynb.
