# Export FFNet ONNX and TensorFlow Lite model.

## Reference
- https://github.com/Qualcomm-AI-research/FFNet
- https://github.com/PINTO0309/onnx2tf
- https://github.com/onnx/tensorflow-onnx

## Clone the repository. 

In [None]:
!git clone https://github.com/Qualcomm-AI-research/FFNet.git

In [None]:
%cd FFNet

Change `model_weights_base_path` in config file.

In [None]:
%%bash
cat > ./config.patch
diff --git a/config.py b/config.py
index a9cf687..30efc27 100644
--- a/config.py
+++ b/config.py
@@ -3,7 +3,7 @@
 
 imagenet_base_path = "/workspace/imagenet/"
 cityscapes_base_path = "/workspace/cityscapes/"
-model_weights_base_path = "/workspace/ffnet_weights/"
+model_weights_base_path = "./model_weights/"
 
 CITYSCAPES_MEAN = [0.485, 0.456, 0.406]
 CITYSCAPES_STD = [0.229, 0.224, 0.225]

In [None]:
!patch < ./config.patch

## Download the trained model.

In [None]:
import os
import shutil
import glob

import torch
import torch.nn.functional as F

from models.model_registry import model_entrypoint

In [None]:
MODELS = [
    "ffnet101.zip",
    "ffnet122N.zip",
    "ffnet122NS.zip",
    "ffnet134.zip",
    "ffnet150.zip",
    "ffnet150S.zip",
    "ffnet18.zip",
    "ffnet34.zip",
    "ffnet40S.zip",
    "ffnet46N.zip",
    "ffnet46NS.zip",
    "ffnet50.zip",
    "ffnet54S.zip",
    "ffnet56.zip",
    "ffnet74N.zip",
    "ffnet74NS.zip",
    "ffnet78S.zip",
    "ffnet86.zip",
    "ffnet86S.zip",   
]

In [None]:
download_path = os.path.join(".", "model_weights")

for model in MODELS:
    download_url = "https://github.com/Qualcomm-AI-research/FFNet/releases/download/models/" + model
    file_path = os.path.join(".", "model_weights", model)

    !wget $download_url -P $download_path
    !unzip $file_path -d $download_path

## Export ONNX and TensorFlow Lite model.

In [None]:
SEG_MODEL_NAME = {
    ("segmentation_ffnet101_dAAA", (1024, 2048)),
    ("segmentation_ffnet50_dAAA", (1024, 2048)),
    ("segmentation_ffnet150_dAAA", (1024, 2048)),
    ("segmentation_ffnet134_dAAA", (1024, 2048)),
    ("segmentation_ffnet86_dAAA", (1024, 2048)),
    ("segmentation_ffnet56_dAAA", (1024, 2048)),
    ("segmentation_ffnet34_dAAA", (1024, 2048)),
    ("segmentation_ffnet18_dAAA", (1024, 2048)),
    # ("segmentation_ffnet150_dAAC", (1024, 2048)), size mismatch
    # ("segmentation_ffnet86_dAAC", (512, 1024)), size mismatch
    # ("segmentation_ffnet34_dAAC", (512, 1024)), size mismatch
    # ("segmentation_ffnet18_dAAC", (512, 1024)), size mismatch
    ("segmentation_ffnet150S_dBBB", (1024, 2048)),
    ("segmentation_ffnet86S_dBBB", (1024, 2048)),
    ("segmentation_ffnet86S_dBBB_mobile", (1024, 2048)),
    ("segmentation_ffnet78S_dBBB_mobile", (1024, 2048)),
    ("segmentation_ffnet54S_dBBB_mobile", (1024, 2048)),
    ("segmentation_ffnet40S_dBBB_mobile", (1024, 2048)),
    ("segmentation_ffnet150S_BBB_mobile", (512, 1024)),
    ("segmentation_ffnet150S_BBB_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet86S_BBB_mobile", (512, 1024)),
    ("segmentation_ffnet86S_BBB_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet78S_BBB_mobile", (512, 1024)),
    ("segmentation_ffnet78S_BBB_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet54S_BBB_mobile", (512, 1024)),
    ("segmentation_ffnet54S_BBB_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet40S_BBB_mobile", (512, 1024)),
    ("segmentation_ffnet40S_BBB_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet150S_BCC_mobile", (512, 1024)),
    ("segmentation_ffnet40S_BBB_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet40S_BBB_mobile", (512, 1024)),
    ("segmentation_ffnet86S_BCC_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet86S_BCC_mobile", (512, 1024)),
    ("segmentation_ffnet78S_BCC_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet78S_BCC_mobile", (512, 1024)),
    ("segmentation_ffnet54S_BCC_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet54S_BCC_mobile", (512, 1024)),
    ("segmentation_ffnet40S_BCC_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet40S_BCC_mobile", (512, 1024)),
    ("segmentation_ffnet122NS_CBB_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet122NS_CBB_mobile", (512, 1024)),
    ("segmentation_ffnet74NS_CBB_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet74NS_CBB_mobile", (512, 1024)),
    ("segmentation_ffnet46NS_CBB_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet46NS_CBB_mobile", (512, 1024)),
    ("segmentation_ffnet122NS_CCC_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet122NS_CCC_mobile", (512, 1024)),
    # ("segmentation_ffnet74NS_CCC_mobile_pre_down", (512, 1024)),  RuntimeError: Error(s) in loading state_dict for FFNet:
    ("segmentation_ffnet74NS_CCC_mobile", (512, 1024)),
    ("segmentation_ffnet46NS_CCC_mobile_pre_down", (512, 1024)),
    ("segmentation_ffnet46NS_CCC_mobile", (512, 1024))
}

In [None]:
class ExportFFNet(torch.nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model = model_entrypoint(model_name)()

    def forward(self, x):
        x = self.model(x)
        x = F.interpolate(x, (1024, 2048), mode="bilinear", align_corners=True)
        x = torch.argmax(x, dim=1)
        return x

In [None]:
output_path = os.path.join(".", "ffnet_models")

if os.path.exists(output_path):
    shutil.rmtree(output_path)
os.mkdir(output_path)

if os.path.exists(os.path.join(".", "saved_model")):
    shutil.rmtree(os.path.join(".", "saved_model"))
for tmp_onnx_path in glob.glob(os.path.join(".", "*.onnx"), recursive=False):
    os.remove(tmp_onnx_path)

In [None]:
for model_name, size in SEG_MODEL_NAME:
    height, width = size

    print("----- Start {} -----".format(model_name), )

    if os.path.exists(os.path.join(output_path, model_name + "_fused_argmax.onnx")):
        print("{} has already been exported.".format(model_name))
        continue

    # load model and export onnx.
    model = ExportFFNet(model_name=model_name)
    dummy_input = torch.randn(1, 3, height, width, device="cpu")
    tmp_onnx_path = os.path.join(".", model_name + ".onnx")
    torch.onnx.export(
        model,
        dummy_input,
        tmp_onnx_path,
        verbose=False,
        input_names=[ "input1" ],
        output_names=[ "output1" ]
    )

    # Convert default argmax.
    output_onnx_path = os.path.join(output_path, model_name + ".onnx")
    tflite_float32_path = os.path.join(".", "saved_model", model_name + "_float32.tflite")

    !onnx2tf -i $tmp_onnx_path
    !python -m tf2onnx.convert --tflite $tflite_float32_path --inputs-as-nchw inputs_0 --output $output_onnx_path
    for tflite_file in glob.glob(os.path.join(".", "saved_model", "*.tflite"), recursive=True):
        shutil.copyfile(tflite_file, os.path.join(output_path, os.path.basename(tflite_file)))
    shutil.rmtree(os.path.join(".", "saved_model"))
    
    # Convert fused-argmax
    src_name = tmp_onnx_path
    tmp_onnx_path = os.path.join(".", model_name + "_fused_argmax.onnx")
    os.rename(src_name, tmp_onnx_path)
    output_onnx_path = os.path.join(output_path, model_name + "_fused_argmax.onnx")
    tflite_float32_path = os.path.join(".", "saved_model", model_name + "_fused_argmax_float32.tflite")

    !onnx2tf -i $tmp_onnx_path -rafi64
    !python -m tf2onnx.convert --tflite $tflite_float32_path --inputs-as-nchw inputs_0 --output $output_onnx_path
    for tflite_file in glob.glob(os.path.join(".", "saved_model", "*.tflite"), recursive=True):
        shutil.copyfile(tflite_file, os.path.join(output_path, os.path.basename(tflite_file)))
    os.remove(tmp_onnx_path) 
    shutil.rmtree(os.path.join(".", "saved_model"))

In [None]:
!tar czf ffnet_models.tar.gz ffnet_models