In [1]:
import sys
sys.path.append('/home/taylor/PycharmProjects/hakai-ml-train/src')

In [25]:
import torch
from pathlib import Path
import onnx
from torchvision import transforms as t
from utils.transforms import PadOut, normalize, target_to_tensor
from PIL import Image
import onnxruntime
import numpy as np

from models.lit_lraspp_mobilenet_v3_large import LRASPPMobileNetV3Large
DEVICE = torch.device('cpu')


In [3]:
x = torch.rand(1, 3, 8, 8, device=DEVICE, requires_grad=False)

### Presence model

In [4]:
p_weights_path = "/home/taylor/PycharmProjects/hakai-ml-train/exports/presence/best-val_miou=0.8023-epoch=18-step=17593_fix.ckpt"

p_model = LRASPPMobileNetV3Large.load_from_checkpoint(p_weights_path, train=False, map_location=DEVICE, strict=False, num_classes=2)

#### Export Torch JIT Trace

In [5]:
p_torchscript_path = "/home/taylor/PycharmProjects/hakai-ml-train/exports/presence/LRASPP_MobileNetV3_kelp_presence_jit_miou=0.8023.pt"
_ = p_model.to_torchscript(file_path=p_torchscript_path, method='trace', example_inputs=x)

#### Export ONNX

I've contructed a new model that takes care of the data normalization and argmax on the logits, with the intention being to eliminate the need for a pytorch installation on user's machines

In [6]:
p_onnx_path = "/home/taylor/PycharmProjects/hakai-ml-train/exports/presence/LRASPP_MobileNetV3_kelp_presence_miou=0.8023.onnx"


class PresenceInferenceModel(LRASPPMobileNetV3Large):
    def forward(self, x):
        logits = super().forward(normalize(x))
        return torch.argmax(logits, dim=1)
        
p_inf_model = PresenceInferenceModel.load_from_checkpoint(p_weights_path, train=False, map_location=DEVICE, strict=False, num_classes=2)
    
p_inf_model.to_onnx(
    p_onnx_path, 
    x, 
    export_params=True, 
    input_names=["x"],
    output_names=["pred"],
    dynamic_axes={
        "x": {
            0: "batch_size", 
            2: "height", 
            3: "width"
        },
    }
)

verbose: False, log level: Level.ERROR



##### Test the exports all do what they're supposed to

In [8]:
trans = t.Compose([PadOut(512, 512, fill_value=0), t.ToTensor(), normalize])
# No normalization
trans1 = t.Compose([PadOut(512, 512, fill_value=0), t.ToTensor()])
img = Image.open("/home/taylor/PycharmProjects/hakai-ml-train/exports/Simmonds_kelp_U0171_08_03.tif")

In [9]:
x = trans(img).repeat(repeats=[2,1,1,1]).to(DEVICE)
#Unnormalized
x1 = trans1(img).repeat(repeats=[2,1,1,1]).to(DEVICE)
x.shape

torch.Size([2, 3, 512, 512])

In [24]:
p_model_kom = torch.jit.load("/home/taylor/PycharmProjects/hakai-ml-train/exports/LRASPP_MobileNetV3_kelp_presence_jit.pt", map_location=DEVICE)
p_model_kom = p_model_kom.eval()

# KoM with jit weights output
with torch.no_grad():
    p_kom_out = p_model_kom(x).numpy()

In [26]:
p_model_test = torch.jit.load(p_torchscript_path, map_location=DEVICE)
p_model_test = p_model_test.eval()

# Test that new JIT export matches KOM outputs
with torch.no_grad():
    p_test_out = p_model_test(x).numpy()
    
np.allclose(p_kom_out, p_test_out)

True

In [28]:
ort_session = onnxruntime.InferenceSession(p_onnx_path)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: x1.numpy()}
p_ort_outs = ort_session.run(None, ort_inputs)

# Test that onnx output matches KOM output
np.allclose(np.argmax(p_kom_out, axis=1), p_ort_outs[0])

True

### Species model

In [36]:
s_weights_path = "/home/taylor/PycharmProjects/hakai-ml-train/exports/species/miou=0.9634_epoch=13.ckpt"
s_model = LRASPPMobileNetV3Large.load_from_checkpoint(s_weights_path, train=False, map_location=DEVICE, strict=False, num_classes=2)

#### Export JIT Trace

In [39]:
s_torchscript_path = "/home/taylor/PycharmProjects/hakai-ml-train/exports/species/LRASPP_MobileNetV3_kelp_species_jit_miou=0.9634.pt"
_ = s_model.to_torchscript(file_path=s_torchscript_path, method='trace', example_inputs=x)

#### Export ONNX

In [49]:
s_onnx_path = "/home/taylor/PycharmProjects/hakai-ml-train/exports/species/LRASPP_MobileNetV3_kelp_species_miou=0.9634.onnx"

class SpeciesInferenceModel(LRASPPMobileNetV3Large):
    def update_presence_model(self, model):
        # PresenceInferenceModel
        self.presence_model = model
        
    def forward(self, x):
        presence = self.presence_model.forward(x)  # 0: bg, 1: kelp
        
        s_logits = super().forward(normalize(x))
        species = torch.add(torch.argmax(s_logits, dim=1), 2)  # 2: macro, 3: nereo

        return torch.mul(presence, species)  # 0: bg, 2: macro, 3: nereo

        
s_inf_model = SpeciesInferenceModel.load_from_checkpoint(s_weights_path, train=False, map_location=DEVICE, strict=False, num_classes=2)
s_inf_model.update_presence_model(p_inf_model)

s_inf_model.to_onnx(
    s_onnx_path, 
    x, 
    export_params=True, 
    input_names=["x"],
    output_names=["pred"],
    dynamic_axes={
        "x": {
            0: "batch_size", 
            2: "height", 
            3: "width"
        },
    }
)


verbose: False, log level: Level.ERROR



##### Test the exports all do what they're supposed to

In [50]:
s_model_kom = torch.jit.load("/home/taylor/PycharmProjects/hakai-ml-train/exports/LRASPP_MobileNetV3_kelp_species_jit_miou=0.9634.pt", map_location=DEVICE)
s_model_kom = s_model_kom.eval()

# KoM with jit weights output
with torch.no_grad():
    s_kom_out = s_model_kom(x).numpy()

In [51]:
s_model_test = torch.jit.load(s_torchscript_path, map_location=DEVICE)
s_model_test = s_model_test.eval()

# Test that new JIT export matches KOM outputs
with torch.no_grad():
    s_test_out = s_model_test(x).numpy()
    
np.allclose(s_kom_out, s_test_out)

True

In [58]:
ort_session = onnxruntime.InferenceSession(s_onnx_path)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: x1.numpy()}
s_ort_outs = ort_session.run(None, ort_inputs)

# Test that onnx output matches KOM output
np.allclose(
    np.multiply(np.argmax(p_kom_out, axis=1), np.argmax(s_kom_out, axis=1) + 2),
    s_ort_outs[0]
)

True

### Mussels model

In [100]:
m_weights_path = "/home/taylor/PycharmProjects/hakai-ml-train/exports/mussels/val_miou=0.8384_epoch=4.ckpt"

m_model = LRASPPMobileNetV3Large.load_from_checkpoint(m_weights_path, train=False, map_location=DEVICE, strict=False, num_classes=2)

#### Export Torch JIT Trace

In [101]:
m_torchscript_path = "/home/taylor/PycharmProjects/hakai-ml-train/exports/mussels/LRASPP_MobileNetV3_mussel_presence_jit_miou=0.8384.pt"
_ = m_model.to_torchscript(file_path=m_torchscript_path, method='trace', example_inputs=x)

#### Export ONNX

I've contructed a new model that takes care of the data normalization and argmax on the logits, with the intention being to eliminate the need for a pytorch installation on user's machines

In [102]:
m_onnx_path = "/home/taylor/PycharmProjects/hakai-ml-train/exports/mussels/LRASPP_MobileNetV3_mussel_presence_miou=0.8384.onnx"


class MusselsInferenceModel(LRASPPMobileNetV3Large):
    def forward(self, x):
        logits = super().forward(normalize(x))
        return torch.argmax(logits, dim=1)
        
m_inf_model = MusselsInferenceModel.load_from_checkpoint(m_weights_path, train=False, map_location=DEVICE, strict=False, num_classes=2)
    
m_inf_model.to_onnx(
    m_onnx_path, 
    x, 
    export_params=True, 
    input_names=["x"],
    output_names=["pred"],
    dynamic_axes={
        "x": {
            0: "batch_size", 
            2: "height", 
            3: "width"
        },
    }
)

verbose: False, log level: Level.ERROR



##### Test the exports all do what they're supposed to

In [103]:
trans = t.Compose([PadOut(512, 512, fill_value=0), t.ToTensor(), normalize])
# No normalization
trans1 = t.Compose([PadOut(512, 512, fill_value=0), t.ToTensor()])
img = Image.open("/home/taylor/PycharmProjects/hakai-ml-train/exports/Simmonds_kelp_U0171_08_03.tif")

In [104]:
x = trans(img).repeat(repeats=[2,1,1,1]).to(DEVICE)
#Unnormalized
x1 = trans1(img).repeat(repeats=[2,1,1,1]).to(DEVICE)
x.shape

torch.Size([2, 3, 512, 512])

In [105]:
m_model_kom = torch.jit.load("/home/taylor/PycharmProjects/hakai-ml-train/exports/LRASPP_MobileNetV3_mussel_presence_jit_v2.pt", map_location=DEVICE)
m_model_kom = m_model_kom.eval()

# KoM with jit weights output
with torch.no_grad():
    m_kom_out = m_model_kom(x).numpy()

In [106]:
m_model_test = torch.jit.load(m_torchscript_path, map_location=DEVICE)
m_model_test = m_model_test.eval()

# Test that new JIT export matches KOM outputs
with torch.no_grad():
    m_test_out = m_model_test(x).numpy()
    
np.allclose(m_kom_out, m_test_out)

True

In [108]:
ort_session = onnxruntime.InferenceSession(m_onnx_path)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: x1.numpy()}
m_ort_outs = ort_session.run(None, ort_inputs)

# Test that onnx output matches KOM output
np.allclose(np.argmax(m_kom_out, axis=1), m_ort_outs[0])

True